emmi.modules.attention.anchor_attention.cross_anchor_attention¶
Classes¶
Anchor attention across branches: each configured branch attends to the anchors of all other branches. |
Module Contents¶
- class emmi.modules.attention.anchor_attention.cross_anchor_attention.CrossAnchorAttention(config)¶
Bases:
emmi.modules.attention.anchor_attention.multi_branch_anchor_attention.MultiBranchAnchorAttentionAnchor attention across branches: each configured branch attends to the anchors of all other branches.
For a list of branches (e.g., A, B, C), this creates a pattern, where A attend to (B_anchors + C_anchors), B attends to (A_anchors + C_anchors), etc. It requires all configured branches and their anchors to be present in the input.
Example: all surface tokens attend to volume_anchors and all volume tokens attend to surface_anchors. This is achieved via the following attention patterns:
AttentionPattern(query_tokens=[“surface_anchors”, “surface_queries”], key_value_tokens=[“volume_anchors”]) AttentionPattern(query_tokens=[“volume_anchors”, “volume_queries”], key_value_tokens=[“surface_anchors”])
- Parameters:
dim – Model dimension.
num_heads – Number of attention heads.
use_rope – Whether to use rotary position embeddings.
bias – Whether to use bias in the linear projections.
init_weights – Weight initialization method.
branches – A sequence of all participating branch names.
anchor_suffix – Suffix identifying anchor tokens.
config (emmi.schemas.modules.attention.anchor_attention.CrossAchorAttentionConfig)
Initialize internal Module state, shared by both nn.Module and ScriptModule.