emmi.functional.modulation¶
Functions¶
|
Scales and shifts the input x featurewise with scale and shift. Scale is 1 by default and the scale tensor is |
|
Gates the input x feature-wise with gate. |
Module Contents¶
- emmi.functional.modulation.modulate_scale_shift(x, scale, shift)¶
Scales and shifts the input x featurewise with scale and shift. Scale is 1 by default and the scale tensor is the offset from the default, i.e., if scale == 0 and shift == 0 this method is equivalent to the identity.
- Parameters:
x (torch.Tensor) – Input tensor (e.g., input to a transformer block with shape (batch_size, sequence_length, dim)).
scale (torch.Tensor) – Scale tensor with shape (batch_size, dim) or (batch_size, 1 dim).
shift (torch.Tensor) – Shift tensor with shape (batch_size, dim) or (batch_size, 1 dim).
- Return type:
torch.Tensor
- emmi.functional.modulation.modulate_gate(x, gate)¶
Gates the input x feature-wise with gate.
- Parameters:
x (torch.Tensor) – Input tensor (e.g., input to a transformer block with shape (batch_size, sequence_length, dim)).
gate (torch.Tensor) – Gate tensor with shape (batch_size, dim) or (batch_size, 1 dim).
- Return type:
torch.Tensor