noether.core.callbacks.default.online_loss¶
Classes¶
Callback to track the loss of the model after every gradient accumulation step and log the average loss. |
Module Contents¶
- class noether.core.callbacks.default.online_loss.OnlineLossCallback(callback_config, **kwargs)¶
Bases:
noether.core.callbacks.periodic.PeriodicCallbackCallback to track the loss of the model after every gradient accumulation step and log the average loss.
This callback is initialized by the
BaseTrainerand should not be added manually to the trainer’s callbacks.Initialize the OnlineLossCallback.
- Parameters:
callback_config (noether.core.schemas.callbacks.OnlineLossCallbackConfig) – Configuration for the callback. See
OnlineLossCallbackConfigfor available options.**kwargs – Additional arguments passed to the parent class.
- verbose¶
- tracked_losses: collections.defaultdict[str, list[torch.Tensor]]¶
- track_after_accumulation_step(*, losses, **_)¶
Hook called after each individual gradient accumulation step.
This method is invoked for every batch processed during training, regardless of whether an optimizer update is performed in that step (i.e., when
accumulation_steps > 1). It is primarily used for tracking metrics that should be averaged or aggregated across accumulation steps.Common use cases include:
Logging per-batch losses for high-frequency monitoring
Accumulating statistics across batches before an optimizer update
Implementing custom logging that needs access to individual batch data
Note
This method is generally intended to be called within a
torch.no_grad()context by the trainer to ensure no gradients are tracked during logging operations.- Parameters:
update_counter –
UpdateCounterinstance to access current training progress.batch – The current data batch processed in this accumulation step.
losses – Dictionary of computed losses for the current batch.
update_outputs – Optional dictionary of model outputs for the current batch.
accumulation_steps – Total number of accumulation steps before an optimizer update.
accumulation_step – The current accumulation step index (0-indexed).
- Return type:
None
- periodic_callback(*, interval_type, **_)¶
Hook called periodically based on the configured intervals.
This method is the primary entry point for periodic actions in subclasses. It is triggered when any of the configured intervals (
every_n_epochs,every_n_updates, orevery_n_samples) are reached.Subclasses should override this method to implement periodic logic such as:
Calculating and logging expensive validation metrics
Saving specific model checkpoints or artifacts
Visualizing training progress (e.g., plotting samples)
Adjusting training hyperparameters or model state
Note
This method is executed within a
torch.no_grad()context.- Parameters:
interval_type (noether.core.callbacks.periodic.IntervalType) – “epoch”, “update”, “sample” or “eval” indicating which interval triggered this callback.
update_counter –
UpdateCounterinstance providing details about the current training progress (epoch, update, sample counts).**kwargs – Additional keyword arguments passed from the triggering hook (e.g., from
after_epoch()orafter_update()).
- Return type:
None