emmi_inference.models.modules.attention.dot_product_attention

Classes

DotProductAttention

Scaled dot-product attention module.

Module Contents

class emmi_inference.models.modules.attention.dot_product_attention.DotProductAttention(dim, num_heads=8)

Bases: torch.nn.Module

Scaled dot-product attention module.

Parameters:
  • dim (int) – Input dimension of the attention module.

  • num_heads (int) – Number of attention heads. Defaults to 8.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dim
num_heads = 8
head_dim
qkv
proj
forward(x, freqs)

Forward function of the DotProductAttention module.

Parameters:
  • x (torch.Tensor) – Tensor to apply self-attention over, shape (batch size, sequence length, dim).

  • freqs (torch.Tensor) – Frequencies for Rotary Positional Embedding (RoPE) of queries/keys.

Returns:

(batch_size, sequence_length, dim)

Return type:

torch.Tensor