emmi.modules.layers.transformer_batchnorm ========================================= .. py:module:: emmi.modules.layers.transformer_batchnorm Classes ------- .. autoapisummary:: emmi.modules.layers.transformer_batchnorm.TransformerBatchNorm Module Contents --------------- .. py:class:: TransformerBatchNorm(num_features, eps = 1e-05, elementwise_affine = True, bias = True) Bases: :py:obj:`torch.nn.Module` Wrapper around `torch.nn.BatchNorm1d` that considers all tokens of a single sample as the full batch. Additionally remaps `affine` to `elementwise_affine` and supports disabling bias to comply with the `torch.nn.LayerNorm` interface. Does not use any nn.BatchNorm1d modules to avoid errors with nn.SyncBatchnorm. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:attribute:: num_features .. py:attribute:: eps :value: 1e-05 .. py:attribute:: elementwise_affine :value: True .. py:attribute:: bias :value: True .. py:method:: forward(x) BatchNorm1d where all tokens of a single sample correspond to a full batch. :param x: Tensor of shape (batch_size, seqlen, dim). :returns: Normalized x of shape (batch_size, seqlen, dim).