noether.core.callbacks.checkpoint.best_checkpoint¶
Classes¶
Callback to save the best model based on a metric. |
Module Contents¶
- class noether.core.callbacks.checkpoint.best_checkpoint.BestCheckpointCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to save the best model based on a metric.
This callback monitors a specified metric and saves the model checkpoint whenever a new best value is achieved. It supports storing different model components when using a composite model and can save checkpoints at different tolerance thresholds.
Example config:
callbacks: - kind: noether.core.callbacks.BestCheckpointCallback name: BestCheckpointCallback every_n_epochs: 1 metric_key: loss/val/total model_names: # only applies when training a CompositeModel - encoder
- Parameters:
callback_config (noether.core.schemas.callbacks.BestCheckpointCallbackConfig) – Configuration for the callback. See
BestCheckpointCallbackConfigfor available options including metric key, model names, and tolerance settings.**kwargs – Additional arguments passed to the parent class.
- metric_key¶
- model_names¶
- higher_is_better¶
- best_metric_value¶
- save_frozen_weights¶
- tolerances_is_exceeded¶
- tolerance_counter = 0¶
- state_dict()¶
Return the state of the callback for checkpointing.
- load_state_dict(state_dict)¶
Load the callback state from a checkpoint.
Note
This modifies the input state_dict in place.
- before_training(*, update_counter)¶
Validate callback configuration before training starts.
- Parameters:
update_counter – The training update counter.
- Raises:
NotImplementedError – If resuming training with tolerances is attempted.
- Return type:
None
- periodic_callback(**_)¶
Execute the periodic callback to check and save best model.
This method is called at the configured frequency (e.g., every N epochs). It checks if the current metric value is better than the previous best, and if so, saves the model checkpoint. Also tracks tolerance-based checkpoints.
- Raises:
KeyError – If the log cache is empty or the metric key is not found.
- Return type:
None
- after_training(**kwargs)¶
Log the best metric values at different tolerance thresholds after training completes.
- Parameters:
**kwargs – Additional keyword arguments (unused).
- Return type:
None