emmi.modules.attention.anchor_attention.mixed_attention ======================================================= .. py:module:: emmi.modules.attention.anchor_attention.mixed_attention Classes ------- .. autoapisummary:: emmi.modules.attention.anchor_attention.mixed_attention.MixedAttention Module Contents --------------- .. py:class:: MixedAttention(config) Bases: :py:obj:`emmi.modules.attention.DotProductAttention` Mixed attention with a selectable implementation for performance or readability. This module allows for structured attention patterns where different groups of tokens (defined by `TokenSpec`) have specific interaction patterns (defined by `AttentionPattern`). Instead of full self-attention, you can specify, for example, that one type of token can only attend to itself, while another can attend to all tokens. This is achieved by splitting the main Q, K, V tensors based on the token specs and then performing separate attention computations for each pattern. Example input structure (forward pass signature) for implementing Anchor Attention: x = torch.cat([surface_anchors, surface_queries, volume_anchors, volume_queries], dim=1) # sequence dim token_specs = [ TokenSpec("surface_anchors", 100), TokenSpec("surface_queries", 50), TokenSpec("volume_anchors", 80), TokenSpec("volume_queries", 60), ] attention_patterns = [ AttentionPattern(query_tokens=["surface_anchors", "surface_queries"], key_value_tokens=["surface_anchors"]), AttentionPattern(query_tokens=["volume_anchors", "volume_queries"], key_value_tokens=["volume_anchors"]), ] :param config: Configuration for the MixedAttention module. .. py:method:: forward(x, token_specs, attention_patterns, attention_mask = None, freqs = None) Apply mixed attention with flexible token-name-based patterns. :param x: Input tensor [batch_size, n_tokens, dim] :param token_specs: Sequence of token specifications defining the input structure: assumes that the input x is a concatenation of tokens in the order of token_specs. :param attention_patterns: Sequence of attention patterns to apply. Each pattern defines which token groups (queries) attend to which other token groups (keys/values). The provided patterns must be exhaustive and non-overlapping. This means every token group defined in `token_specs` must be a query in exactly one pattern. :param attention_mask: Optional attention mask (not currently supported) :param freqs: RoPE frequencies for positional encoding