emmi_inference.models.modules.attention.sharedweights_splitattn_attention

Classes

SharedweightsSplitattnAttention

Scaled dot-product attention module.

Module Contents

class emmi_inference.models.modules.attention.sharedweights_splitattn_attention.SharedweightsSplitattnAttention(dim, num_heads=8)

Bases: emmi_inference.models.modules.attention.DotProductAttention

Scaled dot-product attention module.

Parameters:
  • dim (int) – Input dimension of the attention module.

  • num_heads (int) – Number of attention heads. Defaults to 8.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, split_size, freqs)

Attention between: - q=surface_anchors -> kv=surface_anchors - q=volume_anchors -> kv=volume_anchors - q=surface_queries -> kv=surface_anchors - q=volume_queries -> kv=volume_anchors

Parameters:
  • x (torch.Tensor) – Tensor containing all anchors/queries (batch size, sequence length, dim).

  • split_size (list[int]) – How to split x into: len(split_size) == 2: (surface_anchors, volume_anchors) len(split_size) == 4: (surface_anchors, surface_queries, volume_anchors, volume_queries)

  • freqs (torch.Tensor) – Frequencies for Rotary Positional Embedding (RoPE) of queries/keys. None if use_rope=False.

Returns:

(batch size, sequence length, dim)

Return type:

torch.Tensor