emmi_inference.models.modules.attention.perceiver_attention¶
Classes¶
Perceiver style attention module. This module is similar to a cross-attention modules. |
Module Contents¶
- class emmi_inference.models.modules.attention.perceiver_attention.PerceiverAttention(dim, num_heads=8)¶
Bases:
torch.nn.ModulePerceiver style attention module. This module is similar to a cross-attention modules.
- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- num_heads = 8¶
- head_dim¶
- q¶
- kv¶
- proj¶
- forward(q, kv, q_freqs, k_freqs)¶
Forward function of the PerceiverAttention module.
- Parameters:
q (torch.Tensor) – Query tensor, shape (batch size, number of points/tokens, dim).
kv (torch.Tensor) – Key/value tensor, shape (batch size, number of latent tokens, dim).
q_freqs (torch.Tensor) – Frequencies for Rotary Positional Embedding (RoPE) of queries.
k_freqs (torch.Tensor) – Frequencies for Rotary Positional Embedding (RoPE) of keys.
- Returns:
(batch size, query sequence length, dim)
- Return type:
torch.Tensor