ksuit.callbacks.base.periodic_iterator_callback¶
Classes¶
A base class for callbacks that perform periodic iterations over a dataset. |
Module Contents¶
- class ksuit.callbacks.base.periodic_iterator_callback.PeriodicIteratorCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)¶
Bases:
ksuit.callbacks.base.periodic_callback.PeriodicCallbackA base class for callbacks that perform periodic iterations over a dataset.
Periodic callbacks are typically used to calculate a metric on, e.g., a test dataset. Therefore, this class provides according functionality to integrate into the dataloading pipeline via the _register_sampler_configs and _iterate_over_dataset methods.
A basic example for a callback that calculates the test accuracy of a classification model:
``` class AccuracyCallback(PeriodicIteratorCallback):
- def _register_sampler_config(self, trainer) -> None:
return self._sampler_config_from_key(key=self.dataset_key)
- def _forward(self, batch, *, trainer_model):
y_hat = trainer_model(batch[“x”].to(trainer_model.device)) return y_hat, batch[“class”].clone()
- def _process_results(self, results, *, interval_type, update_counter, **_):
y_hat, y = results accuracy = (y_hat.argmax(dim=1) == y).float().mean() …
Initializes the PeriodicCallback.
- Parameters:
callback_config (ksuit.schemas.callbacks.callbacks_config.CallBackBaseConfig) – Configuration of the PeriodicCallback. Implements the CallBackBaseConfig schema.
trainer (ksuit.trainers.base.BaseTrainer) – Trainer of the current run, subclass of SgdTrainer.
model (ksuit.models.model_base.ModelBase) – Model of the current run.
data_container (ksuit.utils.data.data_container.DataContainer) – DataContainer instance that provides access to all datasets.
tracker (ksuit.trackers.base_tracker.BaseTracker) – Tracker instance to log metrics to stdout/disk/online platform.
log_writer (ksuit.writers.LogWriter) – LogWriter instance to log metrics.
checkpoint_writer (ksuit.writers.CheckpointWriter) – CheckpointWriter instance to save checkpoints.
metric_property_provider (ksuit.providers.metric_property_provider.MetricPropertyProvider) – MetricPropertyProvider instance to access properties of metrics.
name (str | None) – Name of the callback.
- total_data_time = 0.0¶
- register_sampler_config()¶
Registers the datasets that are used for this callback into the dataloading pipeline.
- Parameters:
trainer – Trainer of the current run.
- Returns:
The registered sampler_config
- Return type:
ksuit.data.samplers.interleaved_sampler.SamplerIntervalConfig