noether.core.schemas.models.upt

Classes

UPTConfig

Configuration for a UPT model.

Module Contents

class noether.core.schemas.models.upt.UPTConfig(/, **data)

Bases: noether.core.schemas.models.base.ModelBaseConfig, noether.core.schemas.mixins.InjectSharedFieldFromParentMixin

Configuration for a UPT model.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

model_config

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

num_heads: int = None

Number of attention heads in the model.

hidden_dim: int = None

Hidden dimension of the model.

mlp_expansion_factor: int = None

Expansion factor for the MLP of the FF layers.

approximator_depth: int = None

Number of approximator layers.

use_rope: bool = None
supernode_pooling_config: Annotated[noether.core.schemas.modules.SupernodePoolingConfig, noether.core.schemas.mixins.Shared]
approximator_config: Annotated[noether.core.schemas.modules.blocks.TransformerBlockConfig, noether.core.schemas.mixins.Shared]
decoder_config: Annotated[noether.core.schemas.modules.DeepPerceiverDecoderConfig, noether.core.schemas.mixins.Shared]
bias_layers: bool = None
data_specs: noether.core.schemas.dataset.AeroDataSpecs
linear_output_projection_config()
Return type:

noether.core.schemas.modules.layers.LinearProjectionConfig

rope_frequency_config()
Return type:

noether.core.schemas.modules.layers.RopeFrequencyConfig

validate_rope_usage()

Ensure that if use_rope is True in the main config, it is also True in the approximator_config.

Return type:

UPTConfig

update_supernode_pooling_config()

Inject shared fields into supernode_pooling_config.

Return type:

UPTConfig

pos_embedding_config()
Return type:

noether.core.schemas.modules.layers.ContinuousSincosEmbeddingConfig

validate_parameters()

Validate validity of parameters across the model and its submodules.

Ensures that: 1. hidden_dim is divisible by num_heads in parent and all submodules with num_heads 2. hidden_dim is consistent across parent and all submodules

Return type:

UPTConfig