ksuit.writers.checkpoint_writer

Classes

CheckpointKeys

Defines the standard, possible keys in the checkpoint dict.

CheckpointWriter

Class to easily write checkpoints in a structured way to the disk.

Module Contents

class ksuit.writers.checkpoint_writer.CheckpointKeys

Bases: enum.StrEnum

Defines the standard, possible keys in the checkpoint dict.

Initialize self. See help(type(self)) for accurate signature.

state_dict

The pytorch state dict of the model. I.e. the model weights/tensors/buffers.

checkpoint_tag

The checkpoint tag, e.g., “E10_U200_S800” or “latest”.

training_iteration

The detailed information about training iteration as a dict with keys ‘epoch’, ‘update’, and ‘sample’.

run_id

The ID of the run from which this checkpoint was created.

model_config

The model configuration used to instantiate the model. A serialized dict of the pydantic model config.

config_kind

The kind (i.e., class path) of the model configuration. Used to instantiate the correct model configuration.

callback_state_dicts

The state dicts of the callbacks.

grad_scaler

The state dict of the grad scaler (if used).

class ksuit.writers.checkpoint_writer.CheckpointWriter(path_provider, update_counter)

Class to easily write checkpoints in a structured way to the disk.

Each Model will be stored in a separate file where additionally weights and optimizer state are also separate files. This allows flexible storing of model states without producing files that are never needed after training. For example, to resume runs, one need the model weights and optimizer states. However, storing optimizer states for all checkpoints is expensive as optimizer states are commonly 2x as large as only the weights.

To illustrate the flexibility, consider the use-case of training an autoencoder model where the goal is to train a good encoder that should then be used for another task. This model is implemented via a class Autoencoder that inherits from CompositeModel and contains two submodels, an encoder and decoder (both which inherit from Model). During training, we want to store the following files to the disk: - The encoder weights after every 10 epochs to evaluate performance at various training lengths. - The latest weights and optimizer states of encoder and decoder to allow resuming a run if it crashes. The CheckpointWriter provides functionality to store the following files: - autoencoder.encoder_cp=E10_… model.th: encoder weights after 10 epochs - autoencoder.encoder_cp=E20_… model.th: encoder weights after 20 epochs - autoencoder.encoder_cp=E30_… model.th: encoder weights after 30 epochs - autoencoder.encoder_cp=last_model.th: latest encoder weights - autoencoder.encoder_cp=last_optim.th: latest encoder optimizer state - autoencoder.decoder_cp=last_model.th: latest decoder weights - autoencoder.decoder_cp=last_optim.th: latest decoder optimizer state

Each model checkpoint is populated with metadata. Each checkpoint will be a dictionary containing the keys: - “state_dict”: Weights of the model. - “model_config”: The model configuration used to instantiate the model. A serialized dict of the pydantic model config. - “checkpoint_tag”: The name of the checkpoint. E.g., E10_U200_S800 for a progress-based checkpoint or “latest” for a

string-based checkpoint.

  • “training_iteration”: The detailed information about training iteration as a dict with keys ‘epoch’, ‘update’, and ‘sample’. E.g., for the “latest” checkpoint you would not know from which epoch the checkpoint is, therefore the “training_iteration” field of that checkpoint contains “E13_U…_S…”.

  • “run_id”: The ID of the run from which it was created.

Parameters:
  • path_provider (ksuit.providers.PathProvider)

  • update_counter (ksuit.utils.training.UpdateCounter)

logger
path_provider
update_counter
save_model_checkpoint(output_name, state_dict, checkpoint_tag, model_config=None, **extra)

Save a checkpoint to disk.

Parameters:
  • output_name (str) – Output name of the checkpoint (including an extension).

  • state_dict (dict[str, Any]) – Model state dict to save.

  • checkpoint_tag (str) – Checkpoint tag, for example “latest” or “E10_U200_S800”.

  • model_config (ksuit.schemas.models.ModelBaseConfig | None) – Model configuration. Defaults to None.

  • **extra

Raises:

RuntimeError – in case of an unexpected error while parsing model_config.

Return type:

None

save(model, checkpoint_tag, trainer=None, save_weights=True, save_optim=True, save_latest_weights=False, save_latest_optim=False, model_names_to_save=None, save_frozen_weights=True)

Saves a model to the disk.

Parameters:
  • model (ksuit.models.ModelBase) – Model to save.

  • checkpoint_tag (str) – Checkpoint tag, for example “latest” or “E10_U200_S800”.

  • trainer (ksuit.trainers.BaseTrainer | None) – If defined, also stores the state_dict of the trainer (and callbacks).

  • save_weights (bool) – If true, stores model weights.

  • save_optim (bool) – If true, stores optimizer states.

  • save_latest_weights (bool) – If true, also stores the weights with the checkpoint identifier “latest”. This file will be repeatedly overwritten throughout a training procedure to save storage.

  • save_latest_optim (bool) – If true, also stores the optimizer states with the checkpoint identifier “latest”. This file will be repeatedly overwritten throughout a training procedure to save storage.

  • model_names_to_save (list[str] | None) – If defined, only store some of the submodels of a CompositeModel.

  • save_frozen_weights (bool) – If true, also stores the weights of frozen models.

Return type:

None