noether.core.callbacks.checkpoint.ema¶
Classes¶
Internal base class for all registry-based configs. |
|
Callback for exponential moving average (EMA) of model weights. |
Module Contents¶
- class noether.core.callbacks.checkpoint.ema.EmaCallbackConfig(/, **data)¶
Bases:
noether.core.callbacks.base.CallBackBaseConfigInternal base class for all registry-based configs.
Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
data (Any)
- name: Literal['EmaCallback'] = None¶
- model_paths: list[str | None] | None = None¶
The paths to the models to apply the EMA to (i.e., composite_model.encoder/composite_model.decoder, path of the PyTorch nn.Modules in the checkpoint). If None, the EMA is applied to the whole model. When training with a CompositeModel, the paths on the submodules (i.e., ‘encoder’, ‘decoder’, etc.) should be provided via this field, otherwise the EMA will be applied to the CompositeModel as a whole which is not possible to restore later on.
- save_last_weights: bool = None¶
Save the weights of the model when training is over (i.e., at the end of training, save the EMA weights).
- save_latest_weights: bool = None¶
Save the latest EMA weights. Note that the latest weights are always overwritten on the next invocation of this callback.
- eval_callbacks: list[Annotated[Any, Discriminated(CallBackBaseConfig)]] | None = None¶
Optional nested periodic callbacks to run against EMA weights. Each child retains its own schedule (
every_n_epochsetc.); the EMA callback swaps its stored EMA parameters into the live model around eval-time hooks (after_epoch,after_update,at_eval) and restores the live weights on exit. Children are dispatched once pertarget_factorand their metric keys are automatically prefixed withema=<factor>/to avoid collisions with live-model metrics. Note:before_trainingandafter_trainingare forwarded without swapping, so EMA initialization and the final save see live weights.
- class noether.core.callbacks.checkpoint.ema.EmaCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback for exponential moving average (EMA) of model weights.
In addition to maintaining and checkpointing EMA weights, this callback can optionally own a list of child evaluation callbacks via
eval_callbacks. At each eval-time hook (after_epoch,after_update,at_eval) the EMA weights are swapped into the live model, the children are dispatched, and the live weights are restored. Children are dispatched once pertarget_factorand their metric keys are automatically prefixed withema=<factor>/to avoid collisions with live-model metrics.Example config:
- kind: noether.core.callbacks.EmaCallback every_n_epochs: 10 save_weights: false save_last_weights: false save_latest_weights: true target_factors: - 0.9999 name: EmaCallback eval_callbacks: - kind: noether.training.callbacks.OfflineLossCallback every_n_epochs: 1 dataset_key: val
- Parameters:
callback_config (EmaCallbackConfig) – Configuration for the callback. See
EmaCallbackConfigfor available options.**kwargs – Additional arguments passed to the parent class.
- model_paths¶
- target_factors¶
- save_weights¶
- save_last_weights¶
- save_latest_weights¶
- eval_callbacks: dict[float, list[noether.core.callbacks.base.CallbackBase]]¶
- get_children()¶
Flat list of child eval callbacks (across all
target_factors).Exposed to the trainer so nested
PeriodicDataIteratorCallbackinstances have their samplers registered on the shared data loader. The EMA callback remains responsible for dispatching lifecycle hooks to its children.- Return type:
- resume_from_checkpoint(resumption_paths, model)¶
Resume EMA state from a checkpoint.
Tries
cp=latestfirst (written by periodic saves), thencp=last(written byafter_training, e.g. on graceful signal interrupt). If neither exists, falls back to initializing EMA from the current model weights.- Parameters:
resumption_paths (noether.core.providers.path.PathProvider) –
PathProviderwith paths to checkpoint files.model – Model to resume EMA state for.
- Return type:
None
- before_training(**kwargs)¶
Hook called once before the training loop starts.
This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:
Initializing experiment tracking (e.g., logging hyperparameters)
Printing model summaries or architecture details
Initializing specific data structures or buffers needed during training
Performing sanity checks on the data or configuration
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None
- apply_ema(cur_model, model_path, target_factor)¶
fused in-place implementation
- track_after_update_step(**_)¶
Hook called after each optimizer update step.
This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:
Latest loss values
Learning rates
Model parameter statistics (norms, etc.)
Training throughput and timing measurements
Unlike
periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.times – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).
- Return type:
None
- after_epoch(update_counter, **kwargs)¶
Invoked after every epoch to check if callback should be invoked.
Applies
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.**kwargs – Additional keyword arguments.
- Return type:
None
- after_update(update_counter, **kwargs)¶
Invoked after every update to check if callback should be invoked.
Applies
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.**kwargs – Additional keyword arguments.
- Return type:
None
- at_eval(update_counter, **kwargs)¶
- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter)
- Return type:
None
- periodic_callback(*, interval_type, update_counter, **_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- after_training(**kwargs)¶
Hook called once after the training loop finishes.
This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:
Performing a final evaluation on the test set
Saving final model weights or artifacts
Sending notifications (e.g., via Slack or email) about the completed run
Closing or finalizing experiment tracking sessions
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None