Shortcuts

Source code for torchdata.datasets

"""**Concrete implementations of** `torchdata.Dataset` **and** `torchdata.Iterable`.

Classes below extend and/or make it easier for user to implement common functionalities.
To use standard PyTorch datasets defined by, for example, `torchvision`, you can
use `WrapDataset` or `WrapIterable` like this::

    import torchdata
    import torchvision

    dataset = torchdata.datasets.WrapDataset(
        torchvision.datasets.MNIST("./data", download=True)
    )

After that you can use `map`, `apply` and other functionalities like you normally would with
either `torchdata.Dataset` or `torchdata.Iterable`.
"""

import abc
import functools
import pathlib
import typing

from torch.utils.data import ChainDataset as TorchChain
from torch.utils.data import ConcatDataset as TorchConcatDataset
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import IterableDataset as TorchIterable
from torch.utils.data import TensorDataset as TorchTensorDataset

from ._base import Base, MetaDataset, MetaIterable
from .cachers import Memory


class _DatasetBase(Base):
    def __init__(self, concat_object, chain_object):
        self._maps = []
        self._concat_object = concat_object
        self._chain_object = chain_object

    def map(self, function: typing.Callable):
        r"""**Map function to each element of dataset.**

        Function has no specified signature; it is user's responsibility to ensure
        it is taking correct arguments as returned from `__getitem__` (in case of `Dataset`)
        or `__iter__` (in case of `Iterable`).

        Parameters
        ----------
        function: typing.Callable
                Function (or functor) taking arguments returned from `__getitem__`
                and returning anything.

        Returns
        -------
        self

        """
        self._maps.append(function)
        return self

    def apply(self, function):
        r"""**Apply function to every element of the dataset.**

        Specified function has to take Python generator as first argument.
        This generator yields consecutive samples from the dataset and the function is free
        to do whatever it wants with them.

        Other arguments will be forwarded to function.

        **WARNING:**

        This function returns anything that's returned from function
        and it's up to user to ensure correct pipeline functioning
        after using this transformation.

        **Example**::

            class Dataset(torchdata.Dataset):
                def __init__(self, max: int):
                    super().__init__() # This is necessary
                    self.range = list(range(max))

                def __getitem__(self, index):
                    return self.range[index]

                def __len__(self):
                    return len(self.range)

            def summation(generator):
                return sum(value for value in generator)

            summed_dataset = Dataset(101).apply(summation) # Returns 5050


        Parameters
        ----------
        function : typing.Callable
                Function (or functional object) taking item generator as first object
                and variable list of other arguments (if necessary).

        Returns
        -------
        typing.Any
                Value returned by function

        """
        return function((value for value in self))

    def __or__(self, other):
        r"""**Concatenate {self} and another {self} compatible object.**

        During iteration, items from both dataset will be returned as `tuple`.
        Another object could be PyTorch's base class of this object.

        Length of resulting dataset is equal to `min(len(self), len(other))`

        Parameters
        ----------
        other : {self} or PyTorch's base counterpart
                Dataset instance whose sample will be iterated over together

        Returns
        -------
        {concat_object}
                Proxy object responsible for concatenation between samples.
                Can be used in the same manner as this object.

        """.format(
            self=self, concat_object=self._concat_object
        )
        return self._concat_object((self, other))

    def __add__(self, other):
        r"""**Chain {self} and another {self} compatible object.**

        During iteration, items from self will be returned first and items
        from other dataset after those.

        Length of such dataset is equal to `len(self) + len(other)`

        Parameters
        ----------
        other : {self} or PyTorch's base counterpart
                Dataset whose sample will be yielded after this dataset.

        Returns
        -------
        {chain_object}
                Proxy object responsible for chaining datasets.
                Can be used in the same manner as this object.

        """.format(
            self=self, chain_object=self._chain_object
        )
        return self._chain_object((self, other))


[docs]class Iterable(TorchIterable, _DatasetBase, metaclass=MetaIterable): r"""`torch.utils.data.IterableDataset` **dataset with extended capabilities**. This class inherits from `torch.utils.data.IterableDataset <https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset>`__, co can be used in the same manner after inheritance. It allows user to perform following operations: - `map` - apply function to each element of dataset - `apply` - apply function to **all** elements of dataset - `filter` - return elements for which `predicate` returns `True` **Example**:: # Based on original PyTorch example class Dataset(torchdata.Iterable): def __init__(self, start: int, end: int): super().__init__() # This is necessary self.start: int = start self.end: int = end def __iter__(self): return iter(range(self.start, self.end)) # range(1,25) originally, mapped to range(13, 37) dataset = Dataset(1, 25).map(lambda value: value + 12) # Sample-wise concatenation, yields range(13, 37) and range(1, 25) for first, second in dataset | Dataset(1, 25): print(first, second) # 13 1 up to 37 25 """ @abc.abstractmethod def __iter__(self): pass def __init__(self): _DatasetBase.__init__(self, ConcatIterable, ChainIterable) self._filters = [] self._which = [0] def filter(self, predicate: typing.Callable): r"""**Filtered data according to** `predicate`. Values are filtered based on value returned after every operation (including `map`) specified before `filter`, for example:: dataset = ( ExampleIterable(0, 100) .map(lambda value: value + 50) .filter(lambda elem: elem % 2 == 0) ) Above will return elements `[50, 100]` divisible by `2`. Parameters ---------- predicate: Callable -> bool Function returning bool and taking single argument (which is whatever is returned from the dataset when `filter` is applied). If `True`, sample will be returned, otherwise it is skipped. Returns ------- Dataset Returns self """ self._which.append(len(self._maps)) self._filters.append(predicate) return self
[docs]class Dataset(TorchDataset, _DatasetBase, metaclass=MetaDataset): r"""`torch.utils.data.Dataset` **with extended capabilities.** This class inherits from `torch.utils.data.Dataset <https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset>`__, co can be used in the same manner after inheritance. It allows user to perform the following operations: - `cache` - cache all/part of data in memory or on disk - `map` - apply function to each element of dataset - `apply` - apply function to **all** elements of dataset - `reduce` - reduce dataset to single value with specified function **Important:** - Last cache which is able to hold sample is used. Does not matter whether it's in-memory or on-disk or user-specified. - Although multiple cache calls in different parts of `map` should work, users are encouraged to use it as rare as possible and possibly as late as possible for best performance. **Example**:: import torchvision from PIL import Image # Image loading dataset (use Files for more serious business) class Dataset(torchdata.Dataset): def __init__(self, path: pathlib.Path): super().__init__() # This is necessary self.files = [file for file in path.glob("*")] def __getitem__(self, index): return Image.open(self.files[index]) def __len__(self, index): return len(self.files) # Map PIL to Tensor and cache dataset dataset = Dataset("data").map(torchvision.transforms.ToTensor()).cache() # Create DataLoader as normally dataloader = torch.utils.data.DataLoader(dataset) """ def __init__(self): _DatasetBase.__init__(self, ConcatDataset, ConcatIterable) self._cachers = [] self._which = [] @abc.abstractmethod def __len__(self): pass @abc.abstractmethod def __getitem__(self, index): pass
[docs] def cache(self, cacher: typing.Callable = None): r"""**Cache data in memory, disk or specify custom caching.** By default all samples are cached in memory. To change this behaviour specify `cacher` argument. Some `cacher` implementations can be found in `torchdata.cacher` module or you can provide your own by inheriting from `torchdata.cacher.Cacher` and implementing appropriate methods. Parameters ---------- cacher : torchdata.cacher.Cacher, optional Instance of `torchdata.cacher.Cacher` (or any other object with compatible interface). Check `cacher` module documentation for more information. Default: `torchdata.cacher.Memory` which caches data in-memory Returns ------- Dataset Returns self """ if cacher is None: cacher = Memory() self._cachers.append(cacher) self._which.append(len(self._maps)) return self
def reduce(self, function: typing.Callable, initializer=None): r"""**Reduce dataset to single element with function.** Works like `functools.reduce <https://docs.python.org/3/library/functools.html#functools.reduce>`__. **Example**:: class Dataset(torchdata.Dataset): def __init__(self, max: int): super().__init__() # This is necessary self.range = list(range(max)) def __getitem__(self, index): return self.range[index] def __len__(self): return len(self.range) summed_dataset = Dataset(10).reduce(lambda x, y: x + y) # Returns 45 Parameters ---------- function : typing.Callable Two argument function returning single value used to `reduce` dataset. initializer: typing.Any, optional Value with which reduction will start. Returns ------- typing.Any Reduced value """ if initializer is None: return functools.reduce(function, (item for item in self)) return functools.reduce(function, (item for item in self), initializer) def reset(self, cache: bool = True, maps: bool = True): r"""**Reset dataset state.** `cache` and `maps` can be resetted separately. Parameters ---------- cache : bool, optional Reset current cache. Default: `True` maps : bool, optional Reset current disk cache. Default: `True` """ if cache: self._cachers = [] if maps: self._maps = []
################################################################################ # # Dataset Concatenations # ################################################################################
[docs]class ConcatDataset(Dataset): r"""**Concrete** `torchdata.Dataset` **responsible for sample-wise concatenation.** This class is returned when `|` (logical or operator) is used on instance of `torchdata.Dataset` (original `torch.utils.data.Dataset <https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset>`__ can be used as well). **Important:** This class is meant to be more of a proxy for `|` operator, you can use it directly though. **Example**:: dataset = ( torchdata.ConcatDataset([dataset1, dataset2, dataset3]) .map(lambda sample: sample[0] + sample[1] + sample[2])) ) Any `Dataset` methods can be used normally. Attributes ---------- datasets : List[Union[torchdata.Dataset, torch.utils.data.Dataset]] List of datasets to be concatenated sample-wise. """
[docs] def __init__(self, datasets: typing.List): super().__init__() self.datasets = datasets
def __getitem__(self, index): return tuple(dataset[index] for dataset in self.datasets) def __len__(self): return min(len(dataset) for dataset in self.datasets)
[docs]class ConcatIterable(Iterable): r"""**Concrete** `Iterable` **responsible for sample-wise concatenation.** This class is returned when `|` (logical or operator) is used on instance of `Iterable` (original `torch.utils.data.IterableDataset <https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset>`__ can be used as well). **Important:** This class is meant to be more of a proxy for `|` operator, you can use it directly though. **Example**:: dataset = ( torchdata.ConcatIterable([dataset1, dataset2, dataset3]) .map(lambda x, y, z: (x + y, z)) ) Any `IterableDataset` methods can be used normally. Attributes ---------- datasets : List[Union[torchdata.Iterable, torch.utils.data.IterableDataset]] List of datasets to be concatenated sample-wise. """
[docs] def __init__(self, datasets: typing.List): super().__init__() self.datasets = datasets
def __iter__(self): yield from zip(*self.datasets) def __getitem__(self, index): return tuple(dataset[index] for dataset in self.datasets) def __len__(self): return min(len(dataset) for dataset in self.datasets)
[docs]class ChainDataset(TorchConcatDataset, Dataset): r"""**Concrete** `torchdata.Dataset` **responsible for chaining multiple datasets.** This class is returned when `+` (logical or operator) is used on instance of `torchdata.Dataset` (original `torch.utils.data.Dataset` can be used as well). Acts just like PyTorch's `+` or rather `torch.utils.data.ConcatDataset <https://pytorch.org/docs/stable/data.html#torch.utils.data.ConcatDataset>`__ **Important:** This class is meant to be more of a proxy for `+` operator, you can use it directly though. **Example**:: # Iterate over 3 datasets consecutively dataset = torchdata.ChainDataset([dataset1, dataset2, dataset3]) Any `Dataset` methods can be used normally. Attributes ---------- datasets : List[Union[torchdata.Dataset, torch.utils.data.Dataset]] List of datasets to be chained. """
[docs] def __init__(self, datasets): Dataset.__init__(self) TorchConcatDataset.__init__(self, datasets)
[docs]class ChainIterable(TorchChain, Iterable): r"""**Concrete** `torchdata.Iterable` **responsible for chaining multiple datasets.** This class is returned when `+` (logical or operator) is used on instance of `torchdata.Iterable` (original `torch.utils.data.Iterable` can be used as well). Acts just like PyTorch's `+` and `ChainDataset <https://pytorch.org/docs/stable/data.html#torch.utils.data.ChainDataset>`__. **Important:** This class is meant to be more of a proxy for `+` operator, you can use it directly though. **Example**:: # Iterate over 3 iterable datasets consecutively dataset = torchdata.ChainDataset([dataset1, dataset2, dataset3]) Any `Iterable` methods can be used normally. Attributes ---------- datasets : List[Union[torchdata.Iterable, torch.utils.data.IterableDataset]] List of datasets to be chained. """
[docs] def __init__(self, datasets): Iterable.__init__(self) TorchChain.__init__(self, datasets)
############################################################################### # # CONCRETE CLASSES # ###############################################################################
[docs]class Files(Dataset): r"""**Create** `Dataset` **from list of files.** Each file is a separate sample. User can use this class directly as all necessary methods are implemented. `__getitem__` uses Python's `open <https://docs.python.org/3/library/functions.html#open>`__ and returns file. It's implementation looks like:: # You can modify open behaviour by passing args nad kwargs to __init__ with open(self.files[index], *self.args, **self.kwargs) as file: return file you can use `map` method in order to modify returned `file` or you can overload `__getitem__` (image opening example below):: import torchdata import torchvision from PIL import Image # Image loading dataset class ImageDataset(torchdata.datasets.FilesDataset): def __getitem__(self, index): return Image.open(self.files[index]) # Useful class methods are inherited as well dataset = ImageDataset.from_folder("./data", regex="*.png").map( torchvision.transforms.ToTensor() ) `from_folder` class method is available for common case of creating dataset from files in folder. Parameters ---------- files : List[pathlib.Path] List of files to be used. regex : str, optional Regex to be used for filtering. Default: `*` (all files) *args Arguments saved for `__getitem__` **kwargs Keyword arguments saved for `__getitem__` """
[docs] @classmethod def from_folder(cls, path: pathlib.Path, regex: str = "*", *args, **kwargs): r"""**Create dataset from** `pathlib.Path` **-like object.** Path should be a directory and will be extended via `glob` method taking `regex` (if specified). Varargs and kwargs will be saved for use for `__getitem__` method. Parameters ---------- path : pathlib.Path Path object (directory) containing samples. regex : str, optional Regex to be used for filtering. Default: `*` (all files) *args Arguments saved for `__getitem__` **kwargs Keyword arguments saved for `__getitem__` Returns ------- FilesDataset Instance of your file based dataset. """ files = [file for file in path.glob(regex)] return cls(files, *args, **kwargs)
[docs] def __init__(self, files: typing.List[pathlib.Path], *args, **kwargs): super().__init__() self.files = files self.args = args self.kwargs = kwargs
def __len__(self): return len(self.files) def __getitem__(self, index): with open(self.files[index], *self.args, **self.kwargs) as file: return file
[docs] def filter(self, predicate: typing.Callable): r"""**Remove** `files` **for which predicate returns** `False`**.** **Note:** This is different from `torchdata.Iterable`'s `filter` method, as the filtering is done when called, not during iteration. Parameters ---------- predicate : Callable Function-like object taking file as argument and returning boolean indicating whether to keep a file. Returns ------- FilesDataset Modified self """ self.files = [file for file in self.files if predicate(file)] return self
[docs] def sort(self, key=None, reverse=False): r"""**Sort files using Python's built-in** `sorted` **method.** Arguments are passed directly to `sorted`. Parameters ---------- key: Callable, optional Specifies a function of one argument that is used to extract a comparison key from each element. Default: `None` (compare the elements directly). reverse: bool, optional Whether `sorting` should be descending. Default: `False` Returns ------- FilesDataset Modified self """ self.files = sorted(self.files, key=key, reverse=reverse) return self
[docs]class TensorDataset(TorchTensorDataset, Dataset): r"""**Dataset wrapping** `torch.tensors` **.** `cache`, `map` etc. enabled version of `torch.utils.data.TensorDataset <https://pytorch.org/docs/stable/data.html#torch.utils.data.TensorDataset>`__. Parameters: ----------- *tensors : torch.Tensor List of `tensors` to be wrapped. """
[docs] def __init__(self, *tensors): Dataset.__init__(self) TorchTensorDataset.__init__(self, *tensors)
[docs]class Generator(Iterable): r"""**Iterable wrapping any generator expression.** Parameters: ----------- expression: Generator expression Generator from which one can `yield` via `yield from` syntax. """
[docs] def __init__(self, expression): super().__init__() self.expression = expression
def __iter__(self): yield from self.expression
class _Wrap: def __getattr__(self, name): return getattr(self.dataset, name)
[docs]class WrapDataset(_Wrap, Dataset): r"""**Dataset wrapping standard** `torch.data.utils.Dataset` **and making it** `torchdata.Dataset` **compatible.** All attributes of wrapped dataset can be used normally, for example:: dataset = td.datasets.WrapDataset( torchvision.datasets.MNIST("./data") ) dataset.train # True, has all MNIST attributes Parameters: ----------- dataset: `torch.data.utils.Dataset` Dataset to be wrapped """
[docs] def __init__(self, dataset): self.dataset = dataset Dataset.__init__(self)
def __getitem__(self, index): return self.dataset[index] def __len__(self): return len(self.dataset)
[docs]class WrapIterable(_Wrap, Iterable): r"""**Iterable wrapping standard** `torch.data.utils.IterableDataset` **and making it** `torchdata.Iterable` **compatible.** All attributes of wrapped dataset can be used normally as is the case for `torchdata.datasets.WrapDataset`. Parameters: ----------- dataset: `torch.data.utils.Dataset` Dataset to be wrapped """
[docs] def __init__(self, dataset): Iterable.__init__(self) self.dataset = dataset
def __iter__(self): yield from self.dataset