noether.training.callbacks.offline_loss

Classes

OfflineLossCallbackConfig

Internal base class for all registry-based configs.

OfflineLossCallback

A periodic Callback that is invoked at the end of each epoch to calculate and track the loss and a dataset.

Module Contents

class noether.training.callbacks.offline_loss.OfflineLossCallbackConfig(/, **data)

Bases: noether.core.callbacks.periodic.PeriodicDataIteratorCallbackConfig

Internal base class for all registry-based configs.

Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

name: Literal['OfflineLossCallback'] = None
output_patterns_to_log: list[str] | None = None

additional arguments passed to the parent class.

Type:

For instance, if the output key is ‘some_loss’ and the pattern is [‘loss’]. **kwargs

class noether.training.callbacks.offline_loss.OfflineLossCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicDataIteratorCallback

A periodic Callback that is invoked at the end of each epoch to calculate and track the loss and a dataset.

Parameters:

callback_config (OfflineLossCallbackConfig) – configuration of the OfflineLossCallback. See OfflineLossCallbackConfig for the available options.

dataset_key
output_patterns_to_log
process_data(batch, *, trainer_model, **_)

Template method that is called for each batch that is loaded from the dataset.

This method should process a single batch and return results that will be collated.

Parameters:
  • batch – The loaded batch.

  • trainer_model – Model of the current training run.

Returns:

Processed results for this batch. Can be a tensor, dict of tensors, list, or tuple.

process_results(results, **_)

Template method that is called with the collated results from dataset iteration.

For example, metrics can be computed from the results for the entire test/validation dataset and logged.

Parameters:
  • results (tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]) – The collated results that were produced by _iterate_over_dataset() and the individual process_data() calls.

  • interval_type – The type of interval that triggered this callback invocation.

  • update_counterUpdateCounter with the current training state.

  • **_ – Additional unused keyword arguments.

Return type:

None