torchtraining.callbacks¶
Submodules
Traditionally known callback-like pipes.
Note
IMPORTANT: This module is one of core features so be sure to understand how it works.
This module allows user to (for example):
* `save` their best model
* terminate training (early stopping)
* log data to `stdout`
Example:
class TrainStep(tt.steps.Train):
def forward(self, module, sample):
...
return loss, targets
step = TrainStep(criterion, device)
step ** tt.Select(loss=0) ** tt.callbacks.TerminateOnNan()
Users can also use specific callbacks
which integrate with third party tools,
namely:
tensorboard
neptune
comet
Note
IMPORTANT: Most of the training related logging/saving/processing is (or will be) in this package.
-
class
torchtraining.callbacks.
EarlyStopping
(patience: int, delta: numbers.Number = 0, comparator: Callable = <built-in function gt>, log='NONE')[source]¶ Stop
epoch
ifpatience
was reached without improvement.Example:
class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, accuracy step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Stop if mean accuracy did not improve for `5` iterations iteration ** tt.Select(accuracy=1) ** tt.accumulators.Mean() ** tt.callbacks.EarlyStopping( patience=5 ) # Assume epoch was created from `iteration`
- Parameters
patience (int) – How long not to terminate if metric does not improve
delta (Number, optional) – Difference between
best
value and current considered as an improvement. Default:0
.comparator (Callable(Number, Number) -> bool, optional) – Function comparing two values - current metric and best metric. If
true
, reset patience and use current value as the best one. One can use Python’s standardoperator
library for this argument. Default:operator.gt
(current
**best
)log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFO
- Returns
Data passed in forward
- Return type
Any
-
forward
(data)[source]¶ - Parameters
data (Any) – Anything which can be passed to
comparator
(e.g.torch.Tensor
).
-
class
torchtraining.callbacks.
Log
(name: str, log='INFO')[source]¶ Log data using
loguru.logger
.Example:
class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, accuracy step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Log with loguru.logger accuracy iteration ** tt.Select(accuracy=1) ** tt.callbacks.Logger("Accuracy")
- Parameters
name (str) – Name under which data will be logged. It will be in format “{name}: {data}”
log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFO
data (Any) – Anything which can be sensibly represented with
__str__
magic method.
- Returns
Data passed in forward
- Return type
Any
-
class
torchtraining.callbacks.
Save
(module: torch.nn.modules.module.Module, path: pathlib.Path, comparator: Callable = <built-in function gt>, method: Callable = None, log: Union[str, int] = 'NONE')[source]¶ Save best module according to specified metric.
Note
IMPORTANT: This class plays the role of
ModelCheckpointer
known from other training libs. It is user’s role to load module and pass tostep
, hence we provide onlysaving
part of checkpointing (may be subject to change).Example:
import operator class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Lower (operator.lt) loss than current best -> save the model iteration ** tt.accumulators.Mean() ** tt.callbacks.Save( module, "my_model.pt", comparator=operator.lt )
- Parameters
module (torch.nn.Module) – Module to save.
path (pathlib.Path) – Path where module will be saved. Usually ends with
pt
suffix, see PyTorch documentation.comparator (Callable(Number, Number) -> bool, optional) – Function comparing two values - current metric and best metric. If
true
, save new module and use current value as the best one. One can use Python’s standard operator library for this argument. Default:operator.gt
(current
**best
)method (Callable(torch.nn.Module, pathlib.Path) -> None, optional) – Method to save
torch.nn.Module
. Takes module and path and returns anything (return value is discarded). Might be useful to transform model intotorch.jit.ScriptModule
or do some preprocessing before saving. Default:torch.save
(whole model saving)log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFO
- Returns
Data passed in forward
- Return type
Any
-
forward
(data: Any) → Any[source]¶ - Parameters
data (Any) – Anything which can be passed to
comparator
(e.g.torch.Tensor
).
-
class
torchtraining.callbacks.
TerminateOnNan
(log: Union[str, int] = 'NONE')[source]¶ Stop
epoch
if anyNaN
value encountered indata
.Example:
class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, targets step = TrainStep(criterion, device) step ** tt.Select(loss=0) ** tt.callbacks.TerminateOnNan()
- Parameters
log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFO
- Returns
Data passed in forward
- Return type
Any
-
forward
(data)[source]¶ - Parameters
data (torch.Tensor) – Tensor possibly containing
NaN
values.
-
class
torchtraining.callbacks.
TimeStopping
(duration: float, log='NONE')[source]¶ Stop
epoch
after specified duration.Python’s
time.time()
functionality is used.Can be placed anywhere (e.g.
step ** TimeStopping(60 * 60)
) as it’s not data dependent.Example:
class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Stop after 30 minutes iteration ** tt.callbacks.TimeStopping(duration=60 * 30)
- Parameters
duration (int | float) – How long to run (in seconds) before exiting program.
log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFO
- Returns
Data passed in forward
- Return type
Any
-
class
torchtraining.callbacks.
Unfreeze
(module, n: int = 0, log='NONE')[source]¶ Unfreeze module’s parameters after
n
steps.Example:
class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, accuracy step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Assume `module`'s parameters are frozen # Doesn't matter what data goes it, so you can unfreeze however you wish # And it doesn't matter what the accumulated value is iteration ** tt.Select(accuracy=1) ** tt.accumulators.Sum() ** tt.callbacks.Unfreeze( module )
- Parameters
module (torch.nn.Module) – Module whose
parameters
will be unfrozen (grad
set toTrue
).n (int) – Module will be unfrozen after this many steps.
log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFO
- Returns
Data passed in forward
- Return type
Any