emmi_inference.models.modules.attention.dot_product_attention ============================================================= .. py:module:: emmi_inference.models.modules.attention.dot_product_attention Classes ------- .. autoapisummary:: emmi_inference.models.modules.attention.dot_product_attention.DotProductAttention Module Contents --------------- .. py:class:: DotProductAttention(dim, num_heads = 8) Bases: :py:obj:`torch.nn.Module` Scaled dot-product attention module. :param dim: Input dimension of the attention module. :param num_heads: Number of attention heads. Defaults to 8. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:attribute:: dim .. py:attribute:: num_heads :value: 8 .. py:attribute:: head_dim .. py:attribute:: qkv .. py:attribute:: proj .. py:method:: forward(x, freqs) Forward function of the DotProductAttention module. :param x: Tensor to apply self-attention over, shape (batch size, sequence length, dim). :param freqs: Frequencies for Rotary Positional Embedding (RoPE) of queries/keys. :returns: (batch_size, sequence_length, dim)