emmi.modules.attention.anchor_attention.multi_branch_anchor_attention

Exceptions

MissingBranchTokensError

Raised when expected tokens for a configured branch are not present.

MissingAnchorTokenError

Raised when a required anchor token is not present.

UnexpectedTokenError

Raised when an unexpected token is present.

Classes

MultiBranchAnchorAttention

A base class for multi-branch anchor-based attention modules with shared parameters between branches.

Module Contents

exception emmi.modules.attention.anchor_attention.multi_branch_anchor_attention.MissingBranchTokensError

Bases: ValueError

Raised when expected tokens for a configured branch are not present.

Initialize self. See help(type(self)) for accurate signature.

exception emmi.modules.attention.anchor_attention.multi_branch_anchor_attention.MissingAnchorTokenError

Bases: ValueError

Raised when a required anchor token is not present.

Initialize self. See help(type(self)) for accurate signature.

exception emmi.modules.attention.anchor_attention.multi_branch_anchor_attention.UnexpectedTokenError

Bases: ValueError

Raised when an unexpected token is present.

Initialize self. See help(type(self)) for accurate signature.

class emmi.modules.attention.anchor_attention.multi_branch_anchor_attention.MultiBranchAnchorAttention(config)

Bases: torch.nn.Module

A base class for multi-branch anchor-based attention modules with shared parameters between branches.

Anchor attention limits the self-attention to anchor tokens while other tokens use cross-attention. Multiple branches for different modalities use the same linear-projection parameters. This base class provides a common constructor, validation logic, and forward method implementation. Subclasses only need to implement _create_attention_patterns to define their specific attention patterns.

Parameters:
  • dim – Model dimension.

  • num_heads – Number of attention heads.

  • use_rope – Whether to use rotary position embeddings.

  • bias – Whether to use bias in the linear projections.

  • init_weights – Weight initialization method.

  • branches – A sequence of all participating branch names.

  • anchor_suffix – Suffix identifying anchor tokens.

  • config (emmi.schemas.modules.attention.AttentionConfig)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

mixed_attention
branches = None
anchor_suffix = None
forward(x, token_specs, freqs=None)

Apply attention using the patterns defined by the subclass.

Parameters:
  • x (torch.Tensor)

  • token_specs (collections.abc.Sequence[emmi.schemas.modules.attention.anchor_attention.TokenSpec])

  • freqs (torch.Tensor | None)

Return type:

torch.Tensor