emmi_inference.models.modules.blocks.perceiver_block¶
Classes¶
The PerceiverBlock takes different input tensors for the query and the key/value. |
Module Contents¶
- class emmi_inference.models.modules.blocks.perceiver_block.PerceiverBlock(dim, num_heads)¶
Bases:
torch.nn.ModuleThe PerceiverBlock takes different input tensors for the query and the key/value.
- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- norm1q¶
- norm1kv¶
- attn¶
- norm2¶
- mlp¶
- forward(q, kv, attn_kwargs=None)¶
Forward pass of the PerceiverBlock.
- Parameters:
q (torch.Tensor) – Input tensor with shape (batch_size, num_q_tokens, dim) for the query representations.
kv (torch.Tensor) – Input tensor with shape (batch_size, num_kv_tokens, dim) for the key and value representations.
attn_kwargs (dict[str, Any] | None) – Dict with arguments for the attention (such as rope frequencies). Defaults to None.
- Returns:
(batch_size, num_q_tokens, dim)
- Return type:
torch.Tensor