emmi.modules.attention.transolver_plusplus_attention

Classes

Module Contents

class emmi.modules.attention.transolver_plusplus_attention.TransolverPlusPlusAttention(config)

Bases: torch.nn.Module

Transolver++ Attention module as implemented in https://github.com/thuml/Transolver_plus/blob/main/models/Transolver_plus.py

Initialize the TransolverPlusPlusAttention module.

Parameters:

config (emmi.schemas.modules.attention.AttentionConfig) – Configuration object for the attention module.

dim_head
num_heads = None
scale
softmax
dropout
bias
proj_temperature
in_project_x
in_project_slice
qkv
to_out
forward(x, attn_mask=None)

Forward pass of the Transolver attention module.

Parameters:
  • x (torch.Tensor) – Input tensor with shape (batch_size, seqlen, hidden_dim).

  • attn_mask (torch.Tensor | None) – Attention mask tensor with shape (batch_size). Defaults to None.

Returns:

Tensor after applying the transolver attention mechanism.