emmi.modules.layers.scalar_conditioner

Classes

ScalarsConditioner

Base class for all neural network modules.

Module Contents

class emmi.modules.layers.scalar_conditioner.ScalarsConditioner(hidden_dim, num_scalars, condition_dim=None, init_weights='truncnormal')

Bases: torch.nn.Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables:

training (bool) – Boolean represents whether this module is in training or evaluation mode.

Parameters:
  • hidden_dim (int)

  • num_scalars (int)

  • condition_dim (int | None)

  • init_weights (str)

Embeds num_scalars scalars into a single conditioning vector via first encoding every scalar with sine-cosine embeddings followed by a mlp (per scalar). These vectors are then concatenated and projected down to condition_dim with an MLP.

Parameters:
  • hidden_dim (int) – Dimension for embedding the scalars and the per-scalar MLP.

  • num_scalars (int) – How many scalars are embedded.

  • condition_dim (int | None) – Dimension of the final conditioning vector. Defaults to 4 * dim if condition_dim is None.

  • init_weights (str) – Weight initialization for MLPs.

hidden_dim
num_scalars
condition_dim
embed
mlps
shared_mlp
forward(*args, **kwargs)

Embeds scalars into a single conditioning vector. Scalars can be passed as *args or as **kwargs. It is recommended to use kwargs to avoid bugs that originate from passing scalars in a different order at two locations in the code. Recommended usage: condition = conditioner(geometry_angle=75.3, friction_angle=24.6) :param *args: Scalars in tensor representation (batch_size,) or (batch_size, 1). :param **kwargs: Scalars in tensor representation (batch_size,) or (batch_size, 1).

Returns:

Conditioning vector with shape (batch_size, condition_dim)

Parameters:
  • args (tuple[torch.Tensor, Ellipsis])

  • kwargs (dict[str, torch.Tensor])

Return type:

torch.Tensor