• >
  • torchtraining.steps
Shortcuts

torchtraining.steps

Perform single step on data and via specific module(s).

Note

IMPORTANT: This module is one of core features so be sure to understand how it works. It is the core and defines how you perform single step through the data.

See Introduction tutorial for example of step.

Usually it looks something along those lines:

class Step(tt.steps.Step):
    def forward(self, module, batch):
        images, labels = batch
        images, labels = images.to(self.device), labels.to(self.device)

        predictions = module(images)
        loss = self.criterion(predictions, labels)

        return loss, predictions, labels

step = Step(criterion=torch.nn.BCEWithLogitsLoss, device=torch.device("cuda"))

Note

IMPORTANT: You can override __init__ if you wish to pass other arguments.

Note

IMPORTANT: You can override forward signature to anything you desire. Just be sure to pass appropriate data to it (via iteration or epoch) or simple __call__.

Note

IMPORTANT: module is passed from other objects and can be anything. In case of GANs in tutorial this is a Tuple of torch.nn.Module.

class torchtraining.steps.Eval(criterion: Callable, device=None)[source]

Bases: torchtraining.steps.Step

Perform user specified evaluation step with disabled gradient..

Users should override forward method.

Parameters
  • criterion (typing.Callable) – Criterion to use to get loss value. Available in forward as self.criterion attribute.

  • device (torch.device) – Device to which tensors could be casted. Available in forward as self.device

abstract forward(*args, **kwargs)[source]
class torchtraining.steps.Step(criterion: Callable, gradient, device=None)[source]

Bases: torchtraining._base.Step

General step, usable both in training & evaluation..

User should override forward method.

Parameters
  • criterion (typing.Callable) – Criterion to use to get loss value. Available in forward as self.criterion attribute.

  • gradient (bool) – Whether to turn gradient on/off (for training/evaluation respectively).

  • device (torch.device) – Device to which tensors could be casted. Available in forward as self.device

abstract forward(*args, **kwargs)[source]
class torchtraining.steps.Train(criterion: Callable, device=None)[source]

Bases: torchtraining.steps.Step

Perform user specified training step with enabled gradient..

Users should override forward method.

Parameters
  • criterion (typing.Callable) – Criterion to use to get loss value. Available in forward as self.criterion attribute.

  • device (torch.device) – Device to which tensors could be casted. Available in forward as self.device

abstract forward(*args, **kwargs)[source]