emmi.modules.attention.dot_product_attention¶
Classes¶
Scaled dot-product attention module. |
Module Contents¶
- class emmi.modules.attention.dot_product_attention.DotProductAttention(config)¶
Bases:
torch.nn.ModuleScaled dot-product attention module.
Initialize the DotProductAttention module.
- Parameters:
config (emmi.schemas.modules.attention.AttentionConfig) – configuration of the attention module.
- num_heads = None¶
- head_dim¶
- init_weights = None¶
- use_rope = None¶
- qkv¶
- proj¶
- reset_parameters()¶
Reset the parameters of the DotProductAttention module with a specific initialization method.
- Raises:
NotImplementedError – when a specific initialization method is not implemented. Either use “torch”, or “truncnormal002”.
- Return type:
None
- forward(x, attn_mask=None, freqs=None)¶
Forward function of the DotProductAttention module.
- Parameters:
x (torch.Tensor) – Tensor to apply self-attention over, shape (batch size, sequence length, hidden_dim).
attn_mask (torch.Tensor | None) – For causal attention (i.e., no attention over the future token) a attention mask should be provided. Defaults to None.
freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of queries/keys. None if use_rope=False.
- Returns:
Returns the output of the attention module.
- Return type:
torch.Tensor