torchtraining.steps¶
Perform single step on data and via specific module(s).
Note
IMPORTANT: This module is one of core features so be sure to understand how it works. It is the core and defines how you perform single step through the data.
See Introduction tutorial for example of step
.
Usually it looks something along those lines:
class Step(tt.steps.Step):
def forward(self, module, batch):
images, labels = batch
images, labels = images.to(self.device), labels.to(self.device)
predictions = module(images)
loss = self.criterion(predictions, labels)
return loss, predictions, labels
step = Step(criterion=torch.nn.BCEWithLogitsLoss, device=torch.device("cuda"))
Note
IMPORTANT: You can override __init__
if you wish to pass
other arguments.
Note
IMPORTANT: You can override forward
signature to anything you
desire. Just be sure to pass appropriate data to it (via iteration
or epoch
)
or simple __call__
.
Note
IMPORTANT: module
is passed from other objects and can be anything.
In case of GANs in tutorial this is a Tuple
of torch.nn.Module
.
-
class
torchtraining.steps.
Eval
(criterion: Callable, device=None)[source]¶ Bases:
torchtraining.steps.Step
Perform user specified evaluation step with disabled gradient..
Users should override forward method.
- Parameters
-
class
torchtraining.steps.
Step
(criterion: Callable, gradient, device=None)[source]¶ Bases:
torchtraining._base.Step
General
step
, usable both in training & evaluation..User should override
forward
method.- Parameters
criterion (typing.Callable) – Criterion to use to get loss value. Available in
forward
asself.criterion
attribute.gradient (bool) – Whether to turn gradient on/off (for training/evaluation respectively).
device (torch.device) – Device to which tensors could be casted. Available in
forward
asself.device
-
class
torchtraining.steps.
Train
(criterion: Callable, device=None)[source]¶ Bases:
torchtraining.steps.Step
Perform user specified training step with enabled gradient..
Users should override forward method.
- Parameters