emmi_inference.models.modules.attention.perceiver_attention

Classes

PerceiverAttention

Perceiver style attention module. This module is similar to a cross-attention modules.

Module Contents

class emmi_inference.models.modules.attention.perceiver_attention.PerceiverAttention(dim, num_heads=8)

Bases: torch.nn.Module

Perceiver style attention module. This module is similar to a cross-attention modules.

Parameters:
  • dim (int) – Hidden dimension of the layer/module.

  • num_heads (int) – Number of attention heads. Defaults to 8.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_heads = 8
head_dim
q
kv
proj
forward(q, kv, q_freqs, k_freqs)

Forward function of the PerceiverAttention module.

Parameters:
  • q (torch.Tensor) – Query tensor, shape (batch size, number of points/tokens, dim).

  • kv (torch.Tensor) – Key/value tensor, shape (batch size, number of latent tokens, dim).

  • q_freqs (torch.Tensor) – Frequencies for Rotary Positional Embedding (RoPE) of queries.

  • k_freqs (torch.Tensor) – Frequencies for Rotary Positional Embedding (RoPE) of keys.

Returns:

(batch size, query sequence length, dim)

Return type:

torch.Tensor