emmi_inference.models.modules.attention.anchor_attention ======================================================== .. py:module:: emmi_inference.models.modules.attention.anchor_attention Classes ------- .. autoapisummary:: emmi_inference.models.modules.attention.anchor_attention.AnchorAttention Module Contents --------------- .. py:class:: AnchorAttention(dim, num_heads = 8) Bases: :py:obj:`emmi_inference.models.modules.attention.dot_product_attention.DotProductAttention` Scaled dot-product attention module. :param dim: Input dimension of the attention module. :param num_heads: Number of attention heads. Defaults to 8. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:method:: forward(x, freqs, num_anchor_tokens = None) Self-attention between anchor tokens, other tokens (query tokens) have only cross-attention to anchor tokens :param x: Tensor to apply self-attention over, shape (batch_size, sequence_length, dim). :param freqs: Frequencies for RoPE. :param num_anchor_tokens: Number of anchor tokens. If provided, the first num_anchor_tokens of x will be the anchors (full self-attention) and the other tokens will be the queries (only cross-attention to the anchor tokens). :returns: (batch_size, sequence_length, dim)