• >
  • torchtraining.callbacks
Shortcuts

torchtraining.callbacks

Traditionally known callback-like pipes.

Note

IMPORTANT: This module is one of core features so be sure to understand how it works.

This module allows user to (for example):

* `save` their best model
* terminate training (early stopping)
* log data to `stdout`

Example:

class TrainStep(tt.steps.Train):
    def forward(self, module, sample):
        ...
        return loss, targets


step = TrainStep(criterion, device)
step ** tt.Select(loss=0) ** tt.callbacks.TerminateOnNan()

Users can also use specific callbacks which integrate with third party tools, namely:

  • tensorboard

  • neptune

  • comet

Note

IMPORTANT: Most of the training related logging/saving/processing is (or will be) in this package.

class torchtraining.callbacks.EarlyStopping(patience: int, delta: numbers.Number = 0, comparator: Callable = <built-in function gt>, log='NONE')[source]

Stop epoch if patience was reached without improvement.

Example:

class TrainStep(tt.steps.Train):
    def forward(self, module, sample):
        ...
        return loss, accuracy


step = TrainStep(criterion, device)
iteration = tt.iterations.Train(step, module, dataloader)

# Stop if mean accuracy did not improve for `5` iterations
iteration ** tt.Select(accuracy=1) ** tt.accumulators.Mean() ** tt.callbacks.EarlyStopping(
    patience=5
)
# Assume epoch was created from `iteration`
Parameters
  • patience (int) – How long not to terminate if metric does not improve

  • delta (Number, optional) – Difference between best value and current considered as an improvement. Default: 0.

  • comparator (Callable(Number, Number) -> bool, optional) – Function comparing two values - current metric and best metric. If true, reset patience and use current value as the best one. One can use Python’s standard operator library for this argument. Default: operator.gt (current ** best)

  • log (str | int, optional) –

    Severity level for logging object’s actions. Available levels of logging:

    • NONE 0

    • TRACE 5

    • DEBUG 10

    • INFO 20

    • SUCCESS 25

    • WARNING 30

    • ERROR 40

    • CRITICAL 50

    Default: INFO

Returns

Data passed in forward

Return type

Any

forward(data)[source]
Parameters

data (Any) – Anything which can be passed to comparator (e.g. torch.Tensor).

class torchtraining.callbacks.Log(name: str, log='INFO')[source]

Log data using loguru.logger.

Example:

class TrainStep(tt.steps.Train):
    def forward(self, module, sample):
        ...
        return loss, accuracy


step = TrainStep(criterion, device)
iteration = tt.iterations.Train(step, module, dataloader)

# Log with loguru.logger accuracy
iteration ** tt.Select(accuracy=1) ** tt.callbacks.Logger("Accuracy")
Parameters
  • name (str) – Name under which data will be logged. It will be in format “{name}: {data}”

  • log (str | int, optional) –

    Severity level for logging object’s actions. Available levels of logging:

    • NONE 0

    • TRACE 5

    • DEBUG 10

    • INFO 20

    • SUCCESS 25

    • WARNING 30

    • ERROR 40

    • CRITICAL 50

    Default: INFO

  • data (Any) – Anything which can be sensibly represented with __str__ magic method.

Returns

Data passed in forward

Return type

Any

forward(data)[source]
Parameters

data (Any) – Anything which can be sensibly represented with __str__ magic method.

class torchtraining.callbacks.Save(module: torch.nn.modules.module.Module, path: pathlib.Path, comparator: Callable = <built-in function gt>, method: Callable = None, log: Union[str, int] = 'NONE')[source]

Save best module according to specified metric.

Note

IMPORTANT: This class plays the role of ModelCheckpointer known from other training libs. It is user’s role to load module and pass to step, hence we provide only saving part of checkpointing (may be subject to change).

Example:

import operator


class TrainStep(tt.steps.Train):
    def forward(self, module, sample):
        ...
        return loss


step = TrainStep(criterion, device)
iteration = tt.iterations.Train(step, module, dataloader)

# Lower (operator.lt) loss than current best -> save the model
iteration ** tt.accumulators.Mean() ** tt.callbacks.Save(
    module, "my_model.pt", comparator=operator.lt
)
Parameters
  • module (torch.nn.Module) – Module to save.

  • path (pathlib.Path) – Path where module will be saved. Usually ends with pt suffix, see PyTorch documentation.

  • comparator (Callable(Number, Number) -> bool, optional) – Function comparing two values - current metric and best metric. If true, save new module and use current value as the best one. One can use Python’s standard operator library for this argument. Default: operator.gt (current ** best)

  • method (Callable(torch.nn.Module, pathlib.Path) -> None, optional) – Method to save torch.nn.Module. Takes module and path and returns anything (return value is discarded). Might be useful to transform model into torch.jit.ScriptModule or do some preprocessing before saving. Default: torch.save (whole model saving)

  • log (str | int, optional) –

    Severity level for logging object’s actions. Available levels of logging:

    • NONE 0

    • TRACE 5

    • DEBUG 10

    • INFO 20

    • SUCCESS 25

    • WARNING 30

    • ERROR 40

    • CRITICAL 50

    Default: INFO

Returns

Data passed in forward

Return type

Any

forward(data: Any) → Any[source]
Parameters

data (Any) – Anything which can be passed to comparator (e.g. torch.Tensor).

class torchtraining.callbacks.TerminateOnNan(log: Union[str, int] = 'NONE')[source]

Stop epoch if any NaN value encountered in data.

Example:

class TrainStep(tt.steps.Train):
    def forward(self, module, sample):
        ...
        return loss, targets


step = TrainStep(criterion, device)
step ** tt.Select(loss=0) ** tt.callbacks.TerminateOnNan()
Parameters

log (str | int, optional) –

Severity level for logging object’s actions. Available levels of logging:

  • NONE 0

  • TRACE 5

  • DEBUG 10

  • INFO 20

  • SUCCESS 25

  • WARNING 30

  • ERROR 40

  • CRITICAL 50

Default: INFO

Returns

Data passed in forward

Return type

Any

forward(data)[source]
Parameters

data (torch.Tensor) – Tensor possibly containing NaN values.

class torchtraining.callbacks.TimeStopping(duration: float, log='NONE')[source]

Stop epoch after specified duration.

Python’s time.time() functionality is used.

Can be placed anywhere (e.g. step ** TimeStopping(60 * 60)) as it’s not data dependent.

Example:

class TrainStep(tt.steps.Train):
    def forward(self, module, sample):
        ...
        return loss


step = TrainStep(criterion, device)
iteration = tt.iterations.Train(step, module, dataloader)

# Stop after 30 minutes
iteration ** tt.callbacks.TimeStopping(duration=60 * 30)
Parameters
  • duration (int | float) – How long to run (in seconds) before exiting program.

  • log (str | int, optional) –

    Severity level for logging object’s actions. Available levels of logging:

    • NONE 0

    • TRACE 5

    • DEBUG 10

    • INFO 20

    • SUCCESS 25

    • WARNING 30

    • ERROR 40

    • CRITICAL 50

    Default: INFO

Returns

Data passed in forward

Return type

Any

forward(data)[source]
Parameters

data (Any) – Anything as data will be simply forwarded

class torchtraining.callbacks.Unfreeze(module, n: int = 0, log='NONE')[source]

Unfreeze module’s parameters after n steps.

Example:

class TrainStep(tt.steps.Train):
    def forward(self, module, sample):
        ...
        return loss, accuracy


step = TrainStep(criterion, device)
iteration = tt.iterations.Train(step, module, dataloader)

# Assume `module`'s parameters are frozen

# Doesn't matter what data goes it, so you can unfreeze however you wish
# And it doesn't matter what the accumulated value is
iteration ** tt.Select(accuracy=1) ** tt.accumulators.Sum() ** tt.callbacks.Unfreeze(
    module
)
Parameters
  • module (torch.nn.Module) – Module whose parameters will be unfrozen (grad set to True).

  • n (int) – Module will be unfrozen after this many steps.

  • log (str | int, optional) –

    Severity level for logging object’s actions. Available levels of logging:

    • NONE 0

    • TRACE 5

    • DEBUG 10

    • INFO 20

    • SUCCESS 25

    • WARNING 30

    • ERROR 40

    • CRITICAL 50

    Default: INFO

Returns

Data passed in forward

Return type

Any

forward(data)[source]
Parameters

data (Any) – Anything as data is simply forwarded