emmi_inference.models.modules.rope_frequency

Classes

RopeFrequency

Creates frequencies for rotary embeddings (RoPE) from https://arxiv.org/abs/2104.09864 for variable positions.

Functions

maybe_autocast(device, enabled[, dtype])

Module Contents

emmi_inference.models.modules.rope_frequency.maybe_autocast(device, enabled, dtype=None)
Parameters:
  • device (torch.device)

  • enabled (bool)

  • dtype (torch.dtype | None)

class emmi_inference.models.modules.rope_frequency.RopeFrequency(dim, ndim, max_wavelength=10000.0, assert_positive=True)

Bases: torch.nn.Module

Creates frequencies for rotary embeddings (RoPE) from https://arxiv.org/abs/2104.09864 for variable positions.

Parameters:
  • dim (int) – Dimensionality of frequencies (in transformers this should be the head dimension).

  • ndim (int) – Dimensionality of the coordinates (e.g., 2 for 2D coordinates, 3 for 3D coordinates).

  • max_wavelength (int) – Theta parameter for the transformer sine/cosine embedding. Default: 10000.0

  • assert_positive (bool) – Makes sure that coordinates were rescaled to be positive only. Default: True

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dim
ndim
ndim_padding
sincos_padding
max_wavelength = 10000.0
padding
assert_positive = True
forward(coords)
Parameters:

coords (torch.Tensor)

Return type:

torch.Tensor