noether.core.utils.torch¶
Submodules¶
Classes¶
A no-operation context manager that does nothing. |
|
A no-operation gradient scaler that performs no scaling. |
Functions¶
|
Returns the appropriate gradient scaler and autocast context manager for the given precision. |
|
Returns desired_precision if it is supported and backup_precision otherwise. For example, bfloat16 is not |
|
Checks if bfloat16 precision is supported on the given device. |
|
Checks if float16 precision is supported on the given device. |
|
Moves everything in the batch to the given device. |
Package Contents¶
- class noether.core.utils.torch.NoopContext¶
A no-operation context manager that does nothing.
- class noether.core.utils.torch.NoopGradScaler¶
Bases:
torch.amp.grad_scaler.GradScalerA no-operation gradient scaler that performs no scaling.
- scale(outputs)¶
- Parameters:
outputs (Any)
- Return type:
Any
- unscale_(optimizer)¶
- Parameters:
optimizer (torch.optim.Optimizer)
- Return type:
None
- static step(optimizer, *args, **kwargs)¶
- Parameters:
optimizer (torch.optim.Optimizer)
- Return type:
None
- update(new_scale=None)¶
- Parameters:
new_scale (float | torch.Tensor | None)
- Return type:
None
- noether.core.utils.torch.get_grad_scaler_and_autocast_context(precision, device)¶
Returns the appropriate gradient scaler and autocast context manager for the given precision.
- Parameters:
precision (torch.dtype) – The desired precision.
device (torch.device) – The device where computation occurs.
- Returns:
The corresponding scaler and autocast context.
- Return type:
tuple[torch.amp.grad_scaler.GradScaler, torch.autocast | NoopContext]
- noether.core.utils.torch.get_supported_precision(desired_precision, device)¶
Returns desired_precision if it is supported and backup_precision otherwise. For example, bfloat16 is not supported by all GPUs.
- Parameters:
desired_precision (str) – The desired precision format.
device (torch.device) – The selected device (e.g., torch.device(“cuda”)).
- Returns:
The most suitable precision supported by the device.
- Return type:
- noether.core.utils.torch.is_bfloat16_compatible(device)¶
Checks if bfloat16 precision is supported on the given device.
- Parameters:
device (torch.device) – The device to check.
- Returns:
True if bfloat16 is supported, False otherwise.
- Return type:
- noether.core.utils.torch.is_float16_compatible(device)¶
Checks if float16 precision is supported on the given device.
- Parameters:
device (torch.device) – The device to check.
- Returns:
True if float16 is supported, False otherwise.
- Return type: