emmi.modules.decoder.deep_perceiver_decoder¶
Contains a PerceiverDecoder implementation.
Classes¶
A deep Perceiver decoder module. Can be configured with different number of layers and hidden dimensions. |
Module Contents¶
- class emmi.modules.decoder.deep_perceiver_decoder.DeepPerceiverDecoder(config)¶
Bases:
torch.nn.ModuleA deep Perceiver decoder module. Can be configured with different number of layers and hidden dimensions. However, it should be noted that this layer is not a full-fledged Perceiver, since it only has a cross-attention mechanism.
Initialize the DeepPerceiverDecoder.
- Parameters:
config (emmi.schemas.modules.decoder.DeepPerceiverDecoderConfig) – Configuration for the DeepPerceiverDecoder module.
- blocks¶
- forward(kv, queries, unbatch_mask=None, attn_kwargs=None, condition=None)¶
Forward pass of the model.
- Parameters:
x – Latent tokens as dense tensor (batch_size, num_latent_tokens, dim).
pos – Query positions (batch_size, num_output_pos, pos_dim).
block_kwargs – Additional arguments for the block.
unbatch_mask (torch.Tensor | None) – Unbatch mask.
kv (torch.Tensor)
queries (torch.Tensor)
condition (torch.Tensor | None)
- Returns:
The predictions as sparse tensor (batch_size * num_output_pos, num_out_values).
- Return type:
torch.Tensor