Shortcuts

Source code for torchtraining.iterations

import typing

import torch
from rich import progress

from . import _base, exceptions
from .utils import iterations as iterations_utils


[docs]@iterations_utils.docs( header="Perform `step` (`train` or `eval`) until `data` is exhausted", body="Provided `module` will be passed to every `step`.", ) class Iteration(_base.Iteration): def __init__( self, step: typing.Any, module: torch.nn.Module, data: typing.Union[torch.utils.data.Dataset, torch.utils.data.DataLoader], train: bool, log: typing.Union[int, str] = "NONE", *args, **kwargs, ): super().__init__() self.step = step self.module = module self.data = data self.train = train self.log = log self.args = args self.kwargs = kwargs def __call__(self, *args, **kwargs): self.module.train(self.train) with torch.set_grad_enabled(self.train): yield from super().__call__(*args, **kwargs)
[docs] def forward(self, *args, **kwargs): for sample in progress.track( self.data, description=" ", *self.args, **self.kwargs ): yield self.step(self.module, sample, *args, **kwargs)
[docs]@iterations_utils.docs( header="Perform training step until `data` is exhausted", body="Provided `module` will be passed to every `step`.", ) class Train(Iteration): def __init__( self, step: typing.Any, module: torch.nn.Module, data: typing.Union[torch.utils.data.Dataset, torch.utils.data.DataLoader], log: typing.Union[int, str] = "NONE", *args, **kwargs, ): super().__init__(step, module, data, True, log, *args, **kwargs)
[docs]@iterations_utils.docs( header="Perform evaluation step until `data` is exhausted", body="Provided `module` will be passed to every `step`.", ) class Eval(Iteration): def __init__( self, step: typing.Any, module: torch.nn.Module, data: typing.Union[torch.utils.data.Dataset, torch.utils.data.DataLoader], log: typing.Union[int, str] = "NONE", *args, **kwargs, ): super().__init__(step, module, data, False, log, *args, **kwargs)
# @utils.iterations.docs( # header="Perform evaluation `step`s until `data` is exhausted", # body="Provided `module` will be passed to every `step`.", # ) # class MultiIteration(_base.Iteration): # def __init__( # self, # steps: typing.Iterable[typing.Callable[..., typing.Any]], # modules: torch.nn.Module # datas: typing.Union[torch.utils.data.Dataset, torch.utils.data.DataLoader], # train: bool, # intervals: typing.Optional[typing.Iterable[int]] = None, # log: typing.Union[int, str] = "NONE", # *args, # **kwargs, # ): # super().__init__() # self.steps = steps # self.modules = modules # self.datas = datas # self.train = train # if intervals is None: # self.intervals = tuple(range(len(steps))) # self.intervals = intervals # self.log = log # self.args = args # self.kwargs = kwargs # def __call__(self, *args, **kwargs): # self.module.train(self.train) # with torch.set_grad_enabled(self.train): # yield from super().__call__(*args, **kwargs) # def forward(self, *args, **kwargs): # pass # @utils.iterations.docs( # header="Perform training `step`s until `data` is exhausted", # body="Provided `module` will be passed to every `step`.", # ) # class MultiTrain(MultiIteration): # def __init__( # self, # step: typing.Any, # module: torch.nn.Module, # data: typing.Union[torch.utils.data.Dataset, torch.utils.data.DataLoader], # log: typing.Union[int, str] = "NONE", # *args, # **kwargs, # ): # super().__init__(step, module, data, True, log, *args, **kwargs) # @utils.iterations.docs( # header="Perform evaluation `step`s until `data` is exhausted", # body="Provided `module` will be passed to every `step`.", # ) # class MultiEval(MultiIteration): # def __init__( # self, # step: typing.Any, # module: torch.nn.Module, # data: typing.Union[torch.utils.data.Dataset, torch.utils.data.DataLoader], # log: typing.Union[int, str] = "NONE", # *args, # **kwargs, # ): # super().__init__(step, module, data, False, log, *args, **kwargs)