
Source code for torchtraining.steps

"""Perform single step on data and via specific module(s).

.. note::

    **IMPORTANT**: This module is one of core features
    so be sure to understand how it works.
    It is the core and defines how you perform single
    step through the data.

See `Introduction tutorial <>`_ for example of `step`.

Usually it looks something along those lines::

    class Step(tt.steps.Step):
        def forward(self, module, batch):
            images, labels = batch
            images, labels =,

            predictions = module(images)
            loss = self.criterion(predictions, labels)

            return loss, predictions, labels

    step = Step(criterion=torch.nn.BCEWithLogitsLoss, device=torch.device("cuda"))

.. note::

    **IMPORTANT**: You can override `__init__` if you wish to pass
    other arguments.

.. note::

    **IMPORTANT**: You can override `forward` signature to anything you
    desire. Just be sure to pass appropriate data to it (via `iteration` or `epoch`)
    or simple `__call__`.

.. note::

    **IMPORTANT**: `module` is passed from other objects and can be anything.
    In case of GANs in tutorial this is a `Tuple` of `torch.nn.Module`.


import abc
import collections
import dataclasses
import typing

import torch

from . import _base, utils
from .utils import steps as steps_utils

[docs]@steps_utils.docstring( header="General `step`, usable both in training & evaluation.", body="User should override `forward` method.", ) class Step(_base.Step): def __init__( self, criterion: typing.Callable, gradient, device=None, ): super().__init__() # Criterion self.criterion = criterion self.gradient = gradient self.device = device def __call__(self, *args, **kwargs): with torch.set_grad_enabled(self.gradient): return super().__call__(*args, **kwargs)
[docs] @abc.abstractmethod def forward(self, *args, **kwargs): pass
[docs]@steps_utils.docstring( header="Perform user specified training step with enabled gradient.", body="Users should override forward method.", ) class Train(Step): def __init__( self, criterion: typing.Callable, device=None, ): super().__init__(criterion, True, device)
[docs] @abc.abstractmethod def forward(self, *args, **kwargs): pass
[docs]@steps_utils.docstring( header="Perform user specified evaluation step with disabled gradient.", body="Users should override forward method.", ) class Eval(Step): def __init__( self, criterion: typing.Callable, device=None, ): super().__init__(criterion, False, device)
[docs] @abc.abstractmethod def forward(self, *args, **kwargs): pass