emmi.modules.attention.perceiver_attention ========================================== .. py:module:: emmi.modules.attention.perceiver_attention Classes ------- .. autoapisummary:: emmi.modules.attention.perceiver_attention.PerceiverAttention Module Contents --------------- .. py:class:: PerceiverAttention(config) Bases: :py:obj:`torch.nn.Module` Perceiver style attention module. This module is similar to a cross-attention modules. Initialize the PerceiverAttention module. :param config: configuration of the attention module. .. py:attribute:: num_heads :value: None .. py:attribute:: head_dim .. py:attribute:: init_weights :value: None .. py:attribute:: use_rope :value: None .. py:attribute:: kv .. py:attribute:: q .. py:attribute:: proj .. py:attribute:: dropout :value: None .. py:attribute:: proj_dropout .. py:method:: forward(q, kv, attn_mask = None, q_freqs = None, k_freqs = None) Forward function of the PerceiverAttention module. :param q: Query tensor, shape (batch size, number of points/tokens, hidden_dim). :param kv: Key/value tensor, shape (batch size, number of latent tokens, hidden_dim). :param attn_mask: When applying causal attention, an attention mask is required. Defaults to None. :param q_freqs: Frequencies for Rotary Positional Embedding (RoPE) of queries. None if use_rope=False. :param k_freqs: Frequencies for Rotary Positional Embedding (RoPE) of keys. None if use_rope=False. :returns: Returns the output of the perceiver attention module.