noether.training.callbacks¶
Submodules¶
Classes¶
A periodic Callback that is invoked at the end of each epoch to calculate and track the loss and a dataset. |
|
Internal base class for all registry-based configs. |
|
Profiles the training loop with |
|
Configuration for the PyTorch profiler callback. |
Package Contents¶
- class noether.training.callbacks.OfflineLossCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicDataIteratorCallbackA periodic Callback that is invoked at the end of each epoch to calculate and track the loss and a dataset.
- Parameters:
callback_config (OfflineLossCallbackConfig) – configuration of the OfflineLossCallback. See
OfflineLossCallbackConfigfor the available options.
- dataset_key¶
- output_patterns_to_log¶
- process_data(batch, *, trainer_model, **_)¶
Template method that is called for each batch that is loaded from the dataset.
This method should process a single batch and return results that will be collated.
- Parameters:
batch – The loaded batch.
trainer_model – Model of the current training run.
- Returns:
Processed results for this batch. Can be a tensor, dict of tensors, list, or tuple.
- process_results(results, **_)¶
Template method that is called with the collated results from dataset iteration.
For example, metrics can be computed from the results for the entire test/validation dataset and logged.
- Parameters:
results (tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]) – The collated results that were produced by
_iterate_over_dataset()and the individualprocess_data()calls.interval_type – The type of interval that triggered this callback invocation.
update_counter –
UpdateCounterwith the current training state.**_ – Additional unused keyword arguments.
- Return type:
None
- class noether.training.callbacks.OfflineLossCallbackConfig(/, **data)¶
Bases:
noether.core.callbacks.periodic.PeriodicDataIteratorCallbackConfigInternal base class for all registry-based configs.
Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
data (Any)
- name: Literal['OfflineLossCallback'] = None¶
- class noether.training.callbacks.PyTorchProfilerCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackProfiles the training loop with
torch.profiler.profile.The profiler is entered in
before_training(), stepped once per optimizer update intrack_after_update_step(), and exited inafter_training(). Traces are written to<run_output_path>/<trace_subdir>viatensorboard_trace_handlerand can be loaded in TensorBoard (tensorboard --logdir <path>) or inspected inchrome://tracing.Note
every_n_updates=1must be set so thattrack_after_update_stepis called on every update (anyevery_n_*value works — it only gates the unusedperiodic_callbackhook, not the tracking hooks).Example
callbacks: - kind: callbacks.PyTorchProfilerCallback every_n_updates: 1 wait: 1 warmup: 1 active: 3 repeat: 2 record_shapes: true profile_memory: false with_stack: false with_flops: false with_modules: true activities: - cpu - cuda
- Parameters:
callback_config (PyTorchProfilerCallbackConfig) – Configuration for the callback. See
CallBackBaseConfigfor available options.trainer – Trainer of the current run.
model – Model of the current run.
data_container –
DataContainerinstance that provides access to all datasets.tracker –
BaseTrackerinstance to log metrics to stdout/disk/online platform.log_writer –
LogWriterinstance to log metrics.checkpoint_writer –
CheckpointWriterinstance to save checkpoints.metric_property_provider –
MetricPropertyProviderinstance to access properties of metrics.name – Name of the callback.
- before_training(*, update_counter)¶
Hook called once before the training loop starts.
This method is intended to be overridden by derived classes to perform initialization tasks before training begins. Common use cases include:
Initializing experiment tracking (e.g., logging hyperparameters)
Printing model summaries or architecture details
Initializing specific data structures or buffers needed during training
Performing sanity checks on the data or configuration
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.- Return type:
None
- track_after_update_step(*, update_counter, times)¶
Hook called after each optimizer update step.
This method is invoked after a successful optimizer step and parameter update. It is typically used for tracking metrics that should be recorded once per update cycle, such as:
Latest loss values
Learning rates
Model parameter statistics (norms, etc.)
Training throughput and timing measurements
Unlike
periodic_callback(), this hook is called on every update step, making it suitable for maintaining running averages or high-frequency telemetry.Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.times (dict[str, float]) – Dictionary containing time measurements for various parts of the training step (e.g., ‘data_time’, ‘forward_time’, ‘backward_time’, ‘update_time’).
- Return type:
None
- after_training(*, update_counter)¶
Hook called once after the training loop finishes.
This method is intended to be overridden by derived classes to perform cleanup or final reporting tasks after training is complete. Common use cases include:
Performing a final evaluation on the test set
Saving final model weights or artifacts
Sending notifications (e.g., via Slack or email) about the completed run
Closing or finalizing experiment tracking sessions
Note
This method is executed within a
torch.no_grad()context.- Parameters:
update_counter (noether.core.utils.training.counter.UpdateCounter) –
UpdateCounterinstance to access current training progress.- Return type:
None
- class noether.training.callbacks.PyTorchProfilerCallbackConfig(/, **data)¶
Bases:
noether.core.callbacks.base.CallBackBaseConfigConfiguration for the PyTorch profiler callback.
The profiler uses
torch.profiler.profilewith a scheduled trace. Profiling is driven off oftrack_after_update_stephooks, i.e. the profiler is stepped once per optimizer update. The resulting traces are written to<run_output_path>/profilerand can be opened in TensorBoard or chrome://tracing.Recommended usage: limit training with
trainer.max_updatesto a value slightly larger thanwait + warmup + active(timesrepeatif > 1).Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
data (Any)
- repeat: int = None¶
Number of times the (wait, warmup, active) cycle is repeated. 0 means repeat indefinitely.
- with_stack: bool = None¶
Whether to record Python call stacks for each op (can add significant overhead).
- profile_cuda: bool = None¶
Whether to profile CUDA operations. If False, only CPU operations are profiled.
- profile_cpu: bool = None¶
Whether to profile CPU operations. If False, only CUDA operations are profiled.