noether.core.callbacks.checkpoint

Submodules

Classes

BestCheckpointCallback

Callback to save the best model based on a metric.

BestCheckpointCallbackConfig

Internal base class for all registry-based configs.

CheckpointCallback

Callback to save the model and optimizer state periodically.

CheckpointCallbackConfig

Internal base class for all registry-based configs.

EmaCallback

Callback for exponential moving average (EMA) of model weights.

EmaCallbackConfig

Internal base class for all registry-based configs.

Package Contents

class noether.core.callbacks.checkpoint.BestCheckpointCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to save the best model based on a metric.

This callback monitors a specified metric and saves the model checkpoint whenever a new best value is achieved. It supports storing different model components when using a composite model and can save checkpoints at different tolerance thresholds.

Example config:

callbacks:
  - kind: noether.core.callbacks.BestCheckpointCallback
    name: BestCheckpointCallback
    every_n_epochs: 1
    metric_key: loss/val/total
    model_names:  # only applies when training a CompositeModel
      - encoder
    eval_callbacks:
      - kind: noether.training.callbacks.OfflineLossCallback
        every_n_epochs: 1  # ignored; the parent triggers on new-best
        dataset_key: test
Parameters:
  • callback_config (BestCheckpointCallbackConfig) – Configuration for the callback. See BestCheckpointCallbackConfig for available options including metric key, model names, and tolerance settings.

  • **kwargs – Additional arguments passed to the parent class.

metric_key
model_names
higher_is_better
best_metric_value
save_frozen_weights
tolerances_is_exceeded
tolerance_counter = 0
metric_at_exceeded_tolerance: dict[float, float]
eval_callbacks: list[noether.core.callbacks.periodic.PeriodicCallback] = []
get_children()

Non-iterator children only — iterator children are owned end-to-end here and must not be registered on the shared InterleavedSampler (we build their loaders on dispatch instead). The trainer always passes batch_size to every PeriodicCallback hook, so we can build child loaders without needing the trainer’s iterator-args bundle.

Return type:

list[noether.core.callbacks.base.CallbackBase]

state_dict()

Return the state of the callback for checkpointing.

Returns:

Dictionary containing the best metric value, tolerance tracking state, and counter information.

Return type:

dict[str, Any]

load_state_dict(state_dict)

Load the callback state from a checkpoint.

Note

This modifies the input state_dict in place.

Parameters:

state_dict (dict[str, Any]) – Dictionary containing the saved callback state.

Return type:

None

before_training(*, update_counter, **kwargs)

Validate callback configuration before training starts.

Parameters:
  • update_counter – The training update counter.

  • **kwargs – Additional keyword arguments forwarded to child eval callbacks.

Raises:

NotImplementedError – If resuming training with tolerances is attempted.

Return type:

None

periodic_callback(*, interval_type, **kwargs)

Execute the periodic callback to check and save best model.

This method is called at the configured frequency (e.g., every N epochs). It checks if the current metric value is better than the previous best, and if so, saves the model checkpoint. Also tracks tolerance-based checkpoints.

When a new best is detected, child eval callbacks (if configured) are dispatched against the live (newly-best) model. Iterator children iterate their own DataLoader (built on first use) — they do not consume from the trainer’s shared data_iter.

On interval_type="eval" (post-training eval, where the trainer loads the saved best checkpoint into the live model and calls every callback’s at_eval), children are dispatched unconditionally so they evaluate the loaded best model. No checkpoint save / tolerance bookkeeping runs in eval mode (the in-memory best_metric_value starts at ±inf in a fresh eval process).

Raises:

KeyError – If the log cache is empty or the metric key is not found.

Parameters:

interval_type (noether.core.callbacks.periodic.IntervalType)

Return type:

None

after_training(**kwargs)

Log the best metric values at different tolerance thresholds after training completes.

Parameters:

**kwargs – Additional keyword arguments forwarded to child eval callbacks.

Return type:

None

class noether.core.callbacks.checkpoint.BestCheckpointCallbackConfig(/, **data)

Bases: noether.core.callbacks.base.CallBackBaseConfig

Internal base class for all registry-based configs.

Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

name: Literal['BestCheckpointCallback'] = None
metric_key: str = None

“The key of the metric to be used for checking the best model.

save_frozen_weights: bool = None

Whether to also save the frozen weights of the model.

tolerances: list[int] | None = None

“If provided, this callback will produce multiple best models which differ in the amount of intervals they allow the metric to not improve. For example, tolerance=[5] with every_n_epochs=1 will store a checkpoint where at most 5 epochs have passed until the metric improved. Additionally, the best checkpoint over the whole training will always be stored (i.e., tolerance=infinite). When setting different tolerances, one can evaluate different early stopping configurations with one training run.

model_names: list[str] | None = None

Which model name to save (e.g., if only the encoder of an autoencoder should be stored, one could pass model_name=’encoder’ here). This only applies when training a CompositeModel. If None, all models are saved.

eval_callbacks: list[Annotated[Any, Discriminated(CallBackBaseConfig)]] | None = None

Optional nested callbacks to dispatch whenever a new best model is detected. Each child’s metric keys are automatically prefixed with best=<metric_key>/ (slashes in the metric key are replaced with dots) so they don’t collide with the live-model metrics. Children are invoked via their at_eval hook, which bypasses their own schedule — the trigger is the new-best event, not the child’s every_n_*. Tolerance- exceeded saves do not trigger children. before_training and after_training are forwarded unconditionally so children can initialize and finalize cleanly.

PeriodicDataIteratorCallback children get a dedicated DataLoader built from their sampler_config; they are not registered on the shared InterleavedSampler. This means a child’s every_n_* is irrelevant here (only the dataset_key / batch_size / pipeline matter) and the child’s schedule does not need to match this callback’s.

class noether.core.callbacks.checkpoint.CheckpointCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to save the model and optimizer state periodically.

Example config:

- kind: noether.core.callbacks.CheckpointCallback
  name: CheckpointCallback
  every_n_epochs: 1
  save_weights: true
  save_optim: true
Parameters:
  • callback_config (CheckpointCallbackConfig) – Configuration for the callback. See CheckpointCallbackConfig for available options.

  • **kwargs – Additional arguments passed to the parent class.

save_weights
save_optim
save_latest_weights
save_latest_optim
model_names
before_training(*, update_counter)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None

periodic_callback(*, interval_type, update_counter, **kwargs)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

after_training(**_)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

class noether.core.callbacks.checkpoint.CheckpointCallbackConfig(/, **data)

Bases: noether.core.callbacks.base.CallBackBaseConfig

Internal base class for all registry-based configs.

Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

name: Literal['CheckpointCallback'] = None
save_weights: bool = None

Whether to save the weights of the model every time this callback is invoked. The checkpoint name will contain the training iteration (e.g., epoch/update/sample) at which the checkpoint was saved.

save_optim: bool = None

Whether to save the optimizer state every time this callback is invoked. The checkpoint name will contain the training iteration (e.g., epoch/update/sample) at which the checkpoint was saved.

save_latest_weights: bool = None

Whether to save the latest weights of the model every time this callback is invoked. Note that the latest weights are always overwritten on the next invocation of this callback.

save_latest_optim: bool = None

Whether to save the latest optimizer state every time this callback is invoked. Note that the latest optimizer state is always overwritten on the next invocation of this callback

model_names: list[str] | None = None

The name of the model to save. If None, all models are saved.

class noether.core.callbacks.checkpoint.EmaCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback for exponential moving average (EMA) of model weights.

In addition to maintaining and checkpointing EMA weights, this callback can optionally own a list of child evaluation callbacks via eval_callbacks. At each eval-time hook (after_epoch, after_update, at_eval) the EMA weights are swapped into the live model, the children are dispatched, and the live weights are restored. Children are dispatched once per target_factor and their metric keys are automatically prefixed with ema=<factor>/ to avoid collisions with live-model metrics.

Example config:

- kind: noether.core.callbacks.EmaCallback
  every_n_epochs: 10
  save_weights: false
  save_last_weights: false
  save_latest_weights: true
  target_factors:
    - 0.9999
  name: EmaCallback
  eval_callbacks:
    - kind: noether.training.callbacks.OfflineLossCallback
      every_n_epochs: 1
      dataset_key: val
Parameters:
  • callback_config (EmaCallbackConfig) – Configuration for the callback. See EmaCallbackConfig for available options.

  • **kwargs – Additional arguments passed to the parent class.

model_paths
target_factors
save_weights
save_last_weights
save_latest_weights
parameters: dict[tuple[str | None, float], dict[str, torch.Tensor]]
buffers: dict[str | None, dict[str, torch.Tensor]]
eval_callbacks: dict[float, list[noether.core.callbacks.base.CallbackBase]]
get_children()

Flat list of child eval callbacks (across all target_factors).

Exposed to the trainer so nested PeriodicDataIteratorCallback instances have their samplers registered on the shared data loader. The EMA callback remains responsible for dispatching lifecycle hooks to its children.

Return type:

list[noether.core.callbacks.base.CallbackBase]

resume_from_checkpoint(resumption_paths, model)

Resume EMA state from a checkpoint.

Tries cp=latest first (written by periodic saves), then cp=last (written by after_training, e.g. on graceful signal interrupt). If neither exists, falls back to initializing EMA from the current model weights.

Parameters:
Return type:

None

before_training(**kwargs)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

apply_ema(cur_model, model_path, target_factor)

fused in-place implementation

track_after_update_step(**_)

Hook called after each optimizer update step.

This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:

  • Latest loss values

  • Learning rates

  • Model parameter statistics (norms, etc.)

  • Training throughput and timing measurements

Unlike periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • update_counterUpdateCounter instance to access current training progress.

  • times – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).

Return type:

None

after_epoch(update_counter, **kwargs)

Invoked after every epoch to check if callback should be invoked.

Applies torch.no_grad() context.

Parameters:
Return type:

None

after_update(update_counter, **kwargs)

Invoked after every update to check if callback should be invoked.

Applies torch.no_grad() context.

Parameters:
Return type:

None

at_eval(update_counter, **kwargs)
Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter)

Return type:

None

periodic_callback(*, interval_type, update_counter, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counterUpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

after_training(**kwargs)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

class noether.core.callbacks.checkpoint.EmaCallbackConfig(/, **data)

Bases: noether.core.callbacks.base.CallBackBaseConfig

Internal base class for all registry-based configs.

Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

name: Literal['EmaCallback'] = None
target_factors: list[float] = None

The factors for the EMA.

model_paths: list[str | None] | None = None

The paths to the models to apply the EMA to (i.e., composite_model.encoder/composite_model.decoder, path of the PyTorch nn.Modules in the checkpoint). If None, the EMA is applied to the whole model. When training with a CompositeModel, the paths on the submodules (i.e., ‘encoder’, ‘decoder’, etc.) should be provided via this field, otherwise the EMA will be applied to the CompositeModel as a whole which is not possible to restore later on.

save_weights: bool = None

Whether to save the EMA weights.

save_last_weights: bool = None

Save the weights of the model when training is over (i.e., at the end of training, save the EMA weights).

save_latest_weights: bool = None

Save the latest EMA weights. Note that the latest weights are always overwritten on the next invocation of this callback.

eval_callbacks: list[Annotated[Any, Discriminated(CallBackBaseConfig)]] | None = None

Optional nested periodic callbacks to run against EMA weights. Each child retains its own schedule (every_n_epochs etc.); the EMA callback swaps its stored EMA parameters into the live model around eval-time hooks (after_epoch, after_update, at_eval) and restores the live weights on exit. Children are dispatched once per target_factor and their metric keys are automatically prefixed with ema=<factor>/ to avoid collisions with live-model metrics. Note: before_training and after_training are forwarded without swapping, so EMA initialization and the final save see live weights.