How to Use and Build Callbacks

Callbacks are the primary mechanism in the Noether Framework that allow you to inject custom logic into various stages of the training process. They are primarily used for monitoring, checkpointing, evaluation, and experiment tracking.

What are Callbacks?

In Noether, a callback is a class that inherits from CallbackBase. These objects provide hooks that the trainer calls at specific points:

  • Before training starts: Initialization, logging hyperparameters, or printing model summaries.

  • After each accumulation step: Tracking metrics across multiple batches.

  • After each optimizer update: Updating learning rate schedules or tracking update-level metrics.

  • After each epoch: Performing validation or saving periodic checkpoints.

  • After training ends: Final evaluation, saving final results, or cleanup.

  • At evaluation time: Running inference on validation or test datasets.

Types of Callbacks

Noether provides several base classes and a wide range of built-in callbacks.

Pre/Post Training Callbacks

These callbacks inherit directly from CallbackBase and are designed to run logic only once at the start or end of training. Examples include logging hyperparameters before training or saving final results after training.

Periodic Callbacks

Most callbacks during training are periodic. Inheriting from PeriodicCallback allows you to configure how often a callback should run using one of the following intervals:

  • every_n_epochs: Runs after every $N$ epochs.

  • every_n_updates: Runs after every $N$ optimizer steps.

  • every_n_samples: Runs after every $N$ training samples have been processed.

Periodic Data Iterator Callbacks

For tasks that require iterating over an entire dataset (like validation or computing complex metrics on a test set), Noether provides PeriodicDataIteratorCallback.

This class handles:

  • Distributed data sampling across multiple GPUs.

  • Automatic collation of results from different ranks.

  • Integration with the training data pipeline.

Commonly Used Callbacks

Noether includes many pre-defined callbacks organized by their purpose:

Category

Examples and Usage

Monitoring

ProgressCallback, DatasetStatsCallback, LrCallback, PeakMemoryCallback, OnlineLossCallback, ParamCountCallback, EtaCallback, TrainTimeCallback. Used for real-time tracking of training progress and hardware usage. These callbacks are all initialized by default by the BaseTrainer, the user does not need to add them manually.

Checkpointing

BestCheckpointCallback, CheckpointCallback, EMACallback. Used to save model weights periodically or when a new best metric is achieved.

Early Stopping

MetricEarlyStopper, FixedEarlyStopper. Used to stop training automatically if progress plateaus.

Evaluation

BestMetricCallback, TrackOutputsCallback. Specialized monitoring for tracked metrics.

When to Use What?

  1. Use existing callbacks for standard tasks like logging, checkpointing, and validation. These are highly configurable via YAML.

  2. Inherit from PeriodicCallback if you need to perform an action at regular intervals (e.g., logging a custom internal state of the model).

  3. Inherit from PeriodicDataIteratorCallback if you need to run inference on a specific dataset and aggregate the results to compute a metric. Those callbacks need to configure a dataset key to specify which dataset to run on.

  4. Inherit from CallbackBase if your logic only needs to run once at the very beginning or end of training.

How to Configure Callbacks

Callbacks are usually defined in your experiment configuration under the callbacks key. Each callback requires the fully qualified class name (e.g., noether.core.callbacks.progress.Progress) as the kind, exactly one frequency setting (every_n_*) and any additional parameters specific to that callback.

Example YAML configuration:

callbacks:
 - kind: noether.core.callbacks.CallbackClassName
   name: CallbackInstanceName
   every_n_epochs: 1
   # or every_n_updates: 1
   # additional_param: value

How to Implement Custom Callbacks

To create a custom callback, define a new class that inherits from one of the base callback classes. Override the relevant methods to inject your logic at the desired points in the training process.

import torch
from noether.core.schemas.callbacks import PeriodicDataIteratorCallbackConfig
from noether.core.callbacks.periodic import PeriodicCallback

class CustomCallbackConfig(PeriodicDataIteratorCallbackConfig):
    pass  # Define any configuration parameters your callback needs

class MyCustomCallback(PeriodicCallback):
    def __init__(self, callback_config: CustomCallbackConfig, **kwargs):
        super().__init__(callback_config, **kwargs)

    def process_data(self, batch: dict[str, torch.Tensor], **_) -> dict[str, torch.Tensor]:
        model_output = self.model(**batch)
        # some more custom logic
        out = {"custom_output": model_output}
        return out

    def process_results(self, results: dict[str, torch.Tensor], **_) -> None:
        # this method gets the aggregated results of the process_data method across the dataset
        # do something with the results
        self.writer.add_scalar("custom_metric", results["custom_output"].mean().item())