noether.modeling.diffusion.base¶
Classes¶
Abstract base for diffusion paradigms. |
Module Contents¶
- class noether.modeling.diffusion.base.DiffusionSchedule¶
Bases:
abc.ABCAbstract base for diffusion paradigms.
All schedule state (alphas, sigmas, etc.) is stored as plain tensors. Call
to()before use to move the buffers to the target device.- device: torch.device¶
- to(device)¶
Move every
Tensorattribute todevicein place.Returns
selffor chainability.- Parameters:
device (torch.device | str)
- Return type:
Self
- abstractmethod training_losses(x0, model_fn, condition=None)¶
Compute scalar training loss given clean samples
x0.- Parameters:
x0 (torch.Tensor) – Clean training samples.
model_fn (collections.abc.Callable[[torch.Tensor, torch.Tensor, torch.Tensor | None], torch.Tensor]) – Callable with signature
(noisy_input, timestep_or_sigma, condition) -> prediction.condition (torch.Tensor | None) – Optional conditioning tensor passed through to
model_fn.
- Returns:
Scalar training loss.
- Return type:
- abstractmethod sample(shape, model_fn, condition=None, steps=50)¶
Generate samples from noise.
- Parameters:
model_fn (collections.abc.Callable[[torch.Tensor, torch.Tensor, torch.Tensor | None], torch.Tensor]) – Callable with signature
(noisy_input, timestep_or_sigma, condition) -> prediction.condition (torch.Tensor | None) – Optional conditioning tensor.
steps (int) – Number of solver steps.
- Returns:
Clean samples
x0of shapeshape.- Return type: