ksuit.utils.torch.amp ===================== .. py:module:: ksuit.utils.torch.amp Attributes ---------- .. autoapisummary:: ksuit.utils.torch.amp.FLOAT32_ALIASES ksuit.utils.torch.amp.FLOAT16_ALIASES ksuit.utils.torch.amp.BFLOAT16_ALIASES ksuit.utils.torch.amp.VALID_PRECISIONS ksuit.utils.torch.amp.logger Classes ------- .. autoapisummary:: ksuit.utils.torch.amp.NoopContext ksuit.utils.torch.amp.NoopGradScaler Functions --------- .. autoapisummary:: ksuit.utils.torch.amp.get_supported_precision ksuit.utils.torch.amp.is_compatible ksuit.utils.torch.amp.is_bfloat16_compatible ksuit.utils.torch.amp.is_float16_compatible ksuit.utils.torch.amp.get_grad_scaler_and_autocast_context ksuit.utils.torch.amp.disable Module Contents --------------- .. py:data:: FLOAT32_ALIASES :value: ['float32', 'fp32'] .. py:data:: FLOAT16_ALIASES :value: ['float16', 'fp16'] .. py:data:: BFLOAT16_ALIASES :value: ['bfloat16', 'bf16'] .. py:data:: VALID_PRECISIONS :value: ['float32', 'fp32', 'float16', 'fp16', 'bfloat16', 'bf16'] .. py:data:: logger .. py:function:: get_supported_precision(desired_precision, device) Returns desired_precision if it is supported and backup_precision otherwise. For example, bfloat16 is not supported by all GPUs. :param desired_precision: The desired precision format. :param device: The selected device (e.g., torch.device("cuda")). :returns: The most suitable precision supported by the device. :rtype: torch.dtype .. py:function:: is_compatible(device, dtype) Checks if a given dtype is supported on a device. :param device: The device to check compatibility. :param dtype: The data type to check. :returns: True if the dtype is supported, False otherwise. :rtype: bool .. py:function:: is_bfloat16_compatible(device) Checks if bfloat16 precision is supported on the given device. :param device: The device to check. :returns: True if bfloat16 is supported, False otherwise. :rtype: bool .. py:function:: is_float16_compatible(device) Checks if float16 precision is supported on the given device. :param device: The device to check. :returns: True if float16 is supported, False otherwise. :rtype: bool .. py:class:: NoopContext A no-operation context manager that does nothing. .. py:class:: NoopGradScaler(device = 'cuda', init_scale = 2.0**16, growth_factor = 2.0, backoff_factor = 0.5, growth_interval = 2000, enabled = True) Bases: :py:obj:`torch.amp.grad_scaler.GradScaler` A no-operation gradient scaler that performs no scaling. .. py:method:: scale(outputs) Multiplies ('scales') a tensor or list of tensors by the scale factor. Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned unmodified. :param outputs: Outputs to scale. :type outputs: Tensor or iterable of Tensors .. py:method:: unscale_(optimizer) Divides ("unscales") the optimizer's gradient tensors by the scale factor. :meth:`unscale_` is optional, serving cases where you need to :ref:`modify or inspect gradients` between the backward pass(es) and :meth:`step`. If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: ... scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) scaler.step(optimizer) scaler.update() :param optimizer: Optimizer that owns the gradients to be unscaled. :type optimizer: torch.optim.Optimizer .. note:: :meth:`unscale_` does not incur a CPU-GPU sync. .. warning:: :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, and only after all gradients for that optimizer's assigned parameters have been accumulated. Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. .. warning:: :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. .. py:method:: step(optimizer, *args, **kwargs) :staticmethod: Invoke ``unscale_(optimizer)`` followed by parameter update, if gradients are not infs/NaN. :meth:`step` carries out the following two operations: 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. Returns the return value of ``optimizer.step(*args, **kwargs)``. :param optimizer: Optimizer that applies the gradients. :type optimizer: torch.optim.Optimizer :param args: Any arguments. :param kwargs: Any keyword arguments. .. warning:: Closure use is not currently supported. .. py:method:: update(new_scale = None) Update the scale factor. If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, the scale is multiplied by ``growth_factor`` to increase it. Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not used directly, it's used to fill GradScaler's internal scale tensor. So if ``new_scale`` was a tensor, later in-place changes to that tensor will not further affect the scale GradScaler uses internally.) :param new_scale: New scale factor. :type new_scale: float or :class:`torch.Tensor`, optional, default=None .. warning:: :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has been invoked for all optimizers used this iteration. .. warning:: For performance reasons, we do not check the scale factor value to avoid synchronizations, so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or you are seeing NaNs in your gradients or loss, something is likely wrong. For example, bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges. .. py:function:: get_grad_scaler_and_autocast_context(precision, device) Returns the appropriate gradient scaler and autocast context manager for the given precision. :param precision: The desired precision. :type precision: torch.dtype :param device: The device where computation occurs. :type device: torch.device :returns: The corresponding scaler and autocast context. .. py:function:: disable(device_type) Disables AMP for the given device. :param device: The device to disable AMP for.