ksuit.callbacks.checkpoint_callbacks.ema

Classes

EmaCallback

Callback for exponential moving average (EMA) of model weights.

Module Contents

class ksuit.callbacks.checkpoint_callbacks.ema.EmaCallback(callback_config, **kwargs)

Bases: ksuit.callbacks.base.PeriodicCallback

Callback for exponential moving average (EMA) of model weights.

Initializes the EmaCallback.

Parameters:
  • callback_config (ksuit.schemas.callbacks.callbacks_config.EmaCallbackConfig) – configuration of the EmaCallback. Implements the CallBackBaseConfig schema.

  • **kwargs – additional arguments passed to the parent class.

model_paths
target_factors
save_weights
save_last_weights
save_latest_weights
parameters: dict[str | None, dict[str, torch.Tensor]]
buffers: dict[str | None, dict[str, torch.Tensor]]
resume_from_checkpoint(resumption_paths, model)

If a callback stores large files to disk and is stateful (e.g., an EMA of the model), it would be unecessarily wasteful to also store the state in the callbacks state_dict. Therefore, resume_from_checkpoint is called when resuming a run, which allows callbacks to load their state from any file that was stored on the disk.

Parameters:
Return type:

None

apply_ema(cur_model, model_path, target_factor)

fused in-place implementation