noether.core.optimizer.optimizer_wrapper

Classes

OptimizerWrapper

Wrapper around an torch.optim.Optimizer that allows

Module Contents

class noether.core.optimizer.optimizer_wrapper.OptimizerWrapper(model, torch_optim_ctor, optim_wrapper_config, update_counter=None)
Wrapper around an torch.optim.Optimizer that allows
  • excluding biases and weights of normalization layers from weight decay

  • creating param_groups (e.g., for a layerwise lr scaling)

  • learning rate scheduling

  • gradient clipping

  • weight decay scheduling

Have a look at the class:noether.core.schemas.optimizers.OptimizerConfig for available options.

Parameters:
schedule: noether.core.utils.training.schedule_wrapper.ScheduleWrapper | None = None
weight_decay_schedule: noether.core.utils.training.schedule_wrapper.ScheduleWrapper | None = None
logger
model
update_counter = None
config
param_idx_to_name
torch_optim
all_parameters = None
step(grad_scaler=None)

Wrapper around torch.optim.Optimizer.step which automatically handles: - gradient scaling for mixed precision (including updating the GradientScaler state) - gradient clipping - calling the .step function of the optimizer

Parameters:

grad_scaler (torch.amp.grad_scaler.GradScaler | None)

Return type:

None

schedule_step()

Applies the current state of the schedules to the parameter groups.

Return type:

None

zero_grad(set_to_none=True)

Wrapper around torch.optim.Optimizer.zero_grad.

state_dict()

Wrapper around torch.optim.Optimizer.state_dict. Additionally adds info about index to name mapping.

Return type:

dict[str, Any]

load_state_dict(state_dict_to_load)

Wrapper around torch.optim.Optimizer.load_state_dict. Additionally handles edge cases if the parameter groups of the loaded state_dict do not match the current configuration. By default, torch would overwrite the current parameter groups with the one from the checkpoint. This is undesireable in the following cases: - add new parameters (e.g. unfreeze something) - change weight_decay or other param_group properties: the load_state_dict would overwrite the actual

weight_decay (defined in the constructor of the OptimizerWrapper) with the weight_decay from the checkpoint

Parameters:

state_dict_to_load (dict[str, Any]) – The optimizer state to load.

Return type:

None