emmi.modules.encoder.supernode_pooling_posonly ============================================== .. py:module:: emmi.modules.encoder.supernode_pooling_posonly Classes ------- .. autoapisummary:: emmi.modules.encoder.supernode_pooling_posonly.SupernodePoolingPosonly Module Contents --------------- .. py:class:: SupernodePoolingPosonly(config) Bases: :py:obj:`torch.nn.Module` Supernode pooling layer which operates only on positional information, no additional node features are used. The permutation of the supernodes is preserved through the message passing (contrary to the (GP-)UPT code). Additionally, radius is used instead of radius_graph, which is more efficient. Initialize the SupernodePoolingPosonly. :param radius: Radius around each supernode. From points within this radius, messages are passed to the supernode. :param k: Numer of neighbors for each supernode. From the k-NN points, messages are passed to the supernode. :param hidden_dim: Hidden dimension for positional embeddings, messages and the resulting output vector. :param input_dim: Number of positional dimension which is last tensor dimension (e.g., input_dim=2 for a 2D position, input_dim=3 for a 3D position) :param max_degree: Maximum degree of the radius graph. Defaults to 32. :param mode: Are positions embedded in absolute space ("abspos"), relative space ("relpos") or both ("absrelpos"). "readd_supernode_pos" will always use the absolute position. :param init_weights: Weight initialization of linear layers. Defaults to "truncnormal002". :param readd_supernode_pos: If true, the absolute positional encoding of the supernode is concated to the supernode vector after message passing and linearly projected back to hidden_dim. Defaults to True. :param aggregation: Aggregation for message passing ("mean" or "sum"). :param message_mode: How messages created. "mlp" (2 layer MLP), "linear" (nn.Linear), "identity" (nn.Identity). Defaults to "mlp". .. py:attribute:: radius .. py:attribute:: spool_pos_mode .. py:attribute:: max_degree .. py:attribute:: aggregation .. py:attribute:: readd_supernode_pos .. py:attribute:: k .. py:attribute:: pos_embed .. py:attribute:: output_dim .. py:method:: compute_src_and_dst_indices(input_pos, supernode_idx, batch_idx = None) Compute the source and destination indices for the message passing to the supernodes. :param input_pos: Sparse tensor with shape (batch_size * numner of points, 3), representing the input geometries. :param supernode_idx: Indexes of the supernodes in the sparse tensor input_pos. :param batch_idx: 1D tensor, containing the batch index of each entry in input_pos. Default None. :returns: Tensor with src and destination indexes for the message passing into the supernodes. .. py:method:: create_messages(input_pos, src_idx, dst_idx, supernode_idx) Create messages for the message passing to the supernodes, based on different positional encoding representations. :param input_pos: Tensor of shape (batch_size * number_of_points_per_sample, {2,3}), representing the point cloud representation of the input geometry. :param src_idx: Index of the source nodes from input_pos. :param dst_idx: Source index of the destination nodes from input_pos tensor. These indexes should be the matching supernode indexes. :param supernode_idx: Indexes of the node in input_pos that are considered supernodes. :raises NotImplementedError: Raised if the mode is not implemented. Either "abspos", "relpos" or "absrelpos" are allowed. :returns: Tensor with messages for the message passing into the super nodes and the embedding coordinates of the supernodes. .. py:method:: accumulate_messages(x, dst_idx, supernode_idx, batch_idx = None) Method the accumulate the messages of neighbouring points into the supernodes. :param x: Tensor containing the message representation of each neighbour representation. :param dst_idx: Index of the destination (i.e., supernode) where each message should go to. :param supernode_idx: Indexes of the supernode in the input point cloud. :param batch_idx: Batch index of the points in the sparse tensor. :returns: Tensor with the aggregated messages for each supernode. .. py:method:: forward(input_pos, supernode_idx, batch_idx = None) Forward pass of the supernode pooling layer. :param input_pos: Sparse tensor with shape (batch_size * number_of_points_per_sample, 3), representing the point cloud representation of the input geometry. :param supernode_idx: indexes of the supernodes in the sparse tensor input_pos. :param batch_idx: 1D tensor, containing the batch index of each entry in input_pos. Default None. :returns: Tensor with the aggregated messages for each supernode.