noether.modeling.modules.layers.scalar_conditioner¶
Classes¶
Embeds num_scalars scalars into a single conditioning vector via first encoding every scalar with |
Module Contents¶
- class noether.modeling.modules.layers.scalar_conditioner.ScalarsConditionerConfig(/, **data)¶
Bases:
pydantic.BaseModel- Parameters:
data (Any)
Dimension for embedding the scalars and the per-scalar MLP.
- condition_dim: int | None = None¶
Dimension of the final conditioning vector. Defaults to 4 * dim if condition_dim is None.
- init_weights: noether.core.types.InitWeightsMode = 'truncnormal002'¶
Weight initialization for MLPs.
- class noether.modeling.modules.layers.scalar_conditioner.ScalarsConditioner(config)¶
Bases:
torch.nn.ModuleEmbeds num_scalars scalars into a single conditioning vector via first encoding every scalar with sine-cosine embeddings followed by a mlp (per scalar). These vectors are then concatenated and projected down to condition_dim with an MLP.
- Parameters:
config (ScalarsConditionerConfig) – configuration for the ScalarsConditioner. See
ScalarsConditionerConfigfor available options.
- num_scalars¶
- condition_dim¶
- embed¶
- mlps¶
- forward(*args, **kwargs)¶
Embeds scalars into a single conditioning vector. Scalars can be passed as *args or as **kwargs. It is recommended to use kwargs to avoid bugs that originate from passing scalars in a different order at two locations in the code. Recommended usage: condition = conditioner(geometry_angle=75.3, friction_angle=24.6) :param *args: Scalars in tensor representation (batch_size,) or (batch_size, 1). :param **kwargs: Scalars in tensor representation (batch_size,) or (batch_size, 1).
- Returns:
Conditioning vector with shape (batch_size, condition_dim)
- Parameters:
args (torch.Tensor)
kwargs (torch.Tensor)
- Return type:
Example: .. code-block:: python
- conditioner = ScalarsConditioner(
- ScalarsConditionerConfig(
hidden_dim=64, num_scalars=2, condition_dim=128, init_weights=”truncnormal002”,
)
) geometry_angle = torch.tensor([75.3, 80.1]) # shape (batch_size,) friction_angle = torch.tensor([24.6, 30.2]) # shape (batch_size,) condition = conditioner(
geometry_angle=geometry_angle, friction_angle=friction_angle
) # shape (batch_size, condition_dim)