noether.core.optimizer

Submodules

Classes

Lion

Implements Lion algorithm.

OptimizerWrapper

Wrapper around an torch.optim.Optimizer that allows

LrScaleByNameModifier

Scales the learning rate of a certain parameter.

ParamGroupModifierBase

Generic implementation to change properties of optimizer parameter groups.

WeightDecayByNameModifier

Changes the weight decay value for a single parameter. Use-cases:

Package Contents

class noether.core.optimizer.Lion(params, lr, betas=(0.9, 0.99), weight_decay=0.0, caution=False, maximize=False, foreach=None)

Bases: torch.optim.optimizer.Optimizer

Implements Lion algorithm.

Initialize the hyperparameters.

Parameters:
  • params (torch.optim.optimizer.ParamsT) – iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – learning rate

  • betas (tuple[float, float]) – coefficients used for computing running averages of gradient and its square

  • weight_decay (float) – weight decay coefficient

  • caution (bool) – apply caution

  • maximize (bool)

  • foreach (bool | None)

step(closure=None)

Performs a single optimization step.

Parameters:

closure – A closure that reevaluates the model and returns the loss.

Returns:

the loss.

class noether.core.optimizer.OptimizerWrapper(model, torch_optim_ctor, optim_wrapper_config, update_counter=None)
Wrapper around an torch.optim.Optimizer that allows
  • excluding biases and weights of normalization layers from weight decay

  • creating param_groups (e.g., for a layerwise lr scaling)

  • learning rate scheduling

  • gradient clipping

  • weight decay scheduling

Have a look at the class:noether.core.schemas.optimizers.OptimizerConfig for available options.

Parameters:
schedule: noether.core.utils.training.schedule_wrapper.ScheduleWrapper | None = None
weight_decay_schedule: noether.core.utils.training.schedule_wrapper.ScheduleWrapper | None = None
logger
model
update_counter = None
config
param_idx_to_name
torch_optim
all_parameters = None
step(grad_scaler=None)

Wrapper around torch.optim.Optimizer.step which automatically handles: - gradient scaling for mixed precision (including updating the GradientScaler state) - gradient clipping - calling the .step function of the optimizer

Parameters:

grad_scaler (torch.amp.grad_scaler.GradScaler | None)

Return type:

None

schedule_step()

Applies the current state of the schedules to the parameter groups.

Return type:

None

zero_grad(set_to_none=True)

Wrapper around torch.optim.Optimizer.zero_grad.

state_dict()

Wrapper around torch.optim.Optimizer.state_dict. Additionally adds info about index to name mapping.

Return type:

dict[str, Any]

load_state_dict(state_dict_to_load)

Wrapper around torch.optim.Optimizer.load_state_dict. Additionally handles edge cases if the parameter groups of the loaded state_dict do not match the current configuration. By default, torch would overwrite the current parameter groups with the one from the checkpoint. This is undesireable in the following cases: - add new parameters (e.g. unfreeze something) - change weight_decay or other param_group properties: the load_state_dict would overwrite the actual

weight_decay (defined in the constructor of the OptimizerWrapper) with the weight_decay from the checkpoint

Parameters:

state_dict_to_load (dict[str, Any]) – The optimizer state to load.

Return type:

None

class noether.core.optimizer.LrScaleByNameModifier(param_group_modifier_config)

Bases: noether.core.optimizer.param_group_modifiers.base.ParamGroupModifierBase

Scales the learning rate of a certain parameter.

Parameters:

param_group_modifier_config (noether.core.schemas.optimizers.ParamGroupModifierConfig)

scale
name
param_was_found = False
get_properties(model, name, param)

This method is called with all items of model.named_parameters() to compose the parameter groups for the whole model. If the desired parameter name is found, it returns a modifier that scales down the learning rate.

Parameters:
  • model (torch.nn.Module) – Model from which the parameter originates from. Used to extract properties (e.g., number of layers for a layerwise learning rate decay).

  • name (str) – Name of the parameter as stored inside the model.

  • param (torch.Tensor) – The parameter tensor.

Return type:

dict[str, float]

was_applied_successfully()

Check if the parameter was found within the model.

Return type:

bool

class noether.core.optimizer.ParamGroupModifierBase

Generic implementation to change properties of optimizer parameter groups.

abstractmethod get_properties(model, name, param)

Returns the modified properties for a given model parameter. This method is called with all items of model.named_parameters() to compose the parameter groups for the whole model.

Parameters:
  • model (torch.nn.Module) – Model from which the parameter originates from. Used to extract properties (e.g., number of layers for a layerwise learning rate decay).

  • name (str) – Name of the parameter as stored inside the model.

  • param (torch.Tensor) – The parameter tensor.

Return type:

dict[str, float]

abstractmethod was_applied_successfully()

Checks if the parameter group modifier was applied successfully.

Return type:

bool

class noether.core.optimizer.WeightDecayByNameModifier(param_group_modifier_config)

Bases: noether.core.optimizer.param_group_modifiers.base.ParamGroupModifierBase

Changes the weight decay value for a single parameter. Use-cases: - ViT exclude CLS token parameters - Transformer learned positional embeddings - Learnable query tokens for cross attention (“PerceiverPooling”)

Parameters:

param_group_modifier_config (noether.core.schemas.optimizers.ParamGroupModifierConfig)

name
value
param_was_found = False
get_properties(model, name, param)

This method is called with all items of model.named_parameters() to compose the parameter groups for the whole model. If the desired parameter name is found, it returns a modifier that sets the weight decay.

Parameters:
  • model (torch.nn.Module) – Model from which the parameter originates from. Used to extract properties (e.g., number of layers for a layerwise learning rate decay).

  • name (str) – Name of the parameter as stored inside the model.

  • param (torch.Tensor) – The parameter tensor.

Return type:

dict[str, float]

was_applied_successfully()

Check if the parameter was found within the model.