emmi.modules.attention.perceiver_attention

Classes

PerceiverAttention

Perceiver style attention module. This module is similar to a cross-attention modules.

Module Contents

class emmi.modules.attention.perceiver_attention.PerceiverAttention(config)

Bases: torch.nn.Module

Perceiver style attention module. This module is similar to a cross-attention modules.

Initialize the PerceiverAttention module.

Parameters:

config (emmi.schemas.modules.attention.AttentionConfig) – configuration of the attention module.

num_heads = None
head_dim
init_weights = None
use_rope = None
kv
q
proj
reset_parameters()

Restet the parameters of the PerceiverAttention module with a specific initialization method.

Raises:

NotImplementedError – when a specific initialization method is not implemented. Either use “torch”, or “truncnormal002”.

Return type:

None

forward(q, kv, attn_mask=None, q_freqs=None, k_freqs=None)

Forward function of the PerceiverAttention module.

Parameters:
  • q (torch.Tensor) – Query tensor, shape (batch size, number of points/tokens, hidden_dim).

  • kv (torch.Tensor) – Key/value tensor, shape (batch size, number of latent tokens, hidden_dim).

  • attn_mask (torch.Tensor | None) – When applying causal attention, an attention mask is required. Defaults to None.

  • q_freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of queries. None if use_rope=False.

  • k_freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of keys. None if use_rope=False.

Returns:

Returns the output of the perceiver attention module.

Return type:

torch.Tensor