ksuit.optimizer.optimizer_wrapper ================================= .. py:module:: ksuit.optimizer.optimizer_wrapper Classes ------- .. autoapisummary:: ksuit.optimizer.optimizer_wrapper.OptimizerWrapper Module Contents --------------- .. py:class:: OptimizerWrapper(model, torch_optim_ctor, optim_wrapper_config, update_counter = None) Wrapper around an `torch.optim.Optimizer` that allows - excluding biases and weights of normalization layers from weight decay - creating param_groups (e.g., for a layerwise lr scaling) - learning rate scheduling - gradient clipping - weight decay scheduling :param model: Parameters of this model will be optimized. :param torch_optim_ctor: The `torch.optim.Optimizer` that should be wrapped. Needs to be a callable because it requires the parameters of the model for initialization. :param optim_wrapper_config: The configuration for the optimizer wrapper. :param update_counter: Object that provides the current training progress to enable scheduling of the learning rate or the weight decay. .. py:attribute:: logger .. py:attribute:: model .. py:attribute:: update_counter :value: None .. py:attribute:: config .. py:attribute:: param_idx_to_name .. py:attribute:: torch_optim .. py:attribute:: all_parameters :value: None .. py:attribute:: schedule :value: None .. py:attribute:: weight_decay_schedule :value: None .. py:method:: step(grad_scaler = None) Wrapper around `torch.optim.Optimizer.step` which automatically handles: - gradient scaling for mixed precision (including updating the GradientScaler state) - gradient clipping - calling the .step function of the optimizer .. py:method:: schedule_step() Applies the current state of the schedules to the parameter groups. .. py:method:: zero_grad(set_to_none=True) Wrapper around `torch.optim.Optimizer.zero_grad`. .. py:method:: state_dict() Wrapper around `torch.optim.Optimizer.state_dict`. Additionally adds info about index to name mapping. .. py:method:: load_state_dict(state_dict_to_load) Wrapper around `torch.optim.Optimizer.load_state_dict`. Additionally handles edge cases if the parameter groups of the loaded state_dict do not match the current configuration. By default, torch would overwrite the current parameter groups with the one from the checkpoint. This is undesireable in the following cases: - add new parameters (e.g. unfreeze something) - change weight_decay or other param_group properties: the load_state_dict would overwrite the actual weight_decay (defined in the constructor of the OptimizerWrapper) with the weight_decay from the checkpoint :param state_dict_to_load: The optimizer state to load.