emmi_inference.models.modules.supernode_pooling_posonly ======================================================= .. py:module:: emmi_inference.models.modules.supernode_pooling_posonly Classes ------- .. autoapisummary:: emmi_inference.models.modules.supernode_pooling_posonly.SupernodePoolingPosonly Module Contents --------------- .. py:class:: SupernodePoolingPosonly(hidden_dim, ndim, radius = None, k = None, max_degree = 32, mode = 'relpos') Bases: :py:obj:`torch.nn.Module` Supernode pooling layer. The permutation of the supernodes is preserved through the message passing (contrary to the (GP-)UPT code). Additionally, radius is used instead of radius_graph, which is more efficient. :param radius: Radius around each supernode. From points within this radius, messages are passed to the supernode. :param k: Numer of neighbors for each supernode. From the k-NN points, messages are passed to the supernode. :param hidden_dim: Hidden dimension for positional embeddings, messages and the resulting output vector. :param ndim: Number of positional dimension (e.g., ndim=2 for a 2D position, ndim=3 for a 3D position) :param max_degree: Maximum degree of the radius graph. Defaults to 32. :param mode: Are positions embedded in absolute space ("abspos") or relative space ("relpos"). "readd_supernode_pos" will always use the absolute position. :param readd_supernode_pos: If true, the absolute positional encoding of the supernode is concated to the supernode vector after message passing and linearly projected back to hidden_dim. Defaults to True. Initialize internal Module state, shared by both nn.Module and ScriptModule. .. py:attribute:: radius :value: None .. py:attribute:: k :value: None .. py:attribute:: max_degree :value: 32 .. py:attribute:: hidden_dim .. py:attribute:: ndim .. py:attribute:: mode :value: 'relpos' .. py:attribute:: pos_embed .. py:attribute:: message .. py:attribute:: proj .. py:attribute:: output_dim .. py:method:: compute_src_and_dst_indices(input_pos, supernode_idx, batch_idx = None) Compute the source and destination indices for the message passing to the supernodes. :param input_pos: Sparse tensor with shape (batch_size * numner of points, 3), representing the input geometries. :param supernode_idx: Indexes of the supernodes in the sparse tensor input_pos. :param batch_idx: 1D tensor, containing the batch index of each entry in input_pos. Default None. :returns: Tensor with src and destination indexes for the message passing into the supernodes. .. py:method:: create_messages(input_pos, src_idx, dst_idx, supernode_idx) Create messages for the message passing to the supernodes, based on different positional encoding representations. :param input_pos: Tensor of shape (batch_size * number_of_points_per_sample, {2,3}), representing the point cloud representation of the input geometry. :param src_idx: Index of the source nodes from input_pos. :param dst_idx: Source index of the destination nodes from input_pos tensor. These indexes should be the matching supernode indexes. :param supernode_idx: Indexes of the node in input_pos that are considered supernodes. :raises NotImplementedError: Raised if the mode is not implemented. Either "abspos" or "relpos" are allowed. :returns: Tensor with messages for the message passing into the super nodes and the embedding coordinates of the supernodes. .. py:method:: accumulate_messages(x, dst_idx, supernode_idx, batch_idx = None) :staticmethod: Method the accumulate the messages of neighbouring points into the supernodes. :param x: Tensor containing the message representation of each neighbour representation. :param dst_idx: Index of the destination (i.e., supernode) where each message should go to. :param supernode_idx: Indexes of the supernode in the input point cloud. :param batch_idx: Batch index of the points in the sparse tensor. :returns: Tensor with the aggregated messages for each supernode. .. py:method:: forward(input_pos, supernode_idx, batch_idx = None) Forward pass of the supernode pooling layer. :param input_pos: Sparse tensor with shape (batch_size * number_of_points_per_sample, 3), representing the point cloud representation of the input geometry. :param supernode_idx: indexes of the supernodes in the sparse tensor input_pos. :param batch_idx: 1D tensor, containing the batch index of each entry in input_pos. Default None. :returns: Tensor with the aggregated messages for each supernode.