noether.training.trainers¶
Submodules¶
Classes¶
Base class for all trainers that use SGD-based optimizers. |
|
Internal base class for all registry-based configs. |
|
Generic trainer that computes weighted loss per output field. |
Package Contents¶
- class noether.training.trainers.BaseTrainer(config, data_container, device, tracker, path_provider, main_sampler_kwargs=None, metric_property_provider=None)¶
Base class for all trainers that use SGD-based optimizers.
This class implements the main training loop and provides utility functions for logging, checkpointing, and callbacks. In your down-stream you have to implement the loss_compute method that calculates the loss based on the model output and the targets. Optionally, you can also override the train_step method if you want to implement a custom training step (e.g., for multi-loss training or custom backward logic). If you only want to implement a custom loss calculation but keep the rest of the training loop, you can just override the loss_compute method. For example:
class MyTrainer(BaseTrainer): def __init__(self, trainer_config: BaseTrainerConfig, **kwargs): super().__init__(trainer_config, **kwargs) def loss_compute( self, forward_output: dict[str, torch.Tensor], targets: dict[str, torch.Tensor] ) -> LossResult: # compute loss based on model output and targets return loss
- Parameters:
config (BaseTrainerConfig) – Configuration for the trainer. See
BaseTrainerConfigfor the available options.data_container (noether.data.container.DataContainer) – The
DataContainerwhich includes the data and dataloader.device (str) – The device to use for training (e.g., “cuda”). It is assumed that the process was configured such that only 1 device is visible (e.g., via the CUDA_VISIBLE_DEVICES environment variable).
main_sampler_kwargs (dict | None) – Kwargs passed to instantiate the main sampler.
tracker (noether.core.trackers.BaseTracker) – The tracker to use for training.
path_provider (noether.core.providers.PathProvider) – The
PathProviderto use for training.metric_property_provider (noether.core.providers.MetricPropertyProvider | None) – The
MetricPropertyProviderto use for training.
- logger¶
- config¶
- data_container¶
- path_provider¶
- main_sampler_kwargs = None¶
- device: torch.device¶
- end_checkpoint¶
- precision¶
- updates_per_epoch¶
- skip_nan_loss_counter = 0¶
- initializer: noether.core.initializers.InitializerBase | None = None¶
- tracker¶
- metric_property_provider = None¶
- update_counter¶
- log_writer¶
- checkpoint_writer¶
- callbacks: list[noether.core.callbacks.CallbackBase] = []¶
- forward_properties¶
- target_properties¶
- batch_keys¶
- get_user_callbacks(model, evaluation=False)¶
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
- get_all_callbacks(model)¶
Get all callbacks including default/trainer callbacks.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
- get_trainer_callbacks(callback_default_args)¶
Get trainer-specific callbacks. This may optionally be overridden by derived classes.
- Parameters:
- Return type:
- get_default_callback_intervals()¶
Get default intervals at which callbacks are called.
- get_default_callbacks(default_kwargs)¶
- Parameters:
- Return type:
- load_state_dict(state_dict)¶
Load the state dict of the trainer.
- apply_resume_initializer(model)¶
Apply the resume initializer to the model.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
None
- get_data_loader(iterator_callbacks, batch_size, evaluation=False)¶
Get the data loader for training.
- Parameters:
iterator_callbacks (list[noether.core.callbacks.PeriodicDataIteratorCallback])
batch_size (int)
evaluation (bool)
- Return type:
- abstractmethod loss_compute(forward_output, targets)¶
Each trainer that extends this class needs to implement a custom loss computation using the targets and the model output.
- Parameters:
forward_output (dict[str, torch.Tensor]) – Output of the model after the forward pass.
targets (dict[str, torch.Tensor]) – Dict with target tensors needed to compute the loss for this trainer.
- Returns:
A dict with the (weighted) sub-losses to log. Or a tuple of (losses, additional_outputs) where additional_outputs is a dict with additional information about the model forward pass that is passed to the track_after_accumulation_step method of the callbacks, e.g., the logits and targets to calculate a training accuracy in a callback).
- Return type:
noether.training.trainers.types.LossResult | tuple[noether.training.trainers.types.LossResult, dict[str, torch.Tensor]]
Note: If a tuple is returned, the second element will be passed as additional_outputs in the TrainerResult returned by the train_step method.
- train_step(batch, model)¶
Overriding this function is optional. By default, the train_step of the model will be called and is expected to return a TrainerResult. Trainers can override this method to implement custom training logic.
- Parameters:
batch (dict[str, torch.Tensor]) – Batch of data from which the loss is calculated.
model (torch.nn.Module) – Model to use for processing the data.
- Returns:
TrainerResult dataclass with the loss for backpropagation, (optionally) individual losses if multiple losses are used, and (optionally) additional information about the model forward pass that is passed to the callbacks (e.g., the logits and targets to calculate a training accuracy in a callback).
- Return type:
- wrap_model(model)¶
Wrap the model for training, return the model, wrapped model and ddp+compiled model.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
- wrap_ddp(model)¶
Wrap the model with DistributedDataParallel in multi-GPU settings.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
noether.core.models.ModelBase | torch.nn.parallel.DistributedDataParallel
- wrap_compile(ddp_model)¶
Wrap the model with torch.compile.
- Parameters:
ddp_model (noether.core.models.ModelBase | torch.nn.parallel.DistributedDataParallel)
- Return type:
- train(model)¶
Train the model.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
None
- static drop_metadata(data)¶
- update(batch, dist_model, model, accumulation_steps_total, accumulation_step, retain_graph=False)¶
Perform forward and backward pass.
- Parameters:
batch (dict[str, torch.Tensor])
dist_model (torch.nn.Module)
model (noether.core.models.ModelBase)
accumulation_steps_total (int)
accumulation_step (int)
retain_graph (bool)
- Return type:
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor] | None, dict[str, noether.core.utils.common.stopwatch.Stopwatch]]
- call_before_training(callbacks)¶
Hook that is called before training starts.
- Parameters:
callbacks (list[noether.core.callbacks.CallbackBase])
- Return type:
None
- call_after_training(callbacks)¶
Hook that is called after training ends.
- Parameters:
callbacks (list[noether.core.callbacks.CallbackBase])
- Return type:
None
- eval(model)¶
Run evaluation by executing all configured callbacks.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
None
- class noether.training.trainers.BaseTrainerConfig[TCallbackConfig: noether.core.callbacks.base.CallBackBaseConfig](/, **data)¶
Bases:
noether.core.schemas.lib._RegistryBaseInternal base class for all registry-based configs.
Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
data (Any)
- max_epochs: int | None = None¶
The maximum number of epochs to train for. Mutually exclusive with max_updates and max_samples. If set to 0, training will be skipped and all callbacks will be invoked once (useful for evaluation-only runs).
- max_updates: int | None = None¶
The maximum number of updates to train for. Mutually exclusive with max_epochs and max_samples. If set to 0, training will be skipped and all callbacks will be invoked once (useful for evaluation-only runs).
- max_samples: int | None = None¶
The maximum number of samples to train for. Mutually exclusive with max_epochs and max_updates. If set to 0, training will be skipped and all callbacks will be invoked once (useful for evaluation-only runs).
- start_at_epoch: int | None = None¶
The epoch to start training at. This means that the trainer will skip all epochs before this epoch. Learning rate and other schedulers will be stepped accordingly. Useful for resuming training from a specific epoch.
- add_default_callbacks: bool | None = None¶
Whether to add default callbacks. Default callbacks log things like simple dataset statistics or the current value of the learning rate if it is scheduled.
- add_trainer_callbacks: bool | None = None¶
Whether to add trainer specific callbacks (e.g., a callback to log the training accuracy for a classification task).
- effective_batch_size: int = None¶
the “global batch size”. In multi-GPU setups, the batch size per device, (“local batch size”) is effective_batch_size / number of devices. If gradient accumulation is used, the forward-pass batch size is derived by dividing by the number of gradient accumulation steps.
- Type:
The effective batch size used for optimization. This is the number of samples that are processed before an update step is taken
- precision: Literal['float32', 'fp32', 'float16', 'fp16', 'bfloat16', 'bf16'] = None¶
The precision to use for training (e.g., “float32”). Mixed precision training (e.g., “float16” or “bfloat16”) can be used to speed up training and reduce memory usage on supported hardware (e.g., NVIDIA GPUs).
- callbacks: list[Annotated[TCallbackConfig, Discriminated(CallBackBaseConfig)]] | None = None¶
The callbacks to use for training.
- initializer: noether.core.initializers.InitializerConfig | None = None¶
The initializer to use for training. Mainly used for resuming training via ResumeInitializer.
- track_every_n_epochs: int | None = None¶
The integer number of epochs to periodically track metrics at.
- track_every_n_updates: int | None = None¶
The integer number of updates to periodically track metrics at.
- track_every_n_samples: int | None = None¶
The integer number of samples to periodically track metrics at.
- max_batch_size: int | None = None¶
The maximum batch size to use for model forward pass in training. If the effective_batch_size is larger than max_batch_size, gradient accumulation will be used to simulate the larger batch size. For example, if effective_batch_size=8 and max_batch_size=2, 4 gradient accumulation steps will be taken before each optimizer step.
- skip_nan_loss: bool = None¶
Whether to skip NaN losses. These can sometimes occur due to unlucky coincidences. If true, NaN losses will be skipped without terminating the training up until 100 NaN losses occurred in a row.
- disable_gradient_accumulation: bool = None¶
Whether to disable gradient accumulation. Gradient accumulation is sometimes used to simulate larger batch sizes, but can lead to worse generalization.
- save_on_sigint: bool = None¶
Whether to save a checkpoint on SIGINT (Ctrl+C). SIGTERM always triggers a checkpoint save. When False (default), Ctrl+C will stop training immediately without saving.
- use_torch_compile: bool = None¶
Whether to use torch.compile to compile the model for faster training.
- forward_properties: list[str] | None = []¶
Properties (i.e., keys from the batch dict) from the input batch that are used as inputs to the model during the forward pass.
- target_properties: list[str] | None = []¶
Properties (i.e., keys from the batch dict) from the input batch that are used as targets for the model during the forward pass.
- model_config¶
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- dataloader_prefetch_factor: int | None = None¶
The prefetch_factor to use for the training dataloader. This controls how many batches are prefetched by each worker process in the dataloader. Increasing this can speed up training if data loading is a bottleneck, but also increases memory usage.
- monitor_training_stability: bool = None¶
Whether to monitor training stability by logging gradient norms, model norms and grad scaler scale at regular intervals using the TrainingStabilityCallback. This can be useful for diagnosing issues with exploding or vanishing gradients.
- monitor_interval: int | None = None¶
The interval (in updates) at which to monitor training stability when monitor_training_stability is True. This controls how often the TrainingStabilityCallback logs gradient norms, model norms and grad scaler scale.
- validate_callback_frequency()¶
Ensures that exactly one frequency (‘every_n_*’) is specified and that ‘batch_size’ is present if ‘every_n_samples’ is used.
- Return type:
- validate_max_training_criteria()¶
Ensures that exactly one of max_epochs, max_updates, or max_samples is specified.
- Return type:
- class noether.training.trainers.TrainerResult¶
- total_loss: torch.Tensor¶
- losses_to_log: dict[str, torch.Tensor] | None = None¶
- additional_outputs: dict[str, torch.Tensor] | None = None¶
- class noether.training.trainers.WeightedLossTrainer(trainer_config, **kwargs)¶
Bases:
noether.training.trainers.BaseTrainerGeneric trainer that computes weighted loss per output field.
Expects the model forward to return
dict[str, Tensor]with keys matchingfield_weightskeys, and the batch to contain<field_name>_targetkeys.The loss function defaults to MSE and can be changed via the
loss_fnconfig parameter. Use a built-in short name or a dotted import path for custom losses.Built-in losses:
trainer_params = dict(field_weights={"pressure": 1.0}, loss_fn="l1")
Custom loss function from your project:
trainer_params = dict( field_weights={"pressure": 1.0}, loss_fn="my_project.losses.weighted_huber", )
The custom callable must have the signature
(input, target) -> Tensor, matchingtorch.nn.functionalloss functions.- Parameters:
config – Configuration for the trainer. See
BaseTrainerConfigfor the available options.data_container – The
DataContainerwhich includes the data and dataloader.device – The device to use for training (e.g., “cuda”). It is assumed that the process was configured such that only 1 device is visible (e.g., via the CUDA_VISIBLE_DEVICES environment variable).
main_sampler_kwargs – Kwargs passed to instantiate the main sampler.
tracker – The tracker to use for training.
path_provider – The
PathProviderto use for training.metric_property_provider – The
MetricPropertyProviderto use for training.trainer_config (WeightedLossTrainerConfig)
- loss_compute(forward_output, targets)¶
Each trainer that extends this class needs to implement a custom loss computation using the targets and the model output.
- Parameters:
forward_output (dict[str, torch.Tensor]) – Output of the model after the forward pass.
targets (dict[str, torch.Tensor]) – Dict with target tensors needed to compute the loss for this trainer.
- Returns:
A dict with the (weighted) sub-losses to log. Or a tuple of (losses, additional_outputs) where additional_outputs is a dict with additional information about the model forward pass that is passed to the track_after_accumulation_step method of the callbacks, e.g., the logits and targets to calculate a training accuracy in a callback).
- Return type:
Note: If a tuple is returned, the second element will be passed as additional_outputs in the TrainerResult returned by the train_step method.