emmi.modules.blocks.transformer_block¶
Classes¶
A transformer block with a single attention layer and a feedforward layer. |
Module Contents¶
- class emmi.modules.blocks.transformer_block.TransformerBlock(config)¶
Bases:
torch.nn.ModuleA transformer block with a single attention layer and a feedforward layer.
Initializes a transformer block.
Args:
- Parameters:
config (emmi.schemas.modules.blocks.TransformerBlockConfig)
- norm1¶
- attention_block¶
- ls1¶
- drop_path1¶
- norm2¶
- mlp¶
- ls2¶
- drop_path2¶
- forward(x, condition=None, attn_kwargs=None)¶
Forward pass of the transformer block.
- Parameters:
x (torch.Tensor) – Input tensor with shape (batch_size, seqlen/num_tokens, hidden_dim).
condition (torch.Tensor | None) – Conditioning vector. If provided, the attention and MLP will be scaled, shifted and gated feature-wise with predicted values from this vector.
attn_kwargs (dict[str, Any] | None) – Dict with arguments for the attention (such as the attention mask). Defaults to None.
- Returns:
Tensor after the forward pass of the transformer block.
- Return type:
torch.Tensor