attention.config

Classes

AttentionConfig

Configuration for an attention module. Since we can have many different attention implementations, we allow extra fields. such that we can use the same schema for all attention modules.

DotProductAttentionConfig

Configuration for the Dot Product attention module.

TransolverAttentionConfig

Configuration for the Transolver attention module.

TransolverPlusPlusAttentionConfig

Configuration for the Transolver++ attention module.

IrregularNatAttentionConfig

Configuration for the Irregular Neighbourhood Attention Transformer (NAT) attention module.

PerceiverAttentionConfig

Configuration for the Perceiver attention module.

Module Contents

class attention.config.AttentionConfig(/, **data)

Bases: pydantic.BaseModel

Configuration for an attention module. Since we can have many different attention implementations, we allow extra fields. such that we can use the same schema for all attention modules.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

class Config
extra = 'allow'
hidden_dim: int = None

Dimensionality of the hidden features.

num_heads: int = None

Number of attention heads.

use_rope: bool = None

Whether to use Rotary Positional Embeddings (RoPE).

dropout: float = None

Dropout rate for the attention weights and output projection.

init_weights: emmi.types.InitWeightsMode = None

Weight initialization strategy.

bias: bool = None

Whether to use bias terms in linear layers.

head_dim: int | None = None

Dimensionality of each attention head.

validate_hidden_dim_and_num_heads()
class attention.config.DotProductAttentionConfig(/, **data)

Bases: AttentionConfig

Configuration for the Dot Product attention module.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

class Config
extra = None
class attention.config.TransolverAttentionConfig(/, **data)

Bases: AttentionConfig

Configuration for the Transolver attention module.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

num_slices: int = None

Number of slices to project the input tokens to.

class Config
extra = None
class attention.config.TransolverPlusPlusAttentionConfig(/, **data)

Bases: TransolverAttentionConfig

Configuration for the Transolver++ attention module.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

class Config
extra = None
use_overparameterization: bool = None

Whether to use overparameterization for the slice projection.

use_adaptive_temperature: bool = None

Whether to use an adaptive temperature for the slice selection.

temperature_activation: Literal['sigmoid', 'softplus', 'exp'] | None = None

Activation function for the adaptive temperature.

use_gumbel_softmax: bool = None

Whether to use Gumbel-Softmax for the slice selection.

class attention.config.IrregularNatAttentionConfig(/, **data)

Bases: AttentionConfig

Configuration for the Irregular Neighbourhood Attention Transformer (NAT) attention module.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

class Config
extra = None
input_dim: int = None

Dimensionality of the input features.

radius: float = None

Radius for the radius graph.

max_degree: int = None

Maximum number of neighbors per point.

relpos_mlp_hidden_dim: int = None

Hidden dimensionality of the relative position bias MLP.

relpos_mlp_dropout: float = None

Dropout rate for the relative position bias MLP.

class attention.config.PerceiverAttentionConfig(/, **data)

Bases: AttentionConfig

Configuration for the Perceiver attention module.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

class Config
extra = None
kv_dim: int | None = None

Dimensionality of the key/value features. If None, use hidden_dim.

set_kv_dim()