noether.training.callbacks.offline_loss¶
Classes¶
A periodic Callback that is invoked at the end of each epoch to calculate and track the loss and a dataset. |
Module Contents¶
- class noether.training.callbacks.offline_loss.OfflineLossCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicDataIteratorCallbackA periodic Callback that is invoked at the end of each epoch to calculate and track the loss and a dataset.
- Parameters:
callback_config (noether.core.schemas.callbacks.OfflineLossCallbackConfig) – configuration of the OfflineLossCallback. See
OfflineLossCallbackConfigfor the available options.
- dataset_key¶
- output_patterns_to_log¶
- process_data(batch, *, trainer_model, **_)¶
Template method that is called for each batch that is loaded from the dataset.
This method should process a single batch and return results that will be collated.
- Parameters:
batch – The loaded batch.
trainer_model – Model of the current training run.
- Returns:
Processed results for this batch. Can be a tensor, dict of tensors, list, or tuple.
- process_results(results, **_)¶
Template method that is called with the collated results from dataset iteration.
For example, metrics can be computed from the results for the entire test/validation dataset and logged.
- Parameters:
results (tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]) – The collated results that were produced by
_iterate_over_dataset()and the individualprocess_data()calls.interval_type – The type of interval that triggered this callback invocation.
update_counter –
UpdateCounterwith the current training state.**_ – Additional unused keyword arguments.
- Return type:
None