ksuit.data.base.dataset¶
Classes¶
Ksuit dataset implementation, which is a wrapper around torch.utils.data.Dataset that can hold a dataset_config_provider. |
Functions¶
|
Decorator to apply a normalizer to the output of a getitem_* function of the implemented Dataset class. |
Module Contents¶
- ksuit.data.base.dataset.with_normalizers(normalizer_key)¶
Decorator to apply a normalizer to the output of a getitem_* function of the implemented Dataset class.
This decorator will look for a normalizer registered under the specified key and apply it to the output of the decorated function. Exaple usage: >>> @with_normalizers(“image”) >>> def getitem_image(self, idx): >>> # Load image tensor >>> return torch.load(f”{self.path}/image_tensor/{idx}.pt”)
- Parameters:
normalizer_key (str) – The key of the normalizer to apply. This key should be present in the self.normalizers dictionary of the Dataset class.
- class ksuit.data.base.dataset.Dataset(dataset_config)¶
Bases:
torch.utils.data.DatasetKsuit dataset implementation, which is a wrapper around torch.utils.data.Dataset that can hold a dataset_config_provider. A dataset should map a key (i.e., an index) to its corresponding data. Each sub-class should implement individual getitem_* methods, where * is the name of an item in the dataset. Each getitem_* method loads an individual tensor/data sample from disk. For example, if you dataset consists of images and targets/labels (stored as tensors), a getitem_image(idx) and getitem_target(idx) method should be implemented in the dataset subclass. The __getitem__ method of this class will loop over all the individual getitem_* methods implemented by the child class and return their results. Optionally it is possible to configure which getitem methods are called.
- Example: Image classification datasets
>>> class ImageDataset(Dataset): >>> def __init__(self, path, dataset_normalizers, **kwargs): >>> super().__init__(dataset_normalizers=dataset_normalizers, **kwargs) >>> self.path = path >>> def __len__(self): >>> return 100 # Example length >>> def getitem_image(self, idx): >>> # Load image tensor >>> return torch.load(f"{self.path}/image_tensor/{idx}.pt") >>> def getitem_target(self, idx): >>> # Load target tensor >>> return torch.load(f"{self.path}/target_tensor/{idx}.pt") >>> >>> dataset = ImageDataset("path/to/dataset") >>> sample0 = dataset[0] >>> image_0 = sample0["image"] >>> target_0 = sample0["target"]
Data from a getitem method should be normalized in many cases. To apply normalization, add a the decorator function to the getitem method. For example:
>>> @with_normalizers("image") >>> def getitem_image(self, idx): >>> # Load image tensor >>> return torch.load(f"{self.path}/image_tensor/{idx}.pt")
“image” is the key in the self.normalizers dictionary, this key maps to a preprocessor that should implement the correct data normalization.
Args: dataset_config_provider: Optional provider for dataset configuration. dataset_normalizers: A dictionary that contains normalization ComposePreProcess(ers) for each data type. The key for each normalizer can be used in the with_normalizers decorator.
- Parameters:
dataset_config (ksuit.schemas.data.dataset_config.DatasetBaseConfig)
- logger¶
- config¶