emmi_inference.models.modules.blocks.transformer_block

Classes

TransformerBlock

A transformer block with a single attention layer and a feedforward layer.

Module Contents

class emmi_inference.models.modules.blocks.transformer_block.TransformerBlock(dim, num_heads, attn_ctor=DotProductAttention)

Bases: torch.nn.Module

A transformer block with a single attention layer and a feedforward layer.

Parameters:
  • dim (int) – hidden Dimension of the transformer block.

  • num_heads (int) – Number of attention heads.

  • attn_ctor (type[torch.nn.Module])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

norm1
attn
norm2
mlp
forward(x, attn_kwargs=None)

Forward pass of the transformer block.

Parameters:
  • x (torch.Tensor) – Input tensor with shape (batch_size, seqlen/num_tokens, dim).

  • attn_kwargs (dict[str, Any] | None) – Dict with arguments for the attention (such as the rope frequencies). Defaults to None.

Returns:

(batch_size, num_tokens, dim)

Return type:

torch.Tensor