emmi.modules.blocks.perceiver_block

Classes

PerceiverBlock

For a self-attention module, the input tensor for the query, key, and value are the same. The PerceiverBlock,

Module Contents

class emmi.modules.blocks.perceiver_block.PerceiverBlock(config)

Bases: torch.nn.Module

For a self-attention module, the input tensor for the query, key, and value are the same. The PerceiverBlock, takes different input tensors for the query and the key/value.

Perceiver-style cross-attention block.

Parameters:

config (emmi.schemas.modules.blocks.PerceiverBlockConfig) – Configuration of the PerceiverBlock.

norm1q
norm1kv
attn
ls1
drop_path1
norm2
mlp
ls2
drop_path2
forward(q, kv, condition=None, attn_kwargs=None)

Forward pass of the PerceiverBlock.

Parameters:
  • q (torch.Tensor) – Input tensor with shape (batch_size, seqlen/num_tokens, hidden_dim) for the query representations.

  • kv (torch.Tensor) – Input tensor with shape (batch_size, seqlen/num_tokens, hidden_dim) for the key and value representations.

  • condition (torch.Tensor | None) – Conditioning vector. If provided, the attention and MLP will be scaled, shifted and gated feature-wise with predicted values from this vector.

  • attn_kwargs (dict[str, Any] | None) – Dict with arguments for the attention (such as the attention mask). Defaults to None.

Returns:

Tensor after the forward pass of the PerceiverBlock.

Return type:

torch.Tensor