noether.core.callbacks.periodic¶
Attributes¶
Type alias for periodic callback interval types. |
Classes¶
Base class for callbacks that are invoked periodically during training. |
|
Base class for callbacks that perform periodic iterations over a dataset. |
Module Contents¶
- noether.core.callbacks.periodic.IntervalType¶
Type alias for periodic callback interval types.
Defines the unit of training progress used to trigger periodic callbacks:
“epoch”: Callback is triggered based on completed epochs
“update”: Callback is triggered based on optimizer update steps
“sample”: Callback is triggered based on number of samples processed
“eval”: Callback is triggered independent of schedule for post-training evaluation
- class noether.core.callbacks.periodic.PeriodicCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Bases:
noether.core.callbacks.base.CallbackBaseBase class for callbacks that are invoked periodically during training.
PeriodicCallback extends
CallbackBaseto support periodic execution based on training progress. Callbacks can be configured to run at regular intervals defined by epochs, updates (optimizer steps), or samples (data points processed). This class implements the infrastructure for periodic invocation while child classes define the actual behavior via theperiodic_callback()method.- Interval Configuration:
Callbacks can be configured to run periodically using one or more of:
every_n_epochs: Execute callback every N epochsevery_n_updates: Execute callback every N optimizer updatesevery_n_samples: Execute callback every N samples processed
- Tracking vs. Periodic Execution:
The class provides two types of hooks:
Tracking hooks (
track_after_accumulation_step(),track_after_update_step()): Called on every accumulation/update step to track metrics continuously (e.g., for running averages). I.e., if you want to log an exponential moving average of the loss every epoch, the logging is done in the periodic callback; however, the tracking of the loss values for computing the moving average is done in the tracking hook.Periodic hook (
periodic_callback()): Called only when the configured interval is reached, typically for expensive operations like evaluation or checkpointing.
Examples
Creating a custom periodic callback that logs metrics every 10 epochs:
class CustomMetricCallback(PeriodicCallback): def periodic_callback( self, *, interval_type: IntervalType, update_counter: UpdateCounter, **kwargs, ) -> None: # This method is called every 10 epochs metric_value = self.compute_expensive_metric() self.writer.add_scalar( key="custom_metric", value=metric_value, logger=self.logger, ) # Configure in YAML: # callbacks: # - kind: path.to.CustomMetricCallback # every_n_epochs: 10
Tracking metrics at every update and logging periodically:
class RunningAverageCallback(PeriodicCallback): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.loss_accumulator = [] def track_after_update_step(self, *, update_counter: UpdateCounter, times: dict[str, float]) -> None: # Track at every update self.loss_accumulator.append(self.trainer.last_loss) def periodic_callback( self, *, interval_type: IntervalType, update_counter: UpdateCounter, **kwargs, ) -> None: # Log periodically avg_loss = sum(self.loss_accumulator) / len(self.loss_accumulator) self.writer.add_scalar("avg_loss", avg_loss, logger=self.logger) self.loss_accumulator.clear()
- every_n_epochs¶
If set, callback is invoked every N epochs.
- every_n_updates¶
If set, callback is invoked every N optimizer updates.
- every_n_samples¶
If set, callback is invoked every N samples processed.
- batch_size¶
Batch size used during training.
- Parameters:
callback_config (noether.core.schemas.callbacks.CallBackBaseConfig) – Configuration for the callback. See
CallBackBaseConfigfor available options.trainer (noether.training.trainers.BaseTrainer) – Trainer of the current run.
model (noether.core.models.ModelBase) – Model of the current run.
data_container (noether.data.container.DataContainer) –
DataContainerinstance that provides access to all datasets.tracker (noether.core.trackers.BaseTracker) –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer (noether.core.writers.LogWriter) –
LogWriterinstance to log metrics.checkpoint_writer (noether.core.writers.CheckpointWriter) –
CheckpointWriterinstance to save checkpoints.metric_property_provider (noether.core.providers.MetricPropertyProvider) –
MetricPropertyProviderinstance to access properties of metrics.name (str | None) – Name of the callback.
- every_n_epochs¶
- every_n_updates¶
- every_n_samples¶
- batch_size¶
- periodic_callback(*, interval_type, update_counter, **kwargs)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type (IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None
- track_after_accumulation_step(*, update_counter, batch, losses, update_outputs, accumulation_steps, accumulation_step)¶
Hook called after each individual gradient accumulation step.
This method is invoked for every batch processed during training, regardless of whether an optimizer update is performed in that step (i.e., when
accumulation_steps > 1). It is primarily used for tracking metrics that should be averaged or aggregated across accumulation steps.Common use cases include:
Logging per-batch losses for high-frequency monitoring
Accumulating statistics across batches before an optimizer update
Implementing custom logging that needs access to individual batch data
Note
This method is generally intended to be called within a
torch.no_grad()context by the trainer to ensure no gradients are tracked during logging operations.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.batch (Any) – The current data batch processed in this accumulation step.
losses (dict[str, torch.Tensor]) – Dictionary of computed losses for the current batch.
update_outputs (dict[str, torch.Tensor] | None) – Optional dictionary of model outputs for the current batch.
accumulation_steps (int) – Total number of accumulation steps before an optimizer update.
accumulation_step (int) – The current accumulation step index (0-indexed).
- Return type:
None
- track_after_update_step(*, update_counter, times)¶
Hook called after each optimizer update step.
This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:
Latest loss values
Learning rates
Model parameter statistics (norms, etc.)
Training throughput and timing measurements
Unlike
periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.times (dict[str, float]) – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).
- Return type:
None
- after_epoch(update_counter, **kwargs)¶
Invoked after every epoch to check if callback should be invoked.
Applies
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.**kwargs – Additional keyword arguments.
- Return type:
None
- after_update(update_counter, **kwargs)¶
Invoked after every update to check if callback should be invoked.
Applies
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.**kwargs – Additional keyword arguments.
- Return type:
None
- at_eval(update_counter, **kwargs)¶
- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter)
- Return type:
None
- updates_till_next_invocation(update_counter)¶
Calculate how many updates remain until this callback is invoked.
- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.- Returns:
Number of updates remaining until the next callback invocation.
- Return type:
- updates_per_interval(update_counter)¶
Calculate how many updates are from one invocation of this callback to the next.
- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.- Returns:
Number of updates between callback invocations.
- Return type:
- get_interval_string_verbose()¶
Return interval configuration as a verbose string.
- Returns:
Interval as, e.g., “every_n_epochs=1” for epoch-based intervals.
- Return type:
- class noether.core.callbacks.periodic.PeriodicDataIteratorCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Bases:
PeriodicCallbackBase class for callbacks that perform periodic iterations over a dataset.
PeriodicDataIteratorCallback extends
PeriodicCallbackto support evaluations or computations that require iterating over an entire dataset. This is commonly used for validation/test set evaluation, computing metrics on held-out data, or any operation that needs to process batches from a dataset at regular training intervals.The class integrates with the training data pipeline by registering samplers that control when and how data is loaded. It handles the complete iteration workflow: data loading, batch processing, result collation across distributed ranks, and final processing.
- Workflow:
Iteration (
_iterate_over_dataset()): When the periodic interval is reached, iterate through the dataset in batches.Process Data (
process_data()): Process a single batch (e.g., run model inference) and return results.Collation (
_collate_result()): Aggregate results across all batches and distributed ranks.Processing (
process_results()): Compute final metrics or perform actions with the aggregated results.
- Key Features:
Distributed Support: Automatically handles distributed evaluation with proper gathering across ranks and padding removal.
Flexible Collation: Supports collating various result types (tensors, dicts of tensors, lists).
Data Pipeline Integration: Uses
SamplerIntervalConfigto integrate with the interleaved sampler for efficient data loading.Progress Tracking: Provides progress bars and timing information for data loading.
- Template Methods to Override:
Child classes must implement
process_data()andprocess_results():process_data(): Process a single batch (e.g., run model inference).process_results(): Process the aggregated results from all batches.
Examples
Basic validation accuracy callback that evaluates on a test set every epoch:
class AccuracyCallback(PeriodicDataIteratorCallback): def __init__(self, *args, dataset_key="test", **kwargs): super().__init__(*args, **kwargs) self.dataset_key = dataset_key def process_data(self, batch, *, trainer_model): # Run inference on batch x = batch["x"].to(trainer_model.device) y_true = batch["class"].clone() y_pred = trainer_model(x) return {"predictions": y_pred, "labels": y_true} def process_results(self, results, *, interval_type, update_counter, **_): # Compute accuracy from aggregated results y_pred = results["predictions"] y_true = results["labels"] accuracy = (y_pred.argmax(dim=1) == y_true).float().mean() self.writer.add_scalar( key="test/accuracy", value=accuracy.item(), logger=self.logger, format_str=".4f", ) # Configure in YAML: # callbacks: # - kind: path.to.AccuracyCallback # every_n_epochs: 1 # dataset_key: "test"
Advanced example with multiple return values and custom collation:
class DetailedEvaluationCallback(PeriodicDataIteratorCallback): def process_data(self, batch, *, trainer_model): x = batch["x"].to(trainer_model.device) y = batch["label"] # Return multiple outputs as tuple logits = trainer_model(x) embeddings = trainer_model.get_embeddings(x) return logits, embeddings, y def process_results(self, results, *, interval_type, update_counter, **_): # results is a tuple: (all_logits, all_embeddings, all_labels) logits, embeddings, labels = results # Compute multiple metrics accuracy = (logits.argmax(dim=1) == labels).float().mean() mean_embedding_norm = embeddings.norm(dim=-1).mean() self.writer.add_scalar("accuracy", accuracy.item()) self.writer.add_scalar("embedding_norm", mean_embedding_norm.item())
- dataset_key¶
Key to identify the dataset to iterate over from
self.data_container. Automatically set from the callback config.
- sampler_config¶
Configuration for the sampler that controls dataset iteration. Automatically set when dataset is initialized.
- total_data_time¶
Cumulative time spent waiting for data loading across all periodic callbacks.
Note
The
process_data()method is called within atorch.no_grad()context automatically.For distributed training, results are automatically gathered across all ranks with proper padding removal.
- Parameters:
callback_config (noether.core.schemas.callbacks.PeriodicDataIteratorCallbackConfig) – Configuration for the callback. See
PeriodicDataIteratorCallbackConfigfor available options.trainer (noether.training.trainers.BaseTrainer) – Trainer of the current run.
model (noether.core.models.ModelBase) – Model of the current run.
data_container (noether.data.container.DataContainer) –
DataContainerinstance that provides access to all datasets.tracker (noether.core.trackers.BaseTracker) –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer (noether.core.writers.LogWriter) –
LogWriterinstance to log metrics.checkpoint_writer (noether.core.writers.CheckpointWriter) –
CheckpointWriterinstance to save checkpoints.metric_property_provider (noether.core.providers.MetricPropertyProvider) –
MetricPropertyProviderinstance to access properties of metrics.name (str | None) – Name of the callback.
- dataset_key¶
- total_data_time = 0.0¶
- sampler_config¶
- _sampler_config_from_key(key, properties=None, max_size=None)¶
Register the dataset that is used for this callback in the dataloading pipeline.
- Parameters:
key (str | None) – Key for identifying the dataset from
self.data_container. Uses the first dataset ifNone.properties (set[str] | None) – Optionally specifies a subset of properties to load from the dataset.
max_size (int | None) – If provided, only uses a subset of the full dataset. Default:
None(no subset).
- Returns:
SamplerIntervalConfigfor the registered dataset.- Return type:
- abstractmethod process_data(batch, *, trainer_model)¶
Template method that is called for each batch that is loaded from the dataset.
This method should process a single batch and return results that will be collated.
- Parameters:
batch – The loaded batch.
trainer_model (torch.nn.Module) – Model of the current training run.
- Returns:
Processed results for this batch. Can be a tensor, dict of tensors, list, or tuple.
- Return type:
Any
- process_results(results, *, interval_type, update_counter, **_)¶
Template method that is called with the collated results from dataset iteration.
For example, metrics can be computed from the results for the entire test/validation dataset and logged.
- Parameters:
results (Any) – The collated results that were produced by
_iterate_over_dataset()and the individualprocess_data()calls.interval_type – The type of interval that triggered this callback invocation.
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterwith the current training state.**_ – Additional unused keyword arguments.
- Return type:
None
- periodic_callback(*, interval_type, update_counter, data_iter, trainer_model, batch_size, **_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type (IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).data_iter (collections.abc.Iterator)
batch_size (int)
- Return type:
None
- after_training(**_)¶
Hook called once after the training loop finishes.
This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:
Performing a final evaluation on the test set
Saving final model weights or artifacts
Sending notifications (e.g., via Slack or email) about the completed run
Closing or finalizing experiment tracking sessions
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.- Return type:
None