ksuit.callbacks.base.periodic_iterator_callback

Classes

PeriodicIteratorCallback

A base class for callbacks that perform periodic iterations over a dataset.

Module Contents

class ksuit.callbacks.base.periodic_iterator_callback.PeriodicIteratorCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: ksuit.callbacks.base.periodic_callback.PeriodicCallback

A base class for callbacks that perform periodic iterations over a dataset.

Periodic callbacks are typically used to calculate a metric on, e.g., a test dataset. Therefore, this class provides according functionality to integrate into the dataloading pipeline via the _register_sampler_configs and _iterate_over_dataset methods.

A basic example for a callback that calculates the test accuracy of a classification model:

``` class AccuracyCallback(PeriodicIteratorCallback):

def _register_sampler_config(self, trainer) -> None:

return self._sampler_config_from_key(key=self.dataset_key)

def _forward(self, batch, *, trainer_model):

y_hat = trainer_model(batch[“x”].to(trainer_model.device)) return y_hat, batch[“class”].clone()

def _process_results(self, results, *, interval_type, update_counter, **_):

y_hat, y = results accuracy = (y_hat.argmax(dim=1) == y).float().mean() …

```

Initializes the PeriodicCallback.

Parameters:
total_data_time = 0.0
register_sampler_config()

Registers the datasets that are used for this callback into the dataloading pipeline.

Parameters:

trainer – Trainer of the current run.

Returns:

The registered sampler_config

Return type:

ksuit.data.samplers.interleaved_sampler.SamplerIntervalConfig