Source code for torchdata.modifiers
r"""**This module allows you to modify behaviour of** `torchdata.cachers`.
To cache in `memory` only `20` first samples you could do (assuming you have already created
`torchdata.Dataset` instance named `dataset`)::
    dataset.cache(td.modifiers.UpToIndex(20, td.cachers.Memory()))
Modifers could be mixed intuitively as well using logical operators `|` (or) and
`&` (and).
**Example** (cache to disk `20` first or samples with index `1000` and upwards)::
    dataset.cache(
        td.modifiers.UpToIndex(20, td.cachers.Memory())
        | td.modifiers.FromIndex(1000, td.cachers.Memory())
    )
You can mix provided modifiers or extend them by inheriting from `Modifier`
and implementing `condition` method (interface described below).
For most of cases `Lambda` modifier should be sufficient, for example::
    # Only element up to `25th` and those which are divisible by `2`
    dataset = dataset.cache(
        td.modifiers.UpToIndex(25, cacher)
        & td.modifiers.Lambda(lambda index: index % 2 == 0, cacher)
    )
"""
import abc
import typing
from ._base import Base
[docs]class Modifier(Base):
    r"""**Interface for all modifiers.**
    Most methods are pre-configured, so user should not override them.
    In-fact only `condition` has to be overriden and `__init__` implemented.
    Constructor should assign `cacher` to `self` in order for everything
    to work, see example below.
    Example implementation of `modifier` caching only elements `0` to `100`
    of any `td.cacher.Cacher`::
        import torchdata as td
        class ExampleModifier(td.modifiers.Modifier):
            # You have to assign cacher to self.cacher so modifier works.
            def __init__(self, cacher):
                self.cacher = cacher
            def condition(self, index):
                return index < 100 # Cache if index smaller than 100
    """
    @abc.abstractmethod
    def condition(self, index: int) -> bool:
        r"""**Based on index, decide whether cache should interact with the sample.**
        Only this function should be implemented by user.
        If `True` returned, `cacher` will act on sample normally (e.g. saving it or loading).
        Parameters
        ----------
        index : int
                Index of sample
        Returns
        -------
        bool
                Whether to act on sample with given index
        """
[docs]    def __contains__(self, index: int) -> bool:
        r"""**Acts as invisible proxy for** `cacher`'s `__contains__` **method.**
        **User should not override this method.**
        For more information check `torchdata.cacher.Cacher` interface.
        Parameters
        ----------
        index : int
                Index of sample
        """
        if self.condition(index):
            return index in self.cacher
        return False 
[docs]    def __setitem__(self, index: int, data: typing.Any) -> None:
        r"""**Acts as invisible proxy for** `cacher`'s `__setitem__` **method.**
        **User should not override this method.**
        For more information check `torchdata.cacher.Cacher` interface.
        Parameters
        ----------
        index : int
                Index of sample
        data : typing.Any
                Data generated by dataset.
        """
        if self.condition(index):
            self.cacher[index] = data 
[docs]    def __getitem__(self, index: int):
        r"""**Acts as invisible proxy for** `cacher`'s `__getitem__` **method.**
        **User should not override this method.**
        For more information check `torchdata.cacher.Cacher` interface.
        Parameters
        ----------
        index : int
                Index of sample
        """
        return self.cacher[index] 
[docs]    def __or__(self, other):
        r"""**If self or other returns True, then use** `cacher`.
        User should not override this method.
        **Important:** `self` and `other` should have the same `cacher` wrapped.
        Otherwise exception is thrown. Cacher of first modifier is used in such case.
        Parameters
        ----------
        other : Modifier
                Another modifier
        Returns
        -------
        Any
                Modifier concatenating both modifiers.
        """
        return Any(self, other) 
[docs]    def __and__(self, other):
        r"""**If self and other returns True, then use** `cacher`.
        **Important:** `self` and `other` should have the same `cacher` wrapped.
        Cacher of first modifier is used no matter what.
        Parameters
        ----------
        other : Modifier
                Another modifier
        Returns
        -------
        All
                Modifier concatenating both modifiers.
        """
        return All(self, other)  
class _Mix(Modifier):
    r"""**{}**
    Parameters
    ----------
    *modifiers: List[torchdata.modifiers.Modifier]
            List of modifiers
    """
    def __init__(self, *modifiers):
        self.modifiers = modifiers
        self.cacher = modifiers[0].cacher
[docs]class All(_Mix):
    __doc__ = _Mix.__doc__.format(
        r"Return True if all modifiers return True on given sample."
    )
    def condition(self, index):
        return all(modifier.condition(index) for modifier in self.modifiers) 
[docs]class Any(_Mix):
    __doc__ = _Mix.__doc__.format(
        r"Return True if any modifier returns True on given sample."
    )
    def condition(self, index):
        return any(modifier.condition(index) for modifier in self.modifiers) 
class _Percent(Modifier):
    r"""**{}**
    Parameters
    ----------
    p : float
            Percentage specified as flow between `[0, 1]`.
    length : int
            How many samples are in dataset. You can pass `len(dataset)`.
    cacher : torchdata.cacher.Cacher
            Instance of cacher
    """
    @abc.abstractmethod
    def condition(self, index):
        pass
    def __init__(self, p: float, length: int, cacher):
        if not 0 < p < 1:
            raise ValueError(
                "Percentage has to be between 0 and 1, but got {}".format(p)
            )
        self.threshold = int(length * p)
        self.cacher = cacher
[docs]class UpToPercentage(_Percent):
    __doc__ = _Percent.__doc__.format(
        r"""Cache up to percentage of samples leaving the rest untouched."""
    )
    def condition(self, index):
        return index < self.threshold 
[docs]class FromPercentage(_Percent):
    __doc__ = _Percent.__doc__.format(
        r"""Cache from specified percentage of samples leaving the rest untouched."""
    )
    def condition(self, index):
        return index > self.threshold 
class _Index(Modifier):
    r"""**{}**
    Parameters
    ----------
    index : int
            Index of sample
    cacher : torchdata.cacher.Cacher
            Instance of cacher
    """
    @abc.abstractmethod
    def condition(self, index):
        pass
    def __init__(self, index: int, cacher):
        self.index = index
        self.cacher = cacher
[docs]class UpToIndex(_Index):
    __doc__ = _Index.__doc__.format(
        r"""Cache up to samples of specified index leaving the rest untouched."""
    )
    def condition(self, index):
        return index < self.index 
[docs]class FromIndex(_Index):
    __doc__ = _Index.__doc__.format(
        r"""Cache samples from specified index leaving the rest untouched."""
    )
    def condition(self, index):
        return index > self.index 
[docs]class Indices(Modifier):
    r"""**Cache samples if index is one of specified.**
    Parameters
    ----------
    cacher : List[torchdata.modifiers.Modifier]
            List of modifiers
    index : int
            Index of sample
    """
[docs]    def __init__(self, cacher, *indices):
        self.cacher = cacher
        self.indices = indices 
    def condition(self, index):
        return index in self.indices 
[docs]class Lambda(Modifier):
    r"""**Cache samples if specified function returns** `True`.
    Parameters
    ----------
    function: Callable
            Single-element callable, if `True` returned, cache this sample.
            Number of sample is passed as an argument.
    cacher : torchdata.cacher.Cacher
            Instance of cacher
    """
[docs]    def __init__(self, function: typing.Callable, cacher):
        self.function = function 
    def condition(self, index):
        return self.function(index)