noether.training.trainers.weighted_loss¶
Attributes¶
Classes¶
Generic trainer that computes weighted loss per output field. |
Module Contents¶
- noether.training.trainers.weighted_loss.LOSS_REGISTRY: dict[str, collections.abc.Callable[Ellipsis, torch.Tensor]]¶
- class noether.training.trainers.weighted_loss.WeightedLossTrainer(trainer_config, **kwargs)¶
Bases:
noether.training.trainers.base.BaseTrainerGeneric trainer that computes weighted loss per output field.
Expects the model forward to return
dict[str, Tensor]with keys matchingfield_weightskeys, and the batch to contain<field_name>_targetkeys.The loss function defaults to MSE and can be changed via the
loss_fnconfig parameter. Use a built-in short name or a dotted import path for custom losses.Built-in losses:
trainer_params = dict(field_weights={"pressure": 1.0}, loss_fn="l1")
Custom loss function from your project:
trainer_params = dict( field_weights={"pressure": 1.0}, loss_fn="my_project.losses.weighted_huber", )
The custom callable must have the signature
(input, target) -> Tensor, matchingtorch.nn.functionalloss functions.- Parameters:
config – Configuration for the trainer. See
BaseTrainerConfigfor the available options.data_container – The
DataContainerwhich includes the data and dataloader.device – The device to use for training (e.g., “cuda”). It is assumed that the process was configured such that only 1 device is visible (e.g., via the CUDA_VISIBLE_DEVICES environment variable).
main_sampler_kwargs – Kwargs passed to instantiate the main sampler.
tracker – The tracker to use for training.
path_provider – The
PathProviderto use for training.metric_property_provider – The
MetricPropertyProviderto use for training.trainer_config (noether.core.schemas.trainers.WeightedLossTrainerConfig)
- loss_compute(forward_output, targets)¶
Each trainer that extends this class needs to implement a custom loss computation using the targets and the model output.
- Parameters:
forward_output (dict[str, torch.Tensor]) – Output of the model after the forward pass.
targets (dict[str, torch.Tensor]) – Dict with target tensors needed to compute the loss for this trainer.
- Returns:
A dict with the (weighted) sub-losses to log. Or a tuple of (losses, additional_outputs) where additional_outputs is a dict with additional information about the model forward pass that is passed to the track_after_accumulation_step method of the callbacks, e.g., the logits and targets to calculate a training accuracy in a callback).
- Return type:
Note: If a tuple is returned, the second element will be passed as additional_outputs in the TrainerResult returned by the train_step method.