emmi.modules.attention.dot_product_attention ============================================ .. py:module:: emmi.modules.attention.dot_product_attention Classes ------- .. autoapisummary:: emmi.modules.attention.dot_product_attention.DotProductAttention Module Contents --------------- .. py:class:: DotProductAttention(config) Bases: :py:obj:`torch.nn.Module` Scaled dot-product attention module. Initialize the DotProductAttention module. :param config: configuration of the attention module. .. py:attribute:: num_heads :value: None .. py:attribute:: head_dim .. py:attribute:: init_weights :value: None .. py:attribute:: use_rope :value: None .. py:attribute:: dropout :value: None .. py:attribute:: proj_dropout .. py:attribute:: qkv .. py:attribute:: proj .. py:method:: forward(x, attn_mask = None, freqs = None) Forward function of the DotProductAttention module. :param x: Tensor to apply self-attention over, shape (batch size, sequence length, hidden_dim). :param attn_mask: For causal attention (i.e., no attention over the future token) a attention mask should be provided. Defaults to None. :param freqs: Frequencies for Rotary Positional Embedding (RoPE) of queries/keys. None if use_rope=False. :returns: Returns the output of the attention module.