noether.training.trainers.base ============================== .. py:module:: noether.training.trainers.base Attributes ---------- .. autoapisummary:: noether.training.trainers.base.TRAINING_DATA_WAIT_TIME noether.training.trainers.base.TRAINING_UPDATE_TIME Classes ------- .. autoapisummary:: noether.training.trainers.base.TrainingContextFilter noether.training.trainers.base.BaseTrainer Module Contents --------------- .. py:class:: TrainingContextFilter(update_counter) Bases: :py:obj:`logging.Filter` Filter instances are used to perform arbitrary filtering of LogRecords. Loggers and Handlers can optionally use Filter instances to filter records as desired. The base filter class only allows events which are below a certain point in the logger hierarchy. For example, a filter initialized with "A.B" will allow events logged by loggers "A.B", "A.B.C", "A.B.C.D", "A.B.D" etc. but not "A.BB", "B.A.B" etc. If initialized with the empty string, all events are passed. Initialize a filter. Initialize with the name of the logger which, together with its children, will have its events allowed through the filter. If no name is specified, allow every event. .. py:attribute:: update_counter .. py:method:: filter(record) Determine if the specified record is to be logged. Returns True if the record should be logged, or False otherwise. If deemed appropriate, the record may be modified in-place. .. py:data:: TRAINING_DATA_WAIT_TIME :value: 'data_wait' .. py:data:: TRAINING_UPDATE_TIME :value: 'update' .. py:class:: BaseTrainer(config, data_container, device, tracker, path_provider, main_sampler_kwargs = None, metric_property_provider = None) Base class for all trainers that use SGD-based optimizers. :param config: The configuration for the trainer. Implements the `BaseTrainerConfig` schema. :param data_container: The data container which includes the data and dataloader. :param device: The device to use for training (e.g., "cuda"). It is assumed that the process was configured such that only 1 device is visible (e.g., via the CUDA_VISIBLE_DEVICES environment variable). :param main_sampler_kwargs: Kwargs passed to instantiate the main sampler. :param tracker: The tracker to use for training. :param path_provider: The path provider to use for training. :param metric_property_provider: The metric property provider to use for training. .. py:attribute:: logger .. py:attribute:: config .. py:attribute:: data_container .. py:attribute:: path_provider .. py:attribute:: main_sampler_kwargs :value: None .. py:attribute:: device :type: torch.device .. py:attribute:: end_checkpoint .. py:attribute:: precision :value: Ellipsis .. py:attribute:: updates_per_epoch .. py:attribute:: skip_nan_loss_counter :value: 0 .. py:attribute:: initializer :type: noether.core.initializers.InitializerBase | None :value: None .. py:attribute:: tracker .. py:attribute:: metric_property_provider :value: None .. py:attribute:: update_counter .. py:attribute:: log_writer .. py:attribute:: checkpoint_writer .. py:attribute:: callbacks :type: list[noether.core.callbacks.CallbackBase] :value: [] .. py:attribute:: forward_properties .. py:attribute:: target_properties .. py:attribute:: batch_keys .. py:method:: get_user_callbacks(model, evaluation=False) .. py:method:: get_all_callbacks(model) Get all callbacks including default/trainer callbacks. .. py:method:: get_trainer_callbacks(callback_default_args) Get trainer-specific callbacks. This may optionally be overridden by derived classes. .. py:method:: get_default_callback_intervals() Get default intervals at which callbacks are called. .. py:method:: get_default_callbacks(default_kwargs) .. py:method:: state_dict() Get the state dict of the trainer. .. py:method:: load_state_dict(state_dict) Load the state dict of the trainer. .. py:method:: apply_resume_initializer(model) Apply the resume initializer to the model. .. py:method:: get_data_loader(iterator_callbacks, batch_size, evaluation = False) Get the data loader for training. .. py:method:: loss_compute(forward_output, targets) :abstractmethod: Each trainer that extends this class needs to implement a custom loss computation by using the targers and the output of the model. :param forward_output: Output of the model after the forward pass. :param targets: Dict with target tensors needed to compute the loss for this trainer :returns: A dict with the (weighted) sub-losses to log. .. py:method:: train_step(batch, dist_model) Overriding this function is optional and, by default, the `train_step` of the model will be called and is expected to return a TrainerResult. Trainers can override this method to implement custom training logic. :param batch: Batch of data from which the loss is calculated. :param dist_model: Model to use for processing the data. :returns: Loss for backpropagation, (optionally) individual losses if multiple losses are used and (optionally) additional infos about the model forward pass that is passed to the callbacks (e.g., the logits and targets to calculate a training accuracy in a callback). .. py:method:: wrap_model(model) Wrap the model for training, return the model, wrapped model and ddp+compiled model. .. py:method:: wrap_ddp(model) Wrap the model with DistributedDataParallel in multi-GPU settings. .. py:method:: wrap_compile(ddp_model) Wrap the model with torch.compile. .. py:method:: train(model) Train the model. .. py:method:: drop_metadata(data) :staticmethod: .. py:method:: update(batch, dist_model, model = None, training = True, accumulation_steps = 1, iter_step = 0, **kwargs) Perform forward and backward pass. .. py:method:: call_before_training(callbacks) Hook that is called before training starts. .. py:method:: call_after_training(callbacks) Hook that is called after training ends. .. py:method:: eval(model) Run evaluation by executing all configured callbacks.