emmi_inference.models.modules.attention.dot_product_attention¶
Classes¶
Scaled dot-product attention module. |
Module Contents¶
- class emmi_inference.models.modules.attention.dot_product_attention.DotProductAttention(dim, num_heads=8)¶
Bases:
torch.nn.ModuleScaled dot-product attention module.
- Parameters:
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- dim¶
- num_heads = 8¶
- head_dim¶
- qkv¶
- proj¶
- forward(x, freqs)¶
Forward function of the DotProductAttention module.
- Parameters:
x (torch.Tensor) – Tensor to apply self-attention over, shape (batch size, sequence length, dim).
freqs (torch.Tensor) – Frequencies for Rotary Positional Embedding (RoPE) of queries/keys.
- Returns:
(batch_size, sequence_length, dim)
- Return type:
torch.Tensor