emmi_inference.models.modules.attention.sharedweights_splitattn_attention¶
Classes¶
Scaled dot-product attention module. |
Module Contents¶
- class emmi_inference.models.modules.attention.sharedweights_splitattn_attention.SharedweightsSplitattnAttention(dim, num_heads=8)¶
Bases:
emmi_inference.models.modules.attention.DotProductAttentionScaled dot-product attention module.
- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, split_size, freqs)¶
Attention between: - q=surface_anchors -> kv=surface_anchors - q=volume_anchors -> kv=volume_anchors - q=surface_queries -> kv=surface_anchors - q=volume_queries -> kv=volume_anchors
- Parameters:
x (torch.Tensor) – Tensor containing all anchors/queries (batch size, sequence length, dim).
split_size (list[int]) – How to split x into: len(split_size) == 2: (surface_anchors, volume_anchors) len(split_size) == 4: (surface_anchors, surface_queries, volume_anchors, volume_queries)
freqs (torch.Tensor) – Frequencies for Rotary Positional Embedding (RoPE) of queries/keys. None if use_rope=False.
- Returns:
(batch size, sequence length, dim)
- Return type:
torch.Tensor