emmi.modules.layers.layer_scale

Classes

LayerScale

LayerScale module scales the input tensor by a learnable parameter gamma.

Module Contents

class emmi.modules.layers.layer_scale.LayerScale(config)

Bases: torch.nn.Module

LayerScale module scales the input tensor by a learnable parameter gamma.

Initialize the LayerScale module. :param hidden_dim: Number of dimensions of the input tensor to be scaled. :param init_scale: Initial gamme scale value. Defaults to 1e-5.

Parameters:

config (emmi.schemas.modules.layers.LayerScaleConfig)

forward(x)

Forward function of the LayerScale module.

Parameters:

x (torch.Tensor) – Input tensor to be scaled.

Returns:

Tensor scaled by the gamma parameter.

Return type:

torch.Tensor