Shortcuts

Source code for torchdata.cachers

r"""**This module contains interface needed for** `cachers` **(used in** `cache` **method of** `td.Dataset` **) .**

To cache on disk all samples using Python's `pickle <https://docs.python.org/3/library/pickle.html>`__ in folder `cache`
(assuming you have already created `td.Dataset` instance named `dataset`)::

    import torchdata as td

    ...
    dataset.cache(td.cachers.Pickle("./cache"))

Users are encouraged to write their custom `cachers` if the ones provided below
are too slow or not good enough for their purposes (see `Cacher` abstract interface below).

"""

import abc
import pathlib
import pickle
import shutil
import typing

from ._base import Base


[docs]class Cacher(Base): r"""**Interface to fulfil to make object compatible with** `torchdata.Dataset.cache` **method.** If you want to implement your own `caching` functionality, inherit from this class and implement methods described below. """
[docs] @abc.abstractmethod def __contains__(self, index: int) -> bool: r"""**Return true if sample under** `index` **is cached.** If `False` returned, cacher's `__setitem__` will be called, hence if you are not going to cache sample under this `index`, you should describe this operation at that method. This is simply a boolean indicator whether sample is cached. If `True` cacher's `__getitem__` will be called and it's users responsibility to return correct value in such case. Parameters ---------- index : int Index of sample """
# Save if doesn't contain
[docs] @abc.abstractmethod def __setitem__(self, index: int, data: typing.Any) -> None: r"""**Saves sample under index in cache or do nothing.** This function should save sample under `index` to be later retrieved by `__getitem__`. If you don't want to save specific `index`, you can implement this functionality in `cacher` or create separate `modifier` solely for this purpose (second approach is highly recommended). Parameters ---------- index : int Index of sample data : Any Data generated by dataset. """
# Save if doesn't contain
[docs] @abc.abstractmethod def __getitem__(self, index) -> typing.Any: r"""**Retrieve sample from cache.** **This function MUST return valid data sample and it's users responsibility if custom cacher is implemented**. Return from this function datasample which lies under it's respective `index`. Parameters ---------- index : int Index of sample """
[docs]class Pickle(Cacher): r"""**Save and load data from disk using** `pickle` **module.** Data will be saved as `.pkl` in specified path. If path does not exist, it will be created. **This object can be used as a** `context manager` **and it will delete** `path` **at the end of block**:: with td.cachers.Pickle(pathlib.Path("./disk")) as pickler: dataset = dataset.map(lambda x: x+1).cache(pickler) ... # Do something with dataset ... # Folder removed You can also issue `clean()` method manually for the same effect (though it's discouraged as you might crash `__setitem__` method). **Important:** This `cacher` can act between consecutive runs, just don't use `clean()` method or don't delete the folder manually. If so, **please ensure correct sampling** (same seed and sampling order) for reproducible behaviour between runs. Attributes ---------- path: pathlib.Path Path to the folder where samples will be saved and loaded from. extension: str Extension to use for saved pickle files. Default: `.pkl` """ def __init__(self, path: pathlib.Path, extension: str = ".pkl"): self.path = path self.path.mkdir(parents=True, exist_ok=True) self.extension = extension
[docs] def __contains__(self, index: int) -> bool: """**Check whether file exists on disk.** If file is available it is considered cached, hence you can cache data between multiple runs (if you ensure repeatable sampling). """ return pathlib.Path( (self.path / str(index)).with_suffix(self.extension) ).is_file()
[docs] def __setitem__(self, index: int, data: int): """**Save** `data` **in specified folder.** Name of the item will be equal to `{self.path}/{index}{extension}`. """ with open((self.path / str(index)).with_suffix(self.extension), "wb") as file: pickle.dump(data, file)
[docs] def __getitem__(self, index: int): """**Retrieve** `data` **specified by** `index`. Name of the item will be equal to `{self.path}/{index}{extension}`. """ with open((self.path / str(index)).with_suffix(self.extension), "rb") as file: return pickle.load(file)
[docs] def clean(self) -> None: """**Remove recursively folder** `self.path`. Behaves just like `shutil.rmtree`, but won't act if directory does not exist. """ if self.path.is_dir(): shutil.rmtree(self.path)
def __enter__(self): return self def __exit__(self, *args): self.clean()
[docs]class Memory(Cacher): r"""**Save and load data in Python dictionary**. This `cacher` is used by default inside `torchdata.Dataset`. """ def __init__(self): self.cache = {}
[docs] def __contains__(self, index: int) -> bool: """True if index in dictionary.""" return index in self.cache
[docs] def __setitem__(self, index: int, data: int): """Adds data to dictionary.""" self.cache[index] = data
[docs] def __getitem__(self, index: int): """Retrieve data from dictionary.""" return self.cache[index]