ksuit.data.pipeline.collator ============================ .. py:module:: ksuit.data.pipeline.collator Attributes ---------- .. autoapisummary:: ksuit.data.pipeline.collator.CollatorType Classes ------- .. autoapisummary:: ksuit.data.pipeline.collator.Collator Module Contents --------------- .. py:data:: CollatorType .. py:class:: Collator Base object that uses torch.utils.data.default_collate in its __call__ function. Derived classes can overwrite the __call__ implementation to implement a custom collate function. The collator can be passed to torch.utils.data.DataLoader via the collate_fn argument (DataLoader(dataset, batch_size=2, collate_fn=Collator()). .. rubric:: Example >>> collator = Collator() >>> num_samples = 2 >>> samples = [{"data": torch.randn(3, 256, 256)} for _ in range(num_samples)] >>> batch = collator(samples) >>> batch["data"].shape # torch.Size([2, 3, 256, 256])