emmi.modules.encoder.supernode_pooling_posonly¶
Classes¶
Supernode pooling layer which operates only on positional information, no additional node features are used. |
Module Contents¶
- class emmi.modules.encoder.supernode_pooling_posonly.SupernodePoolingPosonly(config)¶
Bases:
torch.nn.ModuleSupernode pooling layer which operates only on positional information, no additional node features are used.
The permutation of the supernodes is preserved through the message passing (contrary to the (GP-)UPT code). Additionally, radius is used instead of radius_graph, which is more efficient.
Initialize the SupernodePoolingPosonly.
- Parameters:
radius – Radius around each supernode. From points within this radius, messages are passed to the supernode.
k – Numer of neighbors for each supernode. From the k-NN points, messages are passed to the supernode.
hidden_dim – Hidden dimension for positional embeddings, messages and the resulting output vector.
input_dim – Number of positional dimension which is last tensor dimension (e.g., input_dim=2 for a 2D position, input_dim=3 for a 3D position)
max_degree – Maximum degree of the radius graph. Defaults to 32.
mode – Are positions embedded in absolute space (“abspos”), relative space (“relpos”) or both (“absrelpos”). “readd_supernode_pos” will always use the absolute position.
init_weights – Weight initialization of linear layers. Defaults to “truncnormal002”.
readd_supernode_pos – If true, the absolute positional encoding of the supernode is concated to the supernode vector after message passing and linearly projected back to hidden_dim. Defaults to True.
aggregation – Aggregation for message passing (“mean” or “sum”).
message_mode – How messages created. “mlp” (2 layer MLP), “linear” (nn.Linear), “identity” (nn.Identity). Defaults to “mlp”.
config (emmi.schemas.modules.encoder.SupernodePoolingConfig)
- radius¶
- spool_pos_mode¶
- max_degree¶
- aggregation¶
- readd_supernode_pos¶
- k¶
- pos_embed¶
- output_dim¶
- compute_src_and_dst_indices(input_pos, supernode_idx, batch_idx=None)¶
Compute the source and destination indices for the message passing to the supernodes.
- Parameters:
input_pos (torch.Tensor) – Sparse tensor with shape (batch_size * numner of points, 3), representing the input geometries.
supernode_idx (torch.Tensor) – Indexes of the supernodes in the sparse tensor input_pos.
batch_idx (torch.Tensor | None) – 1D tensor, containing the batch index of each entry in input_pos. Default None.
- Returns:
Tensor with src and destination indexes for the message passing into the supernodes.
- Return type:
tuple[torch.Tensor, torch.Tensor]
- create_messages(input_pos, src_idx, dst_idx, supernode_idx)¶
Create messages for the message passing to the supernodes, based on different positional encoding representations.
- Parameters:
input_pos (torch.Tensor) – Tensor of shape (batch_size * number_of_points_per_sample, {2,3}), representing the point cloud representation of the input geometry.
src_idx (torch.Tensor) – Index of the source nodes from input_pos.
dst_idx (torch.Tensor) – Source index of the destination nodes from input_pos tensor. These indexes should be the matching supernode indexes.
supernode_idx (torch.Tensor) – Indexes of the node in input_pos that are considered supernodes.
- Raises:
NotImplementedError – Raised if the mode is not implemented. Either “abspos”, “relpos” or “absrelpos” are allowed.
- Returns:
- Tensor with messages for the message passing into the super nodes and the embedding coordinates of the
supernodes.
- Return type:
tuple[torch.Tensor, torch.Tensor]
- accumulate_messages(x, dst_idx, supernode_idx, batch_idx=None)¶
Method the accumulate the messages of neighbouring points into the supernodes.
- Parameters:
x (torch.Tensor) – Tensor containing the message representation of each neighbour representation.
dst_idx (torch.Tensor) – Index of the destination (i.e., supernode) where each message should go to.
supernode_idx (torch.Tensor) – Indexes of the supernode in the input point cloud.
batch_idx (torch.Tensor | None) – Batch index of the points in the sparse tensor.
- Returns:
Tensor with the aggregated messages for each supernode.
- Return type:
- forward(input_pos, supernode_idx, batch_idx=None)¶
Forward pass of the supernode pooling layer.
- Parameters:
input_pos (torch.Tensor) – Sparse tensor with shape (batch_size * number_of_points_per_sample, 3), representing the point cloud representation of the input geometry.
supernode_idx (torch.Tensor) – indexes of the supernodes in the sparse tensor input_pos.
batch_idx (torch.Tensor | None) – 1D tensor, containing the batch index of each entry in input_pos. Default None.
- Returns:
Tensor with the aggregated messages for each supernode.
- Return type: