noether.core.callbacks.online

Submodules

Classes

BestMetricCallback

A callback that keeps track of the best metric value over a training run for a certain metric (i.e., source_metric_key) while also logging one or more target metrics.

BestMetricCallbackConfig

Internal base class for all registry-based configs.

TrackAdditionalOutputsCallback

Callback that is invoked during training after every gradient step to track certain outputs from the update step.

TrackAdditionalOutputsCallbackConfig

Internal base class for all registry-based configs.

Package Contents

class noether.core.callbacks.online.BestMetricCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

A callback that keeps track of the best metric value over a training run for a certain metric (i.e., source_metric_key) while also logging one or more target metrics.

For example, track the test loss the epoch with the best validation loss to simulate early stopping.

Example config:

- kind: noether.core.callbacks.BestMetricCallback
  every_n_epochs: 1
  source_metric_key: loss/val/total
  target_metric_keys:
    -  loss/test/total

In this example, whenever a new best validation loss is found, the corresponding test loss is logged under the key loss/test/total/at_best/loss/val/total.

Parameters:
  • callback_config (BestMetricCallbackConfig) – Configuration for the callback. See BestMetricCallbackConfig for available options including source and target metric keys.

  • **kwargs – Additional keyword arguments provided to the parent class.

source_metric_key
target_metric_keys
optional_target_metric_keys
higher_is_better
best_metric_value
previous_log_values: dict[str, Any]
periodic_callback(**__)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counterUpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

class noether.core.callbacks.online.BestMetricCallbackConfig(/, **data)

Bases: noether.core.callbacks.base.CallBackBaseConfig

Internal base class for all registry-based configs.

Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

name: Literal['BestMetricCallback'] = None

The metric to use to dermine whether the current model obtained a new best (e.g., loss/valid/total)

source_metric_key: str = None

The metrics to keep track of (e.g., loss/test/total)

target_metric_keys: list[str] | None = None

The metrics to keep track of if they are present (useful when different model configurations log different evaluation metrics to avoid reconfiguring the callback).

optional_target_metric_keys: list[str] | None = None
class noether.core.callbacks.online.TrackAdditionalOutputsCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback that is invoked during training after every gradient step to track certain outputs from the update step. The update_outputs that are provided in the track_after_accumulation_step method are the additional_outputs field from the TrainerResult returned by the trainer’s update step.

The update_outputs that are provided in the track_after_accumulation_step method are the additional_outputs field from the TrainerResult returned by the trainer’s update step.

The provided update_outputs are assumed to be a dictionary and outputs that match keys or patterns are tracked. An update output matches if either the key matches exactly, e.g. {“some_output”: …} and keys[“some_output”]; or if one of the patterns is contained in the update key name, e.g. {“some_loss”: …} and patterns = [“loss”].

Example config:

- kind: noether.core.callbacks.TrackAdditionalOutputsCallback
  name: TrackAdditionalOutputsCallback
  every_n_updates: 1
  keys:
    - "surface_pressure_loss"
Parameters:
  • callback_config (TrackAdditionalOutputsCallbackConfig) – Configuration for the callback. See TrackAdditionalOutputsCallbackConfig for available options including keys and patterns to track.

  • **kwargs – Additional keyword arguments provided to the parent class.

out: pathlib.Path | None
patterns
keys
verbose
tracked_values: collections.defaultdict[str, list]
reduce
log_output
save_output
track_after_accumulation_step(*, update_counter, update_outputs, **_)

Track the specified outputs after each accumulation step.

Parameters:
  • update_counterUpdateCounter object to track the number of updates.

  • update_outputs – The additional_outputs field from the TrainerResult returned by the trainer’s update step. Note that the base train_step method in the base trainer does not provide any additional outputs by default, and hence this callback can only be used if the train_step is modified to provide additional outputs.

  • **_ – Additional unused keyword arguments.

Return type:

None

periodic_callback(*, update_counter, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

class noether.core.callbacks.online.TrackAdditionalOutputsCallbackConfig(/, **data)

Bases: noether.core.callbacks.base.CallBackBaseConfig

Internal base class for all registry-based configs.

Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

name: Literal['TrackAdditionalOutputsCallback'] = None
keys: list[str] | None = None

List of keys to track in the additional_outputs of the TrainerResult returned by the trainer’s update step.

patterns: list[str] | None = None

List of patterns to track in the additional_outputs of the TrainerResult returned by the trainer’s update step. Matched if it is contained in one of the update_outputs keys.

verbose: bool = None

If True uses the logger to print the tracked values otherwise uses no logger.

reduce: Literal['mean', 'last'] = None

The reduction method to be applied to the tracked values to reduce to scalar. Currently supports ‘mean’ and ‘last’.

log_output: bool = None

Whether to log the tracked scalar values.

save_output: bool = None

Whether to save the tracked scalar values to disk.