noether.core.callbacks.checkpoint.best_checkpoint

Classes

BestCheckpointCallbackConfig

Internal base class for all registry-based configs.

BestCheckpointCallback

Callback to save the best model based on a metric.

Module Contents

class noether.core.callbacks.checkpoint.best_checkpoint.BestCheckpointCallbackConfig(/, **data)

Bases: noether.core.callbacks.base.CallBackBaseConfig

Internal base class for all registry-based configs.

Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

name: Literal['BestCheckpointCallback'] = None
metric_key: str = None

“The key of the metric to be used for checking the best model.

save_frozen_weights: bool = None

Whether to also save the frozen weights of the model.

tolerances: list[int] | None = None

“If provided, this callback will produce multiple best models which differ in the amount of intervals they allow the metric to not improve. For example, tolerance=[5] with every_n_epochs=1 will store a checkpoint where at most 5 epochs have passed until the metric improved. Additionally, the best checkpoint over the whole training will always be stored (i.e., tolerance=infinite). When setting different tolerances, one can evaluate different early stopping configurations with one training run.

model_names: list[str] | None = None

Which model name to save (e.g., if only the encoder of an autoencoder should be stored, one could pass model_name=’encoder’ here). This only applies when training a CompositeModel. If None, all models are saved.

eval_callbacks: list[Annotated[Any, Discriminated(CallBackBaseConfig)]] | None = None

Optional nested callbacks to dispatch whenever a new best model is detected. Each child’s metric keys are automatically prefixed with best=<metric_key>/ (slashes in the metric key are replaced with dots) so they don’t collide with the live-model metrics. Children are invoked via their at_eval hook, which bypasses their own schedule — the trigger is the new-best event, not the child’s every_n_*. Tolerance- exceeded saves do not trigger children. before_training and after_training are forwarded unconditionally so children can initialize and finalize cleanly.

PeriodicDataIteratorCallback children get a dedicated DataLoader built from their sampler_config; they are not registered on the shared InterleavedSampler. This means a child’s every_n_* is irrelevant here (only the dataset_key / batch_size / pipeline matter) and the child’s schedule does not need to match this callback’s.

class noether.core.callbacks.checkpoint.best_checkpoint.BestCheckpointCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to save the best model based on a metric.

This callback monitors a specified metric and saves the model checkpoint whenever a new best value is achieved. It supports storing different model components when using a composite model and can save checkpoints at different tolerance thresholds.

Example config:

callbacks:
  - kind: noether.core.callbacks.BestCheckpointCallback
    name: BestCheckpointCallback
    every_n_epochs: 1
    metric_key: loss/val/total
    model_names:  # only applies when training a CompositeModel
      - encoder
    eval_callbacks:
      - kind: noether.training.callbacks.OfflineLossCallback
        every_n_epochs: 1  # ignored; the parent triggers on new-best
        dataset_key: test
Parameters:
  • callback_config (BestCheckpointCallbackConfig) – Configuration for the callback. See BestCheckpointCallbackConfig for available options including metric key, model names, and tolerance settings.

  • **kwargs – Additional arguments passed to the parent class.

metric_key
model_names
higher_is_better
best_metric_value
save_frozen_weights
tolerances_is_exceeded
tolerance_counter = 0
metric_at_exceeded_tolerance: dict[float, float]
eval_callbacks: list[noether.core.callbacks.periodic.PeriodicCallback] = []
get_children()

Non-iterator children only — iterator children are owned end-to-end here and must not be registered on the shared InterleavedSampler (we build their loaders on dispatch instead). The trainer always passes batch_size to every PeriodicCallback hook, so we can build child loaders without needing the trainer’s iterator-args bundle.

Return type:

list[noether.core.callbacks.base.CallbackBase]

state_dict()

Return the state of the callback for checkpointing.

Returns:

Dictionary containing the best metric value, tolerance tracking state, and counter information.

Return type:

dict[str, Any]

load_state_dict(state_dict)

Load the callback state from a checkpoint.

Note

This modifies the input state_dict in place.

Parameters:

state_dict (dict[str, Any]) – Dictionary containing the saved callback state.

Return type:

None

before_training(*, update_counter, **kwargs)

Validate callback configuration before training starts.

Parameters:
  • update_counter – The training update counter.

  • **kwargs – Additional keyword arguments forwarded to child eval callbacks.

Raises:

NotImplementedError – If resuming training with tolerances is attempted.

Return type:

None

periodic_callback(*, interval_type, **kwargs)

Execute the periodic callback to check and save best model.

This method is called at the configured frequency (e.g., every N epochs). It checks if the current metric value is better than the previous best, and if so, saves the model checkpoint. Also tracks tolerance-based checkpoints.

When a new best is detected, child eval callbacks (if configured) are dispatched against the live (newly-best) model. Iterator children iterate their own DataLoader (built on first use) — they do not consume from the trainer’s shared data_iter.

On interval_type="eval" (post-training eval, where the trainer loads the saved best checkpoint into the live model and calls every callback’s at_eval), children are dispatched unconditionally so they evaluate the loaded best model. No checkpoint save / tolerance bookkeeping runs in eval mode (the in-memory best_metric_value starts at ±inf in a fresh eval process).

Raises:

KeyError – If the log cache is empty or the metric key is not found.

Parameters:

interval_type (noether.core.callbacks.periodic.IntervalType)

Return type:

None

after_training(**kwargs)

Log the best metric values at different tolerance thresholds after training completes.

Parameters:

**kwargs – Additional keyword arguments forwarded to child eval callbacks.

Return type:

None