emmi.modules.attention.transolver_attention¶
Classes¶
Module Contents¶
- class emmi.modules.attention.transolver_attention.TransolverAttention(config)¶
Bases:
torch.nn.ModuleAdapted from https://github.com/thuml/Transolver/blob/main/Car-Design-ShapeNetCar/models/Transolver.py - Readable reshaping operations via einops - Merged qkv linear layer for higher GPU utilization - F.scaled_dot_product_attention instead of slow pytorch attention - Possibility to mask tokens (required to process variable sized inputs)
Initialize the Transolver attention module.
- Parameters:
config (emmi.schemas.modules.attention.AttentionConfig) – configuration of the attention module.
- num_heads = None¶
- dropout = None¶
- temperature¶
- in_project_x¶
- in_project_fx¶
- in_project_slice¶
- qkv¶
- proj¶
- proj_dropout¶
- create_slices(x, num_input_points, attn_mask=None)¶
Given a set of points, project them to a fixed number of slices using the computed the slice weights per token.
- Parameters:
x (torch.Tensor) – Input tensor with shape (batch_size, num_input_points, hidden_dim).
num_input_points (int) – Number of input points.
attn_mask (torch.Tensor | None) – Mask to mask out certain token for the attention. Defaults to None.
- Returns:
Tensor with the projected slice tokens and the slice weights.
- forward(x, attn_mask=None)¶
Forward pass of the Transolver attention module.
- Parameters:
x (torch.Tensor) – Input tensor with shape (batch_size, seqlen, hidden_dim).
attn_mask (torch.Tensor | None) – Attention mask tensor with shape (batch_size). Defaults to None.
- Returns:
Tensor after applying the transolver attention mechanism.