
Source code for torchtraining

"""Root module of `torchtraining` containing common operations.

.. note::

    **IMPORTANT**: This module is one of the most important and
    is used in almost any DL task so be sure to understand it!

Operations in this module can be used to:

    * control pipeline flow
    * select output from `step`s
    * send data to multiple operations

See below for more info.


import typing

import loguru
import torch
import yaml

from . import (accelerators, accumulators, callbacks, cast, device, epochs,
               exceptions, functional, iterations, loss, metrics, pytorch,
               quantization, steps)
from ._base import (Accumulator, Epoch, GeneratorProducer, Iteration,
                    Operation, ReturnProducer, Step)
from ._version import __version__

loguru.logger.level("NONE", no=0)

[docs]class Select(Operation): """Select output item(s) returned from `step` or `iteration` objects. Allows users to focus on specific part of output and pipe specified values to other operations (like `metrics`, `loggers` etc.). .. note:: **IMPORTANT**: This operation is run in almost any case so be sure to understand how it works. Example:: class TrainStep(tt.steps.Train): def forward(self, module, sample): # Generate loss and other necessary items ... return loss, predictions, targets step = TrainStep(criterion, device) # Select `loss` and perform backpropagation # Only single value will be forward to backward from # (loss, predictions, targets) tuple step ** tt.Select(loss=0) ** tt.pytorch.Backward() .. note:: Name of keyword argument can be arbitrary but **you are really encouraged** to name it like the variable returned from `step` (or at least make it understandable to others). Parameters ---------- output_selection : **output_selection `name`: output_index mapping selecting which element from step returned `tuple` to choose. `name` can be arbitrary, but should be named like the variable returned from `step`. See example above. Returns ------- Any | List[Any] If single int is passed `output_selection` return single element from `Iterable`. Otherwise returns chosen elements as `list` """ def __init__(self, **output_selection: int): if len(output_selection) < 1: raise ValueError( "{}: At least one output index has to be specified, got {} output indices.".format( self, len(output_selection) ) ) super().__init__() self.output_selection = output_selection if len(self.output_selection) == 1: self._selection_method = lambda data: data[ list(self.output_selection.values())[0] ] else: self._selection_method = lambda data: [ data[index] for index in list(self.output_selection.values()) ]
[docs] def forward(self, data: typing.Iterable[typing.Any]) -> typing.Any: """ Arguments --------- data: Iterable[Any] Iterable (e.g. `tuple`, `list`) with any elements. """ selected = self._selection_method(data) return selected
def __str__(self) -> str: return yaml.dump({super().__str__(): self.output_selection})
[docs]class Split(Operation): """Split operation with data to multiple components. .. note:: **IMPORTANT**: This operation is run in almost any case so be sure to understand how it works. Useful when users wish to use calculated result in multiple places. Example calculating metrics and logging them in multiple places:: class TrainStep(tt.steps.Train): def forward(self, module, sample): # Generate loss and other necessary items ... # Assume binary classification return loss, logits, targets step = TrainStep(criterion, device) # Push (logits, targets) to Precision and Recall # and log those values after calculating metrics step ** tt.Select(logits=1, targets=2) ** tt.Split( tt.metrics.classification.binary.Precision() ** tt.callbacks.Logger("Precision"), tt.metrics.classification.binary.Recall() ** tt.callbacks.Logger("Recall"), ) Parameters ---------- *operations: Callable(data) -> Any Operations to which results will be passed. return_modified: bool, optional Return outputs from `operations` as a `list` if `True`. If `False`, returns original `data` passed into `Split`. Default: `False` Returns ------- data | List[modified data] Returns `data` passed originally or `list` containing modified data returned from `operations`. """ def __init__( self, *operations: typing.Callable[[typing.Any,], typing.Any], return_modified: bool = False ): if not operations: raise ValueError("Split requires at least one operation to pass data into.") super().__init__() self.operations = operations self.return_modified = return_modified
[docs] def forward( self, data: typing.Any ) -> typing.Union[typing.Any, typing.List[typing.Any]]: """ Arguments --------- data: Any Data which will be passed to provided `operations`. """ processed_data = [] for operation in self.operations: result = operation(data) if self.return_modified: processed_data.append(result) if self.return_modified: return processed_data return data
def __str__(self) -> str: return yaml.dump({super().__str__(): self.operations})
[docs]class OnSplittedTensor(Operation): """Split tensor along dimension and apply operation on each element. By default, `torch.Tensor` will be splitted along batch (`dim=0`) .. note:: **IMPORTANT**: After splitting first dimension is squeezed via `torch.squeeze`. Example:: class TrainStep(tt.steps.Train): def forward(self, module, sample): # Dummy step images, labels return images step = TrainStep(criterion, device) # Assume summary_writer is instance of torch.utils.tensorboard.SummaryWriter() step ** tt.OnSplittedTensor(tt.callbacks.tensorboard.Image(summary_writer)) # Each image will be saved separately Parameters ---------- operation: tt.Operation | Callable(data) -> Any Operation which will be applied to each element of `torch.Tensor`. dim: int, optional Dimension along which `data` `torch.Tensor` will be splitted. Default: `0` Returns ------- Tuple[torch.Tensor] Splitted `data` along `dim` (unmodified by `operation`). """ def __init__(self, operation: Operation, dim: int = 0): super().__init__() self.operation = operation self.dim = dim
[docs] def forward(self, data): """ Arguments --------- data: torch.Tensor Tensor to split and apply `operation` on. """ splitted = tuple(map(torch.squeeze, torch.split(data, 1, self.dim))) for tensor in splitted: self.operation(splitted) return splitted
[docs]class Flatten(Operation): r"""Flatten arbitrarily nested data. Example:: class TrainStep(tt.steps.Train): def forward(self, module, sample): module1, module2 = module ... return ((logits1, targets1), (logits2, targets2), module1, module2) step = TrainStep(criterion, device) # Tuple (logits1, targets1, logits2, targets2, module1, module2) step ** tt.Flatten() Parameters ---------- types : Tuple[type], optional Types to be considered non-flat. Those will be recursively flattened. Default: `(list, tuple)` Returns ------- Tuple[samples] Single `tuple` with all elements (not being `tuple` or `list`). """ def __init__(self, types: typing.Tuple = (list, tuple)): super().__init__() self.types = types
[docs] def forward(self, data) -> typing.List[typing.Any]: """ Parameters ---------- data: Iterable[Iterable ... Iterable[Any]] Arbitrarily nested data being one of type provided in `types`. """ if not isinstance(data, self.types): return data return Flatten._flatten(data, self.types)
@staticmethod def _flatten( items: typing.Iterable[typing.Any], types: typing.Tuple ) -> typing.List[typing.Any]: if isinstance(items, tuple): items = list(items) for index, x in enumerate(items): while index < len(items) and isinstance(items[index], types): items[index : index + 1] = items[index] return items
[docs]class If(Operation): """Run operation only If `condition` is `True`. `condition` can also be a single argument callable, in this case it can be dependent on data, see below:: class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss step = TrainStep(criterion, device) step ** tt.If(lambda loss: loss ** 10, tt.callbacks.Logger("VERY HIGH LOSS!!!")) Parameters ---------- condition: bool | Callable(Any) -> bool If boolean value and if `True`, run underlying Operation (or other Callable). If Callable, should take data as argument and return decision based on that as single `bool`. operation: torchtraining.Operation | Callable Operation to run if `True` Returns ------- Any If `true`, returns value from `operation`, otherwise passes original `data` """ def __init__( self, condition: typing.Union[bool, typing.Callable[[typing.Any], bool,]], operation: typing.Callable[[typing.Any,], typing.Any], ): super().__init__() if isinstance(condition, bool): self._choice_method = lambda data: condition else: self._choice_method = lambda data: condition(data) self.condition = condition self.operation = operation
[docs] def forward(self, data: typing.Any) -> typing.Any: """ Arguments --------- data: Any Anything you want (usually `torch.Tensor` like stuff). """ if self.condition(data): return self.operation(data) return data
def __str__(self) -> str: if self.condition: return str(self.operation) return "no-operation"
# Make it dynamic and static
[docs]class IfElse(Operation): """Run `operation1` only if `condition` is `True`, otherwise run `operation2`. `condition` can also be a single argument callable, in this case it can be dependent on data, see below:: class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss step = TrainStep(criterion, device) step ** tt.IfElse( lambda loss: loss ** 10, tt.callbacks.Logger("VERY HIGH LOSS!!!"), tt.callbacks.Logger("LOSS IS NOT THAT HIGH..."), ) Parameters ---------- condition: bool Boolean value. If `true`, run underlying Op (or other Callable). operation1: torchtraining.Op | Callable Operation or callable getting single argument (`data`) and returning anything. operation2: torchtraining.Op | Callable Operation or callable getting single argument (`data`) and returning anything. Returns ------- Any If `true`, returns value from `operation1`, otherwise from `operation2`. """ def __init__( self, condition: bool, operation1: typing.Callable[[typing.Any,], typing.Any], operation2: typing.Callable[[typing.Any,], typing.Any], ): super().__init__() self.condition = condition self.operation1 = operation1 self.operation2 = operation2
[docs] def forward(self, data: typing.Any) -> typing.Any: """ Arguments --------- data: Any Anything you want (usually `torch.Tensor` like stuff). """ if self.condition: return self.operation1(data) return self.operation2(data)
def __str__(self) -> str: if self.condition: return str(self.operation1) return str(self.operation2)
[docs]class ToAll(Operation): r"""Apply operation to each element of sample.** .. note:: If you want to apply operation to all nested elements (e.g. in nested `tuple`), please use `torchtraining.Flatten` object first. Example:: class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss step = TrainStep(criterion, device) step ** tt.If( lambda loss: loss ** 10, tt.callbacks.Logger("VERY HIGH LOSS!!!"), tt.callbacks.Logger("LOSS IS NOT THAT HIGH..."), ) Parameters ---------- operation: Callable Pipe to apply to each element of sample. Returns ------- Tuple[operation(subsample)] Tuple consisting of subsamples with operation applied. """ def __init__(self, operation: typing.Callable): super().__init__() self.operation = operation
[docs] def forward(self, sample): """ Arguments --------- data: Any Anything you want (usually `torch.Tensor` like stuff). """ return tuple(self.operation(subsample) for subsample in sample)
[docs]class Lambda(Operation): """Run user specified operation on `data`. Example:: class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return accuracy step = TrainStep(criterion, device) # If you want to get that SOTA badly, we got ya covered step ** tt.Lambda(lambda accuracy: accuracy * 2) Parameters ---------- operation: Callable(Any) -> Any Single argument callable getting data and returning some value. name: str, optional `string` representation of this operation (if any). Default: `torchtraining.metrics.Lambda` Returns ------- Any Value returned from `operation` """ def __init__( self, operation: typing.Callable[[typing.Any,], typing.Any], name: str = "torchtraining.Lambda", ): super().__init__() self.operation = operation = name def __str__(self) -> str: return
[docs] def forward(self, data: typing.Any) -> typing.Any: """ Arguments --------- data: Any Anything you want (usually `torch.Tensor` like stuff). """ return self.operation(data)