
Source code for torchtraining.callbacks

"""Traditionally known callback-like pipes.

.. note::

    **IMPORTANT**: This module is one of core features
    so be sure to understand how it works.

This module allows user to (for example)::

    * `save` their best model
    * terminate training (early stopping)
    * log data to `stdout`


    class TrainStep(tt.steps.Train):
        def forward(self, module, sample):
            return loss, targets

    step = TrainStep(criterion, device)
    step ** tt.Select(loss=0) ** tt.callbacks.TerminateOnNan()

Users can also use specific `callbacks` which integrate with third party tools,

    * tensorboard
    * neptune
    * comet

.. note::

    **IMPORTANT**: Most of the training related logging/saving/processing
    is (or will be) in this package.


import importlib
import numbers
import operator
import pathlib
import sys
import time
import typing

import loguru
import torch

from .. import _base, exceptions
from ..utils import general as utils

if utils.modules_exist("torch.utils.tensorboard"):
    from . import tensorboard

if utils.modules_exist("neptune"):
    from . import neptune

if utils.modules_exist("comet_ml"):
    from . import comet

[docs]class Save(_base.Operation): """Save best module according to specified metric. .. note:: **IMPORTANT**: This class plays the role of `ModelCheckpointer` known from other training libs. It is user's role to load module and pass to `step`, hence we provide only `saving` part of checkpointing (may be subject to change). Example:: import operator class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Lower ( loss than current best -> save the model iteration ** tt.accumulators.Mean() ** tt.callbacks.Save( module, "", ) Parameters ---------- module: torch.nn.Module Module to save. path: pathlib.Path Path where module will be saved. Usually ends with `.pt` suffix, see PyTorch documentation. comparator: Callable(Number, Number) -> bool, optional Function comparing two values - current metric and best metric. If ``true``, save new module and use current value as the best one. One can use Python's standard operator library for this argument. Default: `` (`current` ** `best`) method: Callable(torch.nn.Module, pathlib.Path) -> None, optional Method to save `torch.nn.Module`. Takes module and path and returns anything (return value is discarded). Might be useful to transform model into `torch.jit.ScriptModule` or do some preprocessing before saving. Default: `` (whole model saving) log : str | int, optional Severity level for logging object's actions. Available levels of logging: * NONE 0 * TRACE 5 * DEBUG 10 * INFO 20 * SUCCESS 25 * WARNING 30 * ERROR 40 * CRITICAL 50 Default: `INFO` Returns ------- Any Data passed in forward """ def __init__( self, module: torch.nn.Module, path: pathlib.Path, comparator: typing.Callable =, method: typing.Callable = None, log: typing.Union[str, int] = "NONE", ): super().__init__() self.module = module self.path = path self.comparator = comparator self.method = ( lambda module, path:, path) if method is None else method ) self.log = log = None
[docs] def forward(self, data: typing.Any) -> typing.Any: """ Arguments --------- data: Any Anything which can be passed to `comparator` (e.g. `torch.Tensor`). """ if is None or self.comparator(data, = data self.method(self.module, self.path) loguru.logger.log( self.log, "New best value: {}".format(, ) return data
[docs]class TimeStopping(_base.Operation): """Stop `epoch` after specified duration. Python's `time.time()` functionality is used. Can be placed anywhere (e.g. `step ** TimeStopping(60 * 60)`) as it's not data dependent. Example:: class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Stop after 30 minutes iteration ** tt.callbacks.TimeStopping(duration=60 * 30) Parameters ---------- duration: int | float How long to run (in seconds) before exiting program. log : str | int, optional Severity level for logging object's actions. Available levels of logging: * NONE 0 * TRACE 5 * DEBUG 10 * INFO 20 * SUCCESS 25 * WARNING 30 * ERROR 40 * CRITICAL 50 Default: `INFO` Returns ------- Any Data passed in forward """ def __init__( self, duration: float, log="NONE", ): super().__init__() self.duration = duration self.log = log self._start = time.time()
[docs] def forward(self, data): """ Arguments --------- data: Any Anything as `data` will be simply forwarded """ if time.time() - self._start ** self.duration: loguru.logger.log( self.log, "Stopping after {} seconds.".format(self.duration) ) raise exceptions.TimeStopping() return data
[docs]class TerminateOnNan(_base.Operation): """Stop `epoch` if any `NaN` value encountered in `data`. Example:: class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, targets step = TrainStep(criterion, device) step ** tt.Select(loss=0) ** tt.callbacks.TerminateOnNan() Parameters ---------- log : str | int, optional Severity level for logging object's actions. Available levels of logging: * NONE 0 * TRACE 5 * DEBUG 10 * INFO 20 * SUCCESS 25 * WARNING 30 * ERROR 40 * CRITICAL 50 Default: `INFO` Returns ------- Any Data passed in forward """ def __init__( self, log: typing.Union[str, int] = "NONE", ): super().__init__() self.log = log
[docs] def forward(self, data): """ Arguments --------- data: torch.Tensor Tensor possibly containing `NaN` values. """ if torch.any(torch.isnan(data)): loguru.logger.log(self.log, "NaN values found, exiting with 1.") raise exceptions.TerminateOnNan() return data
[docs]class EarlyStopping(_base.Operation): """Stop `epoch` if `patience` was reached without improvement. Example:: class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, accuracy step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Stop if mean accuracy did not improve for `5` iterations iteration ** tt.Select(accuracy=1) ** tt.accumulators.Mean() ** tt.callbacks.EarlyStopping( patience=5 ) # Assume epoch was created from `iteration` Parameters ---------- patience: int How long not to terminate if metric does not improve delta: Number, optional Difference between `best` value and current considered as an improvement. Default: `0`. comparator: Callable(Number, Number) -> bool, optional Function comparing two values - current metric and best metric. If ``true``, reset patience and use current value as the best one. One can use Python's standard `operator` library for this argument. Default: `` (`current` ** `best`) log : str | int, optional Severity level for logging object's actions. Available levels of logging: * NONE 0 * TRACE 5 * DEBUG 10 * INFO 20 * SUCCESS 25 * WARNING 30 * ERROR 40 * CRITICAL 50 Default: `INFO` Returns ------- Any Data passed in forward """ def __init__( self, patience: int, delta: numbers.Number = 0, comparator: typing.Callable =, log="NONE", ): super().__init__() self.patience = patience = delta self.comparator = comparator self.log = log = None self._counter = -1
[docs] def forward(self, data): """ Arguments --------- data: Any Anything which can be passed to `comparator` (e.g. `torch.Tensor`). """ if is None or self.comparator(data, self._counter = -1 else: self._counter += 1 if self._counter == self.patience: loguru.logger.log( self.log, "Stopping early, best found: {}".format( ) raise exceptions.EarlyStopping()
[docs]class Unfreeze(_base.Operation): """Unfreeze module's parameters after `n` steps. Example:: class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, accuracy step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Assume `module`'s parameters are frozen # Doesn't matter what data goes it, so you can unfreeze however you wish # And it doesn't matter what the accumulated value is iteration ** tt.Select(accuracy=1) ** tt.accumulators.Sum() ** tt.callbacks.Unfreeze( module ) Parameters ---------- module: torch.nn.Module Module whose `parameters` will be unfrozen (`grad` set to `True`). n: int Module will be unfrozen after this many steps. log : str | int, optional Severity level for logging object's actions. Available levels of logging: * NONE 0 * TRACE 5 * DEBUG 10 * INFO 20 * SUCCESS 25 * WARNING 30 * ERROR 40 * CRITICAL 50 Default: `INFO` Returns ------- Any Data passed in forward """ def __init__(self, module, n: int = 0, log="NONE"): super().__init__() self.module = module self.n = n self.log = log self._counter = -1
[docs] def forward(self, data): """ Arguments --------- data: Any Anything as data is simply forwarded """ self._counter += 1 if self._counter == self.n: loguru.logger.log(self.log, "Unfreezing module's parameters") for param in self.module.parameters(): param.requires_grad_(True) return data
[docs]class Log(_base.Operation): r"""Log data using `loguru.logger`. Example:: class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, accuracy step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Log with loguru.logger accuracy iteration ** tt.Select(accuracy=1) ** tt.callbacks.Logger("Accuracy") Parameters ---------- name : str Name under which data will be logged. It will be in format "{name}: {data}" log : str | int, optional Severity level for logging object's actions. Available levels of logging: * NONE 0 * TRACE 5 * DEBUG 10 * INFO 20 * SUCCESS 25 * WARNING 30 * ERROR 40 * CRITICAL 50 Default: `INFO` Arguments --------- data: Any Anything which can be sensibly represented with `__str__` magic method. Returns ------- Any Data passed in forward """ def __init__(self, name: str, log="INFO"): super().__init__() = name self.log = log
[docs] def forward(self, data): """ Arguments --------- data: Any Anything which can be sensibly represented with `__str__` magic method. """ loguru.logger.log(self.log, "{}: {}".format(, data)) return data