noether.training.callbacks.offline_loss

Classes

OfflineLossCallback

A periodic Callback that is invoked at the end of each epoch to calculate and track the loss and a dataset.

Module Contents

class noether.training.callbacks.offline_loss.OfflineLossCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicDataIteratorCallback

A periodic Callback that is invoked at the end of each epoch to calculate and track the loss and a dataset.

Parameters:

callback_config (noether.core.schemas.callbacks.OfflineLossCallbackConfig) – configuration of the OfflineLossCallback. See OfflineLossCallbackConfig for the available options.

dataset_key
output_patterns_to_log
process_data(batch, *, trainer_model, **_)

Template method that is called for each batch that is loaded from the dataset.

This method should process a single batch and return results that will be collated.

Parameters:
  • batch – The loaded batch.

  • trainer_model – Model of the current training run.

Returns:

Processed results for this batch. Can be a tensor, dict of tensors, list, or tuple.

process_results(results, **_)

Template method that is called with the collated results from dataset iteration.

For example, metrics can be computed from the results for the entire test/validation dataset and logged.

Parameters:
  • results (tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]) – The collated results that were produced by _iterate_over_dataset() and the individual process_data() calls.

  • interval_type – The type of interval that triggered this callback invocation.

  • update_counterUpdateCounter with the current training state.

  • **_ – Additional unused keyword arguments.

Return type:

None