noether.core.callbacks.online

Submodules

Classes

BestMetricCallback

A callback that keeps track of the best metric value over a training run for a certain metric (i.e., source_metric_key) while also logging one or more target metrics.

TrackAdditionalOutputsCallback

Callback that is invoked during training after every gradient step to track certain outputs from the update step.

Package Contents

class noether.core.callbacks.online.BestMetricCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

A callback that keeps track of the best metric value over a training run for a certain metric (i.e., source_metric_key) while also logging one or more target metrics.

For example, track the test loss the epoch with the best validation loss to simulate early stopping.

Example config:

- kind: noether.core.callbacks.BestMetricCallback
  every_n_epochs: 1
  source_metric_key: loss/val/total
  target_metric_keys:
    -  loss/test/total

In this example, whenever a new best validation loss is found, the corresponding test loss is logged under the key loss/test/total/at_best/loss/val/total.

Parameters:
source_metric_key
target_metric_keys
optional_target_metric_keys
higher_is_better
best_metric_value
previous_log_values: dict[str, Any]
periodic_callback(**__)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counterUpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

class noether.core.callbacks.online.TrackAdditionalOutputsCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback that is invoked during training after every gradient step to track certain outputs from the update step. The update_outputs that are provided in the track_after_accumulation_step method are the additional_outputs field from the TrainerResult returned by the trainer’s update step.

The update_outputs that are provided in the track_after_accumulation_step method are the additional_outputs field from the TrainerResult returned by the trainer’s update step.

The provided update_outputs are assumed to be a dictionary and outputs that match keys or patterns are tracked. An update output matches if either the key matches exactly, e.g. {“some_output”: …} and keys[“some_output”]; or if one of the patterns is contained in the update key name, e.g. {“some_loss”: …} and patterns = [“loss”].

Example config:

- kind: noether.core.callbacks.TrackAdditionalOutputsCallback
  name: TrackAdditionalOutputsCallback
  every_n_updates: 1
  keys:
    - "surface_pressure_loss"
Parameters:
out: pathlib.Path | None
patterns
keys
verbose
tracked_values: collections.defaultdict[str, list]
reduce
log_output
save_output
track_after_accumulation_step(*, update_counter, update_outputs, **_)

Track the specified outputs after each accumulation step.

Parameters:
  • update_counterUpdateCounter object to track the number of updates.

  • update_outputs – The additional_outputs field from the TrainerResult returned by the trainer’s update step. Note that the base train_step method in the base trainer does not provide any additional outputs by default, and hence this callback can only be used if the train_step is modified to provide additional outputs.

  • **_ – Additional unused keyword arguments.

Return type:

None

periodic_callback(*, update_counter, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None