emmi.modules.attention.anchor_attention.mixed_attention

Classes

MixedAttention

Mixed attention with a selectable implementation for performance or readability.

Module Contents

class emmi.modules.attention.anchor_attention.mixed_attention.MixedAttention(config)

Bases: emmi.modules.attention.DotProductAttention

Mixed attention with a selectable implementation for performance or readability.

This module allows for structured attention patterns where different groups of tokens (defined by TokenSpec) have specific interaction patterns (defined by AttentionPattern). Instead of full self-attention, you can specify, for example, that one type of token can only attend to itself, while another can attend to all tokens.

This is achieved by splitting the main Q, K, V tensors based on the token specs and then performing separate attention computations for each pattern.

Example input structure (forward pass signature) for implementing Anchor Attention:

x = torch.cat([surface_anchors, surface_queries, volume_anchors, volume_queries], dim=1) # sequence dim token_specs = [

TokenSpec(“surface_anchors”, 100), TokenSpec(“surface_queries”, 50), TokenSpec(“volume_anchors”, 80), TokenSpec(“volume_queries”, 60),

] attention_patterns = [

AttentionPattern(query_tokens=[“surface_anchors”, “surface_queries”], key_value_tokens=[“surface_anchors”]), AttentionPattern(query_tokens=[“volume_anchors”, “volume_queries”], key_value_tokens=[“volume_anchors”]),

]

Parameters:
  • dim (int) – Model dimension.

  • num_heads (int) – Number of attention heads.

  • use_rope (bool) – Whether to use rotary position embeddings.

  • bias (bool) – Whether to use bias in the linear projections.

  • init_weights (str) – Weight initialization method.

  • parallel (bool) – If True (default), uses a efficient implementation that batches compatible attention patterns. If False, uses a simple and readable sequential implementation.

  • config (emmi.schemas.modules.attention.anchor_attention.MixedAttentionConfig)

Initialize the DotProductAttention module.

Parameters:

config (emmi.schemas.modules.attention.anchor_attention.MixedAttentionConfig) – configuration of the attention module.

parallel
forward(x, token_specs, attention_patterns, attention_mask=None, freqs=None)

Apply mixed attention with flexible token-name-based patterns.

Parameters:
  • x (torch.Tensor) – Input tensor [batch_size, n_tokens, dim]

  • token_specs (collections.abc.Sequence[emmi.schemas.modules.attention.anchor_attention.TokenSpec]) – Sequence of token specifications defining the input structure: assumes that the input x is a concatenation of tokens in the order of token_specs.

  • attention_patterns (collections.abc.Sequence[emmi.schemas.modules.attention.anchor_attention.AttentionPattern]) – Sequence of attention patterns to apply. Each pattern defines which token groups (queries) attend to which other token groups (keys/values). The provided patterns must be exhaustive and non-overlapping. This means every token group defined in token_specs must be a query in exactly one pattern.

  • attention_mask (torch.Tensor | None) – Optional attention mask (not currently supported)

  • freqs (torch.Tensor | None) – RoPE frequencies for positional encoding

Return type:

torch.Tensor