noether.training.callbacks

Submodules

Classes

OfflineLossCallback

A periodic Callback that is invoked at the end of each epoch to calculate and track the loss and a dataset.

OfflineLossCallbackConfig

Internal base class for all registry-based configs.

PyTorchProfilerCallback

Profiles the training loop with torch.profiler.profile.

PyTorchProfilerCallbackConfig

Configuration for the PyTorch profiler callback.

Package Contents

class noether.training.callbacks.OfflineLossCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicDataIteratorCallback

A periodic Callback that is invoked at the end of each epoch to calculate and track the loss and a dataset.

Parameters:

callback_config (OfflineLossCallbackConfig) – configuration of the OfflineLossCallback. See OfflineLossCallbackConfig for the available options.

dataset_key
output_patterns_to_log
process_data(batch, *, trainer_model, **_)

Template method that is called for each batch that is loaded from the dataset.

This method should process a single batch and return results that will be collated.

Parameters:
  • batch – The loaded batch.

  • trainer_model – Model of the current training run.

Returns:

Processed results for this batch. Can be a tensor, dict of tensors, list, or tuple.

process_results(results, **_)

Template method that is called with the collated results from dataset iteration.

For example, metrics can be computed from the results for the entire test/validation dataset and logged.

Parameters:
  • results (tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]) – The collated results that were produced by _iterate_over_dataset() and the individual process_data() calls.

  • interval_type – The type of interval that triggered this callback invocation.

  • update_counterUpdateCounter with the current training state.

  • **_ – Additional unused keyword arguments.

Return type:

None

class noether.training.callbacks.OfflineLossCallbackConfig(/, **data)

Bases: noether.core.callbacks.periodic.PeriodicDataIteratorCallbackConfig

Internal base class for all registry-based configs.

Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

name: Literal['OfflineLossCallback'] = None
output_patterns_to_log: list[str] | None = None

additional arguments passed to the parent class.

Type:

For instance, if the output key is ‘some_loss’ and the pattern is [‘loss’]. **kwargs

class noether.training.callbacks.PyTorchProfilerCallback(callback_config, **kwargs)

Bases: noether.core.callbacks.periodic.PeriodicCallback

Profiles the training loop with torch.profiler.profile.

The profiler is entered in before_training(), stepped once per optimizer update in track_after_update_step(), and exited in after_training(). Traces are written to <run_output_path>/<trace_subdir> via tensorboard_trace_handler and can be loaded in TensorBoard (tensorboard --logdir <path>) or inspected in chrome://tracing.

Note

every_n_updates=1 must be set so that track_after_update_step is called on every update (any every_n_* value works — it only gates the unused periodic_callback hook, not the tracking hooks).

Example

callbacks:
- kind: callbacks.PyTorchProfilerCallback
    every_n_updates: 1
    wait: 1
    warmup: 1
    active: 3
    repeat: 2
    record_shapes: true
    profile_memory: false
    with_stack: false
    with_flops: false
    with_modules: true
    activities:
    - cpu
    - cuda
Parameters:
  • callback_config (PyTorchProfilerCallbackConfig) – Configuration for the callback. See CallBackBaseConfig for available options.

  • trainer – Trainer of the current run.

  • model – Model of the current run.

  • data_containerDataContainer instance that provides access to all datasets.

  • trackerBaseTracker instance to log metrics to stdout/disk/online platform.

  • log_writerLogWriter instance to log metrics.

  • checkpoint_writerCheckpointWriter instance to save checkpoints.

  • metric_property_providerMetricPropertyProvider instance to access properties of metrics.

  • name – Name of the callback.

before_training(*, update_counter)

Hook called once before the training loop starts.

This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:

  • Initializing experiment tracking (e.g., logging hyperparameters)

  • Printing model summaries or architecture details

  • Initializing specific data structures or buffers needed during training

  • Performing sanity checks on the data or configuration

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None

track_after_update_step(*, update_counter, times)

Hook called after each optimizer update step.

This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:

  • Latest loss values

  • Learning rates

  • Model parameter statistics (norms, etc.)

  • Training throughput and timing measurements

Unlike periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.

Note

This method is executed within a torch.no_grad() context.

Parameters:
Return type:

None

after_training(*, update_counter)

Hook called once after the training loop finishes.

This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:

  • Performing a final evaluation on the test set

  • Saving final model weights or artifacts

  • Sending notifications (e.g., via Slack or email) about the completed run

  • Closing or finalizing experiment tracking sessions

Note

This method is executed within a torch.no_grad() context.

Parameters:

update_counter (noether.core.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None

class noether.training.callbacks.PyTorchProfilerCallbackConfig(/, **data)

Bases: noether.core.callbacks.base.CallBackBaseConfig

Configuration for the PyTorch profiler callback.

The profiler uses torch.profiler.profile with a scheduled trace. Profiling is driven off of track_after_update_step hooks, i.e. the profiler is stepped once per optimizer update. The resulting traces are written to <run_output_path>/profiler and can be opened in TensorBoard or chrome://tracing.

Recommended usage: limit training with trainer.max_updates to a value slightly larger than wait + warmup + active (times repeat if > 1).

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

kind: str | None = 'aero_cfd.callbacks.PyTorchProfilerCallback'
wait: int = None

Number of steps to idle before warming up.

warmup: int = None

Number of warmup steps (profiler runs but traces are discarded).

active: int = None

Number of active steps that are recorded in the trace.

repeat: int = None

Number of times the (wait, warmup, active) cycle is repeated. 0 means repeat indefinitely.

record_shapes: bool = None

Whether to record input tensor shapes for each op.

profile_memory: bool = None

Whether to profile tensor memory usage (can add significant overhead).

with_stack: bool = None

Whether to record Python call stacks for each op (can add significant overhead).

with_flops: bool = None

Whether to record estimated FLOPs for each op.

with_modules: bool = None

Whether to record nn.Module hierarchy for each op.

profile_cuda: bool = None

Whether to profile CUDA operations. If False, only CPU operations are profiled.

profile_cpu: bool = None

Whether to profile CPU operations. If False, only CUDA operations are profiled.

trace_subdir: str = None

Subdirectory (relative to run_output_path) where the trace files are written.

rank0_only: bool = None

If True, only rank 0 profiles (noop on other ranks). Avoids noisy/conflicting traces in multi-GPU runs.