ksuit.optimizer.optimizer_wrapper¶
Classes¶
Wrapper around an torch.optim.Optimizer that allows |
Module Contents¶
- class ksuit.optimizer.optimizer_wrapper.OptimizerWrapper(model, torch_optim_ctor, optim_wrapper_config, update_counter=None)¶
Wrapper around an torch.optim.Optimizer that allows - excluding biases and weights of normalization layers from weight decay - creating param_groups (e.g., for a layerwise lr scaling) - learning rate scheduling - gradient clipping - weight decay scheduling
- Parameters:
model (ksuit.models.Model) – Parameters of this model will be optimized.
torch_optim_ctor (collections.abc.Callable[[collections.abc.Iterable[dict[str, Any]]], torch.optim.Optimizer]) – The torch.optim.Optimizer that should be wrapped. Needs to be a callable because it requires the parameters of the model for initialization.
optim_wrapper_config (ksuit.schemas.optim.OptimizerConfig) – The configuration for the optimizer wrapper.
update_counter (ksuit.utils.training.counter.UpdateCounter | None) – Object that provides the current training progress to enable scheduling of the learning rate or the weight decay.
- logger¶
- model¶
- update_counter = None¶
- config¶
- param_idx_to_name¶
- torch_optim¶
- all_parameters = None¶
- schedule = None¶
- weight_decay_schedule = None¶
- step(grad_scaler=None)¶
Wrapper around torch.optim.Optimizer.step which automatically handles: - gradient scaling for mixed precision (including updating the GradientScaler state) - gradient clipping - calling the .step function of the optimizer
- Parameters:
grad_scaler (torch.amp.grad_scaler.GradScaler | None)
- Return type:
None
- schedule_step()¶
Applies the current state of the schedules to the parameter groups.
- Return type:
None
- zero_grad(set_to_none=True)¶
Wrapper around torch.optim.Optimizer.zero_grad.
- state_dict()¶
Wrapper around torch.optim.Optimizer.state_dict. Additionally adds info about index to name mapping.
- load_state_dict(state_dict_to_load)¶
Wrapper around torch.optim.Optimizer.load_state_dict. Additionally handles edge cases if the parameter groups of the loaded state_dict do not match the current configuration. By default, torch would overwrite the current parameter groups with the one from the checkpoint. This is undesireable in the following cases: - add new parameters (e.g. unfreeze something) - change weight_decay or other param_group properties: the load_state_dict would overwrite the actual
weight_decay (defined in the constructor of the OptimizerWrapper) with the weight_decay from the checkpoint