ksuit.data.pipeline.collator¶
Attributes¶
Classes¶
Base object that uses torch.utils.data.default_collate in its __call__ function. Derived classes can overwrite |
Module Contents¶
- ksuit.data.pipeline.collator.CollatorType¶
- class ksuit.data.pipeline.collator.Collator¶
Base object that uses torch.utils.data.default_collate in its __call__ function. Derived classes can overwrite the __call__ implementation to implement a custom collate function. The collator can be passed to torch.utils.data.DataLoader via the collate_fn argument (DataLoader(dataset, batch_size=2, collate_fn=Collator()).
Example
>>> collator = Collator() >>> num_samples = 2 >>> samples = [{"data": torch.randn(3, 256, 256)} for _ in range(num_samples)] >>> batch = collator(samples) >>> batch["data"].shape # torch.Size([2, 3, 256, 256])