ksuit.writers.checkpoint_writer =============================== .. py:module:: ksuit.writers.checkpoint_writer Classes ------- .. autoapisummary:: ksuit.writers.checkpoint_writer.CheckpointKeys ksuit.writers.checkpoint_writer.CheckpointWriter Module Contents --------------- .. py:class:: CheckpointKeys Bases: :py:obj:`enum.StrEnum` Defines the standard, possible keys in the checkpoint dict. Initialize self. See help(type(self)) for accurate signature. .. py:attribute:: state_dict The pytorch state dict of the model. I.e. the model weights/tensors/buffers. .. py:attribute:: checkpoint_tag The checkpoint tag, e.g., "E10_U200_S800" or "latest". .. py:attribute:: training_iteration The detailed information about training iteration as a dict with keys 'epoch', 'update', and 'sample'. .. py:attribute:: run_id The ID of the run from which this checkpoint was created. .. py:attribute:: model_config The model configuration used to instantiate the model. A serialized dict of the pydantic model config. .. py:attribute:: config_kind The kind (i.e., class path) of the model configuration. Used to instantiate the correct model configuration. .. py:attribute:: callback_state_dicts The state dicts of the callbacks. .. py:attribute:: grad_scaler The state dict of the grad scaler (if used). .. py:class:: CheckpointWriter(path_provider, update_counter) Class to easily write checkpoints in a structured way to the disk. Each `Model` will be stored in a separate file where additionally weights and optimizer state are also separate files. This allows flexible storing of model states without producing files that are never needed after training. For example, to resume runs, one need the model weights and optimizer states. However, storing optimizer states for all checkpoints is expensive as optimizer states are commonly 2x as large as only the weights. To illustrate the flexibility, consider the use-case of training an autoencoder model where the goal is to train a good encoder that should then be used for another task. This model is implemented via a class `Autoencoder` that inherits from `CompositeModel` and contains two submodels, an encoder and decoder (both which inherit from `Model`). During training, we want to store the following files to the disk: - The encoder weights after every 10 epochs to evaluate performance at various training lengths. - The latest weights and optimizer states of encoder and decoder to allow resuming a run if it crashes. The CheckpointWriter provides functionality to store the following files: - `autoencoder.encoder_cp=E10_... model.th`: encoder weights after 10 epochs - `autoencoder.encoder_cp=E20_... model.th`: encoder weights after 20 epochs - `autoencoder.encoder_cp=E30_... model.th`: encoder weights after 30 epochs - `autoencoder.encoder_cp=last_model.th`: latest encoder weights - `autoencoder.encoder_cp=last_optim.th`: latest encoder optimizer state - `autoencoder.decoder_cp=last_model.th`: latest decoder weights - `autoencoder.decoder_cp=last_optim.th`: latest decoder optimizer state Each model checkpoint is populated with metadata. Each checkpoint will be a dictionary containing the keys: - "state_dict": Weights of the model. - "model_config": The model configuration used to instantiate the model. A serialized dict of the pydantic model config. - "checkpoint_tag": The name of the checkpoint. E.g., E10_U200_S800 for a progress-based checkpoint or "latest" for a string-based checkpoint. - "training_iteration": The detailed information about training iteration as a dict with keys 'epoch', 'update', and 'sample'. E.g., for the "latest" checkpoint you would not know from which epoch the checkpoint is, therefore the "training_iteration" field of that checkpoint contains "E13_U..._S...". - "run_id": The ID of the run from which it was created. .. py:attribute:: logger .. py:attribute:: path_provider .. py:attribute:: update_counter .. py:method:: save_model_checkpoint(output_name, state_dict, checkpoint_tag, model_config = None, **extra) Save a checkpoint to disk. :param output_name: Output name of the checkpoint (including an extension). :param state_dict: Model state dict to save. :param checkpoint_tag: Checkpoint tag, for example "latest" or "E10_U200_S800". :param model_config: Model configuration. Defaults to None. :param \*\*extra: :raises RuntimeError: in case of an unexpected error while parsing `model_config`. .. py:method:: save(model, checkpoint_tag, trainer = None, save_weights = True, save_optim = True, save_latest_weights = False, save_latest_optim = False, model_names_to_save = None, save_frozen_weights = True) Saves a model to the disk. :param model: Model to save. :param checkpoint_tag: Checkpoint tag, for example "latest" or "E10_U200_S800". :param trainer: If defined, also stores the state_dict of the trainer (and callbacks). :param save_weights: If true, stores model weights. :param save_optim: If true, stores optimizer states. :param save_latest_weights: If true, also stores the weights with the checkpoint identifier "latest". This file will be repeatedly overwritten throughout a training procedure to save storage. :param save_latest_optim: If true, also stores the optimizer states with the checkpoint identifier "latest". This file will be repeatedly overwritten throughout a training procedure to save storage. :param model_names_to_save: If defined, only store some of the submodels of a `CompositeModel`. :param save_frozen_weights: If true, also stores the weights of frozen models.