emmi.modules.attention.dot_product_attention

Classes

DotProductAttention

Scaled dot-product attention module.

Module Contents

class emmi.modules.attention.dot_product_attention.DotProductAttention(config)

Bases: torch.nn.Module

Scaled dot-product attention module.

Initialize the DotProductAttention module.

Parameters:

config (emmi.schemas.modules.attention.AttentionConfig) – configuration of the attention module.

num_heads = None
head_dim
init_weights = None
use_rope = None
qkv
proj
reset_parameters()

Reset the parameters of the DotProductAttention module with a specific initialization method.

Raises:

NotImplementedError – when a specific initialization method is not implemented. Either use “torch”, or “truncnormal002”.

Return type:

None

forward(x, attn_mask=None, freqs=None)

Forward function of the DotProductAttention module.

Parameters:
  • x (torch.Tensor) – Tensor to apply self-attention over, shape (batch size, sequence length, hidden_dim).

  • attn_mask (torch.Tensor | None) – For causal attention (i.e., no attention over the future token) a attention mask should be provided. Defaults to None.

  • freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of queries/keys. None if use_rope=False.

Returns:

Returns the output of the attention module.

Return type:

torch.Tensor