noether.core.callbacks.checkpoint.best_checkpoint

Classes

BestCheckpointCallback

Callback to save the best model based on a metric.

Module Contents

class noether.core.callbacks.checkpoint.best_checkpoint.BestCheckpointCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to save the best model based on a metric.

This callback monitors a specified metric and saves the model checkpoint whenever a new best value is achieved. It supports storing different model components when using a composite model and can save checkpoints at different tolerance thresholds.

Example config:

callbacks:
  - kind: noether.core.callbacks.BestCheckpointCallback
    name: BestCheckpointCallback
    every_n_epochs: 1
    metric_key: loss/val/total
    model_names:  # only applies when training a CompositeModel
      - encoder
Parameters:
metric_key
model_names
higher_is_better
best_metric_value
save_frozen_weights
tolerances_is_exceeded
tolerance_counter = 0
metric_at_exceeded_tolerance: dict[float, float]
state_dict()

Return the state of the callback for checkpointing.

Returns:

Dictionary containing the best metric value, tolerance tracking state, and counter information.

Return type:

dict[str, Any]

load_state_dict(state_dict)

Load the callback state from a checkpoint.

Note

This modifies the input state_dict in place.

Parameters:

state_dict (dict[str, Any]) – Dictionary containing the saved callback state.

Return type:

None

before_training(*, update_counter)

Validate callback configuration before training starts.

Parameters:

update_counter – The training update counter.

Raises:

NotImplementedError – If resuming training with tolerances is attempted.

Return type:

None

periodic_callback(**_)

Execute the periodic callback to check and save best model.

This method is called at the configured frequency (e.g., every N epochs). It checks if the current metric value is better than the previous best, and if so, saves the model checkpoint. Also tracks tolerance-based checkpoints.

Raises:

KeyError – If the log cache is empty or the metric key is not found.

Return type:

None

after_training(**kwargs)

Log the best metric values at different tolerance thresholds after training completes.

Parameters:

**kwargs – Additional keyword arguments (unused).

Return type:

None