noether.core.callbacks.checkpoint

Submodules

Classes

BestCheckpointCallback

Callback to save the best model based on a metric.

CheckpointCallback

Callback to save the model and optimizer state periodically.

EmaCallback

Callback for exponential moving average (EMA) of model weights.

Package Contents

class noether.core.callbacks.checkpoint.BestCheckpointCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to save the best model based on a metric.

This callback monitors a specified metric and saves the model checkpoint whenever a new best value is achieved. It supports storing different model components when using a composite model and can save checkpoints at different tolerance thresholds.

Example config:

callbacks:
  - kind: noether.core.callbacks.BestCheckpointCallback
    name: BestCheckpointCallback
    every_n_epochs: 1
    metric_key: loss/val/total
    model_names:  # only applies when training a CompositeModel
      - encoder
Parameters:
metric_key
model_names
higher_is_better
best_metric_value
save_frozen_weights
tolerances_is_exceeded
tolerance_counter = 0
metric_at_exceeded_tolerance: dict[float, float]
state_dict()

Return the state of the callback for checkpointing.

Returns:

Dictionary containing the best metric value, tolerance tracking state, and counter information.

Return type:

dict[str, Any]

load_state_dict(state_dict)

Load the callback state from a checkpoint.

Note

This modifies the input state_dict in place.

Parameters:

state_dict (dict[str, Any]) – Dictionary containing the saved callback state.

Return type:

None

before_training(*, update_counter)

Validate callback configuration before training starts.

Parameters:

update_counter – The training update counter.

Raises:

NotImplementedError – If resuming training with tolerances is attempted.

Return type:

None

periodic_callback(**_)

Execute the periodic callback to check and save best model.

This method is called at the configured frequency (e.g., every N epochs). It checks if the current metric value is better than the previous best, and if so, saves the model checkpoint. Also tracks tolerance-based checkpoints.

Raises:

KeyError – If the log cache is empty or the metric key is not found.

Return type:

None

after_training(**kwargs)

Log the best metric values at different tolerance thresholds after training completes.

Parameters:

**kwargs – Additional keyword arguments (unused).

Return type:

None

class noether.core.callbacks.checkpoint.CheckpointCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback to save the model and optimizer state periodically.

Example config:

- kind: noether.core.callbacks.CheckpointCallback
  name: CheckpointCallback
  every_n_epochs: 1
  save_weights: true
  save_optim: true
Parameters:
save_weights
save_optim
save_latest_weights
save_latest_optim
model_names
before_training(*, update_counter)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None

periodic_callback(*, interval_type, update_counter, **kwargs)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counter (noether.core.utils.training.UpdateCounter) – UpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

after_training(**_)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

class noether.core.callbacks.checkpoint.EmaCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Callback for exponential moving average (EMA) of model weights.

Example config:

- kind: noether.core.callbacks.EmaCallback
  every_n_epochs: 10
  save_weights: false
  save_last_weights: false
  save_latest_weights: true
  target_factors:
    - 0.9999
  name: EmaCallback
Parameters:
model_paths
target_factors
save_weights
save_last_weights
save_latest_weights
parameters: dict[tuple[str | None, float], dict[str, torch.Tensor]]
buffers: dict[str | None, dict[str, torch.Tensor]]
resume_from_checkpoint(resumption_paths, model)

Resume EMA state from a checkpoint.

Parameters:
Return type:

None

before_training(**_)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None

apply_ema(cur_model, model_path, target_factor)

fused in-place implementation

track_after_update_step(**_)

Hook called after each optimizer update step.

This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:

  • Latest loss values

  • Learning rates

  • Model parameter statistics (norms, etc.)

  • Training throughput and timing measurements

Unlike periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • update_counterUpdateCounter instance to access current training progress.

  • times – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).

Return type:

None

periodic_callback(*, interval_type, update_counter, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counterUpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

Return type:

None

after_training(**_)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None