noether.modeling.diffusion.base

Classes

DiffusionSchedule

Abstract base for diffusion paradigms.

Module Contents

class noether.modeling.diffusion.base.DiffusionSchedule

Bases: abc.ABC

Abstract base for diffusion paradigms.

All schedule state (alphas, sigmas, etc.) is stored as plain tensors. Call to() before use to move the buffers to the target device.

device: torch.device
to(device)

Move every Tensor attribute to device in place.

Returns self for chainability.

Parameters:

device (torch.device | str)

Return type:

Self

abstractmethod training_losses(x0, model_fn, condition=None)

Compute scalar training loss given clean samples x0.

Parameters:
Returns:

Scalar training loss.

Return type:

torch.Tensor

abstractmethod sample(shape, model_fn, condition=None, steps=50)

Generate samples from noise.

Parameters:
Returns:

Clean samples x0 of shape shape.

Return type:

torch.Tensor