ksuit.utils.torch.amp

Attributes

Classes

NoopContext

A no-operation context manager that does nothing.

NoopGradScaler

A no-operation gradient scaler that performs no scaling.

Functions

get_supported_precision(desired_precision, device)

Returns desired_precision if it is supported and backup_precision otherwise. For example, bfloat16 is not

is_compatible(device, dtype)

Checks if a given dtype is supported on a device.

is_bfloat16_compatible(device)

Checks if bfloat16 precision is supported on the given device.

is_float16_compatible(device)

Checks if float16 precision is supported on the given device.

get_grad_scaler_and_autocast_context(precision, device)

Returns the appropriate gradient scaler and autocast context manager for the given precision.

disable(device_type)

Disables AMP for the given device.

Module Contents

ksuit.utils.torch.amp.FLOAT32_ALIASES = ['float32', 'fp32']
ksuit.utils.torch.amp.FLOAT16_ALIASES = ['float16', 'fp16']
ksuit.utils.torch.amp.BFLOAT16_ALIASES = ['bfloat16', 'bf16']
ksuit.utils.torch.amp.VALID_PRECISIONS = ['float32', 'fp32', 'float16', 'fp16', 'bfloat16', 'bf16']
ksuit.utils.torch.amp.logger
ksuit.utils.torch.amp.get_supported_precision(desired_precision, device)

Returns desired_precision if it is supported and backup_precision otherwise. For example, bfloat16 is not supported by all GPUs.

Parameters:
  • desired_precision (str) – The desired precision format.

  • device (torch.device) – The selected device (e.g., torch.device(“cuda”)).

Returns:

The most suitable precision supported by the device.

Return type:

torch.dtype

ksuit.utils.torch.amp.is_compatible(device, dtype)

Checks if a given dtype is supported on a device.

Parameters:
  • device (torch.device) – The device to check compatibility.

  • dtype (torch.dtype) – The data type to check.

Returns:

True if the dtype is supported, False otherwise.

Return type:

bool

ksuit.utils.torch.amp.is_bfloat16_compatible(device)

Checks if bfloat16 precision is supported on the given device.

Parameters:

device (torch.device) – The device to check.

Returns:

True if bfloat16 is supported, False otherwise.

Return type:

bool

ksuit.utils.torch.amp.is_float16_compatible(device)

Checks if float16 precision is supported on the given device.

Parameters:

device (torch.device) – The device to check.

Returns:

True if float16 is supported, False otherwise.

Return type:

bool

class ksuit.utils.torch.amp.NoopContext

A no-operation context manager that does nothing.

class ksuit.utils.torch.amp.NoopGradScaler(device='cuda', init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True)

Bases: torch.amp.grad_scaler.GradScaler

A no-operation gradient scaler that performs no scaling.

Parameters:
scale(outputs)

Multiplies (‘scales’) a tensor or list of tensors by the scale factor.

Returns scaled outputs. If this instance of GradScaler is not enabled, outputs are returned unmodified.

Parameters:

outputs (Tensor or iterable of Tensors) – Outputs to scale.

Return type:

Any

unscale_(optimizer)

Divides (“unscales”) the optimizer’s gradient tensors by the scale factor.

unscale_() is optional, serving cases where you need to modify or inspect gradients between the backward pass(es) and step(). If unscale_() is not called explicitly, gradients will be unscaled automatically during step().

Simple example, using unscale_() to enable clipping of unscaled gradients:

...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Parameters:

optimizer (torch.optim.Optimizer) – Optimizer that owns the gradients to be unscaled.

Return type:

None

Note

unscale_() does not incur a CPU-GPU sync.

Warning

unscale_() should only be called once per optimizer per step() call, and only after all gradients for that optimizer’s assigned parameters have been accumulated. Calling unscale_() twice for a given optimizer between each step() triggers a RuntimeError.

Warning

unscale_() may unscale sparse gradients out of place, replacing the .grad attribute.

static step(optimizer, *args, **kwargs)

Invoke unscale_(optimizer) followed by parameter update, if gradients are not infs/NaN.

step() carries out the following two operations:

  1. Internally invokes unscale_(optimizer) (unless unscale_() was explicitly called for optimizer earlier in the iteration). As part of the unscale_(), gradients are checked for infs/NaNs.

  2. If no inf/NaN gradients are found, invokes optimizer.step() using the unscaled gradients. Otherwise, optimizer.step() is skipped to avoid corrupting the params.

*args and **kwargs are forwarded to optimizer.step().

Returns the return value of optimizer.step(*args, **kwargs).

Parameters:
  • optimizer (torch.optim.Optimizer) – Optimizer that applies the gradients.

  • args – Any arguments.

  • kwargs – Any keyword arguments.

Return type:

None

Warning

Closure use is not currently supported.

update(new_scale=None)

Update the scale factor.

If any optimizer steps were skipped the scale is multiplied by backoff_factor to reduce it. If growth_interval unskipped iterations occurred consecutively, the scale is multiplied by growth_factor to increase it.

Passing new_scale sets the new scale value manually. (new_scale is not used directly, it’s used to fill GradScaler’s internal scale tensor. So if new_scale was a tensor, later in-place changes to that tensor will not further affect the scale GradScaler uses internally.)

Parameters:

new_scale (float or torch.Tensor, optional, default=None) – New scale factor.

Return type:

None

Warning

update() should only be called at the end of the iteration, after scaler.step(optimizer) has been invoked for all optimizers used this iteration.

Warning

For performance reasons, we do not check the scale factor value to avoid synchronizations, so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or you are seeing NaNs in your gradients or loss, something is likely wrong. For example, bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges.

ksuit.utils.torch.amp.get_grad_scaler_and_autocast_context(precision, device)

Returns the appropriate gradient scaler and autocast context manager for the given precision.

Parameters:
  • precision (torch.dtype) – The desired precision.

  • device (torch.device) – The device where computation occurs.

Returns:

The corresponding scaler and autocast context.

Return type:

tuple[torch.amp.grad_scaler.GradScaler, torch.autocast | NoopContext]

ksuit.utils.torch.amp.disable(device_type)

Disables AMP for the given device.

Parameters:
  • device – The device to disable AMP for.

  • device_type (str)