Shortcuts

Source code for torchdata.samplers

r"""**This module implements samplers to be used in conjunction with** `torch.utils.data.DataLoader` **instances**.

Those can be used just like PyTorch's `torch.utils.data.Sampler` instances.

See `PyTorch tutorial <https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler>`__
for more examples and information.

"""

import builtins

import torch
from torch.utils.data import RandomSampler, Sampler, SubsetRandomSampler

from ._base import Base


# Source of mixed class below:
# https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py#L68
[docs]class RandomSubsetSampler(Base, RandomSampler): r"""**Sample elements randomly from a given list of indices.** If without `replacement`, then sample from a shuffled dataset. If with replacement, then user can specify :attr:`num_samples` to draw. Similar to PyTorch's `SubsetRandomSampler`, but this one allows you to specify `indices` which will be sampled in random order, not `range` subsampled. Parameters ---------- indices : typing.Iterable A sequence of indices replacement : bool, optional Samples are drawn with replacement if `True`. Default: `False` num_samples : int, optional Number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when `replacement` is `True`. Default: `None` """ def __init__(self, indices, replacement=False, num_samples=None): RandomSampler.__init__(self, indices, replacement, num_samples) def __iter__(self): for index in RandomSampler.__iter__(self): yield self.data_source[index]
class _Equalizer(Sampler): def __init__(self, labels: torch.tensor, function): tensors = [ torch.nonzero(labels == i, as_tuple=False).flatten() for i in torch.unique(labels) ] self.samples_per_label = getattr(builtins, function)(map(len, tensors)) self.samplers = [ iter( RandomSubsetSampler( tensor, replacement=len(tensor) < self.samples_per_label, num_samples=self.samples_per_label if len(tensor) < self.samples_per_label else None, ) ) for tensor in tensors ] @property def num_samples(self): return self.samples_per_label * len(self.samplers) def __iter__(self): for _ in range(self.samples_per_label): for index in torch.randperm(len(self.samplers)).tolist(): yield next(self.samplers[index]) def __len__(self): return self.num_samples
[docs]class RandomOverSampler(_Equalizer): r"""**Sample elements randomly with underrepresented classes upsampled.** Length is equal to `max_samples_per_class * classes`. Parameters ---------- labels : torch.Tensor Tensor containing labels for respective samples. """ def __init__(self, labels): super().__init__(labels, "max")
[docs]class RandomUnderSampler(_Equalizer): r"""**Sample elements randomly with overrepresnted classes downsampled.** Length is equal to `min_samples_per_class * classes`. Parameters ---------- labels : torch.Tensor Tensor containing labels for respective samples. """ def __init__(self, labels: torch.tensor): super().__init__(labels, "min")
[docs]class Distribution(Sampler): r"""**Sample** `num_samples` **indices from distribution object.** Parameters ---------- distribution : torch.distributions.distribution.Distribution Distribution-like object implementing `sample()` method. num_samples : int Number of samples to be yielded """ def __init__( self, distribution: torch.distributions.distribution.Distribution, num_samples: int, ): self.distribution = distribution self.num_samples = num_samples def __iter__(self): for _ in range(self.num_samples): yield self.distribution.sample() def __len__(self): return self.num_samples