ksuit.callbacks.base.callback_base¶
Classes¶
Base class for callbacks that log something before/after training. |
Module Contents¶
- class ksuit.callbacks.base.callback_base.CallbackBase(trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Base class for callbacks that log something before/after training.
Allows overwriting _before_training and _after_training.
If the callback is stateful (i.e., it tracks something across the training process that needs to be loaded if the run is resumed), there are two ways to implement loading the callback state: - state_dict: write current state into a state dict. When the trainer saves the current checkpoint to the disk,
it will also store the state_dict of all callbacks within the trainer state_dict. Once a run is resumed, a callback can load its state from the previously stored state_dict by overwriting the load_state_dict.
- resume_from_checkpoint: If a callback is storing large files onto the disk, it would be redudant to also
store them within its state_dict. Therefore, this method is called on resume to allow callbacks to load their state from files on the disk.
Callbacks have access to a LogWriter, with which callbacks can log metrics. The LogWriter is a singleton. .. rubric:: Examples
``` # log only to experiment tracker, not stdout self.writer.add_scalar(key=”classification_accuracy”, value=0.2) # log to experiment tracker and stdout (as “0.24”) self.writer.add_scalar(
key=”classification_accuracy”, value=0.23623, logger=self.logger, format_str=”.2f”,
)¶
Classes that inherit from CallbackBase get access to the following dependencies: - trainer (SgdTrainer): access to the current trainer - model (ModelBase): access to the current model - data_container (DataContainer): access to all datasets that were initialized for the current run - tracker (BaseTracker): access to the tracker object to log metrics to stdout/disk/online platform - path_provider (PathProvider): access to paths (e.g., output_path: where checkpoints/logs are stored) - metric_property_provider (MetricPropertyProvider): defines properties of metrics (e.g., for a loss, lower values - writer (LogWriter): access to the log writer to log metrics to stdout/disk/online platform - checkpoint_writer (CheckpointWriter): access to the checkpoint writer to store checkpoints during training
are better wheras for an accuracy, higher values are better)
As evaluations are pretty much always done in torch.no_grad() contexts, the hooks implemented by callbacks automatically apply the torch.no_grad() context. Therefore, the CallbackBase class makes use of the “template method” design pattern, where templates (e.g. before_training) implement the invariant behavior (e.g., applying torch.no_grad()). The template implementations start with an underscore (e.g., _before_training) and only these methods should be implemented by child classes. Templates (e.g., before_training) should not be overwritten.
- param trainer:
Trainer of the current run, subclass of SgdTrainer.
- param model:
Model of the current run.
- param data_container:
DataContainer instance that provides access to all datasets.
- param tracker:
Tracker instance to log metrics to stdout/disk/online platform.
- param log_writer:
LogWriter instance to log metrics to stdout/disk/online platform.
- param metric_property_provider:
MetricPropertyProvider instance to access properties of metrics.
- param name:
Name of the callback.
- name = None¶
- trainer¶
- model¶
- data_container¶
- tracker¶
- writer¶
- metric_property_provider¶
- checkpoint_writer¶
- state_dict()¶
If a callback is stateful, the state will be stored when a checkpoint is stored to the disk.
- load_state_dict(state_dict)¶
If a callback is stateful, the state will be stored when a checkpoint is stored to the disk and can be loaded with this method upon resuming a run.
- resume_from_checkpoint(resumption_paths, model)¶
If a callback stores large files to disk and is stateful (e.g., an EMA of the model), it would be unecessarily wasteful to also store the state in the callbacks state_dict. Therefore, resume_from_checkpoint is called when resuming a run, which allows callbacks to load their state from any file that was stored on the disk.
- Parameters:
resumption_path – PathProvider instance to access paths from the checkpoint to resume from.
model (ksuit.models.ModelBase) – model of the current training run.
resumption_paths (ksuit.providers.path_provider.PathProvider)
- Return type:
None
- property logger: logging.Logger¶
Logger for logging to stdout.
- Return type:
- before_training(update_counter)¶
Hook that is called before training starts. Applies torch.no_grad() context.
- Parameters:
update_counter (ksuit.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.
- Return type:
None
- after_training(update_counter)¶
Hook that is called after training ends. Applies torch.no_grad() context.
- Parameters:
update_counter (ksuit.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.
- Return type:
None
- Parameters:
trainer (ksuit.trainers.BaseTrainer)
model (ksuit.models.ModelBase)
data_container (ksuit.utils.data.data_container.DataContainer)
tracker (ksuit.trackers.BaseTracker)
log_writer (ksuit.writers.LogWriter)
checkpoint_writer (ksuit.writers.CheckpointWriter)
metric_property_provider (ksuit.providers.metric_property_provider.MetricPropertyProvider)
name (str | None)