torchtraining.steps¶
Perform single step on data and via specific module(s).
Note
IMPORTANT: This module is one of core features so be sure to understand how it works. It is the core and defines how you perform single step through the data.
See Introduction tutorial for example of step.
Usually it looks something along those lines:
class Step(tt.steps.Step):
def forward(self, module, batch):
images, labels = batch
images, labels = images.to(self.device), labels.to(self.device)
predictions = module(images)
loss = self.criterion(predictions, labels)
return loss, predictions, labels
step = Step(criterion=torch.nn.BCEWithLogitsLoss, device=torch.device("cuda"))
Note
IMPORTANT: You can override __init__ if you wish to pass
other arguments.
Note
IMPORTANT: You can override forward signature to anything you
desire. Just be sure to pass appropriate data to it (via iteration or epoch)
or simple __call__.
Note
IMPORTANT: module is passed from other objects and can be anything.
In case of GANs in tutorial this is a Tuple of torch.nn.Module.
-
class
torchtraining.steps.Eval(criterion: Callable, device=None)[source]¶ Bases:
torchtraining.steps.StepPerform user specified evaluation step with disabled gradient..
Users should override forward method.
- Parameters
-
class
torchtraining.steps.Step(criterion: Callable, gradient, device=None)[source]¶ Bases:
torchtraining._base.StepGeneral
step, usable both in training & evaluation..User should override
forwardmethod.- Parameters
criterion (typing.Callable) – Criterion to use to get loss value. Available in
forwardasself.criterionattribute.gradient (bool) – Whether to turn gradient on/off (for training/evaluation respectively).
device (torch.device) – Device to which tensors could be casted. Available in
forwardasself.device
-
class
torchtraining.steps.Train(criterion: Callable, device=None)[source]¶ Bases:
torchtraining.steps.StepPerform user specified training step with enabled gradient..
Users should override forward method.
- Parameters