noether.modeling.diffusion

Diffusion / flow-matching schedules.

Tensor-only flow-matching implementation behind a common DiffusionSchedule ABC. Pair with the configs in noether.core.schemas.diffusion.

Use build_schedule() to instantiate the right schedule from a config that came out of the AnyDiffusionScheduleConfig discriminated union.

Example:

import torch
from noether.core.schemas.diffusion import FlowMatchingConfig
from noether.modeling.diffusion import build_schedule

schedule = build_schedule(FlowMatchingConfig()).to("cpu")
x0 = torch.randn(4, 16)

def model_fn(xt, t, condition):
    return torch.zeros_like(xt)

loss = schedule.training_losses(x0, model_fn)

Submodules

Attributes

AnyDiffusionScheduleConfig

Discriminated union of all built-in diffusion schedule configurations.

Classes

DiffusionSchedule

Abstract base for diffusion paradigms.

FlowMatchingConfig

Rectified flow matching with optional minibatch optimal transport.

FlowMatchingSchedule

Rectified flow matching with optional minibatch optimal transport.

Functions

build_schedule(config)

Instantiate the right DiffusionSchedule for config.

Package Contents

class noether.modeling.diffusion.DiffusionSchedule

Bases: abc.ABC

Abstract base for diffusion paradigms.

All schedule state (alphas, sigmas, etc.) is stored as plain tensors. Call to() before use to move the buffers to the target device.

device: torch.device
to(device)

Move every Tensor attribute to device in place.

Returns self for chainability.

Parameters:

device (torch.device | str)

Return type:

Self

abstractmethod training_losses(x0, model_fn, condition=None)

Compute scalar training loss given clean samples x0.

Parameters:
Returns:

Scalar training loss.

Return type:

torch.Tensor

abstractmethod sample(shape, model_fn, condition=None, steps=50)

Generate samples from noise.

Parameters:
Returns:

Clean samples x0 of shape shape.

Return type:

torch.Tensor

noether.modeling.diffusion.AnyDiffusionScheduleConfig

Discriminated union of all built-in diffusion schedule configurations.

Pydantic resolves the right variant by inspecting the kind field. Pair with build_schedule() to materialize the schedule object.

noether.modeling.diffusion.build_schedule(config)

Instantiate the right DiffusionSchedule for config.

Parameters:

config (AnyDiffusionScheduleConfig) – Any variant of AnyDiffusionScheduleConfig.

Returns:

A DiffusionSchedule matching the variant’s kind.

Raises:

ValueError – If config is not a recognised schedule config.

Return type:

noether.modeling.diffusion.base.DiffusionSchedule

class noether.modeling.diffusion.FlowMatchingConfig(/, **data)

Bases: pydantic.BaseModel

Rectified flow matching with optional minibatch optimal transport.

Discriminator: kind = "flow_matching". Linear interpolation path xt = t * x1 + (1-t) * x0; the network predicts the velocity v = x1 - x0.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

model_config

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

kind: Literal['flow_matching'] = 'flow_matching'
continuous_time: bool = True

If True, sample t with logit-normal; otherwise uniform on [0, 1].

minibatch_ot: bool = False

If True, reorder the noise samples within a minibatch via optimal transport against the data (Pooladian et al. 2023). Requires SciPy.

class noether.modeling.diffusion.FlowMatchingSchedule(config)

Bases: noether.modeling.diffusion.base.DiffusionSchedule

Rectified flow matching with optional minibatch optimal transport.

Linear interpolation path xt = t * x1 + (1-t) * x0; the network predicts the velocity v = x1 - x0. Logit-normal time sampling for training when continuous_time=True.

Parameters:

config (FlowMatchingConfig)

config
noise_pair(x1, t)

Noise clean data x1 at time t.

Returns:

Tuple (xt, target_velocity).

Parameters:
Return type:

tuple[torch.Tensor, torch.Tensor]

training_losses(x0, model_fn, condition=None)

Compute scalar training loss given clean samples x0.

Parameters:
  • x0 – Clean training samples.

  • model_fn – Callable with signature (noisy_input, timestep_or_sigma, condition) -> prediction.

  • condition – Optional conditioning tensor passed through to model_fn.

Returns:

Scalar training loss.

sample(shape, model_fn, condition=None, steps=10)

Generate samples from noise.

Parameters:
  • shape – Output tensor shape.

  • model_fn – Callable with signature (noisy_input, timestep_or_sigma, condition) -> prediction.

  • condition – Optional conditioning tensor.

  • steps – Number of solver steps.

Returns:

Clean samples x0 of shape shape.

training_losses_joint(x0_list, model_fn, condition=None)

Joint flow-matching loss over multiple clean tensors sharing batch dim.

All tensors are noised to the SAME t per example. model_fn receives (xt_list, t, condition) and must return a list of velocity predictions aligned with x0_list. Returns mean MSE across tensors.

sample_joint(shapes, model_fn, condition=None, steps=10)

Joint Euler sampling over multiple tensors sharing batch dim.

model_fn receives (xt_list, t, condition) and must return a list of velocities aligned with shapes. Returns the list of clean samples.