emmi.modules.attention.perceiver_attention¶
Classes¶
Perceiver style attention module. This module is similar to a cross-attention modules. |
Module Contents¶
- class emmi.modules.attention.perceiver_attention.PerceiverAttention(config)¶
Bases:
torch.nn.ModulePerceiver style attention module. This module is similar to a cross-attention modules.
Initialize the PerceiverAttention module.
- Parameters:
config (emmi.schemas.modules.attention.AttentionConfig) – configuration of the attention module.
- num_heads = None¶
- head_dim¶
- init_weights = None¶
- use_rope = None¶
- kv¶
- q¶
- proj¶
- reset_parameters()¶
Restet the parameters of the PerceiverAttention module with a specific initialization method.
- Raises:
NotImplementedError – when a specific initialization method is not implemented. Either use “torch”, or “truncnormal002”.
- Return type:
None
- forward(q, kv, attn_mask=None, q_freqs=None, k_freqs=None)¶
Forward function of the PerceiverAttention module.
- Parameters:
q (torch.Tensor) – Query tensor, shape (batch size, number of points/tokens, hidden_dim).
kv (torch.Tensor) – Key/value tensor, shape (batch size, number of latent tokens, hidden_dim).
attn_mask (torch.Tensor | None) – When applying causal attention, an attention mask is required. Defaults to None.
q_freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of queries. None if use_rope=False.
k_freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of keys. None if use_rope=False.
- Returns:
Returns the output of the perceiver attention module.
- Return type:
torch.Tensor