noether.training.trainers.weighted_loss

Attributes

Classes

WeightedLossTrainerConfig

Config for a generic trainer that computes weighted loss per output field.

WeightedLossTrainer

Generic trainer that computes weighted loss per output field.

Module Contents

noether.training.trainers.weighted_loss.LOSS_REGISTRY: dict[str, collections.abc.Callable[Ellipsis, torch.Tensor]]
class noether.training.trainers.weighted_loss.WeightedLossTrainerConfig(/, **data)

Bases: noether.training.trainers.BaseTrainerConfig

Config for a generic trainer that computes weighted loss per output field.

field_weights maps output field names to their loss weights. Keys must match model output dict keys. Target keys in the batch are expected to follow the <field_name>_target convention.

Built-in loss example:

WeightedLossTrainerConfig(
    kind="noether.training.trainers.WeightedLossTrainer",
    field_weights={"surface_pressure": 1.0, "volume_velocity": 1.0},
    loss_fn="l1",
)

Custom loss function from a downstream project:

WeightedLossTrainerConfig(
    kind="noether.training.trainers.WeightedLossTrainer",
    field_weights={"surface_pressure": 1.0},
    loss_fn="my_project.losses.weighted_huber",
)

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

field_weights: dict[str, float] = None
loss_fn: str = None
class noether.training.trainers.weighted_loss.WeightedLossTrainer(trainer_config, **kwargs)

Bases: noether.training.trainers.BaseTrainer

Generic trainer that computes weighted loss per output field.

Expects the model forward to return dict[str, Tensor] with keys matching field_weights keys, and the batch to contain <field_name>_target keys.

The loss function defaults to MSE and can be changed via the loss_fn config parameter. Use a built-in short name or a dotted import path for custom losses.

Built-in losses:

trainer_params = dict(field_weights={"pressure": 1.0}, loss_fn="l1")

Custom loss function from your project:

trainer_params = dict(
    field_weights={"pressure": 1.0},
    loss_fn="my_project.losses.weighted_huber",
)

The custom callable must have the signature (input, target) -> Tensor, matching torch.nn.functional loss functions.

Parameters:
  • config – Configuration for the trainer. See BaseTrainerConfig for the available options.

  • data_container – The DataContainer which includes the data and dataloader.

  • device – The device to use for training (e.g., “cuda”). It is assumed that the process was configured such that only 1 device is visible (e.g., via the CUDA_VISIBLE_DEVICES environment variable).

  • main_sampler_kwargs – Kwargs passed to instantiate the main sampler.

  • tracker – The tracker to use for training.

  • path_provider – The PathProvider to use for training.

  • metric_property_provider – The MetricPropertyProvider to use for training.

  • trainer_config (WeightedLossTrainerConfig)

loss_items: list[tuple[str, float]] = []
loss_compute(forward_output, targets)

Each trainer that extends this class needs to implement a custom loss computation using the targets and the model output.

Parameters:
  • forward_output (dict[str, torch.Tensor]) – Output of the model after the forward pass.

  • targets (dict[str, torch.Tensor]) – Dict with target tensors needed to compute the loss for this trainer.

Returns:

A dict with the (weighted) sub-losses to log. Or a tuple of (losses, additional_outputs) where additional_outputs is a dict with additional information about the model forward pass that is passed to the track_after_accumulation_step method of the callbacks, e.g., the logits and targets to calculate a training accuracy in a callback).

Return type:

dict[str, torch.Tensor]

Note: If a tuple is returned, the second element will be passed as additional_outputs in the TrainerResult returned by the train_step method.