emmi.functional.rope

Functions

rope(x, freqs)

Applies Rotary Position Embeddings (RoPE)

Module Contents

emmi.functional.rope.rope(x, freqs)

Applies Rotary Position Embeddings (RoPE)

Parameters:
  • x (torch.Tensor) – Vector to rotate (e.g., queries or keys of a transformer). Shape=(batch_size, num_heads, seqlen, head_dim).

  • freqs (tuple) – Complex tensor of frequencies for rotating x.

  • freqs – Sine/cosine frequencies for rotating x. For 1D, freqs is a tuple with length 1 with shape

  • (batch_size

  • num_heads

  • rotate. (num_dim_to_rotate) where num_dim_to_rotate is the number of dimensions to)

  • dimensional (If positions are higher)

  • rotation. (corresponds to frequencies of the nth axis for)

Returns:

Rotated x.

Return type:

torch.Tensor