emmi.modules.blocks.transformer_block

Classes

TransformerBlock

A transformer block with a single attention layer and a feedforward layer.

Module Contents

class emmi.modules.blocks.transformer_block.TransformerBlock(config)

Bases: torch.nn.Module

A transformer block with a single attention layer and a feedforward layer.

Initializes a transformer block.

Args:

Parameters:

config (emmi.schemas.modules.blocks.TransformerBlockConfig)

norm1
attention_block
ls1
drop_path1
norm2
mlp
ls2
drop_path2
forward(x, condition=None, attn_kwargs=None)

Forward pass of the transformer block.

Parameters:
  • x (torch.Tensor) – Input tensor with shape (batch_size, seqlen/num_tokens, hidden_dim).

  • condition (torch.Tensor | None) – Conditioning vector. If provided, the attention and MLP will be scaled, shifted and gated feature-wise with predicted values from this vector.

  • attn_kwargs (dict[str, Any] | None) – Dict with arguments for the attention (such as the attention mask). Defaults to None.

Returns:

Tensor after the forward pass of the transformer block.

Return type:

torch.Tensor