noether.core.callbacks.periodic

Attributes

IntervalType

Type alias for periodic callback interval types.

Classes

PeriodicCallback

Base class for callbacks that are invoked periodically during training.

PeriodicDataIteratorCallback

Base class for callbacks that perform periodic iterations over a dataset.

Module Contents

noether.core.callbacks.periodic.IntervalType

Type alias for periodic callback interval types.

Defines the unit of training progress used to trigger periodic callbacks:

  • “epoch”: Callback is triggered based on completed epochs

  • “update”: Callback is triggered based on optimizer update steps

  • “sample”: Callback is triggered based on number of samples processed

  • “eval”: Callback is triggered independent of schedule for post-training evaluation

class noether.core.callbacks.periodic.PeriodicCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: noether.core.callbacks.base.CallbackBase

Base class for callbacks that are invoked periodically during training.

PeriodicCallback extends CallbackBase to support periodic execution based on training progress. Callbacks can be configured to run at regular intervals defined by epochs, updates (optimizer steps), or samples (data points processed). This class implements the infrastructure for periodic invocation while child classes define the actual behavior via the periodic_callback() method.

Interval Configuration:

Callbacks can be configured to run periodically using one or more of:

  • every_n_epochs: Execute callback every N epochs

  • every_n_updates: Execute callback every N optimizer updates

  • every_n_samples: Execute callback every N samples processed

Tracking vs. Periodic Execution:

The class provides two types of hooks:

  • Tracking hooks (track_after_accumulation_step(), track_after_update_step()): Called on every accumulation/update step to track metrics continuously (e.g., for running averages). I.e., if you want to log an exponential moving average of the loss every epoch, the logging is done in the periodic callback; however, the tracking of the loss values for computing the moving average is done in the tracking hook.

  • Periodic hook (periodic_callback()): Called only when the configured interval is reached, typically for expensive operations like evaluation or checkpointing.

Examples

Creating a custom periodic callback that logs metrics every 10 epochs:

class CustomMetricCallback(PeriodicCallback):
    def periodic_callback(
        self,
        *,
        interval_type: IntervalType,
        update_counter: UpdateCounter,
        **kwargs,
    ) -> None:
        # This method is called every 10 epochs
        metric_value = self.compute_expensive_metric()
        self.writer.add_scalar(
            key="custom_metric",
            value=metric_value,
            logger=self.logger,
        )


# Configure in YAML:
# callbacks:
#   - kind: path.to.CustomMetricCallback
#     every_n_epochs: 10

Tracking metrics at every update and logging periodically:

class RunningAverageCallback(PeriodicCallback):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_accumulator = []

    def track_after_update_step(self, *, update_counter: UpdateCounter, times: dict[str, float]) -> None:
        # Track at every update
        self.loss_accumulator.append(self.trainer.last_loss)

    def periodic_callback(
        self,
        *,
        interval_type: IntervalType,
        update_counter: UpdateCounter,
        **kwargs,
    ) -> None:
        # Log periodically
        avg_loss = sum(self.loss_accumulator) / len(self.loss_accumulator)
        self.writer.add_scalar("avg_loss", avg_loss, logger=self.logger)
        self.loss_accumulator.clear()
every_n_epochs

If set, callback is invoked every N epochs.

every_n_updates

If set, callback is invoked every N optimizer updates.

every_n_samples

If set, callback is invoked every N samples processed.

batch_size

Batch size used during training.

Parameters:
every_n_epochs
every_n_updates
every_n_samples
batch_size
periodic_callback(*, interval_type, update_counter, **kwargs)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
Return type:

None

track_after_accumulation_step(*, update_counter, batch, losses, update_outputs, accumulation_steps, accumulation_step)

Hook called after each individual gradient accumulation step.

This method is invoked for every batch processed during training, regardless of whether an optimizer update is performed in that step (i.e., when accumulation_steps > 1). It is primarily used for tracking metrics that should be averaged or aggregated across accumulation steps.

Common use cases include:

  • Logging per-batch losses for high-frequency monitoring

  • Accumulating statistics across batches before an optimizer update

  • Implementing custom logging that needs access to individual batch data

Note

This method is generally intended to be called within a torch.no_grad() context by the trainer to ensure no gradients are tracked during logging operations.

Parameters:
  • update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

  • batch (Any) – The current data batch processed in this accumulation step.

  • losses (dict[str, torch.Tensor]) – Dictionary of computed losses for the current batch.

  • update_outputs (dict[str, torch.Tensor] | None) – Optional dictionary of model outputs for the current batch.

  • accumulation_steps (int) – Total number of accumulation steps before an optimizer update.

  • accumulation_step (int) – The current accumulation step index (0-indexed).

Return type:

None

track_after_update_step(*, update_counter, times)

Hook called after each optimizer update step.

This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:

  • Latest loss values

  • Learning rates

  • Model parameter statistics (norms, etc.)

  • Training throughput and timing measurements

Unlike periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.

Note

This method is executed within a torch.no_grad() context.

Parameters:
Return type:

None

after_epoch(update_counter, **kwargs)

Invoked after every epoch to check if callback should be invoked.

Applies torch.no_grad() context.

Parameters:
Return type:

None

after_update(update_counter, **kwargs)

Invoked after every update to check if callback should be invoked.

Applies torch.no_grad() context.

Parameters:
Return type:

None

at_eval(update_counter, **kwargs)
Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter)

Return type:

None

updates_till_next_invocation(update_counter)

Calculate how many updates remain until this callback is invoked.

Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Returns:

Number of updates remaining until the next callback invocation.

Return type:

int

updates_per_interval(update_counter)

Calculate how many updates are from one invocation of this callback to the next.

Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Returns:

Number of updates between callback invocations.

Return type:

int

get_interval_string_verbose()

Return interval configuration as a verbose string.

Returns:

Interval as, e.g., “every_n_epochs=1” for epoch-based intervals.

Return type:

str

to_short_interval_string()

Return interval configuration as a short string.

Returns:

Interval as, e.g., “E1” if every_n_epochs=1 for epoch-based intervals.

Return type:

str

class noether.core.callbacks.periodic.PeriodicDataIteratorCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: PeriodicCallback

Base class for callbacks that perform periodic iterations over a dataset.

PeriodicDataIteratorCallback extends PeriodicCallback to support evaluations or computations that require iterating over an entire dataset. This is commonly used for validation/test set evaluation, computing metrics on held-out data, or any operation that needs to process batches from a dataset at regular training intervals.

The class integrates with the training data pipeline by registering samplers that control when and how data is loaded. It handles the complete iteration workflow: data loading, batch processing, result collation across distributed ranks, and final processing.

Workflow:
  1. Iteration (_iterate_over_dataset()): When the periodic interval is reached, iterate through the dataset in batches.

  2. Process Data (process_data()): Process a single batch (e.g., run model inference) and return results.

  3. Collation (_collate_result()): Aggregate results across all batches and distributed ranks.

  4. Processing (process_results()): Compute final metrics or perform actions with the aggregated results.

Key Features:
  • Distributed Support: Automatically handles distributed evaluation with proper gathering across ranks and padding removal.

  • Flexible Collation: Supports collating various result types (tensors, dicts of tensors, lists).

  • Data Pipeline Integration: Uses SamplerIntervalConfig to integrate with the interleaved sampler for efficient data loading.

  • Progress Tracking: Provides progress bars and timing information for data loading.

Template Methods to Override:

Child classes must implement process_data() and process_results():

Examples

Basic validation accuracy callback that evaluates on a test set every epoch:

class AccuracyCallback(PeriodicDataIteratorCallback):
    def __init__(self, *args, dataset_key="test", **kwargs):
        super().__init__(*args, **kwargs)
        self.dataset_key = dataset_key

    def process_data(self, batch, *, trainer_model):
        # Run inference on batch
        x = batch["x"].to(trainer_model.device)
        y_true = batch["class"].clone()
        y_pred = trainer_model(x)
        return {"predictions": y_pred, "labels": y_true}

    def process_results(self, results, *, interval_type, update_counter, **_):
        # Compute accuracy from aggregated results
        y_pred = results["predictions"]
        y_true = results["labels"]
        accuracy = (y_pred.argmax(dim=1) == y_true).float().mean()

        self.writer.add_scalar(
            key="test/accuracy",
            value=accuracy.item(),
            logger=self.logger,
            format_str=".4f",
        )


# Configure in YAML:
# callbacks:
#   - kind: path.to.AccuracyCallback
#     every_n_epochs: 1
#     dataset_key: "test"

Advanced example with multiple return values and custom collation:

class DetailedEvaluationCallback(PeriodicDataIteratorCallback):
    def process_data(self, batch, *, trainer_model):
        x = batch["x"].to(trainer_model.device)
        y = batch["label"]

        # Return multiple outputs as tuple
        logits = trainer_model(x)
        embeddings = trainer_model.get_embeddings(x)
        return logits, embeddings, y

    def process_results(self, results, *, interval_type, update_counter, **_):
        # results is a tuple: (all_logits, all_embeddings, all_labels)
        logits, embeddings, labels = results

        # Compute multiple metrics
        accuracy = (logits.argmax(dim=1) == labels).float().mean()
        mean_embedding_norm = embeddings.norm(dim=-1).mean()

        self.writer.add_scalar("accuracy", accuracy.item())
        self.writer.add_scalar("embedding_norm", mean_embedding_norm.item())
dataset_key

Key to identify the dataset to iterate over from self.data_container. Automatically set from the callback config.

sampler_config

Configuration for the sampler that controls dataset iteration. Automatically set when dataset is initialized.

total_data_time

Cumulative time spent waiting for data loading across all periodic callbacks.

Note

  • The process_data() method is called within a torch.no_grad() context automatically.

  • For distributed training, results are automatically gathered across all ranks with proper padding removal.

Parameters:
dataset_key
total_data_time = 0.0
sampler_config
_sampler_config_from_key(key, properties=None, max_size=None)

Register the dataset that is used for this callback in the dataloading pipeline.

Parameters:
  • key (str | None) – Key for identifying the dataset from self.data_container. Uses the first dataset if None.

  • properties (set[str] | None) – Optionally specifies a subset of properties to load from the dataset.

  • max_size (int | None) – If provided, only uses a subset of the full dataset. Default: None (no subset).

Returns:

SamplerIntervalConfig for the registered dataset.

Return type:

noether.data.samplers.SamplerIntervalConfig

abstractmethod process_data(batch, *, trainer_model)

Template method that is called for each batch that is loaded from the dataset.

This method should process a single batch and return results that will be collated.

Parameters:
  • batch – The loaded batch.

  • trainer_model (torch.nn.Module) – Model of the current training run.

Returns:

Processed results for this batch. Can be a tensor, dict of tensors, list, or tuple.

Return type:

Any

process_results(results, *, interval_type, update_counter, **_)

Template method that is called with the collated results from dataset iteration.

For example, metrics can be computed from the results for the entire test/validation dataset and logged.

Parameters:
Return type:

None

periodic_callback(*, interval_type, update_counter, data_iter, trainer_model, batch_size, **_)

Hook called periodically based on the configured intervals.

This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (every_n_epochs, every_n_updates, or every_n_samples) are reached.

Subclasses should override this method to implement periodic logic such as:

  • Calculating and logging expensive validation metrics

  • Saving specific model checkpoints or artifacts

  • Visualizing training progress (e.g., plotting samples)

  • Adjusting training hyperparameters or model state

Note

This method is executed within a torch.no_grad() context.

Parameters:
  • interval_type (IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.

  • update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance providing details about the current training progress (epoch, update, sample counts).

  • **kwargs – Additional keyword arguments passed from the triggering hook (e.g., from after_epoch() or after_update()).

  • data_iter (collections.abc.Iterator)

  • batch_size (int)

Return type:

None

after_training(**_)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counterUpdateCounter instance to access current training progress.

Return type:

None