emmi.functional.modulation

Functions

modulate_scale_shift(x, scale, shift)

Scales and shifts the input x featurewise with scale and shift. Scale is 1 by default and the scale tensor is

modulate_gate(x, gate)

Gates the input x feature-wise with gate.

Module Contents

emmi.functional.modulation.modulate_scale_shift(x, scale, shift)

Scales and shifts the input x featurewise with scale and shift. Scale is 1 by default and the scale tensor is the offset from the default, i.e., if scale == 0 and shift == 0 this method is equivalent to the identity.

Parameters:
  • x (torch.Tensor) – Input tensor (e.g., input to a transformer block with shape (batch_size, sequence_length, dim)).

  • scale (torch.Tensor) – Scale tensor with shape (batch_size, dim) or (batch_size, 1 dim).

  • shift (torch.Tensor) – Shift tensor with shape (batch_size, dim) or (batch_size, 1 dim).

Return type:

torch.Tensor

emmi.functional.modulation.modulate_gate(x, gate)

Gates the input x feature-wise with gate.

Parameters:
  • x (torch.Tensor) – Input tensor (e.g., input to a transformer block with shape (batch_size, sequence_length, dim)).

  • gate (torch.Tensor) – Gate tensor with shape (batch_size, dim) or (batch_size, 1 dim).

Return type:

torch.Tensor