noether.modeling.modules.layers.scalar_conditioner

Classes

ScalarsConditionerConfig

ScalarsConditioner

Embeds num_scalars scalars into a single conditioning vector via first encoding every scalar with

Module Contents

class noether.modeling.modules.layers.scalar_conditioner.ScalarsConditionerConfig(/, **data)

Bases: pydantic.BaseModel

Parameters:

data (Any)

hidden_dim: int = None

Dimension for embedding the scalars and the per-scalar MLP.

num_scalars: int = None

How many scalars are embedded.

condition_dim: int | None = None

Dimension of the final conditioning vector. Defaults to 4 * dim if condition_dim is None.

init_weights: noether.core.types.InitWeightsMode = 'truncnormal002'

Weight initialization for MLPs.

class noether.modeling.modules.layers.scalar_conditioner.ScalarsConditioner(config)

Bases: torch.nn.Module

Embeds num_scalars scalars into a single conditioning vector via first encoding every scalar with sine-cosine embeddings followed by a mlp (per scalar). These vectors are then concatenated and projected down to condition_dim with an MLP.

Parameters:

config (ScalarsConditionerConfig) – configuration for the ScalarsConditioner. See ScalarsConditionerConfig for available options.

hidden_dim
num_scalars
condition_dim
embed
mlps
shared_mlp
forward(*args, **kwargs)

Embeds scalars into a single conditioning vector. Scalars can be passed as *args or as **kwargs. It is recommended to use kwargs to avoid bugs that originate from passing scalars in a different order at two locations in the code. Recommended usage: condition = conditioner(geometry_angle=75.3, friction_angle=24.6) :param *args: Scalars in tensor representation (batch_size,) or (batch_size, 1). :param **kwargs: Scalars in tensor representation (batch_size,) or (batch_size, 1).

Returns:

Conditioning vector with shape (batch_size, condition_dim)

Parameters:
Return type:

torch.Tensor

Example: .. code-block:: python

conditioner = ScalarsConditioner(
ScalarsConditionerConfig(

hidden_dim=64, num_scalars=2, condition_dim=128, init_weights=”truncnormal002”,

)

) geometry_angle = torch.tensor([75.3, 80.1]) # shape (batch_size,) friction_angle = torch.tensor([24.6, 30.2]) # shape (batch_size,) condition = conditioner(

geometry_angle=geometry_angle, friction_angle=friction_angle

) # shape (batch_size, condition_dim)