ksuit.callbacks.base.periodic_callback

Classes

PeriodicCallback

Base class for callbacks that are invoked periodically during training. Epoch, update and sample based intervals

Module Contents

class ksuit.callbacks.base.periodic_callback.PeriodicCallback(callback_config, trainer, model, data_container, tracker, log_writer, checkpoint_writer, metric_property_provider, name=None)

Bases: ksuit.callbacks.base.callback_base.CallbackBase

Base class for callbacks that are invoked periodically during training. Epoch, update and sample based intervals are supported. ```

Initializes the PeriodicCallback.

Parameters:
  • callback_config (ksuit.schemas.callbacks.callbacks_config.CallBackBaseConfig) – Configuration of the PeriodicCallback. Implements the CallBackBaseConfig schema.

  • trainer (ksuit.trainers.BaseTrainer) – Trainer of the current run, subclass of SgdTrainer.

  • model (ksuit.models.model_base.ModelBase) – Model of the current run.

  • data_container (ksuit.utils.data.data_container.DataContainer) – DataContainer instance that provides access to all datasets.

  • tracker (ksuit.trackers.BaseTracker) – Tracker instance to log metrics to stdout/disk/online platform.

  • log_writer (ksuit.writers.LogWriter) – LogWriter instance to log metrics.

  • checkpoint_writer (ksuit.writers.CheckpointWriter) – CheckpointWriter instance to save checkpoints.

  • metric_property_provider (ksuit.providers.metric_property_provider.MetricPropertyProvider) – MetricPropertyProvider instance to access properties of metrics.

  • name (str | None) – Name of the callback.

every_n_epochs
every_n_updates
every_n_samples
batch_size
should_log_after_epoch(training_iteration)

Checks after every epoch if the PeriodicCallback should be invoked.

Parameters:

training_iteration (ksuit.utils.training.training_iteration.TrainingIteration) – TrainingIteration to check.

Return type:

bool

should_log_after_update(training_iteration)

Checks after every update if the PeriodicCallback should be invoked.

Parameters:

training_iteration (ksuit.utils.training.training_iteration.TrainingIteration) – TrainingIteration to check.

Return type:

bool

should_log_after_sample(training_iteration, effective_batch_size)

Checks after every sample if the PeriodicCallback should be invoked.

Parameters:
Return type:

bool

track_after_accumulation_step(*, update_counter, batch, losses, update_outputs, accumulation_steps, accumulation_step)

Invoked after every gradient accumulation step. May be used to track metrics. Applies torch.no_grad().

Parameters:
  • update_counter (ksuit.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

  • batch (Any) – Current batch.

  • losses (dict[str, torch.Tensor]) – Losses computed for the current batch.

  • update_outputs (dict[str, torch.Tensor] | None) – Outputs of the model for the current batch.

  • accumulation_steps (int) – Total number of accumulation steps.

  • accumulation_step (int) – Current accumulation step.

Return type:

None

track_after_update_step(update_counter, times)

Invoked after every update step. May be used to track metrics. Applies torch.no_grad().

Parameters:
Return type:

None

after_epoch(update_counter, **kwargs)

Invoked after every epoch to check if callback should be invoked. Applies torch.no_grad().

Parameters:

update_counter (ksuit.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None

after_update(update_counter, **kwargs)

Invoked after every update to check if callback should be invoked. Applies torch.no_grad().

Parameters:

update_counter (ksuit.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Return type:

None

updates_till_next_log(update_counter)

Calculates how many updates remain until this callback is invoked.

Parameters:

update_counter (ksuit.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Returns:

Number of updates remaining until the next callback invocation.

Return type:

int

updates_per_log_interval(update_counter)

Calculates how many updates are from one invocation of this callback to the next.

Parameters:

update_counter (ksuit.utils.training.counter.UpdateCounter) – UpdateCounter instance to access current training progress.

Returns:

Number of updates between callback invocations.

Return type:

int

get_interval_string_verbose()

Returns every_n_epochs, every_n_updates or every_n_samples depending on which one is not None. :returns: Interval as, e.g., “every_n_epochs=1” for epoch-based intervals. :rtype: str

Return type:

str

to_short_interval_string()

Returns every_n_epochs, every_n_updates or every_n_samples depending on which one is not None. :returns: Interval as, e.g., “E1” if every_n_epochs=1 for epoch-based intervals. :rtype: str

Return type:

str