noether.training.trainers.weighted_loss

Attributes

Classes

WeightedLossTrainer

Generic trainer that computes weighted loss per output field.

Module Contents

noether.training.trainers.weighted_loss.LOSS_REGISTRY: dict[str, collections.abc.Callable[Ellipsis, torch.Tensor]]
class noether.training.trainers.weighted_loss.WeightedLossTrainer(trainer_config, **kwargs)

Bases: noether.training.trainers.base.BaseTrainer

Generic trainer that computes weighted loss per output field.

Expects the model forward to return dict[str, Tensor] with keys matching field_weights keys, and the batch to contain <field_name>_target keys.

The loss function defaults to MSE and can be changed via the loss_fn config parameter. Use a built-in short name or a dotted import path for custom losses.

Built-in losses:

trainer_params = dict(field_weights={"pressure": 1.0}, loss_fn="l1")

Custom loss function from your project:

trainer_params = dict(
    field_weights={"pressure": 1.0},
    loss_fn="my_project.losses.weighted_huber",
)

The custom callable must have the signature (input, target) -> Tensor, matching torch.nn.functional loss functions.

Parameters:
  • config – Configuration for the trainer. See BaseTrainerConfig for the available options.

  • data_container – The DataContainer which includes the data and dataloader.

  • device – The device to use for training (e.g., “cuda”). It is assumed that the process was configured such that only 1 device is visible (e.g., via the CUDA_VISIBLE_DEVICES environment variable).

  • main_sampler_kwargs – Kwargs passed to instantiate the main sampler.

  • tracker – The tracker to use for training.

  • path_provider – The PathProvider to use for training.

  • metric_property_provider – The MetricPropertyProvider to use for training.

  • trainer_config (noether.core.schemas.trainers.WeightedLossTrainerConfig)

loss_items: list[tuple[str, float]] = []
loss_compute(forward_output, targets)

Each trainer that extends this class needs to implement a custom loss computation using the targets and the model output.

Parameters:
  • forward_output (dict[str, torch.Tensor]) – Output of the model after the forward pass.

  • targets (dict[str, torch.Tensor]) – Dict with target tensors needed to compute the loss for this trainer.

Returns:

A dict with the (weighted) sub-losses to log. Or a tuple of (losses, additional_outputs) where additional_outputs is a dict with additional information about the model forward pass that is passed to the track_after_accumulation_step method of the callbacks, e.g., the logits and targets to calculate a training accuracy in a callback).

Return type:

dict[str, torch.Tensor]

Note: If a tuple is returned, the second element will be passed as additional_outputs in the TrainerResult returned by the train_step method.