noether.core.callbacks

Submodules

Exceptions

EarlyStopIteration

Custom StopIteration exception for Early Stoppers.

Classes

CallbackBase

Base class for callbacks that execute something before/after training.

BestCheckpointCallback

Callback to save the best model based on a metric.

CheckpointCallback

Callback to save the model and optimizer state periodically.

EmaCallback

Callback for exponential moving average (EMA) of model weights.

DatasetStatsCallback

A callback that logs the length of each dataset in the data container. Is initialized by the BaseTrainer and should not be added manually to the trainer's callbacks.

EtaCallback

Callback to print the progress and estimated duration until the periodic callback will be invoked.

LrCallback

Callback to log the learning rate of the optimizer.

OnlineLossCallback

Callback to track the loss of the model after every gradient accumulation step and log the average loss.

ParamCountCallback

Callback to log the number of trainable and frozen parameters of the model.

PeakMemoryCallback

Callback to log the peak memory usage of the model. Is initialized by the BaseTrainer and should not be added manually to the trainer's callbacks.

ProgressCallback

Callback to print the progress of the training such as number of epochs and updates.

TrainTimeCallback

Callback to log the time spent on dataloading. Is initialized by the BaseTrainer and should not be added manually to the trainer's callbacks.

EarlyStopperBase

Base class for early stoppers that is used to define the interface for early stoppers used by the trainers.

FixedEarlyStopper

Early stopper (training) based on a fixed number of epochs, updates, or samples.

MetricEarlyStopper

Early stopper (training) based on a metric value to be monitored.

BestMetricCallback

A callback that keeps track of the best metric value over a training run for a certain metric (i.e., source_metric_key) while also logging one or more target metrics.

TrackAdditionalOutputsCallback

Callback that is invoked during training after every gradient step to track certain outputs from the update step.

PeriodicCallback

Base class for callbacks that are invoked periodically during training.

PeriodicDataIteratorCallback

Base class for callbacks that perform periodic iterations over a dataset.

Package Contents

class noether.core.callbacks.CallbackBase(trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Base class for callbacks that execute something before/after training.

Allows overwriting before_training and after_training.

If the callback is stateful (i.e., it tracks something across the training process that needs to be loaded if the run is resumed), there are two ways to implement loading the callback state:

  • state_dict: write current state into a state dict. When the trainer saves the current checkpoint to the disk, it will also store the state_dict of all callbacks within the trainer state_dict. Once a run is resumed, a callback can load its state from the previously stored state_dict by overwriting the load_state_dict.

  • resume_from_checkpoint: If a callback is storing large files onto the disk, it would be redudant to also store them within its state_dict. Therefore, this method is called on resume to allow callbacks to load their state from files on the disk.

Callbacks have access to a LogWriter, with which callbacks can log metrics. The LogWriter is a singleton.

Examples

# THIS IS INSIDE A CUSTOM CALLBACK

# log only to experiment tracker, not stdout
self.writer.add_scalar(key="classification_accuracy", value=0.2)
# log to experiment tracker and stdout (as "0.24")
self.writer.add_scalar(
    key="classification_accuracy",
    value=0.23623,
    logger=self.logger,
    format_str=".2f",
)

Note

As evaluations are pretty much always done in torch.no_grad() contexts, the hooks implemented by callbacks are always executed within a torch.no_grad() context.

Parameters:
trainer: noether.training.trainers.BaseTrainer

Trainer of the current run. Can be used to access training state.

model: noether.core.models.ModelBase

Model of the current run. Can be used to access model parameters.

data_container: noether.data.container.DataContainer

Data container of the current run. Can be used to access all datasets.

tracker: noether.core.trackers.BaseTracker

Tracker of the current run. Can be used for direct access to the experiment tracking platform.

writer: noether.core.writers.LogWriter

Log writer of the current run. Can be used to log metrics to stdout/disk/online platform.

metric_property_provider: noether.core.providers.metric_property.MetricPropertyProvider

Metric property provider of the current run. Defines properties of metrics (e.g., whether higher values are better).

checkpoint_writer: noether.core.writers.CheckpointWriter

Checkpoint writer of the current run. Can be used to store checkpoints during training.

name = None
state_dict()

If a callback is stateful, the state will be stored when a checkpoint is stored to the disk.

Returns:

State of the callback. By default, callbacks are non-stateful and return None.

Return type:

dict[str, torch.Tensor] | None

load_state_dict(state_dict)

If a callback is stateful, the state will be stored when a checkpoint is stored to the disk and can be loaded with this method upon resuming a run.

Parameters:

state_dict (dict[str, Any]) – State to be loaded. By default, callbacks are non-stateful and load_state_dict does nothing.

Return type:

None

resume_from_checkpoint(resumption_paths, model)

If a callback stores large files to disk and is stateful (e.g., an EMA of the model), it would be unnecessarily wasteful to also store the state in the callbacks state_dict. Therefore, resume_from_checkpoint is called when resuming a run, which allows callbacks to load their state from any file that was stored on the disk.

Parameters:
Return type:

None

property logger: logging.Logger

Logger for logging to stdout.

Return type:

logging.Logger

before_training(*, update_counter)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None

after_training(*, update_counter)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None

class noether.core.callbacks.BestCheckpointCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to save the best model based on a metric.

This callback monitors a specified metric and saves the model checkpoint whenever a new best value is achieved. It supports storing different model components when using a composite model and can save checkpoints at different tolerance thresholds.

Example config:

callbacks:
  - kind: noether.core.callbacks.BestCheckpointCallback
    name: BestCheckpointCallback
    every_n_epochs: 1
    metric_key: loss/val/total
    model_names:  # only applies when training a CompositeModel
      - encoder
Parameters:
metric_key
model_names
higher_is_better
best_metric_value
save_frozen_weights
tolerances_is_exceeded
tolerance_counter = 0
metric_at_exceeded_tolerance: dict[float, float]
state_dict()

Return the state of the callback for checkpointing.

Returns:

Dictionary containing the best metric value, tolerance tracking state, and counter information.

Return type:

dict[str, Any]

load_state_dict(state_dict)

Load the callback state from a checkpoint.

Note

This modifies the input state_dict in place.

Parameters:

state_dict (dict[str, Any]) – Dictionary containing the saved callback state.

Return type:

None

before_training(*, update_counter)

Validate callback configuration before training starts.

Parameters:

update_counter – The training update counter.

Raises:

NotImplementedError – If resuming training with tolerances is attempted.

Return type:

None

periodic_callback(**_)

Execute the periodic callback to check and save best model.

This method is called at the configured frequency (e.g., every N epochs). It checks if the current metric value is better than the previous best, and if so, saves the model checkpoint. Also tracks tolerance-based checkpoints.

Raises:

KeyError – If the log cache is empty or the metric key is not found.

Return type:

None

after_training(**kwargs)

Log the best metric values at different tolerance thresholds after training completes.

Parameters:

**kwargs – Additional keyword arguments (unused).

Return type:

None

class noether.core.callbacks.CheckpointCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to save the model and optimizer state periodically.

Example config:

- kind: noether.core.callbacks.CheckpointCallback
  name: CheckpointCallback
  every_n_epochs: 1
  save_weights: true
  save_optim: true
Parameters:
save_weights
save_optim
save_latest_weights
save_latest_optim
model_names
before_training(*, update_counter)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None

periodic_callback(*, interval_type, update_counter, **kwargs)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

after_training(**_)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

class noether.core.callbacks.EmaCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback for exponential moving average (EMA) of model weights.

Example config:

- kind: noether.core.callbacks.EmaCallback
  every_n_epochs: 10
  save_weights: false
  save_last_weights: false
  save_latest_weights: true
  target_factors:
    - 0.9999
  name: EmaCallback
Parameters:
model_paths
target_factors
save_weights
save_last_weights
save_latest_weights
parameters: dict[tuple[str | None, float], dict[str, torch.Tensor]]
buffers: dict[str | None, dict[str, torch.Tensor]]
resume_from_checkpoint(resumption_paths, model)

Resume EMA state from a checkpoint.

Parameters:
Return type:

None

before_training(**_)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

apply_ema(cur_model, model_path, target_factor)

fused in-place implementation

track_after_update_step(**_)

Hook called after each optimizer update step.

This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:

  • Latest loss values

  • Learning rates

  • Model parameter statistics (norms, etc.)

  • Training throughput and timing measurements

Unlike periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • update_counterUpdateCounter instance to access current training progress.

  • times – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).

Return type:

None

periodic_callback(*, interval_type, update_counter, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counterUpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

after_training(**_)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

class noether.core.callbacks.DatasetStatsCallback(trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: noether.core.callbacks.base.CallbackBase

A callback that logs the length of each dataset in the data container. Is initialized by the BaseTrainer and should not be added manually to the trainer’s callbacks.

Parameters:
before_training(**_)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

class noether.core.callbacks.EtaCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to print the progress and estimated duration until the periodic callback will be invoked.

Also counts up the current epoch/update/samples and provides the average update duration. Only used in “unmanaged” runs, i.e., it is not used when the run was started via SLURM.

This callback is initialized by the BaseTrainer and should not be added manually to the trainer’s callbacks.

Parameters:
class LoggerWasCalledHandler

Bases: logging.Handler

Handler instances dispatch logging events to specific destinations.

The base handler class. Acts as a placeholder which defines the Handler interface. Handlers can optionally use Formatter instances to format records as desired. By default, no formatter is specified; in this case, the ‘raw’ message as determined by record.message is logged.

Initializes the instance - basically setting the formatter to None and the filter list to empty.

was_called = False
emit(_)

Do whatever it takes to actually log the specified logging record.

This version is intended to be implemented by subclasses and so raises a NotImplementedError.

total_time = 0.0
time_since_last_log = 0.0
handler
before_training(*, update_counter)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None

track_after_update_step(*, update_counter, times)

Hook called after each optimizer update step.

This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:

  • Latest loss values

  • Learning rates

  • Model parameter statistics (norms, etc.)

  • Training throughput and timing measurements

Unlike periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance to access current training progress.

  • times – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).

Return type:

None

periodic_callback(*, interval_type, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counterUpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

after_training(**_)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

class noether.core.callbacks.LrCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to log the learning rate of the optimizer.

This callback is initialized by the BaseTrainer and should not be added manually to the trainer’s callbacks.

Parameters:
periodic_callback(**_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counterUpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

class noether.core.callbacks.OnlineLossCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to track the loss of the model after every gradient accumulation step and log the average loss.

This callback is initialized by the BaseTrainer and should not be added manually to the trainer’s callbacks.

Initialize the OnlineLossCallback.

Parameters:
verbose
tracked_losses: collections.defaultdict[str, list[torch.Tensor]]
track_after_accumulation_step(*, losses, **_)

Hook called after each individual gradient accumulation step.

This method is invoked for every batch processed during training, regardless of whether an optimizer update is performed in that step (i.e., when accumulation_steps > 1). It is primarily used for tracking metrics that should be averaged or aggregated across accumulation steps.

Common use cases include:

  • Logging per-batch losses for high-frequency monitoring

  • Accumulating statistics across batches before an optimizer update

  • Implementing custom logging that needs access to individual batch data

Note

This method is generally intended to be called within a torch.no_grad() context by the trainer to ensure no gradients are tracked during logging operations.

Parameters:
  • update_counterUpdateCounter instance to access current training progress.

  • batch – The current data batch processed in this accumulation step.

  • losses – Dictionary of computed losses for the current batch.

  • update_outputs – Optional dictionary of model outputs for the current batch.

  • accumulation_steps – Total number of accumulation steps before an optimizer update.

  • accumulation_step – The current accumulation step index (0-indexed).

Return type:

None

periodic_callback(*, interval_type, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counterUpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

class noether.core.callbacks.ParamCountCallback(trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: noether.core.callbacks.base.CallbackBase

Callback to log the number of trainable and frozen parameters of the model.

This callback is initialized by the BaseTrainer and should not be added manually to the trainer’s callbacks.

Parameters:
before_training(**_)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

class noether.core.callbacks.PeakMemoryCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to log the peak memory usage of the model. Is initialized by the BaseTrainer and should not be added manually to the trainer’s callbacks.

Parameters:
periodic_callback(**__)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counterUpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

class noether.core.callbacks.ProgressCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to print the progress of the training such as number of epochs and updates.

This callback is initialized by the BaseTrainer and should not be added manually to the trainer’s callbacks.

Parameters:
before_training(**_)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

periodic_callback(*, interval_type, update_counter, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

track_after_update_step(*, update_counter, **_)

Hook called after each optimizer update step.

This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:

  • Latest loss values

  • Learning rates

  • Model parameter statistics (norms, etc.)

  • Training throughput and timing measurements

Unlike periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance to access current training progress.

  • times – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).

Return type:

None

class noether.core.callbacks.TrainTimeCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to log the time spent on dataloading. Is initialized by the BaseTrainer and should not be added manually to the trainer’s callbacks.

Parameters:
train_times: dict[str, list[float]]
total_train_times: dict[str, torch.Tensor]
track_after_update_step(*, times, **_)

Hook called after each optimizer update step.

This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:

  • Latest loss values

  • Learning rates

  • Model parameter statistics (norms, etc.)

  • Training throughput and timing measurements

Unlike periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • update_counterUpdateCounter instance to access current training progress.

  • times (dict[str, float]) – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).

Return type:

None

periodic_callback(**_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counterUpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

after_training(**_)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

exception noether.core.callbacks.EarlyStopIteration

Bases: StopIteration

Custom StopIteration exception for Early Stoppers.

Initialize self. See help(type(self)) for accurate signature.

class noether.core.callbacks.EarlyStopperBase(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Base class for early stoppers that is used to define the interface for early stoppers used by the trainers.

Parameters:
to_short_interval_string()

Convert the interval to a short string representation used for logging.

Return type:

str

periodic_callback(*, interval_type, update_counter, **kwargs)

Check if training should stop and raise exception if needed.

Parameters:
  • interval_type (noether.core.callbacks.periodic.IntervalType) – Type of interval that triggered this callback.

  • update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance with current training state.

  • **kwargs – Additional keyword arguments.

Raises:

EarlyStopIteration – If training should be stopped based on the stopping criterion.

Return type:

None

class noether.core.callbacks.FixedEarlyStopper(callback_config, **kwargs)

Bases: noether.core.callbacks.early_stoppers.base.EarlyStopperBase

Early stopper (training) based on a fixed number of epochs, updates, or samples.

Example config:

- kind: noether.core.callbacks.FixedEarlyStopper
  stop_at_epoch: 10
  name: FixedEarlyStopper
Parameters:
stop_at_sample
stop_at_update
stop_at_epoch
class noether.core.callbacks.MetricEarlyStopper(callback_config, **kwargs)

Bases: noether.core.callbacks.early_stoppers.base.EarlyStopperBase

Early stopper (training) based on a metric value to be monitored.

Example config:

- kind: noether.core.callbacks.MetricEarlyStopper
  every_n_epochs: 1
  metric_key: loss/val/total
  tolerance: 0.10
  name: MetricEarlyStopper
Parameters:
metric_key
higher_is_better
tolerance
tolerance_counter = 0
best_metric
class noether.core.callbacks.BestMetricCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

A callback that keeps track of the best metric value over a training run for a certain metric (i.e., source_metric_key) while also logging one or more target metrics.

For example, track the test loss the epoch with the best validation loss to simulate early stopping.

Example config:

- kind: noether.core.callbacks.BestMetricCallback
  every_n_epochs: 1
  source_metric_key: loss/val/total
  target_metric_keys:
    -  loss/test/total

In this example, whenever a new best validation loss is found, the corresponding test loss is logged under the key loss/test/total/at_best/loss/val/total.

Parameters:
source_metric_key
target_metric_keys
optional_target_metric_keys
higher_is_better
best_metric_value
previous_log_values: dict[str, Any]
periodic_callback(**__)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counterUpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

class noether.core.callbacks.TrackAdditionalOutputsCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback that is invoked during training after every gradient step to track certain outputs from the update step. The update_outputs that are provided in the track_after_accumulation_step method are the additional_outputs field from the TrainerResult returned by the trainer’s update step.

The update_outputs that are provided in the track_after_accumulation_step method are the additional_outputs field from the TrainerResult returned by the trainer’s update step.

The provided update_outputs are assumed to be a dictionary and outputs that match keys or patterns are tracked. An update output matches if either the key matches exactly, e.g. {“some_output”: …} and keys[“some_output”]; or if one of the patterns is contained in the update key name, e.g. {“some_loss”: …} and patterns = [“loss”].

Example config:

- kind: noether.core.callbacks.TrackAdditionalOutputsCallback
  name: TrackAdditionalOutputsCallback
  every_n_updates: 1
  keys:
    - "surface_pressure_loss"
Parameters:
out: pathlib.Path | None
patterns
keys
verbose
tracked_values: collections.defaultdict[str, list]
reduce
log_output
save_output
track_after_accumulation_step(*, update_counter, update_outputs, **_)

Track the specified outputs after each accumulation step.

Parameters:
  • update_counterUpdateCounter object to track the number of updates.

  • update_outputs – The additional_outputs field from the TrainerResult returned by the trainer’s update step. Note that the base train_step method in the base trainer does not provide any additional outputs by default, and hence this callback can only be used if the train_step is modified to provide additional outputs.

  • **_ – Additional unused keyword arguments.

Return type:

None

periodic_callback(*, update_counter, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

class noether.core.callbacks.PeriodicCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: noether.core.callbacks.base.CallbackBase

Base class for callbacks that are invoked periodically during training.

PeriodicCallback extends CallbackBase to support periodic execution based on training progress. Callbacks can be configured to run at regular intervals defined by epochs, updates (optimizer steps), or samples (data points processed). This class implements the infrastructure for periodic invocation while child classes define the actual behavior via the periodic_callback() method.

Interval Configuration:

Callbacks can be configured to run periodically using one or more of:

  • every_n_epochs: Execute callback every N epochs

  • every_n_updates: Execute callback every N optimizer updates

  • every_n_samples: Execute callback every N samples processed

Tracking vs. Periodic Execution:

The class provides two types of hooks:

  • Tracking hooks (track_after_accumulation_step(), track_after_update_step()): Called on every accumulation/update step to track metrics continuously (e.g., for running averages). I.e., if you want to log an exponential moving average of the loss every epoch, the logging is done in the periodic callback; however, the tracking of the loss values for computing the moving average is done in the tracking hook.

  • Periodic hook (periodic_callback()): Called only when the configured interval is reached, typically for expensive operations like evaluation or checkpointing.

Examples

Creating a custom periodic callback that logs metrics every 10 epochs:

class CustomMetricCallback(PeriodicCallback):
    def periodic_callback(
        self,
        *,
        interval_type: IntervalType,
        update_counter: UpdateCounter,
        **kwargs,
    ) -> None:
        # This method is called every 10 epochs
        metric_value = self.compute_expensive_metric()
        self.writer.add_scalar(
            key="custom_metric",
            value=metric_value,
            logger=self.logger,
        )


# Configure in YAML:
# callbacks:
#   - kind: path.to.CustomMetricCallback
#     every_n_epochs: 10

Tracking metrics at every update and logging periodically:

class RunningAverageCallback(PeriodicCallback):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_accumulator = []

    def track_after_update_step(self, *, update_counter: UpdateCounter, times: dict[str, float]) -> None:
        # Track at every update
        self.loss_accumulator.append(self.trainer.last_loss)

    def periodic_callback(
        self,
        *,
        interval_type: IntervalType,
        update_counter: UpdateCounter,
        **kwargs,
    ) -> None:
        # Log periodically
        avg_loss = sum(self.loss_accumulator) / len(self.loss_accumulator)
        self.writer.add_scalar("avg_loss", avg_loss, logger=self.logger)
        self.loss_accumulator.clear()
every_n_epochs

If set, callback is invoked every N epochs.

every_n_updates

If set, callback is invoked every N optimizer updates.

every_n_samples

If set, callback is invoked every N samples processed.

batch_size

Batch size used during training.

Parameters:
every_n_epochs
every_n_updates
every_n_samples
batch_size
periodic_callback(*, interval_type, update_counter, **kwargs)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
Return type:

None

track_after_accumulation_step(*, update_counter, batch, losses, update_outputs, accumulation_steps, accumulation_step)

Hook called after each individual gradient accumulation step.

This method is invoked for every batch processed during training, regardless of whether an optimizer update is performed in that step (i.e., when accumulation_steps > 1). It is primarily used for tracking metrics that should be averaged or aggregated across accumulation steps.

Common use cases include:

  • Logging per-batch losses for high-frequency monitoring

  • Accumulating statistics across batches before an optimizer update

  • Implementing custom logging that needs access to individual batch data

Note

This method is generally intended to be called within a torch.no_grad() context by the trainer to ensure no gradients are tracked during logging operations.

Parameters:
  • update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

  • batch (Any) – The current data batch processed in this accumulation step.

  • losses (dict[str, torch.Tensor]) – Dictionary of computed losses for the current batch.

  • update_outputs (dict[str, torch.Tensor] | None) – Optional dictionary of model outputs for the current batch.

  • accumulation_steps (int) – Total number of accumulation steps before an optimizer update.

  • accumulation_step (int) – The current accumulation step index (0-indexed).

Return type:

None

track_after_update_step(*, update_counter, times)

Hook called after each optimizer update step.

This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:

  • Latest loss values

  • Learning rates

  • Model parameter statistics (norms, etc.)

  • Training throughput and timing measurements

Unlike periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.

Note

This method is executed within a torch.no_grad() context.

Parameters:
Return type:

None

after_epoch(update_counter, **kwargs)

Invoked after every epoch to check if callback should be invoked.

Applies torch.no_grad() context.

Parameters:
Return type:

None

after_update(update_counter, **kwargs)

Invoked after every update to check if callback should be invoked.

Applies torch.no_grad() context.

Parameters:
Return type:

None

at_eval(update_counter, **kwargs)
Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter)

Return type:

None

updates_till_next_invocation(update_counter)

Calculate how many updates remain until this callback is invoked.

Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Returns:

Number of updates remaining until the next callback invocation.

Return type:

int

updates_per_interval(update_counter)

Calculate how many updates are from one invocation of this callback to the next.

Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Returns:

Number of updates between callback invocations.

Return type:

int

get_interval_string_verbose()

Return interval configuration as a verbose string.

Returns:

Interval as, e.g., “every_n_epochs=1” for epoch-based intervals.

Return type:

str

to_short_interval_string()

Return interval configuration as a short string.

Returns:

Interval as, e.g., “E1” if every_n_epochs=1 for epoch-based intervals.

Return type:

str

class noether.core.callbacks.PeriodicDataIteratorCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: PeriodicCallback

Base class for callbacks that perform periodic iterations over a dataset.

PeriodicDataIteratorCallback extends PeriodicCallback to support evaluations or computations that require iterating over an entire dataset. This is commonly used for validation/test set evaluation, computing metrics on held-out data, or any operation that needs to process batches from a dataset at regular training intervals.

The class integrates with the training data pipeline by registering samplers that control when and how data is loaded. It handles the complete iteration workflow: data loading, batch processing, result collation across distributed ranks, and final processing.

Workflow:
  1. Iteration (_iterate_over_dataset()): When the periodic interval is reached, iterate through the dataset in batches.

  2. Process Data (process_data()): Process a single batch (e.g., run model inference) and return results.

  3. Collation (_collate_result()): Aggregate results across all batches and distributed ranks.

  4. Processing (process_results()): Compute final metrics or perform actions with the aggregated results.

Key Features:
  • Distributed Support: Automatically handles distributed evaluation with proper gathering across ranks and padding removal.

  • Flexible Collation: Supports collating various result types (tensors, dicts of tensors, lists).

  • Data Pipeline Integration: Uses SamplerIntervalConfig to integrate with the interleaved sampler for efficient data loading.

  • Progress Tracking: Provides progress bars and timing information for data loading.

Template Methods to Override:

Child classes must implement process_data() and process_results():

Examples

Basic validation accuracy callback that evaluates on a test set every epoch:

class AccuracyCallback(PeriodicDataIteratorCallback):
    def __init__(self, *args, dataset_key="test", **kwargs):
        super().__init__(*args, **kwargs)
        self.dataset_key = dataset_key

    def process_data(self, batch, *, trainer_model):
        # Run inference on batch
        x = batch["x"].to(trainer_model.device)
        y_true = batch["class"].clone()
        y_pred = trainer_model(x)
        return {"predictions": y_pred, "labels": y_true}

    def process_results(self, results, *, interval_type, update_counter, **_):
        # Compute accuracy from aggregated results
        y_pred = results["predictions"]
        y_true = results["labels"]
        accuracy = (y_pred.argmax(dim=1) == y_true).float().mean()

        self.writer.add_scalar(
            key="test/accuracy",
            value=accuracy.item(),
            logger=self.logger,
            format_str=".4f",
        )


# Configure in YAML:
# callbacks:
#   - kind: path.to.AccuracyCallback
#     every_n_epochs: 1
#     dataset_key: "test"

Advanced example with multiple return values and custom collation:

class DetailedEvaluationCallback(PeriodicDataIteratorCallback):
    def process_data(self, batch, *, trainer_model):
        x = batch["x"].to(trainer_model.device)
        y = batch["label"]

        # Return multiple outputs as tuple
        logits = trainer_model(x)
        embeddings = trainer_model.get_embeddings(x)
        return logits, embeddings, y

    def process_results(self, results, *, interval_type, update_counter, **_):
        # results is a tuple: (all_logits, all_embeddings, all_labels)
        logits, embeddings, labels = results

        # Compute multiple metrics
        accuracy = (logits.argmax(dim=1) == labels).float().mean()
        mean_embedding_norm = embeddings.norm(dim=-1).mean()

        self.writer.add_scalar("accuracy", accuracy.item())
        self.writer.add_scalar("embedding_norm", mean_embedding_norm.item())
dataset_key

Key to identify the dataset to iterate over from self.data_container. Automatically set from the callback config.

sampler_config

Configuration for the sampler that controls dataset iteration. Automatically set when dataset is initialized.

total_data_time

Cumulative time spent waiting for data loading across all periodic callbacks.

Note

  • The process_data() method is called within a torch.no_grad() context automatically.

  • For distributed training, results are automatically gathered across all ranks with proper padding removal.

Parameters:
dataset_key
total_data_time
sampler_config
_sampler_config_from_key(key, properties=None, max_size=None)

Register the dataset that is used for this callback in the dataloading pipeline.

Parameters:
  • key (str | None) – Key for identifying the dataset from self.data_container. Uses the first dataset if None.

  • properties (set[str] | None) – Optionally specifies a subset of properties to load from the dataset.

  • max_size (int | None) – If provided, only uses a subset of the full dataset. Default: None (no subset).

Returns:

SamplerIntervalConfig for the registered dataset.

Return type:

noether.data.samplers.SamplerIntervalConfig

abstractmethod process_data(batch, *, trainer_model)

Template method that is called for each batch that is loaded from the dataset.

This method should process a single batch and return results that will be collated.

Parameters:
  • batch – The loaded batch.

  • trainer_model (torch.nn.Module) – Model of the current training run.

Returns:

Processed results for this batch. Can be a tensor, dict of tensors, list, or tuple.

Return type:

Any

process_results(results, *, interval_type, update_counter, **_)

Template method that is called with the collated results from dataset iteration.

For example, metrics can be computed from the results for the entire test/validation dataset and logged.

Parameters:
Return type:

None

periodic_callback(*, interval_type, update_counter, data_iter, trainer_model, batch_size, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type (IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

  • data_iter (collections.abc.Iterator)

  • batch_size (int)

Return type:

None

after_training(**_)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None