emmi.modules.layers.scalar_conditioner¶
Classes¶
Base class for all neural network modules. |
Module Contents¶
- class emmi.modules.layers.scalar_conditioner.ScalarsConditioner(hidden_dim, num_scalars, condition_dim=None, init_weights='truncnormal')¶
Bases:
torch.nn.ModuleBase class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call
to(), etc.Note
As per the example above, an
__init__()call to the parent class must be made before assignment on the child.- Variables:
training (bool) – Boolean represents whether this module is in training or evaluation mode.
- Parameters:
Embeds num_scalars scalars into a single conditioning vector via first encoding every scalar with sine-cosine embeddings followed by a mlp (per scalar). These vectors are then concatenated and projected down to condition_dim with an MLP.
- Parameters:
hidden_dim (int) – Dimension for embedding the scalars and the per-scalar MLP.
num_scalars (int) – How many scalars are embedded.
condition_dim (int | None) – Dimension of the final conditioning vector. Defaults to 4 * dim if condition_dim is None.
init_weights (str) – Weight initialization for MLPs.
- num_scalars¶
- condition_dim¶
- embed¶
- mlps¶
- forward(*args, **kwargs)¶
Embeds scalars into a single conditioning vector. Scalars can be passed as *args or as **kwargs. It is recommended to use kwargs to avoid bugs that originate from passing scalars in a different order at two locations in the code. Recommended usage: condition = conditioner(geometry_angle=75.3, friction_angle=24.6) :param *args: Scalars in tensor representation (batch_size,) or (batch_size, 1). :param **kwargs: Scalars in tensor representation (batch_size,) or (batch_size, 1).