ksuit.callbacks.checkpoint_callbacks.ema¶
Classes¶
Callback for exponential moving average (EMA) of model weights. |
Module Contents¶
- class ksuit.callbacks.checkpoint_callbacks.ema.EmaCallback(callback_config, **kwargs)¶
Bases:
ksuit.callbacks.base.PeriodicCallbackCallback for exponential moving average (EMA) of model weights.
Initializes the EmaCallback.
- Parameters:
callback_config (ksuit.schemas.callbacks.callbacks_config.EmaCallbackConfig) – configuration of the EmaCallback. Implements the CallBackBaseConfig schema.
**kwargs – additional arguments passed to the parent class.
- model_paths¶
- target_factors¶
- save_weights¶
- save_last_weights¶
- save_latest_weights¶
- resume_from_checkpoint(resumption_paths, model)¶
If a callback stores large files to disk and is stateful (e.g., an EMA of the model), it would be unecessarily wasteful to also store the state in the callbacks state_dict. Therefore, resume_from_checkpoint is called when resuming a run, which allows callbacks to load their state from any file that was stored on the disk.
- Parameters:
resumption_path – PathProvider instance to access paths from the checkpoint to resume from.
model – model of the current training run.
resumption_paths (ksuit.providers.path_provider.PathProvider)
- Return type:
None
- apply_ema(cur_model, model_path, target_factor)¶
fused in-place implementation