noether.core.callbacks.online.track_outputs

Classes

TrackAdditionalOutputsCallback

Callback that is invoked during training after every gradient step to track certain outputs from the update step.

Module Contents

class noether.core.callbacks.online.track_outputs.TrackAdditionalOutputsCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback that is invoked during training after every gradient step to track certain outputs from the update step.

The provided update_outputs are assumed to be a dictionary and outputs that match keys or patterns are tracked. An update output matches if either the key matches exactly, e.g. {“some_output”: …} and keys[“some_output”]; or if one of the patterns is contained in the update key name, e.g. {“some_loss”: …} and patterns = [“loss”].

Initializes the TrackAdditionalOutputsCallback class.

Parameters:
out: pathlib.Path | None
patterns
keys
verbose
tracked_values: collections.defaultdict[str, list]
reduce
log_output
save_output
track_after_accumulation_step(*, update_counter, update_outputs, **_)

Hook called after each individual gradient accumulation step.

This method is invoked for every batch processed during training, regardless of whether an optimizer update is performed in that step (i.e., when accumulation_steps > 1). It is primarily used for tracking metrics that should be averaged or aggregated across accumulation steps.

Common use cases include: * Logging per-batch losses for high-frequency monitoring * Accumulating statistics across batches before an optimizer update * Implementing custom logging that needs access to individual batch data

Note

This method is generally intended to be called within a torch.no_grad() context by the trainer to ensure no gradients are tracked during logging operations.

Parameters:
  • update_counter – UpdateCounter instance to access current training progress.

  • batch – The current data batch processed in this accumulation step.

  • losses – Dictionary of computed losses for the current batch.

  • update_outputs – Optional dictionary of model outputs for the current batch.

  • accumulation_steps – Total number of accumulation steps before an optimizer update.

  • accumulation_step – The current accumulation step index (0-indexed).

Return type:

None

periodic_callback(*, update_counter, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as: * Calculating and logging expensive validation metrics * Saving specific model checkpoints or artifacts * Visualizing training progress (e.g., plotting samples) * Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch or after_update).

Return type:

None