emmi_inference.models.modules.attention.anchor_attention

Classes

AnchorAttention

Scaled dot-product attention module.

Module Contents

class emmi_inference.models.modules.attention.anchor_attention.AnchorAttention(dim, num_heads=8)

Bases: emmi_inference.models.modules.attention.dot_product_attention.DotProductAttention

Scaled dot-product attention module.

Parameters:
  • dim (int) – Input dimension of the attention module.

  • num_heads (int) – Number of attention heads. Defaults to 8.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, freqs, num_anchor_tokens=None)

Self-attention between anchor tokens, other tokens (query tokens) have only cross-attention to anchor tokens

Parameters:
  • x (torch.Tensor) – Tensor to apply self-attention over, shape (batch_size, sequence_length, dim).

  • freqs (torch.Tensor) – Frequencies for RoPE.

  • num_anchor_tokens (int | None) – Number of anchor tokens. If provided, the first num_anchor_tokens of x will be the anchors (full self-attention) and the other tokens will be the queries (only cross-attention to the anchor tokens).

Returns:

(batch_size, sequence_length, dim)

Return type:

torch.Tensor