noether.core.callbacks¶
Submodules¶
Exceptions¶
Custom StopIteration exception for Early Stoppers. |
Classes¶
Base class for callbacks that execute something before/after training. |
|
Callback to save the best model based on a metric. |
|
Callback to save the model and optimizer state periodically. |
|
Callback for exponential moving average (EMA) of model weights. |
|
A callback that logs the length of each dataset in the data container. Is initialized by the |
|
Callback to print the progress and estimated duration until the periodic callback will be invoked. |
|
Callback to log the learning rate of the optimizer. |
|
Callback to track the loss of the model after every gradient accumulation step and log the average loss. |
|
Callback to log the number of trainable and frozen parameters of the model. |
|
Callback to log the peak memory usage of the model. Is initialized by the |
|
Callback to print the progress of the training such as number of epochs and updates. |
|
Callback to log the time spent on dataloading. Is initialized by the |
|
Base class for early stoppers that is used to define the interface for early stoppers used by the trainers. |
|
Early stopper (training) based on a fixed number of epochs, updates, or samples. |
|
Early stopper (training) based on a metric value to be monitored. |
|
A callback that keeps track of the best metric value over a training run for a certain metric (i.e., source_metric_key) while also logging one or more target metrics. |
|
Callback that is invoked during training after every gradient step to track certain outputs from the update step. |
|
Base class for callbacks that are invoked periodically during training. |
|
Base class for callbacks that perform periodic iterations over a dataset. |
Package Contents¶
- class noether.core.callbacks.CallbackBase(trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Base class for callbacks that execute something before/after training.
Allows overwriting before_training and after_training.
If the callback is stateful (i.e., it tracks something across the training process that needs to be loaded if the run is resumed), there are two ways to implement loading the callback state:
state_dict: write current state into a state dict. When the trainer saves the current checkpoint to the disk, it will also store the state_dict of all callbacks within the trainer state_dict. Once a run is resumed, a callback can load its state from the previously stored state_dict by overwriting the load_state_dict.
resume_from_checkpoint: If a callback is storing large files onto the disk, it would be redudant to also store them within its state_dict. Therefore, this method is called on resume to allow callbacks to load their state from files on the disk.
Callbacks have access to a LogWriter, with which callbacks can log metrics. The LogWriter is a singleton.
Examples
# THIS IS INSIDE A CUSTOM CALLBACK # log only to experiment tracker, not stdout self.writer.add_scalar(key="classification_accuracy", value=0.2) # log to experiment tracker and stdout (as "0.24") self.writer.add_scalar( key="classification_accuracy", value=0.23623, logger=self.logger, format_str=".2f", )
Note
As evaluations are pretty much always done in torch.no_grad() contexts, the hooks implemented by callbacks are always executed within a torch.no_grad() context.
- Parameters:
trainer (noether.training.trainers.BaseTrainer) – Trainer of the current run.
model (noether.core.models.ModelBase) – Model of the current run.
data_container (noether.data.container.DataContainer) –
DataContainerinstance that provides access to all datasets.tracker (noether.core.trackers.BaseTracker) –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer (noether.core.writers.LogWriter) –
LogWriterinstance to log metrics to stdout/disk/online platform.checkpoint_writer (noether.core.writers.CheckpointWriter) –
CheckpointWriterinstance to save checkpoints during training.metric_property_provider (noether.core.providers.metric_property.MetricPropertyProvider) –
MetricPropertyProviderinstance to access properties of metrics.name (str | None) – Name of the callback.
- trainer: noether.training.trainers.BaseTrainer¶
Trainer of the current run. Can be used to access training state.
- model: noether.core.models.ModelBase¶
Model of the current run. Can be used to access model parameters.
- data_container: noether.data.container.DataContainer¶
Data container of the current run. Can be used to access all datasets.
- tracker: noether.core.trackers.BaseTracker¶
Tracker of the current run. Can be used for direct access to the experiment tracking platform.
- writer: noether.core.writers.LogWriter¶
Log writer of the current run. Can be used to log metrics to stdout/disk/online platform.
- metric_property_provider: noether.core.providers.metric_property.MetricPropertyProvider¶
Metric property provider of the current run. Defines properties of metrics (e.g., whether higher values are better).
- checkpoint_writer: noether.core.writers.CheckpointWriter¶
Checkpoint writer of the current run. Can be used to store checkpoints during training.
- name = None¶
- state_dict()¶
If a callback is stateful, the state will be stored when a checkpoint is stored to the disk.
- Returns:
State of the callback. By default, callbacks are non-stateful and return None.
- Return type:
dict[str, torch.Tensor] | None
- load_state_dict(state_dict)¶
If a callback is stateful, the state will be stored when a checkpoint is stored to the disk and can be loaded with this method upon resuming a run.
- resume_from_checkpoint(resumption_paths, model)¶
If a callback stores large files to disk and is stateful (e.g., an EMA of the model), it would be unnecessarily wasteful to also store the state in the callbacks state_dict. Therefore, resume_from_checkpoint is called when resuming a run, which allows callbacks to load their state from any file that was stored on the disk.
- Parameters:
resumption_paths (noether.core.providers.path.PathProvider) –
PathProviderinstance to access paths from the checkpoint to resume from.model (noether.core.models.ModelBase) – model of the current training run.
- Return type:
None
- property logger: logging.Logger¶
Logger for logging to stdout.
- Return type:
- before_training(*, update_counter)¶
Hook called once before the training loop starts.
This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:
Initializing experiment tracking (e.g., logging hyperparameters)
Printing model summaries or architecture details
Initializing specific data structures or buffers needed during training
Performing sanity checks on the data or configuration
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.- Return type:
None
- after_training(*, update_counter)¶
Hook called once after the training loop finishes.
This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:
Performing a final evaluation on the test set
Saving final model weights or artifacts
Sending notifications (e.g., via Slack or email) about the completed run
Closing or finalizing experiment tracking sessions
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.- Return type:
None
- class noether.core.callbacks.BestCheckpointCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to save the best model based on a metric.
This callback monitors a specified metric and saves the model checkpoint whenever a new best value is achieved. It supports storing different model components when using a composite model and can save checkpoints at different tolerance thresholds.
Example config:
callbacks: - kind: noether.core.callbacks.BestCheckpointCallback name: BestCheckpointCallback every_n_epochs: 1 metric_key: loss/val/total model_names: # only applies when training a CompositeModel - encoder
- Parameters:
callback_config (noether.core.schemas.callbacks.BestCheckpointCallbackConfig) – Configuration for the callback. See
BestCheckpointCallbackConfigfor available options including metric key, model names, and tolerance settings.**kwargs – Additional arguments passed to the parent class.
- metric_key¶
- model_names¶
- higher_is_better¶
- best_metric_value¶
- save_frozen_weights¶
- tolerances_is_exceeded¶
- tolerance_counter = 0¶
- state_dict()¶
Return the state of the callback for checkpointing.
- load_state_dict(state_dict)¶
Load the callback state from a checkpoint.
Note
This modifies the input state_dict in place.
- before_training(*, update_counter)¶
Validate callback configuration before training starts.
- Parameters:
update_counter – The training update counter.
- Raises:
NotImplementedError – If resuming training with tolerances is attempted.
- Return type:
None
- periodic_callback(**_)¶
Execute the periodic callback to check and save best model.
This method is called at the configured frequency (e.g., every N epochs). It checks if the current metric value is better than the previous best, and if so, saves the model checkpoint. Also tracks tolerance-based checkpoints.
- Raises:
KeyError – If the log cache is empty or the metric key is not found.
- Return type:
None
- after_training(**kwargs)¶
Log the best metric values at different tolerance thresholds after training completes.
- Parameters:
**kwargs – Additional keyword arguments (unused).
- Return type:
None
- class noether.core.callbacks.CheckpointCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to save the model and optimizer state periodically.
Example config:
- kind: noether.core.callbacks.CheckpointCallback name: CheckpointCallback every_n_epochs: 1 save_weights: true save_optim: true
- Parameters:
callback_config (noether.core.schemas.callbacks.CheckpointCallbackConfig) – Configuration for the callback. See
CheckpointCallbackConfigfor available options.**kwargs – Additional arguments passed to the parent class.
- save_weights¶
- save_optim¶
- save_latest_weights¶
- save_latest_optim¶
- model_names¶
- before_training(*, update_counter)¶
Hook called once before the training loop starts.
This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:
Initializing experiment tracking (e.g., logging hyperparameters)
Printing model summaries or architecture details
Initializing specific data structures or buffers needed during training
Performing sanity checks on the data or configuration
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.UpdateCounter) –
UpdateCounterinstance to access current training progress.- Return type:
None
- periodic_callback(*, interval_type, update_counter, **kwargs)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter (noether.core.utils.training.UpdateCounter) –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- after_training(**_)¶
Hook called once after the training loop finishes.
This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:
Performing a final evaluation on the test set
Saving final model weights or artifacts
Sending notifications (e.g., via Slack or email) about the completed run
Closing or finalizing experiment tracking sessions
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None
- class noether.core.callbacks.EmaCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback for exponential moving average (EMA) of model weights.
Example config:
- kind: noether.core.callbacks.EmaCallback every_n_epochs: 10 save_weights: false save_last_weights: false save_latest_weights: true target_factors: - 0.9999 name: EmaCallback
- Parameters:
callback_config (noether.core.schemas.callbacks.EmaCallbackConfig) – Configuration for the callback. See
EmaCallbackConfigfor available options.**kwargs – Additional arguments passed to the parent class.
- model_paths¶
- target_factors¶
- save_weights¶
- save_last_weights¶
- save_latest_weights¶
- resume_from_checkpoint(resumption_paths, model)¶
Resume EMA state from a checkpoint.
- Parameters:
resumption_paths (noether.core.providers.path.PathProvider) –
PathProviderwith paths to checkpoint files.model – Model to resume EMA state for.
- Return type:
None
- before_training(**_)¶
Hook called once before the training loop starts.
This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:
Initializing experiment tracking (e.g., logging hyperparameters)
Printing model summaries or architecture details
Initializing specific data structures or buffers needed during training
Performing sanity checks on the data or configuration
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None
- apply_ema(cur_model, model_path, target_factor)¶
fused in-place implementation
- track_after_update_step(**_)¶
Hook called after each optimizer update step.
This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:
Latest loss values
Learning rates
Model parameter statistics (norms, etc.)
Training throughput and timing measurements
Unlike
periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.times – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).
- Return type:
None
- periodic_callback(*, interval_type, update_counter, **_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- after_training(**_)¶
Hook called once after the training loop finishes.
This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:
Performing a final evaluation on the test set
Saving final model weights or artifacts
Sending notifications (e.g., via Slack or email) about the completed run
Closing or finalizing experiment tracking sessions
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None
- class noether.core.callbacks.DatasetStatsCallback(trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Bases:
noether.core.callbacks.base.CallbackBaseA callback that logs the length of each dataset in the data container. Is initialized by the
BaseTrainerand should not be added manually to the trainer’s callbacks.- Parameters:
trainer (noether.training.trainers.BaseTrainer) – Trainer of the current run.
model (noether.core.models.ModelBase) – Model of the current run.
data_container (noether.data.container.DataContainer) –
DataContainerinstance that provides access to all datasets.tracker (noether.core.trackers.BaseTracker) –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer (noether.core.writers.LogWriter) –
LogWriterinstance to log metrics to stdout/disk/online platform.checkpoint_writer (noether.core.writers.CheckpointWriter) –
CheckpointWriterinstance to save checkpoints during training.metric_property_provider (noether.core.providers.metric_property.MetricPropertyProvider) –
MetricPropertyProviderinstance to access properties of metrics.name (str | None) – Name of the callback.
- before_training(**_)¶
Hook called once before the training loop starts.
This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:
Initializing experiment tracking (e.g., logging hyperparameters)
Printing model summaries or architecture details
Initializing specific data structures or buffers needed during training
Performing sanity checks on the data or configuration
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None
- class noether.core.callbacks.EtaCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to print the progress and estimated duration until the periodic callback will be invoked.
Also counts up the current epoch/update/samples and provides the average update duration. Only used in “unmanaged” runs, i.e., it is not used when the run was started via SLURM.
This callback is initialized by the
BaseTrainerand should not be added manually to the trainer’s callbacks.- Parameters:
callback_config (noether.core.schemas.callbacks.CallBackBaseConfig) – Configuration for the callback. See
CallBackBaseConfigfor available options.trainer – Trainer of the current run.
model – Model of the current run.
data_container –
DataContainerinstance that provides access to all datasets.tracker –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer –
LogWriterinstance to log metrics.checkpoint_writer –
CheckpointWriterinstance to save checkpoints.metric_property_provider –
MetricPropertyProviderinstance to access properties of metrics.name – Name of the callback.
- class LoggerWasCalledHandler¶
Bases:
logging.HandlerHandler instances dispatch logging events to specific destinations.
The base handler class. Acts as a placeholder which defines the Handler interface. Handlers can optionally use Formatter instances to format records as desired. By default, no formatter is specified; in this case, the ‘raw’ message as determined by record.message is logged.
Initializes the instance - basically setting the formatter to None and the filter list to empty.
- was_called = False¶
- emit(_)¶
Do whatever it takes to actually log the specified logging record.
This version is intended to be implemented by subclasses and so raises a NotImplementedError.
- total_time = 0.0¶
- time_since_last_log = 0.0¶
- handler¶
- before_training(*, update_counter)¶
Hook called once before the training loop starts.
This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:
Initializing experiment tracking (e.g., logging hyperparameters)
Printing model summaries or architecture details
Initializing specific data structures or buffers needed during training
Performing sanity checks on the data or configuration
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.UpdateCounter) –
UpdateCounterinstance to access current training progress.- Return type:
None
- track_after_update_step(*, update_counter, times)¶
Hook called after each optimizer update step.
This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:
Latest loss values
Learning rates
Model parameter statistics (norms, etc.)
Training throughput and timing measurements
Unlike
periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.UpdateCounter) –
UpdateCounterinstance to access current training progress.times – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).
- Return type:
None
- periodic_callback(*, interval_type, **_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- after_training(**_)¶
Hook called once after the training loop finishes.
This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:
Performing a final evaluation on the test set
Saving final model weights or artifacts
Sending notifications (e.g., via Slack or email) about the completed run
Closing or finalizing experiment tracking sessions
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None
- class noether.core.callbacks.LrCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to log the learning rate of the optimizer.
This callback is initialized by the
BaseTrainerand should not be added manually to the trainer’s callbacks.- Parameters:
callback_config (noether.core.schemas.callbacks.CallBackBaseConfig) – Configuration for the callback. See
CallBackBaseConfigfor available options.trainer (noether.training.trainers.BaseTrainer) – Trainer of the current run.
model (noether.core.models.ModelBase) – Model of the current run.
data_container (noether.data.container.DataContainer) –
DataContainerinstance that provides access to all datasets.tracker (noether.core.trackers.BaseTracker) –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer (noether.core.writers.LogWriter) –
LogWriterinstance to log metrics.checkpoint_writer (noether.core.writers.CheckpointWriter) –
CheckpointWriterinstance to save checkpoints.metric_property_provider (noether.core.providers.MetricPropertyProvider) –
MetricPropertyProviderinstance to access properties of metrics.name (str | None) – Name of the callback.
- periodic_callback(**_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- class noether.core.callbacks.OnlineLossCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to track the loss of the model after every gradient accumulation step and log the average loss.
This callback is initialized by the
BaseTrainerand should not be added manually to the trainer’s callbacks.Initialize the OnlineLossCallback.
- Parameters:
callback_config (noether.core.schemas.callbacks.OnlineLossCallbackConfig) – Configuration for the callback. See
OnlineLossCallbackConfigfor available options.**kwargs – Additional arguments passed to the parent class.
- verbose¶
- tracked_losses: collections.defaultdict[str, list[torch.Tensor]]¶
- track_after_accumulation_step(*, losses, **_)¶
Hook called after each individual gradient accumulation step.
This method is invoked for every batch processed during training, regardless of whether an optimizer update is performed in that step (i.e., when
accumulation_steps > 1). It is primarily used for tracking metrics that should be averaged or aggregated across accumulation steps.Common use cases include:
Logging per-batch losses for high-frequency monitoring
Accumulating statistics across batches before an optimizer update
Implementing custom logging that needs access to individual batch data
Note
This method is generally intended to be called within a
torch.no_grad()context by the trainer to ensure no gradients are tracked during logging operations.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.batch – The current data batch processed in this accumulation step.
losses – Dictionary of computed losses for the current batch.
update_outputs – Optional dictionary of model outputs for the current batch.
accumulation_steps – Total number of accumulation steps before an optimizer update.
accumulation_step – The current accumulation step index (0-indexed).
- Return type:
None
- periodic_callback(*, interval_type, **_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- class noether.core.callbacks.ParamCountCallback(trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Bases:
noether.core.callbacks.base.CallbackBaseCallback to log the number of trainable and frozen parameters of the model.
This callback is initialized by the
BaseTrainerand should not be added manually to the trainer’s callbacks.- Parameters:
trainer (noether.training.trainers.BaseTrainer) – Trainer of the current run.
model (noether.core.models.ModelBase) – Model of the current run.
data_container (noether.data.container.DataContainer) –
DataContainerinstance that provides access to all datasets.tracker (noether.core.trackers.BaseTracker) –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer (noether.core.writers.LogWriter) –
LogWriterinstance to log metrics to stdout/disk/online platform.checkpoint_writer (noether.core.writers.CheckpointWriter) –
CheckpointWriterinstance to save checkpoints during training.metric_property_provider (noether.core.providers.metric_property.MetricPropertyProvider) –
MetricPropertyProviderinstance to access properties of metrics.name (str | None) – Name of the callback.
- before_training(**_)¶
Hook called once before the training loop starts.
This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:
Initializing experiment tracking (e.g., logging hyperparameters)
Printing model summaries or architecture details
Initializing specific data structures or buffers needed during training
Performing sanity checks on the data or configuration
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None
- class noether.core.callbacks.PeakMemoryCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to log the peak memory usage of the model. Is initialized by the
BaseTrainerand should not be added manually to the trainer’s callbacks.- Parameters:
callback_config (noether.core.schemas.callbacks.CallBackBaseConfig) – Configuration for the callback. See
CallBackBaseConfigfor available options.trainer (noether.training.trainers.BaseTrainer) – Trainer of the current run.
model (noether.core.models.ModelBase) – Model of the current run.
data_container (noether.data.container.DataContainer) –
DataContainerinstance that provides access to all datasets.tracker (noether.core.trackers.BaseTracker) –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer (noether.core.writers.LogWriter) –
LogWriterinstance to log metrics.checkpoint_writer (noether.core.writers.CheckpointWriter) –
CheckpointWriterinstance to save checkpoints.metric_property_provider (noether.core.providers.MetricPropertyProvider) –
MetricPropertyProviderinstance to access properties of metrics.name (str | None) – Name of the callback.
- periodic_callback(**__)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- class noether.core.callbacks.ProgressCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to print the progress of the training such as number of epochs and updates.
This callback is initialized by the
BaseTrainerand should not be added manually to the trainer’s callbacks.- Parameters:
callback_config (noether.core.schemas.callbacks.CallBackBaseConfig) – Configuration for the callback. See
CallBackBaseConfigfor available options.trainer (noether.training.trainers.BaseTrainer) – Trainer of the current run.
model (noether.core.models.ModelBase) – Model of the current run.
data_container (noether.data.container.DataContainer) –
DataContainerinstance that provides access to all datasets.tracker (noether.core.trackers.BaseTracker) –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer (noether.core.writers.LogWriter) –
LogWriterinstance to log metrics.checkpoint_writer (noether.core.writers.CheckpointWriter) –
CheckpointWriterinstance to save checkpoints.metric_property_provider (noether.core.providers.MetricPropertyProvider) –
MetricPropertyProviderinstance to access properties of metrics.name (str | None) – Name of the callback.
- before_training(**_)¶
Hook called once before the training loop starts.
This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:
Initializing experiment tracking (e.g., logging hyperparameters)
Printing model summaries or architecture details
Initializing specific data structures or buffers needed during training
Performing sanity checks on the data or configuration
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None
- periodic_callback(*, interval_type, update_counter, **_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter (noether.core.utils.training.UpdateCounter) –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- track_after_update_step(*, update_counter, **_)¶
Hook called after each optimizer update step.
This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:
Latest loss values
Learning rates
Model parameter statistics (norms, etc.)
Training throughput and timing measurements
Unlike
periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.UpdateCounter) –
UpdateCounterinstance to access current training progress.times – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).
- Return type:
None
- class noether.core.callbacks.TrainTimeCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to log the time spent on dataloading. Is initialized by the
BaseTrainerand should not be added manually to the trainer’s callbacks.- Parameters:
callback_config (noether.core.schemas.callbacks.CallBackBaseConfig) – Configuration for the callback. See
CallBackBaseConfigfor available options.trainer – Trainer of the current run.
model – Model of the current run.
data_container –
DataContainerinstance that provides access to all datasets.tracker –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer –
LogWriterinstance to log metrics.checkpoint_writer –
CheckpointWriterinstance to save checkpoints.metric_property_provider –
MetricPropertyProviderinstance to access properties of metrics.name – Name of the callback.
- total_train_times: dict[str, torch.Tensor]¶
- track_after_update_step(*, times, **_)¶
Hook called after each optimizer update step.
This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:
Latest loss values
Learning rates
Model parameter statistics (norms, etc.)
Training throughput and timing measurements
Unlike
periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.times (dict[str, float]) – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).
- Return type:
None
- periodic_callback(**_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- after_training(**_)¶
Hook called once after the training loop finishes.
This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:
Performing a final evaluation on the test set
Saving final model weights or artifacts
Sending notifications (e.g., via Slack or email) about the completed run
Closing or finalizing experiment tracking sessions
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None
- exception noether.core.callbacks.EarlyStopIteration¶
Bases:
StopIterationCustom StopIteration exception for Early Stoppers.
Initialize self. See help(type(self)) for accurate signature.
- class noether.core.callbacks.EarlyStopperBase(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackBase class for early stoppers that is used to define the interface for early stoppers used by the trainers.
- Parameters:
callback_config (noether.core.schemas.callbacks.CallBackBaseConfig) – Configuration for the callback. See
CallBackBaseConfigfor available options.trainer (noether.training.trainers.BaseTrainer) – Trainer of the current run.
model (noether.core.models.ModelBase) – Model of the current run.
data_container (noether.data.container.DataContainer) –
DataContainerinstance that provides access to all datasets.tracker (noether.core.trackers.BaseTracker) –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer (noether.core.writers.LogWriter) –
LogWriterinstance to log metrics.checkpoint_writer (noether.core.writers.CheckpointWriter) –
CheckpointWriterinstance to save checkpoints.metric_property_provider (noether.core.providers.MetricPropertyProvider) –
MetricPropertyProviderinstance to access properties of metrics.name (str | None) – Name of the callback.
- to_short_interval_string()¶
Convert the interval to a short string representation used for logging.
- Return type:
- periodic_callback(*, interval_type, update_counter, **kwargs)¶
Check if training should stop and raise exception if needed.
- Parameters:
interval_type (noether.core.callbacks.periodic.IntervalType) – Type of interval that triggered this callback.
update_counter (noether.core.utils.training.UpdateCounter) –
UpdateCounterinstance with current training state.**kwargs – Additional keyword arguments.
- Raises:
EarlyStopIteration – If training should be stopped based on the stopping criterion.
- Return type:
None
- class noether.core.callbacks.FixedEarlyStopper(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.early_stoppers.base.EarlyStopperBaseEarly stopper (training) based on a fixed number of epochs, updates, or samples.
Example config:
- kind: noether.core.callbacks.FixedEarlyStopper stop_at_epoch: 10 name: FixedEarlyStopper
- Parameters:
callback_config (noether.core.schemas.callbacks.FixedEarlyStopperConfig) – The configuration for the callback. See
FixedEarlyStopperConfigfor available options.**kwargs – Additional arguments to pass to the parent class.
- stop_at_sample¶
- stop_at_update¶
- stop_at_epoch¶
- class noether.core.callbacks.MetricEarlyStopper(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.early_stoppers.base.EarlyStopperBaseEarly stopper (training) based on a metric value to be monitored.
Example config:
- kind: noether.core.callbacks.MetricEarlyStopper every_n_epochs: 1 metric_key: loss/val/total tolerance: 0.10 name: MetricEarlyStopper
- Parameters:
callback_config (noether.core.schemas.callbacks.MetricEarlyStopperConfig) – Configuration for the callback. See
MetricEarlyStopperConfigfor available options including metric key and tolerance.**kwargs – Additional arguments to pass to the parent class.
- metric_key¶
- higher_is_better¶
- tolerance¶
- tolerance_counter = 0¶
- best_metric¶
- class noether.core.callbacks.BestMetricCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackA callback that keeps track of the best metric value over a training run for a certain metric (i.e., source_metric_key) while also logging one or more target metrics.
For example, track the test loss the epoch with the best validation loss to simulate early stopping.
Example config:
- kind: noether.core.callbacks.BestMetricCallback every_n_epochs: 1 source_metric_key: loss/val/total target_metric_keys: - loss/test/total
In this example, whenever a new best validation loss is found, the corresponding test loss is logged under the key
loss/test/total/at_best/loss/val/total.- Parameters:
callback_config (noether.core.schemas.callbacks.BestMetricCallbackConfig) – Configuration for the callback. See
BestMetricCallbackConfigfor available options including source and target metric keys.**kwargs – Additional keyword arguments provided to the parent class.
- source_metric_key¶
- target_metric_keys¶
- optional_target_metric_keys¶
- higher_is_better¶
- best_metric_value¶
- periodic_callback(**__)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- class noether.core.callbacks.TrackAdditionalOutputsCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback that is invoked during training after every gradient step to track certain outputs from the update step. The update_outputs that are provided in the track_after_accumulation_step method are the additional_outputs field from the TrainerResult returned by the trainer’s update step.
The update_outputs that are provided in the track_after_accumulation_step method are the additional_outputs field from the TrainerResult returned by the trainer’s update step.
The provided
update_outputsare assumed to be a dictionary and outputs that match keys or patterns are tracked. An update output matches if either the key matches exactly, e.g. {“some_output”: …} and keys[“some_output”]; or if one of the patterns is contained in the update key name, e.g. {“some_loss”: …} and patterns = [“loss”].Example config:
- kind: noether.core.callbacks.TrackAdditionalOutputsCallback name: TrackAdditionalOutputsCallback every_n_updates: 1 keys: - "surface_pressure_loss"
- Parameters:
callback_config (noether.core.schemas.callbacks.TrackAdditionalOutputsCallbackConfig) – Configuration for the callback. See
TrackAdditionalOutputsCallbackConfigfor available options including keys and patterns to track.**kwargs – Additional keyword arguments provided to the parent class.
- out: pathlib.Path | None¶
- patterns¶
- keys¶
- verbose¶
- tracked_values: collections.defaultdict[str, list]¶
- reduce¶
- log_output¶
- save_output¶
- track_after_accumulation_step(*, update_counter, update_outputs, **_)¶
Track the specified outputs after each accumulation step.
- Parameters:
update_counter –
UpdateCounterobject to track the number of updates.update_outputs – The additional_outputs field from the TrainerResult returned by the trainer’s update step. Note that the base train_step method in the base trainer does not provide any additional outputs by default, and hence this callback can only be used if the train_step is modified to provide additional outputs.
**_ – Additional unused keyword arguments.
- Return type:
None
- periodic_callback(*, update_counter, **_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter (noether.core.utils.training.UpdateCounter) –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- class noether.core.callbacks.PeriodicCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Bases:
noether.core.callbacks.base.CallbackBaseBase class for callbacks that are invoked periodically during training.
PeriodicCallback extends
CallbackBaseto support periodic execution based on training progress. Callbacks can be configured to run at regular intervals defined by epochs, updates (optimizer steps), or samples (data points processed). This class implements the infrastructure for periodic invocation while child classes define the actual behavior via theperiodic_callback()method.- Interval Configuration:
Callbacks can be configured to run periodically using one or more of:
every_n_epochs: Execute callback every N epochsevery_n_updates: Execute callback every N optimizer updatesevery_n_samples: Execute callback every N samples processed
- Tracking vs. Periodic Execution:
The class provides two types of hooks:
Tracking hooks (
track_after_accumulation_step(),track_after_update_step()): Called on every accumulation/update step to track metrics continuously (e.g., for running averages). I.e., if you want to log an exponential moving average of the loss every epoch, the logging is done in the periodic callback; however, the tracking of the loss values for computing the moving average is done in the tracking hook.Periodic hook (
periodic_callback()): Called only when the configured interval is reached, typically for expensive operations like evaluation or checkpointing.
Examples
Creating a custom periodic callback that logs metrics every 10 epochs:
class CustomMetricCallback(PeriodicCallback): def periodic_callback( self, *, interval_type: IntervalType, update_counter: UpdateCounter, **kwargs, ) -> None: # This method is called every 10 epochs metric_value = self.compute_expensive_metric() self.writer.add_scalar( key="custom_metric", value=metric_value, logger=self.logger, ) # Configure in YAML: # callbacks: # - kind: path.to.CustomMetricCallback # every_n_epochs: 10
Tracking metrics at every update and logging periodically:
class RunningAverageCallback(PeriodicCallback): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.loss_accumulator = [] def track_after_update_step(self, *, update_counter: UpdateCounter, times: dict[str, float]) -> None: # Track at every update self.loss_accumulator.append(self.trainer.last_loss) def periodic_callback( self, *, interval_type: IntervalType, update_counter: UpdateCounter, **kwargs, ) -> None: # Log periodically avg_loss = sum(self.loss_accumulator) / len(self.loss_accumulator) self.writer.add_scalar("avg_loss", avg_loss, logger=self.logger) self.loss_accumulator.clear()
- every_n_epochs¶
If set, callback is invoked every N epochs.
- every_n_updates¶
If set, callback is invoked every N optimizer updates.
- every_n_samples¶
If set, callback is invoked every N samples processed.
- batch_size¶
Batch size used during training.
- Parameters:
callback_config (noether.core.schemas.callbacks.CallBackBaseConfig) – Configuration for the callback. See
CallBackBaseConfigfor available options.trainer (noether.training.trainers.BaseTrainer) – Trainer of the current run.
model (noether.core.models.ModelBase) – Model of the current run.
data_container (noether.data.container.DataContainer) –
DataContainerinstance that provides access to all datasets.tracker (noether.core.trackers.BaseTracker) –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer (noether.core.writers.LogWriter) –
LogWriterinstance to log metrics.checkpoint_writer (noether.core.writers.CheckpointWriter) –
CheckpointWriterinstance to save checkpoints.metric_property_provider (noether.core.providers.MetricPropertyProvider) –
MetricPropertyProviderinstance to access properties of metrics.name (str | None) – Name of the callback.
- every_n_epochs¶
- every_n_updates¶
- every_n_samples¶
- batch_size¶
- periodic_callback(*, interval_type, update_counter, **kwargs)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type (IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- track_after_accumulation_step(*, update_counter, batch, losses, update_outputs, accumulation_steps, accumulation_step)¶
Hook called after each individual gradient accumulation step.
This method is invoked for every batch processed during training, regardless of whether an optimizer update is performed in that step (i.e., when
accumulation_steps > 1). It is primarily used for tracking metrics that should be averaged or aggregated across accumulation steps.Common use cases include:
Logging per-batch losses for high-frequency monitoring
Accumulating statistics across batches before an optimizer update
Implementing custom logging that needs access to individual batch data
Note
This method is generally intended to be called within a
torch.no_grad()context by the trainer to ensure no gradients are tracked during logging operations.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.batch (Any) – The current data batch processed in this accumulation step.
losses (dict[str, torch.Tensor]) – Dictionary of computed losses for the current batch.
update_outputs (dict[str, torch.Tensor] | None) – Optional dictionary of model outputs for the current batch.
accumulation_steps (int) – Total number of accumulation steps before an optimizer update.
accumulation_step (int) – The current accumulation step index (0-indexed).
- Return type:
None
- track_after_update_step(*, update_counter, times)¶
Hook called after each optimizer update step.
This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:
Latest loss values
Learning rates
Model parameter statistics (norms, etc.)
Training throughput and timing measurements
Unlike
periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.times (dict[str, float]) – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).
- Return type:
None
- after_epoch(update_counter, **kwargs)¶
Invoked after every epoch to check if callback should be invoked.
Applies
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.**kwargs – Additional keyword arguments.
- Return type:
None
- after_update(update_counter, **kwargs)¶
Invoked after every update to check if callback should be invoked.
Applies
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.**kwargs – Additional keyword arguments.
- Return type:
None
- at_eval(update_counter, **kwargs)¶
- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter)
- Return type:
None
- updates_till_next_invocation(update_counter)¶
Calculate how many updates remain until this callback is invoked.
- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.- Returns:
Number of updates remaining until the next callback invocation.
- Return type:
- updates_per_interval(update_counter)¶
Calculate how many updates are from one invocation of this callback to the next.
- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.- Returns:
Number of updates between callback invocations.
- Return type:
- get_interval_string_verbose()¶
Return interval configuration as a verbose string.
- Returns:
Interval as, e.g., “every_n_epochs=1” for epoch-based intervals.
- Return type:
- class noether.core.callbacks.PeriodicDataIteratorCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Bases:
PeriodicCallbackBase class for callbacks that perform periodic iterations over a dataset.
PeriodicDataIteratorCallback extends
PeriodicCallbackto support evaluations or computations that require iterating over an entire dataset. This is commonly used for validation/test set evaluation, computing metrics on held-out data, or any operation that needs to process batches from a dataset at regular training intervals.The class integrates with the training data pipeline by registering samplers that control when and how data is loaded. It handles the complete iteration workflow: data loading, batch processing, result collation across distributed ranks, and final processing.
- Workflow:
Iteration (
_iterate_over_dataset()): When the periodic interval is reached, iterate through the dataset in batches.Process Data (
process_data()): Process a single batch (e.g., run model inference) and return results.Collation (
_collate_result()): Aggregate results across all batches and distributed ranks.Processing (
process_results()): Compute final metrics or perform actions with the aggregated results.
- Key Features:
Distributed Support: Automatically handles distributed evaluation with proper gathering across ranks and padding removal.
Flexible Collation: Supports collating various result types (tensors, dicts of tensors, lists).
Data Pipeline Integration: Uses
SamplerIntervalConfigto integrate with the interleaved sampler for efficient data loading.Progress Tracking: Provides progress bars and timing information for data loading.
- Template Methods to Override:
Child classes must implement
process_data()andprocess_results():process_data(): Process a single batch (e.g., run model inference).process_results(): Process the aggregated results from all batches.
Examples
Basic validation accuracy callback that evaluates on a test set every epoch:
class AccuracyCallback(PeriodicDataIteratorCallback): def __init__(self, *args, dataset_key="test", **kwargs): super().__init__(*args, **kwargs) self.dataset_key = dataset_key def process_data(self, batch, *, trainer_model): # Run inference on batch x = batch["x"].to(trainer_model.device) y_true = batch["class"].clone() y_pred = trainer_model(x) return {"predictions": y_pred, "labels": y_true} def process_results(self, results, *, interval_type, update_counter, **_): # Compute accuracy from aggregated results y_pred = results["predictions"] y_true = results["labels"] accuracy = (y_pred.argmax(dim=1) == y_true).float().mean() self.writer.add_scalar( key="test/accuracy", value=accuracy.item(), logger=self.logger, format_str=".4f", ) # Configure in YAML: # callbacks: # - kind: path.to.AccuracyCallback # every_n_epochs: 1 # dataset_key: "test"
Advanced example with multiple return values and custom collation:
class DetailedEvaluationCallback(PeriodicDataIteratorCallback): def process_data(self, batch, *, trainer_model): x = batch["x"].to(trainer_model.device) y = batch["label"] # Return multiple outputs as tuple logits = trainer_model(x) embeddings = trainer_model.get_embeddings(x) return logits, embeddings, y def process_results(self, results, *, interval_type, update_counter, **_): # results is a tuple: (all_logits, all_embeddings, all_labels) logits, embeddings, labels = results # Compute multiple metrics accuracy = (logits.argmax(dim=1) == labels).float().mean() mean_embedding_norm = embeddings.norm(dim=-1).mean() self.writer.add_scalar("accuracy", accuracy.item()) self.writer.add_scalar("embedding_norm", mean_embedding_norm.item())
- dataset_key¶
Key to identify the dataset to iterate over from
self.data_container. Automatically set from the callback config.
- sampler_config¶
Configuration for the sampler that controls dataset iteration. Automatically set when dataset is initialized.
- total_data_time¶
Cumulative time spent waiting for data loading across all periodic callbacks.
Note
The
process_data()method is called within atorch.no_grad()context automatically.For distributed training, results are automatically gathered across all ranks with proper padding removal.
- Parameters:
callback_config (noether.core.schemas.callbacks.PeriodicDataIteratorCallbackConfig) – Configuration for the callback. See
PeriodicDataIteratorCallbackConfigfor available options.trainer (noether.training.trainers.BaseTrainer) – Trainer of the current run.
model (noether.core.models.ModelBase) – Model of the current run.
data_container (noether.data.container.DataContainer) –
DataContainerinstance that provides access to all datasets.tracker (noether.core.trackers.BaseTracker) –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer (noether.core.writers.LogWriter) –
LogWriterinstance to log metrics.checkpoint_writer (noether.core.writers.CheckpointWriter) –
CheckpointWriterinstance to save checkpoints.metric_property_provider (noether.core.providers.MetricPropertyProvider) –
MetricPropertyProviderinstance to access properties of metrics.name (str | None) – Name of the callback.
- dataset_key¶
- total_data_time¶
- sampler_config¶
- _sampler_config_from_key(key, properties=None, max_size=None)¶
Register the dataset that is used for this callback in the dataloading pipeline.
- Parameters:
key (str | None) – Key for identifying the dataset from
self.data_container. Uses the first dataset ifNone.properties (set[str] | None) – Optionally specifies a subset of properties to load from the dataset.
max_size (int | None) – If provided, only uses a subset of the full dataset. Default:
None(no subset).
- Returns:
SamplerIntervalConfigfor the registered dataset.- Return type:
- abstractmethod process_data(batch, *, trainer_model)¶
Template method that is called for each batch that is loaded from the dataset.
This method should process a single batch and return results that will be collated.
- Parameters:
batch – The loaded batch.
trainer_model (torch.nn.Module) – Model of the current training run.
- Returns:
Processed results for this batch. Can be a tensor, dict of tensors, list, or tuple.
- Return type:
Any
- process_results(results, *, interval_type, update_counter, **_)¶
Template method that is called with the collated results from dataset iteration.
For example, metrics can be computed from the results for the entire test/validation dataset and logged.
- Parameters:
results (Any) – The collated results that were produced by
_iterate_over_dataset()and the individualprocess_data()calls.interval_type – The type of interval that triggered this callback invocation.
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterwith the current training state.**_ – Additional unused keyword arguments.
- Return type:
None
- periodic_callback(*, interval_type, update_counter, data_iter, trainer_model, batch_size, **_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type (IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).data_iter (collections.abc.Iterator)
batch_size (int)
- Return type:
None
- after_training(**_)¶
Hook called once after the training loop finishes.
This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:
Performing a final evaluation on the test set
Saving final model weights or artifacts
Sending notifications (e.g., via Slack or email) about the completed run
Closing or finalizing experiment tracking sessions
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None