ksuit.callbacks.base.callback_base ================================== .. py:module:: ksuit.callbacks.base.callback_base Classes ------- .. autoapisummary:: ksuit.callbacks.base.callback_base.CallbackBase Module Contents --------------- .. py:class:: CallbackBase(trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name = None) Base class for callbacks that log something before/after training. Allows overwriting `_before_training` and `_after_training`. If the callback is stateful (i.e., it tracks something across the training process that needs to be loaded if the run is resumed), there are two ways to implement loading the callback state: - `state_dict`: write current state into a state dict. When the trainer saves the current checkpoint to the disk, it will also store the `state_dict` of all callbacks within the trainer `state_dict`. Once a run is resumed, a callback can load its state from the previously stored `state_dict` by overwriting the `load_state_dict`. - `resume_from_checkpoint`: If a callback is storing large files onto the disk, it would be redudant to also store them within its `state_dict`. Therefore, this method is called on resume to allow callbacks to load their state from files on the disk. Callbacks have access to a `LogWriter`, with which callbacks can log metrics. The `LogWriter` is a singleton. .. rubric:: Examples ``` # log only to experiment tracker, not stdout self.writer.add_scalar(key="classification_accuracy", value=0.2) # log to experiment tracker and stdout (as "0.24") self.writer.add_scalar( key="classification_accuracy", value=0.23623, logger=self.logger, format_str=".2f", ) ``` Classes that inherit from `CallbackBase` get access to the following dependencies: - trainer (SgdTrainer): access to the current trainer - model (ModelBase): access to the current model - data_container (DataContainer): access to all datasets that were initialized for the current run - tracker (BaseTracker): access to the tracker object to log metrics to stdout/disk/online platform - path_provider (PathProvider): access to paths (e.g., output_path: where checkpoints/logs are stored) - metric_property_provider (MetricPropertyProvider): defines properties of metrics (e.g., for a loss, lower values - writer (LogWriter): access to the log writer to log metrics to stdout/disk/online platform - checkpoint_writer (CheckpointWriter): access to the checkpoint writer to store checkpoints during training are better wheras for an accuracy, higher values are better) As evaluations are pretty much always done in torch.no_grad() contexts, the hooks implemented by callbacks automatically apply the torch.no_grad() context. Therefore, the `CallbackBase` class makes use of the "template method" design pattern, where templates (e.g. `before_training`) implement the invariant behavior (e.g., applying torch.no_grad()). The template implementations start with an underscore (e.g., `_before_training`) and only these methods should be implemented by child classes. Templates (e.g., `before_training`) should not be overwritten. :param trainer: Trainer of the current run, subclass of `SgdTrainer`. :param model: Model of the current run. :param data_container: DataContainer instance that provides access to all datasets. :param tracker: Tracker instance to log metrics to stdout/disk/online platform. :param log_writer: LogWriter instance to log metrics to stdout/disk/online platform. :param metric_property_provider: MetricPropertyProvider instance to access properties of metrics. :param name: Name of the callback. .. py:attribute:: name :value: None .. py:attribute:: trainer .. py:attribute:: model .. py:attribute:: data_container .. py:attribute:: tracker .. py:attribute:: writer .. py:attribute:: metric_property_provider .. py:attribute:: checkpoint_writer .. py:method:: state_dict() If a callback is stateful, the state will be stored when a checkpoint is stored to the disk. :returns: State of the callback. By default, callbacks are non-stateful and return None. .. py:method:: load_state_dict(state_dict) If a callback is stateful, the state will be stored when a checkpoint is stored to the disk and can be loaded with this method upon resuming a run. :param state_dict: State to be loaded. By default, callbacks are non-stateful and load_state_dict does nothing. .. py:method:: resume_from_checkpoint(resumption_paths, model) If a callback stores large files to disk and is stateful (e.g., an EMA of the model), it would be unecessarily wasteful to also store the state in the callbacks `state_dict`. Therefore, `resume_from_checkpoint` is called when resuming a run, which allows callbacks to load their state from any file that was stored on the disk. :param resumption_path: PathProvider instance to access paths from the checkpoint to resume from. :param model: model of the current training run. .. py:property:: logger :type: logging.Logger Logger for logging to stdout. .. py:method:: before_training(update_counter) Hook that is called before training starts. Applies `torch.no_grad()` context. :param update_counter: UpdateCounter instance to access current training progress. .. py:method:: after_training(update_counter) Hook that is called after training ends. Applies `torch.no_grad()` context. :param update_counter: UpdateCounter instance to access current training progress.