emmi_inference.models.modules.blocks.perceiver_block

Classes

PerceiverBlock

The PerceiverBlock takes different input tensors for the query and the key/value.

Module Contents

class emmi_inference.models.modules.blocks.perceiver_block.PerceiverBlock(dim, num_heads)

Bases: torch.nn.Module

The PerceiverBlock takes different input tensors for the query and the key/value.

Parameters:
  • dim (int) – Hidden dimension of the perceiver block.

  • num_heads (int) – Number of attention heads.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

norm1q
norm1kv
attn
norm2
mlp
forward(q, kv, attn_kwargs=None)

Forward pass of the PerceiverBlock.

Parameters:
  • q (torch.Tensor) – Input tensor with shape (batch_size, num_q_tokens, dim) for the query representations.

  • kv (torch.Tensor) – Input tensor with shape (batch_size, num_kv_tokens, dim) for the key and value representations.

  • attn_kwargs (dict[str, Any] | None) – Dict with arguments for the attention (such as rope frequencies). Defaults to None.

Returns:

(batch_size, num_q_tokens, dim)

Return type:

torch.Tensor