emmi.modules.layers.transformer_batchnorm

Classes

TransformerBatchNorm

Wrapper around torch.nn.BatchNorm1d that considers all tokens of a single sample as the full batch.

Module Contents

class emmi.modules.layers.transformer_batchnorm.TransformerBatchNorm(num_features, eps=1e-05, elementwise_affine=True, bias=True)

Bases: torch.nn.Module

Wrapper around torch.nn.BatchNorm1d that considers all tokens of a single sample as the full batch. Additionally remaps affine to elementwise_affine and supports disabling bias to comply with the torch.nn.LayerNorm interface. Does not use any nn.BatchNorm1d modules to avoid errors with nn.SyncBatchnorm.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
num_features
eps = 1e-05
elementwise_affine = True
bias = True
forward(x)

BatchNorm1d where all tokens of a single sample correspond to a full batch.

Parameters:

x (torch.Tensor) – Tensor of shape (batch_size, seqlen, dim).

Returns:

Normalized x of shape (batch_size, seqlen, dim).

Return type:

torch.Tensor