ksuit.data.samplers.interleaved_sampler

Classes

SamplerIntervalConfig

Configuration dataclass for setting up the dataloading pipeline, which is structured to load data from a "main"

InterleavedSamplerConfig

!!! abstract "Usage Documentation"

InterleavedSampler

Sampler to allow efficient dataloading by using a single large dataset containing train/test/... datasets all at

Module Contents

class ksuit.data.samplers.interleaved_sampler.SamplerIntervalConfig

Configuration dataclass for setting up the dataloading pipeline, which is structured to load data from a “main” dataset (i.e., the dataset used for training), which is interleaved by iterations over other datasets (e.g., a test dataset to calculate a metric in a callback) in regular intervals.

Parameters:
  • sampler (SizedIterable) – Any sampler that would be used in torch.utils.data.DataLoader(sampler=…). Examples: RandomSampler for a training dataset or SequentialSampler for evaluation.

  • every_n_epochs (int | None) – Epoch-based interval. Invokes the callback after every n epochs. Mutually exclusive with other intervals.

  • every_n_updates (int | None) – Update-based interval. Invokes the callback after every n epochs. Mutually exclusive with other intervals.

  • every_n_samples (int | None) – Sample-based interval. Invokes the callback after every n epochs. Mutually exclusive with other intervals.

  • pipeline (Optional[callable]) – Any function that would be used in torch.utils.data.DataLoader(collate_fn=…).

  • batch_size (int | None) – Batch size to use for this callback. Default: None (which will use the same batch_size as used for the “main” sampler, i.e., the one used for training).

sampler: ksuit.utils.common.SizedIterable
pipeline: collections.abc.Callable | None
every_n_epochs: int | None = None
every_n_updates: int | None = None
every_n_samples: int | None = None
batch_size: int | None = None
validate_frequency()

Ensures that exactly one frequency (‘every_n_*’) is specified and that ‘batch_size’ is present if ‘every_n_samples’ is used.

Return type:

SamplerIntervalConfig

classmethod check_positive_values(v)

Ensures that all integer-based frequency and batch size fields are positive.

Parameters:

v (int | None)

Return type:

int | None

class ksuit.data.samplers.interleaved_sampler.InterleavedSamplerConfig(/, **data)

Bases: pydantic.BaseModel

!!! abstract “Usage Documentation”

[Models](../concepts/models.md)

A base class for creating Pydantic models.

Parameters:

data (Any)

__class_vars__

The names of the class variables defined on the model.

__private_attributes__

Metadata about the private attributes of the model.

__signature__

The synthesized __init__ [Signature][inspect.Signature] of the model.

__pydantic_complete__

Whether model building is completed, or if there are still undefined fields.

__pydantic_core_schema__

The core schema of the model.

__pydantic_custom_init__

Whether the model has a custom __init__ function.

__pydantic_decorators__

Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.

__pydantic_generic_metadata__

Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.

__pydantic_parent_namespace__

Parent namespace of the model, used for automatic rebuilding of models.

__pydantic_post_init__

The name of the post-init method for the model, if defined.

__pydantic_root_model__

Whether the model is a [RootModel][pydantic.root_model.RootModel].

__pydantic_serializer__

The pydantic-core SchemaSerializer used to dump instances of the model.

__pydantic_validator__

The pydantic-core SchemaValidator used to validate instances of the model.

__pydantic_fields__

A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.

__pydantic_computed_fields__

A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.

__pydantic_extra__

A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to ‘allow’.

__pydantic_fields_set__

The names of fields explicitly set during instantiation.

__pydantic_private__

Values of private attributes set on the model instance.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

batch_size: int

batch_size to use for creating batches of the main_sampler indices.

drop_last: bool = True

Whether to drop the last non-full batch of the main_sampler.

max_epochs: int | None = None

How many epochs to sample at most from the main_sampler. Whatever limit is reached first (epochs/updates/samples) will stop the sampling.

max_updates: int | None = None

How many updates to sample at most from the main_sampler. Whatever limit is reached first (epochs/updates/samples) will stop the sampling.

max_samples: int | None = None

How many samples to sample at most from the main_sampler. Whatever limit is reached first (epochs/updates/samples) will stop the sampling.

start_epoch: int | None = None

At which epoch to start (used for resuming training). Mutually exclusive with start_update and start_sample.

start_update: int | None = None

At which update to start (used for resuming training). Mutually exclusive with start_epoch and start_sample.

start_sample: int | None = None

At which sample to start (used for resuming training). Mutually exclusive with start_epoch and start_update.

classmethod check_positive_values(v)

Ensures that all integer-based frequency and batch size fields are positive.

Parameters:

v (int | None)

Return type:

int | None

validate_stop()

Ensures that at least one frequency (’_n_’) is specified and

Return type:

InterleavedSamplerConfig

validate_start()

Ensures that at least one start (‘start_*’) is specified

Return type:

InterleavedSamplerConfig

class ksuit.data.samplers.interleaved_sampler.InterleavedSampler(train_sampler, config, train_collator=None, callback_samplers=None)

Sampler to allow efficient dataloading by using a single large dataset containing train/test/… datasets all at once. The sampler will sample from different regionis in the dataset according to its specification. For example, consider a training dataset of length 100 and a test dataset of length 10. If the sampler is configured with a RandomSampler of the training dataset indices as main_sampler, it will repeatedly iterate over the training dataset. If the test dataset is configured with a sequential sampler that should be invoked after every epoch, the sampler will first return indices for the 100 training samples (randomly sampled) and then indices for the 10 test samples (in sequential order).

Parameters:
  • train_sampler (ksuit.utils.common.SizedIterable) – Sampler that is invoked by default (e.g., randomly sample from the trainset)

  • config (InterleavedSamplerConfig) – Configuration for the InterleavedSampler.

  • train_collator (collections.abc.Callable | None) – Collator used to collate samples from indices sampled from the train sampler.

  • callback_samplers (list[SamplerIntervalConfig] | None) – Configurations when the train_sampler should be paused and indices from other samplers (e.g., from a testset) should be returned. Also configures the interval and optionally a different batch_size to use for the interleaved batches.

config
main_sampler
extra_samplers = []
index_offsets
dataset
collator
batch_sampler
batch_size
static calculate_start(config, sampler_len)
Parameters:
get_data_loader(num_workers=0, pin_memory=False)

Creates the DataLoader that uses the InterleavedSampler with the accordingly configured dataset.

Parameters:
  • num_workers (int) – Number of workers to use.

  • pin_memory (bool) – Whether to use pin memory.

Returns:

DataLoader that uses the InterleavedSampler with the accordingly configured dataset.

Return type:

torch.utils.data.DataLoader