ksuit.optimizer.optimizer_wrapper

Classes

OptimizerWrapper

Wrapper around an torch.optim.Optimizer that allows

Module Contents

class ksuit.optimizer.optimizer_wrapper.OptimizerWrapper(model, torch_optim_ctor, optim_wrapper_config, update_counter=None)

Wrapper around an torch.optim.Optimizer that allows - excluding biases and weights of normalization layers from weight decay - creating param_groups (e.g., for a layerwise lr scaling) - learning rate scheduling - gradient clipping - weight decay scheduling

Parameters:
  • model (ksuit.models.Model) – Parameters of this model will be optimized.

  • torch_optim_ctor (collections.abc.Callable[[collections.abc.Iterable[dict[str, Any]]], torch.optim.Optimizer]) – The torch.optim.Optimizer that should be wrapped. Needs to be a callable because it requires the parameters of the model for initialization.

  • optim_wrapper_config (ksuit.schemas.optim.OptimizerConfig) – The configuration for the optimizer wrapper.

  • update_counter (ksuit.utils.training.counter.UpdateCounter | None) – Object that provides the current training progress to enable scheduling of the learning rate or the weight decay.

logger
model
update_counter = None
config
param_idx_to_name
torch_optim
all_parameters = None
schedule = None
weight_decay_schedule = None
step(grad_scaler=None)

Wrapper around torch.optim.Optimizer.step which automatically handles: - gradient scaling for mixed precision (including updating the GradientScaler state) - gradient clipping - calling the .step function of the optimizer

Parameters:

grad_scaler (torch.amp.grad_scaler.GradScaler | None)

Return type:

None

schedule_step()

Applies the current state of the schedules to the parameter groups.

Return type:

None

zero_grad(set_to_none=True)

Wrapper around torch.optim.Optimizer.zero_grad.

state_dict()

Wrapper around torch.optim.Optimizer.state_dict. Additionally adds info about index to name mapping.

Return type:

dict[str, Any]

load_state_dict(state_dict_to_load)

Wrapper around torch.optim.Optimizer.load_state_dict. Additionally handles edge cases if the parameter groups of the loaded state_dict do not match the current configuration. By default, torch would overwrite the current parameter groups with the one from the checkpoint. This is undesireable in the following cases: - add new parameters (e.g. unfreeze something) - change weight_decay or other param_group properties: the load_state_dict would overwrite the actual

weight_decay (defined in the constructor of the OptimizerWrapper) with the weight_decay from the checkpoint

Parameters:

state_dict_to_load (dict[str, Any]) – The optimizer state to load.

Return type:

None