noether.core.callbacks.base

Classes

CallbackBase

Base class for callbacks that execute something before/after training.

Module Contents

class noether.core.callbacks.base.CallbackBase(trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Base class for callbacks that execute something before/after training.

Allows overwriting before_training and after_training.

If the callback is stateful (i.e., it tracks something across the training process that needs to be loaded if the run is resumed), there are two ways to implement loading the callback state:

  • state_dict: write current state into a state dict. When the trainer saves the current checkpoint to the disk, it will also store the state_dict of all callbacks within the trainer state_dict. Once a run is resumed, a callback can load its state from the previously stored state_dict by overwriting the load_state_dict.

  • resume_from_checkpoint: If a callback is storing large files onto the disk, it would be redudant to also store them within its state_dict. Therefore, this method is called on resume to allow callbacks to load their state from files on the disk.

Callbacks have access to a LogWriter, with which callbacks can log metrics. The LogWriter is a singleton.

Examples

# THIS IS INSIDE A CUSTOM CALLBACK

# log only to experiment tracker, not stdout
self.writer.add_scalar(key="classification_accuracy", value=0.2)
# log to experiment tracker and stdout (as "0.24")
self.writer.add_scalar(
    key="classification_accuracy",
    value=0.23623,
    logger=self.logger,
    format_str=".2f",
)

Note

As evaluations are pretty much always done in torch.no_grad() contexts, the hooks implemented by callbacks are always executed within a torch.no_grad() context.

Parameters:
trainer: noether.training.trainers.BaseTrainer

Trainer of the current run. Can be used to access training state.

model: noether.core.models.ModelBase

Model of the current run. Can be used to access model parameters.

data_container: noether.data.container.DataContainer

Data container of the current run. Can be used to access all datasets.

tracker: noether.core.trackers.BaseTracker

Tracker of the current run. Can be used for direct access to the experiment tracking platform.

writer: noether.core.writers.LogWriter

Log writer of the current run. Can be used to log metrics to stdout/disk/online platform.

metric_property_provider: noether.core.providers.metric_property.MetricPropertyProvider

Metric property provider of the current run. Defines properties of metrics (e.g., whether higher values are better).

checkpoint_writer: noether.core.writers.CheckpointWriter

Checkpoint writer of the current run. Can be used to store checkpoints during training.

name = None
state_dict()

If a callback is stateful, the state will be stored when a checkpoint is stored to the disk.

Returns:

State of the callback. By default, callbacks are non-stateful and return None.

Return type:

dict[str, torch.Tensor] | None

load_state_dict(state_dict)

If a callback is stateful, the state will be stored when a checkpoint is stored to the disk and can be loaded with this method upon resuming a run.

Parameters:

state_dict (dict[str, Any]) – State to be loaded. By default, callbacks are non-stateful and load_state_dict does nothing.

Return type:

None

resume_from_checkpoint(resumption_paths, model)

If a callback stores large files to disk and is stateful (e.g., an EMA of the model), it would be unnecessarily wasteful to also store the state in the callbacks state_dict. Therefore, resume_from_checkpoint is called when resuming a run, which allows callbacks to load their state from any file that was stored on the disk.

Parameters:
Return type:

None

property logger: logging.Logger

Logger for logging to stdout.

Return type:

logging.Logger

before_training(*, update_counter)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None

after_training(*, update_counter)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None