emmi_inference.models.modules.supernode_pooling_posonly

Classes

SupernodePoolingPosonly

Supernode pooling layer.

Module Contents

class emmi_inference.models.modules.supernode_pooling_posonly.SupernodePoolingPosonly(hidden_dim, ndim, radius=None, k=None, max_degree=32, mode='relpos')

Bases: torch.nn.Module

Supernode pooling layer.

The permutation of the supernodes is preserved through the message passing (contrary to the (GP-)UPT code). Additionally, radius is used instead of radius_graph, which is more efficient.

Parameters:
  • radius (float | None) – Radius around each supernode. From points within this radius, messages are passed to the supernode.

  • k (int | None) – Numer of neighbors for each supernode. From the k-NN points, messages are passed to the supernode.

  • hidden_dim (int) – Hidden dimension for positional embeddings, messages and the resulting output vector.

  • ndim (int) – Number of positional dimension (e.g., ndim=2 for a 2D position, ndim=3 for a 3D position)

  • max_degree (int) – Maximum degree of the radius graph. Defaults to 32.

  • mode (str) – Are positions embedded in absolute space (“abspos”) or relative space (“relpos”). “readd_supernode_pos” will always use the absolute position.

  • readd_supernode_pos – If true, the absolute positional encoding of the supernode is concated to the supernode vector after message passing and linearly projected back to hidden_dim. Defaults to True.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

radius = None
k = None
max_degree = 32
hidden_dim
ndim
mode = 'relpos'
pos_embed
message
proj
output_dim
compute_src_and_dst_indices(input_pos, supernode_idx, batch_idx=None)

Compute the source and destination indices for the message passing to the supernodes.

Parameters:
  • input_pos (torch.Tensor) – Sparse tensor with shape (batch_size * numner of points, 3), representing the input geometries.

  • supernode_idx (torch.Tensor) – Indexes of the supernodes in the sparse tensor input_pos.

  • batch_idx (torch.Tensor | None) – 1D tensor, containing the batch index of each entry in input_pos. Default None.

Returns:

Tensor with src and destination indexes for the message passing into the supernodes.

Return type:

tuple[torch.Tensor, torch.Tensor]

create_messages(input_pos, src_idx, dst_idx, supernode_idx)

Create messages for the message passing to the supernodes, based on different positional encoding representations.

Parameters:
  • input_pos (torch.Tensor) – Tensor of shape (batch_size * number_of_points_per_sample, {2,3}), representing the point cloud representation of the input geometry.

  • src_idx (torch.Tensor) – Index of the source nodes from input_pos.

  • dst_idx (torch.Tensor) – Source index of the destination nodes from input_pos tensor. These indexes should be the matching supernode indexes.

  • supernode_idx (torch.Tensor) – Indexes of the node in input_pos that are considered supernodes.

Raises:

NotImplementedError – Raised if the mode is not implemented. Either “abspos” or “relpos” are allowed.

Returns:

Tensor with messages for the message passing into the super nodes and the embedding coordinates of the

supernodes.

Return type:

tuple[torch.Tensor, torch.Tensor]

static accumulate_messages(x, dst_idx, supernode_idx, batch_idx=None)

Method the accumulate the messages of neighbouring points into the supernodes.

Parameters:
  • x (torch.Tensor) – Tensor containing the message representation of each neighbour representation.

  • dst_idx (torch.Tensor) – Index of the destination (i.e., supernode) where each message should go to.

  • supernode_idx (torch.Tensor) – Indexes of the supernode in the input point cloud.

  • batch_idx (torch.Tensor | None) – Batch index of the points in the sparse tensor.

Returns:

Tensor with the aggregated messages for each supernode.

Return type:

tuple[torch.Tensor, int]

forward(input_pos, supernode_idx, batch_idx=None)

Forward pass of the supernode pooling layer.

Parameters:
  • input_pos (torch.Tensor) – Sparse tensor with shape (batch_size * number_of_points_per_sample, 3), representing the point cloud representation of the input geometry.

  • supernode_idx (torch.Tensor) – indexes of the supernodes in the sparse tensor input_pos.

  • batch_idx (torch.Tensor | None) – 1D tensor, containing the batch index of each entry in input_pos. Default None.

Returns:

Tensor with the aggregated messages for each supernode.

Return type:

torch.Tensor