torchtraining.callbacks¶
Submodules
Traditionally known callback-like pipes.
Note
IMPORTANT: This module is one of core features so be sure to understand how it works.
This module allows user to (for example):
* `save` their best model
* terminate training (early stopping)
* log data to `stdout`
Example:
class TrainStep(tt.steps.Train):
def forward(self, module, sample):
...
return loss, targets
step = TrainStep(criterion, device)
step ** tt.Select(loss=0) ** tt.callbacks.TerminateOnNan()
Users can also use specific callbacks which integrate with third party tools,
namely:
tensorboard
neptune
comet
Note
IMPORTANT: Most of the training related logging/saving/processing is (or will be) in this package.
-
class
torchtraining.callbacks.EarlyStopping(patience: int, delta: numbers.Number = 0, comparator: Callable = <built-in function gt>, log='NONE')[source]¶ Stop
epochifpatiencewas reached without improvement.Example:
class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, accuracy step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Stop if mean accuracy did not improve for `5` iterations iteration ** tt.Select(accuracy=1) ** tt.accumulators.Mean() ** tt.callbacks.EarlyStopping( patience=5 ) # Assume epoch was created from `iteration`
- Parameters
patience (int) – How long not to terminate if metric does not improve
delta (Number, optional) – Difference between
bestvalue and current considered as an improvement. Default:0.comparator (Callable(Number, Number) -> bool, optional) – Function comparing two values - current metric and best metric. If
true, reset patience and use current value as the best one. One can use Python’s standardoperatorlibrary for this argument. Default:operator.gt(current**best)log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFO
- Returns
Data passed in forward
- Return type
Any
-
forward(data)[source]¶ - Parameters
data (Any) – Anything which can be passed to
comparator(e.g.torch.Tensor).
-
class
torchtraining.callbacks.Log(name: str, log='INFO')[source]¶ Log data using
loguru.logger.Example:
class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, accuracy step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Log with loguru.logger accuracy iteration ** tt.Select(accuracy=1) ** tt.callbacks.Logger("Accuracy")
- Parameters
name (str) – Name under which data will be logged. It will be in format “{name}: {data}”
log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFOdata (Any) – Anything which can be sensibly represented with
__str__magic method.
- Returns
Data passed in forward
- Return type
Any
-
class
torchtraining.callbacks.Save(module: torch.nn.modules.module.Module, path: pathlib.Path, comparator: Callable = <built-in function gt>, method: Callable = None, log: Union[str, int] = 'NONE')[source]¶ Save best module according to specified metric.
Note
IMPORTANT: This class plays the role of
ModelCheckpointerknown from other training libs. It is user’s role to load module and pass tostep, hence we provide onlysavingpart of checkpointing (may be subject to change).Example:
import operator class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Lower (operator.lt) loss than current best -> save the model iteration ** tt.accumulators.Mean() ** tt.callbacks.Save( module, "my_model.pt", comparator=operator.lt )
- Parameters
module (torch.nn.Module) – Module to save.
path (pathlib.Path) – Path where module will be saved. Usually ends with
ptsuffix, see PyTorch documentation.comparator (Callable(Number, Number) -> bool, optional) – Function comparing two values - current metric and best metric. If
true, save new module and use current value as the best one. One can use Python’s standard operator library for this argument. Default:operator.gt(current**best)method (Callable(torch.nn.Module, pathlib.Path) -> None, optional) – Method to save
torch.nn.Module. Takes module and path and returns anything (return value is discarded). Might be useful to transform model intotorch.jit.ScriptModuleor do some preprocessing before saving. Default:torch.save(whole model saving)log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFO
- Returns
Data passed in forward
- Return type
Any
-
forward(data: Any) → Any[source]¶ - Parameters
data (Any) – Anything which can be passed to
comparator(e.g.torch.Tensor).
-
class
torchtraining.callbacks.TerminateOnNan(log: Union[str, int] = 'NONE')[source]¶ Stop
epochif anyNaNvalue encountered indata.Example:
class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, targets step = TrainStep(criterion, device) step ** tt.Select(loss=0) ** tt.callbacks.TerminateOnNan()
- Parameters
log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFO- Returns
Data passed in forward
- Return type
Any
-
forward(data)[source]¶ - Parameters
data (torch.Tensor) – Tensor possibly containing
NaNvalues.
-
class
torchtraining.callbacks.TimeStopping(duration: float, log='NONE')[source]¶ Stop
epochafter specified duration.Python’s
time.time()functionality is used.Can be placed anywhere (e.g.
step ** TimeStopping(60 * 60)) as it’s not data dependent.Example:
class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Stop after 30 minutes iteration ** tt.callbacks.TimeStopping(duration=60 * 30)
- Parameters
duration (int | float) – How long to run (in seconds) before exiting program.
log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFO
- Returns
Data passed in forward
- Return type
Any
-
class
torchtraining.callbacks.Unfreeze(module, n: int = 0, log='NONE')[source]¶ Unfreeze module’s parameters after
nsteps.Example:
class TrainStep(tt.steps.Train): def forward(self, module, sample): ... return loss, accuracy step = TrainStep(criterion, device) iteration = tt.iterations.Train(step, module, dataloader) # Assume `module`'s parameters are frozen # Doesn't matter what data goes it, so you can unfreeze however you wish # And it doesn't matter what the accumulated value is iteration ** tt.Select(accuracy=1) ** tt.accumulators.Sum() ** tt.callbacks.Unfreeze( module )
- Parameters
module (torch.nn.Module) – Module whose
parameterswill be unfrozen (gradset toTrue).n (int) – Module will be unfrozen after this many steps.
log (str | int, optional) –
Severity level for logging object’s actions. Available levels of logging:
NONE 0
TRACE 5
DEBUG 10
INFO 20
SUCCESS 25
WARNING 30
ERROR 40
CRITICAL 50
Default:
INFO
- Returns
Data passed in forward
- Return type
Any