noether.training.trainers.base¶
Classes¶
Filter instances are used to perform arbitrary filtering of LogRecords. |
|
Base class for all trainers that use SGD-based optimizers. |
Module Contents¶
- class noether.training.trainers.base.TrainingContextFilter(update_counter)¶
Bases:
logging.FilterFilter instances are used to perform arbitrary filtering of LogRecords.
Loggers and Handlers can optionally use Filter instances to filter records as desired. The base filter class only allows events which are below a certain point in the logger hierarchy. For example, a filter initialized with “A.B” will allow events logged by loggers “A.B”, “A.B.C”, “A.B.C.D”, “A.B.D” etc. but not “A.BB”, “B.A.B” etc. If initialized with the empty string, all events are passed.
Initialize a filter.
Initialize with the name of the logger which, together with its children, will have its events allowed through the filter. If no name is specified, allow every event.
- Parameters:
update_counter (noether.core.utils.training.UpdateCounter)
- update_counter¶
- filter(record)¶
Determine if the specified record is to be logged.
Returns True if the record should be logged, or False otherwise. If deemed appropriate, the record may be modified in-place.
- Parameters:
record (logging.LogRecord)
- Return type:
- class noether.training.trainers.base.BaseTrainer(config, data_container, device, tracker, path_provider, main_sampler_kwargs=None, metric_property_provider=None)¶
Base class for all trainers that use SGD-based optimizers.
This class implements the main training loop and provides utility functions for logging, checkpointing, and callbacks. In your down-stream you have to implement the loss_compute method that calculates the loss based on the model output and the targets. Optionally, you can also override the train_step method if you want to implement a custom training step (e.g., for multi-loss training or custom backward logic). If you only want to implement a custom loss calculation but keep the rest of the training loop, you can just override the loss_compute method. For example:
class MyTrainer(BaseTrainer): def __init__(self, trainer_config: BaseTrainerConfig, **kwargs): super().__init__(trainer_config, **kwargs) def loss_compute( self, forward_output: dict[str, torch.Tensor], targets: dict[str, torch.Tensor] ) -> LossResult: # compute loss based on model output and targets return loss
- Parameters:
config (noether.core.schemas.BaseTrainerConfig) – Configuration for the trainer. See
BaseTrainerConfigfor the available options.data_container (noether.data.container.DataContainer) – The
DataContainerwhich includes the data and dataloader.device (str) – The device to use for training (e.g., “cuda”). It is assumed that the process was configured such that only 1 device is visible (e.g., via the CUDA_VISIBLE_DEVICES environment variable).
main_sampler_kwargs (dict | None) – Kwargs passed to instantiate the main sampler.
tracker (noether.core.trackers.BaseTracker) – The tracker to use for training.
path_provider (noether.core.providers.PathProvider) – The
PathProviderto use for training.metric_property_provider (noether.core.providers.MetricPropertyProvider | None) – The
MetricPropertyProviderto use for training.
- logger¶
- config¶
- data_container¶
- path_provider¶
- main_sampler_kwargs = None¶
- device: torch.device¶
- end_checkpoint¶
- precision¶
- updates_per_epoch¶
- skip_nan_loss_counter = 0¶
- initializer: noether.core.initializers.InitializerBase | None = None¶
- tracker¶
- metric_property_provider = None¶
- update_counter¶
- log_writer¶
- checkpoint_writer¶
- callbacks: list[noether.core.callbacks.CallbackBase] = []¶
- forward_properties¶
- target_properties¶
- batch_keys¶
- get_user_callbacks(model, evaluation=False)¶
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
- get_all_callbacks(model)¶
Get all callbacks including default/trainer callbacks.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
- get_trainer_callbacks(callback_default_args)¶
Get trainer-specific callbacks. This may optionally be overridden by derived classes.
- Parameters:
- Return type:
- get_default_callback_intervals()¶
Get default intervals at which callbacks are called.
- get_default_callbacks(default_kwargs)¶
- Parameters:
- Return type:
- load_state_dict(state_dict)¶
Load the state dict of the trainer.
- apply_resume_initializer(model)¶
Apply the resume initializer to the model.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
None
- get_data_loader(iterator_callbacks, batch_size, evaluation=False)¶
Get the data loader for training.
- Parameters:
iterator_callbacks (list[noether.core.callbacks.periodic.PeriodicDataIteratorCallback])
batch_size (int)
evaluation (bool)
- Return type:
- abstractmethod loss_compute(forward_output, targets)¶
Each trainer that extends this class needs to implement a custom loss computation using the targets and the model output.
- Parameters:
forward_output (dict[str, torch.Tensor]) – Output of the model after the forward pass.
targets (dict[str, torch.Tensor]) – Dict with target tensors needed to compute the loss for this trainer.
- Returns:
A dict with the (weighted) sub-losses to log. Or a tuple of (losses, additional_outputs) where additional_outputs is a dict with additional information about the model forward pass that is passed to the track_after_accumulation_step method of the callbacks, e.g., the logits and targets to calculate a training accuracy in a callback).
- Return type:
noether.training.trainers.types.LossResult | tuple[noether.training.trainers.types.LossResult, dict[str, torch.Tensor]]
Note: If a tuple is returned, the second element will be passed as additional_outputs in the TrainerResult returned by the train_step method.
- train_step(batch, model)¶
Overriding this function is optional. By default, the train_step of the model will be called and is expected to return a TrainerResult. Trainers can override this method to implement custom training logic.
- Parameters:
batch (dict[str, torch.Tensor]) – Batch of data from which the loss is calculated.
model (torch.nn.Module) – Model to use for processing the data.
- Returns:
TrainerResult dataclass with the loss for backpropagation, (optionally) individual losses if multiple losses are used, and (optionally) additional information about the model forward pass that is passed to the callbacks (e.g., the logits and targets to calculate a training accuracy in a callback).
- Return type:
- wrap_model(model)¶
Wrap the model for training, return the model, wrapped model and ddp+compiled model.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
- wrap_ddp(model)¶
Wrap the model with DistributedDataParallel in multi-GPU settings.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
noether.core.models.ModelBase | torch.nn.parallel.DistributedDataParallel
- wrap_compile(ddp_model)¶
Wrap the model with torch.compile.
- Parameters:
ddp_model (noether.core.models.ModelBase | torch.nn.parallel.DistributedDataParallel)
- Return type:
- train(model)¶
Train the model.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
None
- static drop_metadata(data)¶
- update(batch, dist_model, model=None, accumulation_steps=1, iter_step=0, **kwargs)¶
Perform forward and backward pass.
- Parameters:
batch (dict[str, torch.Tensor])
dist_model (torch.nn.Module)
model (noether.core.models.ModelBase | None)
accumulation_steps (int)
iter_step (int)
- Return type:
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor] | None, dict[str, noether.core.utils.common.stopwatch.Stopwatch]]
- call_before_training(callbacks)¶
Hook that is called before training starts.
- Parameters:
callbacks (list[noether.core.callbacks.CallbackBase])
- Return type:
None
- call_after_training(callbacks)¶
Hook that is called after training ends.
- Parameters:
callbacks (list[noether.core.callbacks.CallbackBase])
- Return type:
None
- eval(model)¶
Run evaluation by executing all configured callbacks.
- Parameters:
model (noether.core.models.ModelBase)
- Return type:
None