noether.inference.run

Notebook-friendly Python API for loading a trained run.

The non-Hydra counterpart to noether-eval: instead of spinning up an InferenceRunner (with trainer context, callbacks, tracker, etc.), it gives you a single handle to a run from which you can pull the resolved config, an instantiated dataset, and a model with checkpoint weights loaded.

Two ways to build a Run:

  • Run (run_dir) — open a full training output directory (hp_resolved.yaml + checkpoints/). Gives access to the resolved config, the dataset, normalizers, and the model.

  • Run.from_checkpoint() (path) — open just a single ..._model.th file. Every checkpoint written by noether’s CheckpointWriter embeds the model config, the discriminator kind, and the per-field normalizer specs + statistics, which is enough for model() and normalizers() without the run directory. dataset() and config are unavailable in this mode.

from noether.inference import Run

# Full run directory.
run = Run("/outputs/2026-04-09_abc12")
for ds in run.config.datasets.values():
    ds.root = "/local/path/to/data"
dataset = run.dataset("test")
model = run.model(checkpoint="latest", device="cuda")

# Single checkpoint file — no run dir, no hp_resolved.yaml, no stats file.
run = Run.from_checkpoint("/outputs/.../checkpoints/ab_upt_cp=last_model.th")
model = run.model(device="cuda")
norms = run.normalizers()

For reproducible eval with metrics, callbacks, and full logging, use noether-eval instead.

Classes

Run

Handle to a trained run.

Functions

sanitize_hp_resolved(hp_resolved_path)

Write a tag-free copy of hp_resolved.yaml to a temp file.

Module Contents

noether.inference.run.sanitize_hp_resolved(hp_resolved_path)

Write a tag-free copy of hp_resolved.yaml to a temp file.

Equivalent to _load_hp_resolved_as_plain_dict() plus a yaml.safe_dump to a fresh tempdir. Kept for external callers; internal loading goes through the in-memory helper to avoid leaking tempdirs.

Parameters:

hp_resolved_path (pathlib.Path)

Return type:

pathlib.Path

class noether.inference.run.Run(run_dir)

Handle to a trained run.

Two construction modes, picked by which constructor you use:

Mutate config between construction and the lazy methods to override training-time values (typically dataset roots when the run was produced on a different machine). Only meaningful in run-dir mode.

Parameters:

run_dir (pathlib.Path | str) – Path to the training run output directory (the one that contains hp_resolved.yaml and a checkpoints/ subdirectory). Typically output_path/run_id or output_path/run_id/stage_name.

run_dir

Resolved absolute path to the run directory in run-dir mode; None in checkpoint-only mode.

checkpoint_path

Resolved absolute path to the .th file in checkpoint-only mode; None in run-dir mode.

Raises:

FileNotFoundError – If run_dir does not exist or doesn’t contain hp_resolved.yaml.

Parameters:

run_dir (pathlib.Path | str)

Example

from noether.inference import Run

# Bring-your-own-data flow: apply the trained model to a custom input dict, then denormalize the predictions.
run = Run.from_checkpoint("/outputs/.../ab_upt_cp=last_model.th")
model = run.model(device="cuda")
norms = run.normalizers()
with torch.inference_mode():
    pred = model(**my_inputs)
pred_phys = norms["surface_pressure"].inverse(pred["surface_pressure"])
run_dir: pathlib.Path | None
checkpoint_path: pathlib.Path | None = None
classmethod from_checkpoint(checkpoint_path)

Build a Run from a single ..._model.th file.

Reads the model config (CheckpointKeys.MODEL_CONFIG), the discriminator kind (CheckpointKeys.CONFIG_KIND), and — if present — the per-field normalizer payload (CheckpointKeys.NORMALIZER_CONFIGS / CheckpointKeys.NORMALIZER_STATISTICS) that CheckpointWriter embeds in every checkpoint.

The model class itself must still be importable in the current process — the kind string points at a class, not at its implementation. If the checkpoint references a recipe-specific model, make sure that recipe is installed (or on sys.path) before calling.

Parameters:

checkpoint_path (pathlib.Path | str) – Path to a ..._model.th file written by noether.

Returns:

A Run in checkpoint-only mode. model() and normalizers() are usable; dataset() and config raise.

Raises:
  • FileNotFoundError – If the checkpoint file does not exist.

  • KeyError – If the checkpoint is missing any of state_dict, model_config, or config_kind (older checkpoints predate the embedded config — fall back to Run(run_dir)).

Return type:

Run

property is_checkpoint_only: bool

True if this Run was built via from_checkpoint() (no run dir, no resolved config).

Return type:

bool

property config: noether.core.schemas.schema.ConfigSchema

Validated ConfigSchema loaded from hp_resolved.yaml.

Safe to mutate before calling dataset() / model() / normalizers().

Raises:

RuntimeError – If this Run was built via from_checkpoint() — no run directory means no resolved config.

Return type:

noether.core.schemas.schema.ConfigSchema

property statistics: dict[str, list[float | int]]

Training-time dataset statistics (config.dataset_statistics or {}).

Convenience accessor for the stat values the training run computed — typically per-field means/stds used by the trainer’s pipeline. Returns an empty dict if the run didn’t compute any stats.

Note: this is separate from the dataset class’s static STATS_FILE, which normalizers() reads in run-dir mode.

Raises:

RuntimeError – In checkpoint-only mode (no resolved config).

Return type:

dict[str, list[float | int]]

normalizers(split='test')

Build the trained run’s field normalizers without instantiating its dataset.

In run-dir mode, reads the dataset class’s STATS_FILE (looked up from config.datasets[split].kind) and constructs each normalizer from config.datasets[split].dataset_normalizers. The data root is never touched.

In checkpoint-only mode, reads the per-field preprocessor configs and resolved statistics that CheckpointWriter embeds in every checkpoint (NORMALIZER_CONFIGS / NORMALIZER_STATISTICS). The split argument is ignored — only the writer-side split (typically test) was captured.

Parameters:

split (str) – Dataset key to source the normalizer configs from. Splits typically share normalizers; the arg is provided for parity with dataset(). Ignored in checkpoint-only mode.

Returns:

Dict mapping field name (e.g. "surface_pressure") to a ComposePreProcess. Empty dict if no normalizers are available for this split.

Raises:

KeyError – In run-dir mode, if split is not in self.config.datasets. In checkpoint-only mode, if the checkpoint predates the embedded normalizer keys.

Return type:

dict[str, noether.data.preprocessors.compose.ComposePreProcess]

dataset(split='test')

Instantiate the dataset for split.

Wires up the collator (dataset.pipeline) the same way the trainer does, so the dataset can be plugged into a torch.utils.data.DataLoader for batched forward passes.

Parameters:

split (str) – Dataset key (e.g. "train", "val", "test").

Raises:
  • RuntimeError – In checkpoint-only mode (the checkpoint doesn’t know about the original dataset configuration).

  • KeyError – If split is not in self.config.datasets.

Return type:

noether.data.base.dataset.Dataset

model(*, checkpoint='latest', device='cpu')

Instantiate the model and load checkpoint weights.

Unlike the training/eval flow, this does not set up an optimizer, apply initializers, or attach the model to a trainer — it just builds the model, loads the state dict, moves it to device, and puts it in eval mode.

Parameters:
  • checkpoint (str) – Checkpoint tag (run-dir mode only). Defaults to "latest". Other examples: "E10", "best_model.loss.test.total". Ignored in checkpoint-only mode — the file was already fixed at from_checkpoint() time.

  • device (str | torch.device) – Torch device (or string) to move the model to.

Returns:

The model in eval mode with weights loaded.

Raises:
  • FileNotFoundError – If the checkpoint file does not exist (run-dir mode).

  • KeyError – If the checkpoint is missing state_dict.

  • RuntimeError – If loading the state dict did not actually change the model weights (sanity check against silently missing or mismatched keys).

Return type:

torch.nn.Module