emmi_inference.models.modules.attention.anchor_attention¶
Classes¶
Scaled dot-product attention module. |
Module Contents¶
- class emmi_inference.models.modules.attention.anchor_attention.AnchorAttention(dim, num_heads=8)¶
Bases:
emmi_inference.models.modules.attention.dot_product_attention.DotProductAttentionScaled dot-product attention module.
- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, freqs, num_anchor_tokens=None)¶
Self-attention between anchor tokens, other tokens (query tokens) have only cross-attention to anchor tokens
- Parameters:
x (torch.Tensor) – Tensor to apply self-attention over, shape (batch_size, sequence_length, dim).
freqs (torch.Tensor) – Frequencies for RoPE.
num_anchor_tokens (int | None) – Number of anchor tokens. If provided, the first num_anchor_tokens of x will be the anchors (full self-attention) and the other tokens will be the queries (only cross-attention to the anchor tokens).
- Returns:
(batch_size, sequence_length, dim)
- Return type:
torch.Tensor