noether.core.utils.torch.amp

Attributes

Classes

NoopContext

A no-operation context manager that does nothing.

NoopGradScaler

A no-operation gradient scaler that performs no scaling.

Functions

get_supported_precision(desired_precision, device)

Returns desired_precision if it is supported and backup_precision otherwise. For example, bfloat16 is not

is_compatible(device, dtype)

Checks if a given dtype is supported on a device.

is_bfloat16_compatible(device)

Checks if bfloat16 precision is supported on the given device.

is_float16_compatible(device)

Checks if float16 precision is supported on the given device.

get_grad_scaler_and_autocast_context(precision, device)

Returns the appropriate gradient scaler and autocast context manager for the given precision.

disable(device_type)

Disables AMP for the given device.

Module Contents

noether.core.utils.torch.amp.FLOAT32_ALIASES = ['float32', 'fp32']
noether.core.utils.torch.amp.FLOAT16_ALIASES = ['float16', 'fp16']
noether.core.utils.torch.amp.BFLOAT16_ALIASES = ['bfloat16', 'bf16']
noether.core.utils.torch.amp.VALID_PRECISIONS = ['float32', 'fp32', 'float16', 'fp16', 'bfloat16', 'bf16']
noether.core.utils.torch.amp.logger
noether.core.utils.torch.amp.get_supported_precision(desired_precision, device)

Returns desired_precision if it is supported and backup_precision otherwise. For example, bfloat16 is not supported by all GPUs.

Parameters:
  • desired_precision (str) – The desired precision format.

  • device (torch.device) – The selected device (e.g., torch.device(“cuda”)).

Returns:

The most suitable precision supported by the device.

Return type:

torch.dtype

noether.core.utils.torch.amp.is_compatible(device, dtype)

Checks if a given dtype is supported on a device.

Parameters:
Returns:

True if the dtype is supported, False otherwise.

Return type:

bool

noether.core.utils.torch.amp.is_bfloat16_compatible(device)

Checks if bfloat16 precision is supported on the given device.

Parameters:

device (torch.device) – The device to check.

Returns:

True if bfloat16 is supported, False otherwise.

Return type:

bool

noether.core.utils.torch.amp.is_float16_compatible(device)

Checks if float16 precision is supported on the given device.

Parameters:

device (torch.device) – The device to check.

Returns:

True if float16 is supported, False otherwise.

Return type:

bool

class noether.core.utils.torch.amp.NoopContext

A no-operation context manager that does nothing.

class noether.core.utils.torch.amp.NoopGradScaler

Bases: torch.amp.grad_scaler.GradScaler

A no-operation gradient scaler that performs no scaling.

scale(outputs)
Parameters:

outputs (Any)

Return type:

Any

unscale_(optimizer)
Parameters:

optimizer (torch.optim.Optimizer)

Return type:

None

static step(optimizer, *args, **kwargs)
Parameters:

optimizer (torch.optim.Optimizer)

Return type:

None

update(new_scale=None)
Parameters:

new_scale (float | torch.Tensor | None)

Return type:

None

noether.core.utils.torch.amp.get_grad_scaler_and_autocast_context(precision, device)

Returns the appropriate gradient scaler and autocast context manager for the given precision.

Parameters:
  • precision (torch.dtype) – The desired precision.

  • device (torch.device) – The device where computation occurs.

Returns:

The corresponding scaler and autocast context.

Return type:

tuple[torch.amp.grad_scaler.GradScaler, torch.autocast | NoopContext]

noether.core.utils.torch.amp.disable(device_type)

Disables AMP for the given device.

Parameters:

device_type (str) – The device type to disable AMP for.