emmi_inference.models.modules.blocks.perceiver_block ==================================================== .. py:module:: emmi_inference.models.modules.blocks.perceiver_block Classes ------- .. autoapisummary:: emmi_inference.models.modules.blocks.perceiver_block.PerceiverBlock Module Contents --------------- .. py:class:: PerceiverBlock(dim, num_heads) Bases: :py:obj:`torch.nn.Module` The PerceiverBlock takes different input tensors for the query and the key/value. :param dim: Hidden dimension of the perceiver block. :param num_heads: Number of attention heads. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:attribute:: norm1q .. py:attribute:: norm1kv .. py:attribute:: attn .. py:attribute:: norm2 .. py:attribute:: mlp .. py:method:: forward(q, kv, attn_kwargs = None) Forward pass of the PerceiverBlock. :param q: Input tensor with shape (batch_size, num_q_tokens, dim) for the query representations. :param kv: Input tensor with shape (batch_size, num_kv_tokens, dim) for the key and value representations. :param attn_kwargs: Dict with arguments for the attention (such as rope frequencies). Defaults to None. :returns: (batch_size, num_q_tokens, dim)