ksuit.data.pipeline.collator

Attributes

Classes

Collator

Base object that uses torch.utils.data.default_collate in its __call__ function. Derived classes can overwrite

Module Contents

ksuit.data.pipeline.collator.CollatorType
class ksuit.data.pipeline.collator.Collator

Base object that uses torch.utils.data.default_collate in its __call__ function. Derived classes can overwrite the __call__ implementation to implement a custom collate function. The collator can be passed to torch.utils.data.DataLoader via the collate_fn argument (DataLoader(dataset, batch_size=2, collate_fn=Collator()).

Example

>>> collator = Collator()
>>> num_samples = 2
>>> samples = [{"data": torch.randn(3, 256, 256)} for _ in range(num_samples)]
>>> batch = collator(samples)
>>> batch["data"].shape  # torch.Size([2, 3, 256, 256])