noether.core.schemas.callbacks

Attributes

Classes

Module Contents

class noether.core.schemas.callbacks.CallBackBaseConfig(/, **data)

Bases: pydantic.BaseModel

Parameters:

data (Any)

name: str
kind: str | None = None
id: str | None = None

Optional unique identifier for this callback instance. Required when multiple stateful callbacks of the same type exist (e.g., two BestCheckpointCallbacks tracking different metrics). Used as the key when saving/loading callback state dicts to ensure correct matching on resume.

every_n_epochs: int | None = None

Epoch-based interval. Invokes the callback after every n epochs. Mutually exclusive with other intervals.

every_n_updates: int | None = None

Update-based interval. Invokes the callback after every n updates. Mutually exclusive with other intervals.

every_n_samples: int | None = None

Sample-based interval. Invokes the callback after every n samples. Mutually exclusive with other intervals.

batch_size: int | None = None

None (use the same batch_size as for training).

Type:

Batch size to use for this callback. Default

model_config

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

validate_callback_frequency()

Ensures that exactly one frequency (‘every_n_*’) is specified and that ‘batch_size’ is present if ‘every_n_samples’ is used.

Return type:

CallBackBaseConfig

classmethod check_positive_values(v)

Ensures that all integer-based frequency and batch size fields are positive.

Parameters:

v (int | None)

Return type:

int | None

classmethod check_kind_is_not_empty(v)

Ensures the ‘kind’ field is a non-empty string.

Parameters:

v (str)

Return type:

str

class noether.core.schemas.callbacks.PeriodicDataIteratorCallbackConfig(/, **data)

Bases: CallBackBaseConfig

Parameters:

data (Any)

name: str
dataset_key: str = None

The key of the dataset to be used for the loss calculation. Can be any key that is registered in the DataContainer.

class noether.core.schemas.callbacks.BestCheckpointCallbackConfig(/, **data)

Bases: CallBackBaseConfig

Parameters:

data (Any)

name: Literal['BestCheckpointCallback'] = None
metric_key: str = None

“The key of the metric to be used for checking the best model.

save_frozen_weights: bool = None

Whether to also save the frozen weights of the model.

tolerances: list[int] | None = None

“If provided, this callback will produce multiple best models which differ in the amount of intervals they allow the metric to not improve. For example, tolerance=[5] with every_n_epochs=1 will store a checkpoint where at most 5 epochs have passed until the metric improved. Additionally, the best checkpoint over the whole training will always be stored (i.e., tolerance=infinite). When setting different tolerances, one can evaluate different early stopping configurations with one training run.

model_names: list[str] | None = None

Which model name to save (e.g., if only the encoder of an autoencoder should be stored, one could pass model_name=’encoder’ here). This only applies when training a CompositeModel. If None, all models are saved.

class noether.core.schemas.callbacks.CheckpointCallbackConfig(/, **data)

Bases: CallBackBaseConfig

Parameters:

data (Any)

name: Literal['CheckpointCallback'] = None
save_weights: bool = None

Whether to save the weights of the model every time this callback is invoked. The checkpoint name will contain the training iteration (e.g., epoch/update/sample) at which the checkpoint was saved.

save_optim: bool = None

Whether to save the optimizer state every time this callback is invoked. The checkpoint name will contain the training iteration (e.g., epoch/update/sample) at which the checkpoint was saved.

save_latest_weights: bool = None

Whether to save the latest weights of the model every time this callback is invoked. Note that the latest weights are always overwritten on the next invocation of this callback.

save_latest_optim: bool = None

Whether to save the latest optimizer state every time this callback is invoked. Note that the latest optimizer state is always overwritten on the next invocation of this callback

model_names: list[str] | None = None

The name of the model to save. If None, all models are saved.

class noether.core.schemas.callbacks.EmaCallbackConfig(/, **data)

Bases: CallBackBaseConfig

Parameters:

data (Any)

name: Literal['EmaCallback'] = None
target_factors: list[float] = None

The factors for the EMA.

model_paths: list[str | None] | None = None

The paths to the models to apply the EMA to (i.e., composite_model.encoder/composite_model.decoder, path of the PyTorch nn.Modules in the checkpoint). If None, the EMA is applied to the whole model. When training with a CompositeModel, the paths on the submodules (i.e., ‘encoder’, ‘decoder’, etc.) should be provided via this field, otherwise the EMA will be applied to the CompositeModel as a whole which is not possible to restore later on.

save_weights: bool = None

Whether to save the EMA weights.

save_last_weights: bool = None

Save the weights of the model when training is over (i.e., at the end of training, save the EMA weights).

save_latest_weights: bool = None

Save the latest EMA weights. Note that the latest weights are always overwritten on the next invocation of this callback.

class noether.core.schemas.callbacks.OnlineLossCallbackConfig(/, **data)

Bases: CallBackBaseConfig

Parameters:

data (Any)

name: Literal['OnlineLossCallback'] = None
verbose: bool = None

Whether to also log to the (console) logger. If False, the loss will only logged to the experiment tracker.

class noether.core.schemas.callbacks.BestMetricCallbackConfig(/, **data)

Bases: CallBackBaseConfig

Parameters:

data (Any)

name: Literal['BestMetricCallback'] = None

The metric to use to dermine whether the current model obtained a new best (e.g., loss/valid/total)

source_metric_key: str = None

The metrics to keep track of (e.g., loss/test/total)

target_metric_keys: list[str] | None = None

The metrics to keep track of if they are present (useful when different model configurations log different evaluation metrics to avoid reconfiguring the callback).

optional_target_metric_keys: list[str] | None = None
class noether.core.schemas.callbacks.TrackAdditionalOutputsCallbackConfig(/, **data)

Bases: CallBackBaseConfig

Parameters:

data (Any)

name: Literal['TrackAdditionalOutputsCallback'] = None
keys: list[str] | None = None

List of keys to track in the additional_outputs of the TrainerResult returned by the trainer’s update step.

patterns: list[str] | None = None

List of patterns to track in the additional_outputs of the TrainerResult returned by the trainer’s update step. Matched if it is contained in one of the update_outputs keys.

verbose: bool = None

If True uses the logger to print the tracked values otherwise uses no logger.

reduce: Literal['mean', 'last'] = None

The reduction method to be applied to the tracked values to reduce to scalar. Currently supports ‘mean’ and ‘last’.

log_output: bool = None

Whether to log the tracked scalar values.

save_output: bool = None

Whether to save the tracked scalar values to disk.

class noether.core.schemas.callbacks.OfflineLossCallbackConfig(/, **data)

Bases: PeriodicDataIteratorCallbackConfig

Parameters:

data (Any)

name: Literal['OfflineLossCallback'] = None
output_patterns_to_log: list[str] | None = None

additional arguments passed to the parent class.

Type:

For instance, if the output key is ‘some_loss’ and the pattern is [‘loss’]. **kwargs

class noether.core.schemas.callbacks.MetricEarlyStopperConfig(/, **data)

Bases: CallBackBaseConfig

Parameters:

data (Any)

name: Literal['MetricEarlyStopper'] = None
metric_key: str

The key of the metric to monitor

tolerance: int

The number of times the metric can stagnate before stopping training

classmethod check_tolerance_positive(v)

Ensures that tolerance is at least 1.

Parameters:

v (int)

Return type:

int

class noether.core.schemas.callbacks.FixedEarlyStopperConfig(/, **data)

Bases: pydantic.BaseModel

Parameters:

data (Any)

kind: str | None = None
name: Literal['FixedEarlyStopper'] = None
stop_at_sample: int | None = None
stop_at_update: int | None = None
stop_at_epoch: int | None = None
validate_callback_frequency()

Ensures that exactly one stop (‘stop_at_*’) is specified

Return type:

FixedEarlyStopperConfig

noether.core.schemas.callbacks.CallbacksConfig