noether.core.callbacks.checkpoint¶
Submodules¶
Classes¶
Callback to save the best model based on a metric. |
|
Callback to save the model and optimizer state periodically. |
|
Callback for exponential moving average (EMA) of model weights. |
Package Contents¶
- class noether.core.callbacks.checkpoint.BestCheckpointCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to save the best model based on a metric.
This callback monitors a specified metric and saves the model checkpoint whenever a new best value is achieved. It supports storing different model components when using a composite model and can save checkpoints at different tolerance thresholds.
Example config:
callbacks: - kind: noether.core.callbacks.BestCheckpointCallback name: BestCheckpointCallback every_n_epochs: 1 metric_key: loss/val/total model_names: # only applies when training a CompositeModel - encoder
- Parameters:
callback_config (noether.core.schemas.callbacks.BestCheckpointCallbackConfig) – Configuration for the callback. See
BestCheckpointCallbackConfigfor available options including metric key, model names, and tolerance settings.**kwargs – Additional arguments passed to the parent class.
- metric_key¶
- model_names¶
- higher_is_better¶
- best_metric_value¶
- save_frozen_weights¶
- tolerances_is_exceeded¶
- tolerance_counter = 0¶
- state_dict()¶
Return the state of the callback for checkpointing.
- load_state_dict(state_dict)¶
Load the callback state from a checkpoint.
Note
This modifies the input state_dict in place.
- before_training(*, update_counter)¶
Validate callback configuration before training starts.
- Parameters:
update_counter – The training update counter.
- Raises:
NotImplementedError – If resuming training with tolerances is attempted.
- Return type:
None
- periodic_callback(**_)¶
Execute the periodic callback to check and save best model.
This method is called at the configured frequency (e.g., every N epochs). It checks if the current metric value is better than the previous best, and if so, saves the model checkpoint. Also tracks tolerance-based checkpoints.
- Raises:
KeyError – If the log cache is empty or the metric key is not found.
- Return type:
None
- after_training(**kwargs)¶
Log the best metric values at different tolerance thresholds after training completes.
- Parameters:
**kwargs – Additional keyword arguments (unused).
- Return type:
None
- class noether.core.callbacks.checkpoint.CheckpointCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to save the model and optimizer state periodically.
Example config:
- kind: noether.core.callbacks.CheckpointCallback name: CheckpointCallback every_n_epochs: 1 save_weights: true save_optim: true
- Parameters:
callback_config (noether.core.schemas.callbacks.CheckpointCallbackConfig) – Configuration for the callback. See
CheckpointCallbackConfigfor available options.**kwargs – Additional arguments passed to the parent class.
- save_weights¶
- save_optim¶
- save_latest_weights¶
- save_latest_optim¶
- model_names¶
- before_training(*, update_counter)¶
Hook called once before the training loop starts.
This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:
Initializing experiment tracking (e.g., logging hyperparameters)
Printing model summaries or architecture details
Initializing specific data structures or buffers needed during training
Performing sanity checks on the data or configuration
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.UpdateCounter) –
UpdateCounterinstance to access current training progress.- Return type:
None
- periodic_callback(*, interval_type, update_counter, **kwargs)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter (noether.core.utils.training.UpdateCounter) –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- after_training(**_)¶
Hook called once after the training loop finishes.
This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:
Performing a final evaluation on the test set
Saving final model weights or artifacts
Sending notifications (e.g., via Slack or email) about the completed run
Closing or finalizing experiment tracking sessions
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None
- class noether.core.callbacks.checkpoint.EmaCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback for exponential moving average (EMA) of model weights.
Example config:
- kind: noether.core.callbacks.EmaCallback every_n_epochs: 10 save_weights: false save_last_weights: false save_latest_weights: true target_factors: - 0.9999 name: EmaCallback
- Parameters:
callback_config (noether.core.schemas.callbacks.EmaCallbackConfig) – Configuration for the callback. See
EmaCallbackConfigfor available options.**kwargs – Additional arguments passed to the parent class.
- model_paths¶
- target_factors¶
- save_weights¶
- save_last_weights¶
- save_latest_weights¶
- resume_from_checkpoint(resumption_paths, model)¶
Resume EMA state from a checkpoint.
- Parameters:
resumption_paths (noether.core.providers.path.PathProvider) –
PathProviderwith paths to checkpoint files.model – Model to resume EMA state for.
- Return type:
None
- before_training(**_)¶
Hook called once before the training loop starts.
This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:
Initializing experiment tracking (e.g., logging hyperparameters)
Printing model summaries or architecture details
Initializing specific data structures or buffers needed during training
Performing sanity checks on the data or configuration
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None
- apply_ema(cur_model, model_path, target_factor)¶
fused in-place implementation
- track_after_update_step(**_)¶
Hook called after each optimizer update step.
This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:
Latest loss values
Learning rates
Model parameter statistics (norms, etc.)
Training throughput and timing measurements
Unlike
periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.times – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).
- Return type:
None
- periodic_callback(*, interval_type, update_counter, **_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- after_training(**_)¶
Hook called once after the training loop finishes.
This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:
Performing a final evaluation on the test set
Saving final model weights or artifacts
Sending notifications (e.g., via Slack or email) about the completed run
Closing or finalizing experiment tracking sessions
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None