ksuit.trainers.base

Classes

BaseTrainer

Base class for all trainers that use SGD-based optimizers.

Module Contents

class ksuit.trainers.base.BaseTrainer(config, data_container, device, tracker, path_provider, main_sampler_kwargs=None, metric_property_provider=None)

Base class for all trainers that use SGD-based optimizers.

Parameters:
  • trainer_config – The configuration for the trainer. Implements the BaseTrainerConfig schema.

  • data_container (ksuit.utils.data.data_container.DataContainer) – The data container which includes the data and dataloader.

  • device (str) – The device to use for training (e.g., “cuda”). It is assumed that the process was configured such that only 1 device is visible (e.g., via the CUDA_VISIBLE_DEVICES environment variable).

  • precision – The precision to use for training (e.g., “float32”).

  • main_sampler_kwargs (dict | None) – Kwargs passed to instantiate the main sampler.

  • tracker (ksuit.trackers.BaseTracker) – The tracker to use for training.

  • path_provider (ksuit.providers.PathProvider) – The path provider to use for training.

  • metric_property_provider (ksuit.providers.MetricPropertyProvider | None) – The metric property provider to use for training.

  • config (ksuit.schemas.BaseTrainerConfig)

logger
config
data_container
path_provider
main_sampler_kwargs = None
device: torch.device
end_checkpoint
precision = Ellipsis
updates_per_epoch
skip_nan_loss_counter = 0
initializer: ksuit.initializers.ResumeInitializer | None = None
tracker
metric_property_provider = None
update_counter
log_writer
checkpoint_writer
callbacks: list[ksuit.callbacks.CallbackBase] = []
forward_properties
target_properties
batch_keys
get_all_callbacks(model)

Get all callbacks including default/trainer callbacks.

Parameters:

model (ksuit.models.model_base.ModelBase)

Return type:

list[ksuit.callbacks.CallbackBase]

get_trainer_callbacks(callback_default_args)

Get trainer-specific callbacks. This may optionally be overridden by derived classes.

Parameters:

callback_default_args (dict[str, Any])

Return type:

list[ksuit.callbacks.CallbackBase]

get_default_callback_kwargs(model)

Get default kwargs for callbacks constructor.

Parameters:

model (ksuit.models.model_base.ModelBase)

Return type:

dict[str, Any]

get_default_callback_intervals()

Get default intervals at which callbacks are called.

Return type:

dict[str, Any]

get_default_callbacks(default_kwargs)

Get default callbacks.

Parameters:

default_kwargs (dict[str, Any])

Return type:

list[ksuit.callbacks.CallbackBase]

state_dict()

Get the state dict of the trainer.

Return type:

dict[str, Any]

load_state_dict(state_dict)

Load the state dict of the trainer.

Parameters:

state_dict (dict[str, Any])

Return type:

None

apply_resume_initializer(model)

Apply the resume initializer to the model.

Parameters:

model (ksuit.models.model_base.ModelBase)

Return type:

None

get_data_loader(iterator_callbacks, batch_size)

Get the data loader for training.

Parameters:
  • iterator_callbacks (list[ksuit.callbacks.PeriodicIteratorCallback])

  • batch_size (int)

Return type:

torch.utils.data.DataLoader

abstractmethod loss_compute(forward_output, targets)

Each trainer that extends this class needs to implement a custom loss computation by using the targers and the output of the model. :param forward_output: Output of the model after the forward pass. :param targets: Dict with target tensors needed to compute the loss for this trainer

Returns:

A dict with the (weighted) sub-losses to log.

Parameters:
  • forward_output (dict[str, torch.Tensor])

  • targets (dict[str, torch.Tensor])

Return type:

ksuit.trainers.types.LossResult | tuple[ksuit.trainers.types.LossResult, dict[str, torch.Tensor]]

train_step(batch, dist_model)

Overriding this function is optional and, by default, the train_step of the model will be called and is expected to return a TrainerResult. Trainers can override this method to implement custom training logic. :param batch: Batch of data from which the loss is calculated. :param dist_model: Model to use for processing the data.

Returns:

Loss for backpropagation, (optionally) individual losses if multiple losses are used and (optionally)

additional infos about the model forward pass that is passed to the callbacks (e.g., the logits and targets to calculate a training accuracy in a callback).

Parameters:
  • batch (dict[str, torch.Tensor])

  • dist_model (torch.nn.Module)

Return type:

ksuit.trainers.types.TrainerResult

wrap_model(model)

Wrap the model for training, return the model, wrapped model and ddp+compiled model.

Parameters:

model (ksuit.models.model_base.ModelBase)

Return type:

torch.nn.Module

wrap_ddp(model)

Wrap the model with DistributedDataParallel in multi-GPU settings.

Parameters:

model (ksuit.models.model_base.ModelBase)

Return type:

ksuit.models.model_base.ModelBase | torch.nn.parallel.DistributedDataParallel

wrap_compile(ddp_model)

Wrap the model with torch.compile.

Parameters:

ddp_model (ksuit.models.model_base.ModelBase | torch.nn.parallel.DistributedDataParallel)

Return type:

torch.nn.Module

train(model)

Train the model.

Parameters:

model (ksuit.models.model_base.ModelBase)

Return type:

None

update(batch, dist_model, model=None, training=True, accumulation_steps=1, iter_step=0, **kwargs)

Perform forward and backward pass.

Parameters:
Return type:

tuple[dict[str, torch.Tensor], dict[str, torch.Tensor] | None]

call_before_training(callbacks)

Hook that is called before training starts.

Parameters:

callbacks (list[ksuit.callbacks.CallbackBase])

Return type:

None

call_after_training(callbacks)

Hook that is called after training ends.

Parameters:

callbacks (list[ksuit.callbacks.CallbackBase])

Return type:

None