emmi.modules.decoder.deep_perceiver_decoder

Contains a PerceiverDecoder implementation.

Classes

DeepPerceiverDecoder

A deep Perceiver decoder module. Can be configured with different number of layers and hidden dimensions.

Module Contents

class emmi.modules.decoder.deep_perceiver_decoder.DeepPerceiverDecoder(config)

Bases: torch.nn.Module

A deep Perceiver decoder module. Can be configured with different number of layers and hidden dimensions. However, it should be noted that this layer is not a full-fledged Perceiver, since it only has a cross-attention mechanism.

Initialize the DeepPerceiverDecoder.

Parameters:

config (emmi.schemas.modules.decoder.DeepPerceiverDecoderConfig) – Configuration for the DeepPerceiverDecoder module.

blocks
forward(kv, queries, unbatch_mask=None, attn_kwargs=None, condition=None)

Forward pass of the model.

Parameters:
  • x – Latent tokens as dense tensor (batch_size, num_latent_tokens, dim).

  • pos – Query positions (batch_size, num_output_pos, pos_dim).

  • block_kwargs – Additional arguments for the block.

  • unbatch_mask (torch.Tensor | None) – Unbatch mask.

  • kv (torch.Tensor)

  • queries (torch.Tensor)

  • attn_kwargs (dict[str, Any] | None)

  • condition (torch.Tensor | None)

Returns:

The predictions as sparse tensor (batch_size * num_output_pos, num_out_values).

Return type:

torch.Tensor