emmi_inference.models.modules.attention.sharedweights_splitattn_attention ========================================================================= .. py:module:: emmi_inference.models.modules.attention.sharedweights_splitattn_attention Classes ------- .. autoapisummary:: emmi_inference.models.modules.attention.sharedweights_splitattn_attention.SharedweightsSplitattnAttention Module Contents --------------- .. py:class:: SharedweightsSplitattnAttention(dim, num_heads = 8) Bases: :py:obj:`emmi_inference.models.modules.attention.DotProductAttention` Scaled dot-product attention module. :param dim: Input dimension of the attention module. :param num_heads: Number of attention heads. Defaults to 8. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: forward(x, split_size, freqs) Attention between: - q=surface_anchors -> kv=surface_anchors - q=volume_anchors -> kv=volume_anchors - q=surface_queries -> kv=surface_anchors - q=volume_queries -> kv=volume_anchors :param x: Tensor containing all anchors/queries (batch size, sequence length, dim). :param split_size: How to split x into: len(split_size) == 2: (surface_anchors, volume_anchors) len(split_size) == 4: (surface_anchors, surface_queries, volume_anchors, volume_queries) :param freqs: Frequencies for Rotary Positional Embedding (RoPE) of queries/keys. None if use_rope=False. :returns: (batch size, sequence length, dim)