emmi_inference.models.modules.blocks.transformer_block ====================================================== .. py:module:: emmi_inference.models.modules.blocks.transformer_block Classes ------- .. autoapisummary:: emmi_inference.models.modules.blocks.transformer_block.TransformerBlock Module Contents --------------- .. py:class:: TransformerBlock(dim, num_heads, attn_ctor = DotProductAttention) Bases: :py:obj:`torch.nn.Module` A transformer block with a single attention layer and a feedforward layer. :param dim: hidden Dimension of the transformer block. :param num_heads: Number of attention heads. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:attribute:: norm1 .. py:attribute:: attn .. py:attribute:: norm2 .. py:attribute:: mlp .. py:method:: forward(x, attn_kwargs = None) Forward pass of the transformer block. :param x: Input tensor with shape (batch_size, seqlen/num_tokens, dim). :param attn_kwargs: Dict with arguments for the attention (such as the rope frequencies). Defaults to None. :returns: (batch_size, num_tokens, dim)