emmi.modules.blocks.perceiver_block¶
Classes¶
For a self-attention module, the input tensor for the query, key, and value are the same. The PerceiverBlock, |
Module Contents¶
- class emmi.modules.blocks.perceiver_block.PerceiverBlock(config)¶
Bases:
torch.nn.ModuleFor a self-attention module, the input tensor for the query, key, and value are the same. The PerceiverBlock, takes different input tensors for the query and the key/value.
Perceiver-style cross-attention block.
- Parameters:
config (emmi.schemas.modules.blocks.PerceiverBlockConfig) – Configuration of the PerceiverBlock.
- norm1q¶
- norm1kv¶
- attn¶
- ls1¶
- drop_path1¶
- norm2¶
- mlp¶
- ls2¶
- drop_path2¶
- forward(q, kv, condition=None, attn_kwargs=None)¶
Forward pass of the PerceiverBlock.
- Parameters:
q (torch.Tensor) – Input tensor with shape (batch_size, seqlen/num_tokens, hidden_dim) for the query representations.
kv (torch.Tensor) – Input tensor with shape (batch_size, seqlen/num_tokens, hidden_dim) for the key and value representations.
condition (torch.Tensor | None) – Conditioning vector. If provided, the attention and MLP will be scaled, shifted and gated feature-wise with predicted values from this vector.
attn_kwargs (dict[str, Any] | None) – Dict with arguments for the attention (such as the attention mask). Defaults to None.
- Returns:
Tensor after the forward pass of the PerceiverBlock.
- Return type:
torch.Tensor