noether.core.models.base¶
Classes¶
Internal base class for all registry-based configs. |
|
Base class for models ( |
Module Contents¶
- class noether.core.models.base.ModelBaseConfig(/, **data)¶
Bases:
noether.core.schemas.lib._RegistryBaseInternal base class for all registry-based configs.
Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
data (Any)
- optimizer_config: noether.core.optimizer.schemas.AnyOptimizerConfig | None = None¶
The optimizer configuration to use for training the model. When a model is used for inference only, this can be left as None.
- initializers: list[Annotated[noether.core.initializers.AnyInitializer, Field(discriminator=kind)]] | None = None¶
List of initializers configs to use for the model.
- forward_properties: list[str] | None = []¶
List of properties to be used as inputs for the forward pass of the model. Only relevant when the train_step of the BaseTrainer is used. When overridden in a class method, this property is ignored.
- model_config¶
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class noether.core.models.base.ModelBase(model_config, update_counter=None, path_provider=None, data_container=None, initializer_config=None)¶
Bases:
torch.nn.ModuleBase class for models (
ModelandCompositeModel) that is used to define the interface for all models trainable by theBaseTrainer.Provides methods to initialize the model weights and setup (model-specific) optimizers.
- Parameters:
model_config (ModelBaseConfig) – Model configuration. See
ModelBaseConfigfor available options.update_counter (noether.core.utils.training.counter.UpdateCounter | None) – The
UpdateCounterprovided to the optimizer.path_provider (noether.core.providers.PathProvider | None) – A path
PathProviderused by the initializer to store or retrieve checkpoints.data_container (noether.data.container.DataContainer | None) – The
DataContainerwhich includes the data and dataloader. This is currently unused but helpful for quick prototyping only, evaluating forward in debug mode, etc.initializer_config (list[noether.core.initializers.InitializerConfig] | None) – The initializer config used to initialize the model e.g. from a checkpoint.
- logger¶
- name¶
- update_counter = None¶
- path_provider = None¶
- data_container = None¶
- initializers: list[noether.core.initializers.InitializerBase] = []¶
- model_config¶
- is_initialized = False¶
- property optimizer: noether.core.optimizer.OptimizerWrapper | None¶
- Return type:
- property device: torch.device¶
- Abstractmethod:
- Return type:
- property trainable_param_count: int¶
Returns the number of parameters that require gradients (i.e., are trainable).
- Return type:
- property frozen_param_count: int¶
Returns the number of parameters that do not require gradients (i.e., are frozen).
- Return type:
- property nograd_paramnames: list[str]¶
Returns a list of parameter names that do not have gradients (i.e., grad is None) but require gradients.
- initialize()¶
Initializes weights and optimizer parameters of the model.
- abstractmethod get_named_models()¶
Returns a dict of {model_name: model}, e.g., to log all learning rates of all models/submodels.
- abstractmethod initialize_weights()¶
Initialize the weights of the model.
- Return type:
Self
- abstractmethod apply_initializers()¶
Apply the initializers to the model.
- Return type:
Self
- abstractmethod initialize_optimizer()¶
Initialize the optimizer of the model.
- Return type:
None
- abstractmethod optimizer_step(grad_scaler)¶
Perform an optimization step.
- Parameters:
grad_scaler (torch.amp.grad_scaler.GradScaler | None)
- Return type:
None
- abstractmethod optimizer_schedule_step()¶
Perform the optimizer learning rate scheduler step.
- Return type:
None