emmi.modules.attention.transolver_attention =========================================== .. py:module:: emmi.modules.attention.transolver_attention Classes ------- .. autoapisummary:: emmi.modules.attention.transolver_attention.TransolverAttention Module Contents --------------- .. py:class:: TransolverAttention(config) Bases: :py:obj:`torch.nn.Module` Adapted from https://github.com/thuml/Transolver/blob/main/Car-Design-ShapeNetCar/models/Transolver.py - Readable reshaping operations via einops - Merged qkv linear layer for higher GPU utilization - F.scaled_dot_product_attention instead of slow pytorch attention - Possibility to mask tokens (required to process variable sized inputs) Initialize the Transolver attention module. :param config: configuration of the attention module. .. py:attribute:: num_heads :value: None .. py:attribute:: dropout :value: None .. py:attribute:: temperature .. py:attribute:: in_project_x .. py:attribute:: in_project_fx .. py:attribute:: in_project_slice .. py:attribute:: qkv .. py:attribute:: proj .. py:attribute:: proj_dropout .. py:method:: create_slices(x, num_input_points, attn_mask = None) Given a set of points, project them to a fixed number of slices using the computed the slice weights per token. :param x: Input tensor with shape (batch_size, num_input_points, hidden_dim). :param num_input_points: Number of input points. :param attn_mask: Mask to mask out certain token for the attention. Defaults to None. :returns: Tensor with the projected slice tokens and the slice weights. .. py:method:: forward(x, attn_mask = None) Forward pass of the Transolver attention module. :param x: Input tensor with shape (batch_size, seqlen, hidden_dim). :param attn_mask: Attention mask tensor with shape (batch_size). Defaults to None. :returns: Tensor after applying the transolver attention mechanism.