emmi.modules.encoder.supernode_pooling_posonly

Classes

SupernodePoolingPosonly

Supernode pooling layer which operates only on positional information, no additional node features are used.

Module Contents

class emmi.modules.encoder.supernode_pooling_posonly.SupernodePoolingPosonly(config)

Bases: torch.nn.Module

Supernode pooling layer which operates only on positional information, no additional node features are used.

The permutation of the supernodes is preserved through the message passing (contrary to the (GP-)UPT code). Additionally, radius is used instead of radius_graph, which is more efficient.

Initialize the SupernodePoolingPosonly.

Parameters:
  • radius – Radius around each supernode. From points within this radius, messages are passed to the supernode.

  • k – Numer of neighbors for each supernode. From the k-NN points, messages are passed to the supernode.

  • hidden_dim – Hidden dimension for positional embeddings, messages and the resulting output vector.

  • input_dim – Number of positional dimension which is last tensor dimension (e.g., input_dim=2 for a 2D position, input_dim=3 for a 3D position)

  • max_degree – Maximum degree of the radius graph. Defaults to 32.

  • mode – Are positions embedded in absolute space (“abspos”), relative space (“relpos”) or both (“absrelpos”). “readd_supernode_pos” will always use the absolute position.

  • init_weights – Weight initialization of linear layers. Defaults to “truncnormal002”.

  • readd_supernode_pos – If true, the absolute positional encoding of the supernode is concated to the supernode vector after message passing and linearly projected back to hidden_dim. Defaults to True.

  • aggregation – Aggregation for message passing (“mean” or “sum”).

  • message_mode – How messages created. “mlp” (2 layer MLP), “linear” (nn.Linear), “identity” (nn.Identity). Defaults to “mlp”.

  • config (emmi.schemas.modules.encoder.SupernodePoolingConfig)

radius
spool_pos_mode
max_degree
aggregation
readd_supernode_pos
k
pos_embed
output_dim
compute_src_and_dst_indices(input_pos, supernode_idx, batch_idx=None)

Compute the source and destination indices for the message passing to the supernodes.

Parameters:
  • input_pos (torch.Tensor) – Sparse tensor with shape (batch_size * numner of points, 3), representing the input geometries.

  • supernode_idx (torch.Tensor) – Indexes of the supernodes in the sparse tensor input_pos.

  • batch_idx (torch.Tensor | None) – 1D tensor, containing the batch index of each entry in input_pos. Default None.

Returns:

Tensor with src and destination indexes for the message passing into the supernodes.

Return type:

tuple[torch.Tensor, torch.Tensor]

create_messages(input_pos, src_idx, dst_idx, supernode_idx)

Create messages for the message passing to the supernodes, based on different positional encoding representations.

Parameters:
  • input_pos (torch.Tensor) – Tensor of shape (batch_size * number_of_points_per_sample, {2,3}), representing the point cloud representation of the input geometry.

  • src_idx (torch.Tensor) – Index of the source nodes from input_pos.

  • dst_idx (torch.Tensor) – Source index of the destination nodes from input_pos tensor. These indexes should be the matching supernode indexes.

  • supernode_idx (torch.Tensor) – Indexes of the node in input_pos that are considered supernodes.

Raises:

NotImplementedError – Raised if the mode is not implemented. Either “abspos”, “relpos” or “absrelpos” are allowed.

Returns:

Tensor with messages for the message passing into the super nodes and the embedding coordinates of the

supernodes.

Return type:

tuple[torch.Tensor, torch.Tensor]

accumulate_messages(x, dst_idx, supernode_idx, batch_idx=None)

Method the accumulate the messages of neighbouring points into the supernodes.

Parameters:
  • x (torch.Tensor) – Tensor containing the message representation of each neighbour representation.

  • dst_idx (torch.Tensor) – Index of the destination (i.e., supernode) where each message should go to.

  • supernode_idx (torch.Tensor) – Indexes of the supernode in the input point cloud.

  • batch_idx (torch.Tensor | None) – Batch index of the points in the sparse tensor.

Returns:

Tensor with the aggregated messages for each supernode.

Return type:

tuple[torch.Tensor, int]

forward(input_pos, supernode_idx, batch_idx=None)

Forward pass of the supernode pooling layer.

Parameters:
  • input_pos (torch.Tensor) – Sparse tensor with shape (batch_size * number_of_points_per_sample, 3), representing the point cloud representation of the input geometry.

  • supernode_idx (torch.Tensor) – indexes of the supernodes in the sparse tensor input_pos.

  • batch_idx (torch.Tensor | None) – 1D tensor, containing the batch index of each entry in input_pos. Default None.

Returns:

Tensor with the aggregated messages for each supernode.

Return type:

dict[str, torch.Tensor]