noether.modeling.modules.encoders.supernode_pooling

Classes

SupernodePoolingConfig

SupernodePooling

Supernode pooling layer.

Module Contents

class noether.modeling.modules.encoders.supernode_pooling.SupernodePoolingConfig(/, **data)

Bases: pydantic.BaseModel

Parameters:

data (Any)

hidden_dim: int = None

Hidden dimension for positional embeddings, messages and the resulting output vector.

input_dim: int = None

Number of positional dimension (e.g., input_dim=2 for a 2D position, input_dim=3 for a 3D position)

radius: float | None = None

Radius around each supernode. From points within this radius, messages are passed to the supernode.

k: int | None = None

Number of neighbors for each supernode. From the k-NN points, messages are passed to the supernode.

max_degree: int = None

Maximum degree of the radius graph. Defaults to 32.

spool_pos_mode: Literal['abspos', 'relpos', 'absrelpos'] = None

absolute space (“abspos”), relative space (“relpos”) or both (“absrelpos”).

Type:

Type of position embedding

init_weights: noether.core.types.InitWeightsMode = None

Weight initialization of linear layers. Defaults to “truncnormal002”.

readd_supernode_pos: bool = None

If true, the absolute positional encoding of the supernode is concatenated to the supernode vector after message passing and linearly projected back to hidden_dim. Defaults to True.

aggregation: Literal['mean', 'sum'] = None

Aggregation for message passing (“mean” or “sum”).

message_mode: Literal['mlp', 'linear', 'identity'] = None

How messages are created. “mlp” (2 layer MLP), “linear” (nn.Linear), “identity” (nn.Identity). Defaults to “mlp”.

input_features_dim: int | None = None

Number of input features per point. None will fall back to a version without features. Defaults to None, which means no input features.

bias: bool = None

Whether to use bias in the linear layers. Defaults to True.

validate_radius_and_k()
class noether.modeling.modules.encoders.supernode_pooling.SupernodePooling(config)

Bases: torch.nn.Module

Supernode pooling layer.

The permutation of the supernodes is preserved through the message passing (contrary to the (GP-)UPT code). Additionally, radius is used instead of radius_graph, which is more efficient.

Initialize the SupernodePooling.

Parameters:

config (SupernodePoolingConfig) – Configuration for the SupernodePooling module. See SupernodePoolingConfig for available options.

radius
k
max_degree
spool_pos_mode
readd_supernode_pos
aggregation
input_features_dim
pos_embed
output_dim
compute_src_and_dst_indices(input_pos, supernode_idx, batch_idx=None)

Compute the source and destination indices for the message passing to the supernodes.

Parameters:
  • input_pos (torch.Tensor) – Sparse tensor with shape (batch_size * number of points, 3), representing the input geometries.

  • supernode_idx (torch.Tensor) – Indexes of the supernodes in the sparse tensor input_pos.

  • batch_idx (torch.Tensor | None) – 1D tensor, containing the batch index of each entry in input_pos. Default None.

Returns:

Tuple of (src_idx, dst_idx, local_dst_idx) where src_idx and dst_idx are absolute indices into input_pos and local_dst_idx is a 0-indexed position into supernode_idx (used for scatter_reduce_).

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

create_messages(input_pos, src_idx, dst_idx, supernode_idx, input_features=None)

Create messages for the message passing to the supernodes, based on different positional encoding representations.

Parameters:
  • input_pos (torch.Tensor) – Tensor of shape (batch_size * number_of_points_per_sample, {2,3}), representing the point cloud representation of the input geometry.

  • src_idx (torch.Tensor) – Index of the source nodes from input_pos.

  • dst_idx (torch.Tensor) – Source index of the destination nodes from input_pos tensor. These indexes should be the matching supernode indexes.

  • supernode_idx (torch.Tensor) – Indexes of the node in input_pos that are considered supernodes.

  • input_features (torch.Tensor | None)

Raises:

NotImplementedError – Raised if the mode is not implemented. Either “abspos”, “relpos” or “absrelpos” are allowed.

Returns:

Tensor with messages for the message passing into the super nodes and the embedding coordinates of the

supernodes.

Return type:

tuple[torch.Tensor, torch.Tensor]

accumulate_messages(x, local_dst_idx, supernode_idx)

Method to accumulate the messages of neighbouring points into the supernodes.

Parameters:
  • x (torch.Tensor) – Tensor containing the message representation of each neighbour representation.

  • local_dst_idx (torch.Tensor) – 0-indexed position into supernode_idx for each message (no CUDA sync).

  • supernode_idx (torch.Tensor) – Indexes of the supernode in the input point cloud.

Returns:

Tensor with the aggregated messages for each supernode.

Return type:

torch.Tensor

forward(input_pos, supernode_idx, batch_idx=None, input_features=None)

Forward pass of the supernode pooling layer.

Parameters:
  • input_pos (torch.Tensor) – Sparse tensor with shape (batch_size * number_of_points_per_sample, 3), representing the point cloud representation of the input geometry.

  • supernode_idx (torch.Tensor) – indexes of the supernodes in the sparse tensor input_pos.

  • batch_idx (torch.Tensor | None) – 1D tensor, containing the batch index of each entry in input_pos. Default None.

  • input_features (torch.Tensor | None) – Sparse tensor with shape (batch_size * number_of_points_per_sample, number_of_features)

Returns:

Tensor with the aggregated messages for each supernode.

Return type:

torch.Tensor | dict[str, torch.Tensor]