emmi.modules.encoder.supernode_pooling ====================================== .. py:module:: emmi.modules.encoder.supernode_pooling Classes ------- .. autoapisummary:: emmi.modules.encoder.supernode_pooling.SupernodePooling Module Contents --------------- .. py:class:: SupernodePooling(config) Bases: :py:obj:`torch.nn.Module` Supernode pooling layer. The permutation of the supernodes is preserved through the message passing (contrary to the (GP-)UPT code). Additionally, radius is used instead of radius_graph, which is more efficient. Initialize the SupernodePooling. :param config: Configuration for the SupernodePooling module. .. py:attribute:: radius .. py:attribute:: k .. py:attribute:: max_degree .. py:attribute:: spool_pos_mode .. py:attribute:: readd_supernode_pos .. py:attribute:: aggregation .. py:attribute:: num_input_features .. py:attribute:: pos_embed .. py:attribute:: output_dim .. py:method:: compute_src_and_dst_indices(input_pos, supernode_idx, batch_idx = None) Compute the source and destination indices for the message passing to the supernodes. :param input_pos: Sparse tensor with shape (batch_size * number of points, 3), representing the input geometries. :param supernode_idx: Indexes of the supernodes in the sparse tensor input_pos. :param batch_idx: 1D tensor, containing the batch index of each entry in input_pos. Default None. :returns: Tensor with src and destination indexes for the message passing into the supernodes. .. py:method:: create_messages(input_pos, src_idx, dst_idx, supernode_idx, input_features = None) Create messages for the message passing to the supernodes, based on different positional encoding representations. :param input_pos: Tensor of shape (batch_size * number_of_points_per_sample, {2,3}), representing the point cloud representation of the input geometry. :param src_idx: Index of the source nodes from input_pos. :param dst_idx: Source index of the destination nodes from input_pos tensor. These indexes should be the matching supernode indexes. :param supernode_idx: Indexes of the node in input_pos that are considered supernodes. :raises NotImplementedError: Raised if the mode is not implemented. Either "abspos", "relpos" or "absrelpos" are allowed. :returns: Tensor with messages for the message passing into the super nodes and the embedding coordinates of the supernodes. .. py:method:: accumulate_messages(x, dst_idx, supernode_idx, batch_idx = None) Method to accumulate the messages of neighbouring points into the supernodes. :param x: Tensor containing the message representation of each neighbour representation. :param dst_idx: Index of the destination (i.e., supernode) where each message should go to. :param supernode_idx: Indexes of the supernode in the input point cloud. :param batch_idx: Batch index of the points in the sparse tensor. :returns: Tensor with the aggregated messages for each supernode. .. py:method:: forward(input_pos, supernode_idx, batch_idx = None, input_features = None) Forward pass of the supernode pooling layer. :param input_pos: Sparse tensor with shape (batch_size * number_of_points_per_sample, 3), representing the point cloud representation of the input geometry. :param supernode_idx: indexes of the supernodes in the sparse tensor input_pos. :param batch_idx: 1D tensor, containing the batch index of each entry in input_pos. Default None. :param input_features: Sparse tensor with shape (batch_size * number_of_points_per_sample, number_of_features) :returns: Tensor with the aggregated messages for each supernode.