emmi_inference.models.modules.attention.perceiver_attention =========================================================== .. py:module:: emmi_inference.models.modules.attention.perceiver_attention Classes ------- .. autoapisummary:: emmi_inference.models.modules.attention.perceiver_attention.PerceiverAttention Module Contents --------------- .. py:class:: PerceiverAttention(dim, num_heads = 8) Bases: :py:obj:`torch.nn.Module` Perceiver style attention module. This module is similar to a cross-attention modules. :param dim: Hidden dimension of the layer/module. :param num_heads: Number of attention heads. Defaults to 8. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:attribute:: num_heads :value: 8 .. py:attribute:: head_dim .. py:attribute:: q .. py:attribute:: kv .. py:attribute:: proj .. py:method:: forward(q, kv, q_freqs, k_freqs) Forward function of the PerceiverAttention module. :param q: Query tensor, shape (batch size, number of points/tokens, dim). :param kv: Key/value tensor, shape (batch size, number of latent tokens, dim). :param q_freqs: Frequencies for Rotary Positional Embedding (RoPE) of queries. :param k_freqs: Frequencies for Rotary Positional Embedding (RoPE) of keys. :returns: (batch size, query sequence length, dim)