ksuit.data.pipeline.multistage

Attributes

Classes

MultiStagePipeline

A Collator that processes the list of samples into a batch in multiple stages:

Module Contents

ksuit.data.pipeline.multistage.SampleProcessorType
ksuit.data.pipeline.multistage.BatchProcessorType
class ksuit.data.pipeline.multistage.MultiStagePipeline(collators=None, sample_processors=None, batch_processors=None)

Bases: ksuit.data.pipeline.collator.Collator

A Collator that processes the list of samples into a batch in multiple stages:
  • sample_processors: Processing the data before collation on a per-sample level.

  • collators: Conversion from a list of samples into a batch (dict of usually tensors).

  • batch_processors: Processing after collation on a batch-level.

Most of the work is usually done by the sample_processors. One or two collators, and batch processors are often not needed. However this depends on the use case. .. rubric:: Example

>>> sample_processors = [MySampleProcessor1(), MySampleProcessor2()]
>>> collators = [MyCollator1(), MyCollator2()]
>>> batch_processors = [MyBatchProcessor1(), MyBatchProcessor2()]
>>> multistage_pipeline = MultiStagePipeline(
>>>     sample_processors=sample_processors,
>>>     collators=collators,
>>>     batch_processors=batch_processors
>>> )
>>> batch = multistage_pipeline(samples)
Parameters:
  • sample_processors (dict[str, SampleProcessorType] | list[SampleProcessorType] | None) – A list of callables that will be applied sequentially to pre-process on a per-sample level (e.g., subsample a pointcloud).

  • collators (dict[str, ksuit.data.pipeline.collator.CollatorType] | list[ksuit.data.pipeline.collator.CollatorType] | None) – A list of callables that will be applied sequentially to convert the list of individual samples into a batched format. If None, the default PyTorch collator will be used.

  • batch_processors (dict[str, BatchProcessorType] | list[BatchProcessorType] | None) – A list of callables that will be applied sequentially to process on a per-batch level.

sample_processors = []
batch_processors = []
get_sample_processor(predicate)

Retrieves a sample processor by a predicate function. Examples: - Search by type (assumes the sample processor type only occurs once in the list of sample processors)

pipeline.get_sample_processor(lambda p: isinstance(p, MySampleProcessorType))

  • Search by type and member pipeline.get_sample_processor(lambda p: isinstance(p, PointSamplingSampleProcessor) and “input_pos” in p.items)

Parameters:

predicate (collections.abc.Callable[[Any], bool]) – A function that is called for each processor and selects if this is the right one.

Returns:

The matching sample processor.

Return type:

Any

Raises:

ValueError – If no matching sample processor are found, multiple matching sample processors are found or if there are no sample processors.