noether.data.base.dataset

Attributes

Classes

DatasetBaseConfig

Internal base class for all registry-based configs.

StandardDatasetConfig

Base config for datasets with fixed splits.

DatasetSplitIDs

Base class for dataset split ID validation with overlap checking.

Dataset

Noether dataset implementation, which is a wrapper around torch.utils.data.Dataset that can hold a dataset_config_provider.

Functions

with_normalizers([_func_or_key])

Decorator to apply a normalizer to the output of a getitem_* function of the implemented Dataset class.

Module Contents

noether.data.base.dataset.logger
noether.data.base.dataset.TPipelineConfig
class noether.data.base.dataset.DatasetBaseConfig[TPipelineConfig: noether.data.pipeline.PipelineConfig](/, **data)

Bases: noether.core.schemas.lib._RegistryBase

Internal base class for all registry-based configs.

Provides auto-registration via __init_subclass__. Not meant to be used directly - use specific config base classes instead.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

kind: str | None = None

Kind of dataset to use.

pipeline: Annotated[TPipelineConfig | None, Discriminated(PipelineConfig)] = None

Config of the pipeline to use for the dataset.

dataset_normalizers: dict[str, list[Annotated[Any, Discriminated(NormalizerConfig)]] | Annotated[Any, Discriminated(NormalizerConfig)]] | None = None

List of normalizers to apply to the dataset. The key is the data source name.

dataset_wrappers: list[noether.data.base.wrappers.DatasetWrappers] | None = None
included_properties: set[str] | None = None

Set of properties (i.e., getitem_* methods that are called) of this dataset that will be loaded, if not set all properties are loaded

excluded_properties: set[str] | None = None

Set of properties of this dataset that will NOT be loaded, even if they are present in the included list

model_config

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class noether.data.base.dataset.StandardDatasetConfig(/, **data)

Bases: DatasetBaseConfig, abc.ABC

Base config for datasets with fixed splits.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

root: str

Root directory of the dataset.

split: Literal['train', 'val', 'test']

Which split of the dataset to use. Must be one of “train”, “val”, or “test”.

class noether.data.base.dataset.DatasetSplitIDs(/, **data)

Bases: pydantic.BaseModel, abc.ABC

Base class for dataset split ID validation with overlap checking.

This base class provides: 1. Automatic validation that train/val/test splits don’t have overlapping IDs 2. Optional size validation for datasets that have expected split sizes

Subclasses can optionally define class variables for size validation: - EXPECTED_TRAIN_SIZE: Expected number of training samples - EXPECTED_VAL_SIZE: Expected number of validation samples - EXPECTED_TEST_SIZE: Expected number of test samples - DATASET_NAME: Name of the dataset for error messages

If these are not defined, only overlap checking will be performed.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

EXPECTED_TRAIN_SIZE: ClassVar[int | None] = None
EXPECTED_VAL_SIZE: ClassVar[int | None] = None
EXPECTED_TEST_SIZE: ClassVar[int | None] = None
EXPECTED_HIDDEN_TEST_SIZE: ClassVar[int | None] = None
DATASET_NAME: ClassVar[str | None] = None
train: list[int]
val: list[int]
test: list[int]
extrap: list[int] = []
interp: list[int] = []
train_subset: list[int] = []
validate_splits()

Validate splits and check for overlaps.

noether.data.base.dataset.with_normalizers(_func_or_key=None)

Decorator to apply a normalizer to the output of a getitem_* function of the implemented Dataset class.

This decorator will look for a normalizer registered under the specified key and apply it to the output of the decorated function. If no key is provided, the key is automatically inferred from the function name by removing the ‘getitem_’ prefix.

Example usage:

# Inferred key: "surface_pressure"
@with_normalizers
def getitem_surface_pressure(self, idx):
    return torch.load(f"{self.path}/surface_pressure/{idx}.pt")


# Explicit key: "pressure"
@with_normalizers("pressure")
def getitem_surface_pressure(self, idx):
    return torch.load(f"{self.path}/surface_pressure/{idx}.pt")
Parameters:

_func_or_key (str | Any | None) – The normalizer key (str) or the function being decorated. If used as @with_normalizers (no arguments), this will be the decorated function. If used as @with_normalizers(“key”), this will be the string key.

Returns:

The decorated function with normalization applied.

Raises:
  • ValueError – If the normalizer key cannot be resolved from the function name.

  • AttributeError – If the class instance does not have a ‘normalizers’ attribute.

  • KeyError – If the requested normalizer key is not found in the ‘normalizers’ dictionary.

class noether.data.base.dataset.Dataset(dataset_config)

Bases: torch.utils.data.Dataset

Noether dataset implementation, which is a wrapper around torch.utils.data.Dataset that can hold a dataset_config_provider. A dataset should map a key (i.e., an index) to its corresponding data. Each sub-class should implement individual getitem_* methods, where * is the name of an item in the dataset. Each getitem_* method loads an individual tensor/data sample from disk. For example, if you dataset consists of images and targets/labels (stored as tensors), a getitem_image(idx) and getitem_target(idx) method should be implemented in the dataset subclass. The __getitem__ method of this class will loop over all the individual getitem_* methods implemented by the child class and return their results. Optionally it is possible to configure which getitem methods are called.

Example: Image classification datasets

class CarAeroDynamicsDataset(Dataset):
    def __init__(self, dataset_config, dataset_normalizers, **kwargs):
        super().__init__(dataset_config=dataset_config, **kwargs)
        self.path = dataset_config.path

    def __len__(self):
        return 100  # Example length

    def getitem_surface_pressure(self, idx):
        # Load surface pressure tensor
        return torch.load(f"{self.path}/surface_pressure_tensor/{idx}.pt")

    def getitem_surface_geometry(self, idx):
        # Load surface geometry tensor
        return torch.load(f"{self.path}/surface_geometry_tensor/{idx}.pt")


dataset = CarAeroDynamicsDataset("path/to/dataset")
sample0 = dataset[0]
surface_pressure_0 = sample0["surface_pressure"]
surface_geometry_0 = sample0["surface_geometry"]

Data from a getitem method should be normalized in many cases. To apply normalization, add a the decorator function to the getitem method. For example:

@with_normalizers("surface_pressure")
def getitem_surface_pressure(self, idx):
    # Load surface pressure tensor
    return torch.load(f"{self.path}/surface_pressure_tensor/{idx}.pt")

“surface_pressure” is the key in the self.normalizers dictionary, this key maps to a preprocessor that should implement the correct data normalization.

Example configuration for dataset normalizers:

# dummy example configuration for an image classification
dataset:
    kind: noether.data.datasets.CarAeroDynamicsDataset
    pipeline:  # configure the data pipeline to collate individual samples into batches
    dataset_normalizers:
        surface_pressure:
            - kind: noether.data.preprocessors.normalizers.MeanStdNormalization
              mean: [1., 2., 3.]
              std: [0.1, 0.2, 0.3]
Parameters:

dataset_config (DatasetBaseConfig) – Configuration for the dataset. See DatasetBaseConfig for available options including dataset normalizers.

logger
config
normalizers: dict[str, noether.data.preprocessors.ComposePreProcess]
compute_statistics = False
fetch_statistics()

Load and cache dataset statistics from the dataset’s STATS_FILE.

By default looks for a STATS_FILE class attribute on the dataset class (or its ancestors). The file should be a YAML file mapping stat names to scalar or list values.

Returns:

Dict mapping stat names to float values or lists of floats.

Return type:

dict[str, list[float] | float] | None

property pipeline: noether.data.pipeline.Collator | None

Returns the pipeline for the dataset.

Return type:

noether.data.pipeline.Collator | None

pre_getitem(idx)

Optional hook called once before the individual getitem_* methods.

Override this to load shared data (e.g. an HDF5 file that contains multiple fields) and return it as a dictionary. The returned dict is forwarded as keyword arguments to every getitem_* call for the same sample, so each getter can pull its field without re-opening the file.

The default implementation returns an empty dict

Parameters:

idx (int)

Return type:

dict[str, Any] | None

post_getitem(idx, pre)

Optional hook called once after all getitem_* methods have run.

Override this to perform per-sample cleanup (e.g. closing a file handle that was opened in pre_getitem()).

The pre argument is the value originally returned by pre_getitem() so that the cleanup logic can access the same resources.

The default implementation does nothing.

Parameters:
Return type:

None

get_all_getitem_names()

Returns all names of getitem functions that are implemented. E.g., image classification has getitem_x and getitem_class -> the result will be [“x”, “class”].

Return type:

list[str]

denormalize(key, data)

Denormalize data using the appropriate normalizer.

This method finds the specific normalizer for the given key and uses it to denormalize, instead of calling pipeline.denormalize which would process the entire pipeline.

Parameters:
  • key (str) – Key to identify the normalizer for denormalization

  • data – Data to denormalize

Returns:

Denormalized data

Raises:

KeyError – If no normalizer is found for the given key