torchfunc.module¶
This module provides functionalities related to torch.nn.Module instances (e.g. freezing parameters).
For performance analysis of torch.nn.Module please see subpackage performance.
-
class
torchfunc.module.Snapshot(*modules: torch.nn.modules.module.Module)[source]¶ Save module snapshots in memory and/or disk.
Next modules can be added with
+or+=and their state or whole model saved to disk with appropriate methods.All added modules are saved unless removed with
pop()method.Additionally, self-explainable methods like
len,__iter__or item access are provided (although there is no__setitem__as it’s discouraged to mutate contained modules).Example:
snapshot = torchfunc.module.Snapshot() snapshot += torch.nn.Sequential(torch.nn.Linear(784, 10), torch.nn.Sigmoid()) snapshot.save("models") # Save all modules to models folder
- Parameters
*modules (torch.nn.Module) – Var args of PyTorch modules to be kept.
-
pop(index: int = - 1)[source]¶ Remove module at
indexfrom memory.- Parameters
index (int, optional) – Index of module to be removed. Default: -1 (last module)
- Returns
module – Module removed by this operation
- Return type
torch.nn.Module
-
save(folder: pathlib.Path = None, remove: bool = False, *indices: int) → None[source]¶ Save module to disk.
Snapshot(s) will be saved using the following naming convention:
module_"index"_"timestamp".pt
See PyTorch’s docs for more information.
- Parameters
folder (pathlib.Path, optional) – Name of the folder where model will be saved. It has to exist. Defaults to current working directory.
remove (bool, optional) – Whether module should be removed from memory after saving. Useful for keeping only best/last model in memory. Default:
False*indices (int, optional) – Possibly empty varargs containing indices of modules to be saved. Negative indexing is supported. If empty, save all models.
-
save_state(folder: pathlib.Path = None, remove: bool = False, *indices: int) → None[source]¶ Save module’s state to disk.
Snapshot(s) will be saved with using the following naming convention:
module_"index"_"timestamp".pt
See PyTorch’s docs about state_dict for more information.
- Parameters
folder (pathlib.Path, optional) – Name of the folder where model will be saved. It has to exist. Defaults to current working directory.
remove (bool, optional) – Whether module should be removed from memory after saving. Useful for keeping only best/last model in memory. Default: False
*indices (int, optional) – Possibly empty varargs containing indices of modules to be saved. Negative indexing is supported. If empty, save all models.
-
torchfunc.module.bias_parameters(module: torch.nn.modules.module.Module, prefix: str = '', recurse: bool = True)[source]¶ Iterate only over
module’s parameters considered biases.- Parameters
module (torch.nn.Module) – Module whose weights and biases will be unfrozen.
prefix (str, optional) – Prefix to prepend to all parameter names. Default:
''(no prefix)recurse (bool, optional) – If
True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Default:True
- Yields
torch.nn.Parameter – Module’s parameters being
bias(e.g. their name consistbiasstring)
-
torchfunc.module.device(obj)[source]¶ Return ** `device` **of
torch.nn.moduleor otherobjcontaining device field.Example:
module = torch.nn.Linear(100, 10) print(torchfunc.module.device(module)) # "cpu"
- Parameters
obj (torch.nn.Module or torch.Tensor) – Object containing
devicefield or containing parameters with device, e.g. module- Returns
Instance of device on which object is currently held. If object is contained on multiple devices,
Noneis returned- Return type
Optional[torch.device]
-
torchfunc.module.freeze(module: torch.nn.modules.module.Module, weight: bool = True, bias: bool = True) → torch.nn.modules.module.Module[source]¶ Freeze module’s parameters.
Sets
requires_gradtoFalsefor specified parameters in module. If bias and weight are specified, ALL parameters will be frozen (even if their names are not matched byweightand bias).If you want to freeze only those whose names contain
biasorweight, call the function twice consecutively (once withbias=Trueandweight=Falseand vice versa).Example:
logistic_regression = torch.nn.Sequential( torch.nn.Linear(784, 10), torch.nn.Sigmoid(), ) # Freeze only bias in logistic regression torchfunc.freeze(logistic_regression, weight = False)
- Parameters
module (torch.nn.Module) – Module whose weights and biases will be frozen.
weight (bool, optional) – Freeze weights. Default: True
bias (bool, optional) – Freeze bias. Default: True
- Returns
module – Module after parameters were frozen
- Return type
torch.nn.Module
-
torchfunc.module.named_parameters(module: torch.nn.modules.module.Module, name: str, prefix: str = '', recurse: bool = True)[source]¶ Iterate only over
module’s parameters havingnameas part of their name.- Parameters
module (torch.nn.Module) – Module whose weights and biases will be unfrozen.
name (str) – Name which parameter needs to be returned
prefix (str, optional) – Prefix to prepend to all parameter names. Default:
''(no prefix)recurse (bool, optional) – If
True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Default:True
- Yields
torch.nn.Parameter – Module’s parameters satisfying
nameconstraint
-
torchfunc.module.switch_device(obj, target)[source]¶ Context manager/decorator switching
deviceoftorch.nn.moduleor otherobjto the specified target.After
withblock ends (or function) specified object is casted back to original device.Example:
module = torch.nn.Linear(100, 10) with torchfunc.module.switch_device(module, torch.device("cuda")): ... # module is on cuda now torchfunc.module.device(module) # back on CPU
- Parameters
obj (torch.nn.Module or torch.Tensor) – Object containing
devicefield or containing parameters with device, e.g. moduletarget (torch.device-like) – PyTorch device or string, compatible with
tocast.
-
torchfunc.module.unfreeze(module: torch.nn.modules.module.Module, weight: bool = True, bias: bool = True) → torch.nn.modules.module.Module[source]¶ Unfreeze module’s parameters.
Sets
requires_gradtoTruefor all parameters in module. Works as complementary function to freeze, see it’s documentation.- Parameters
module (torch.nn.Module) – Module whose weights and biases will be unfrozen.
weight (bool, optional) – Freeze weights. Default: True
bias (bool, optional) – Freeze bias. Default: True
- Returns
module – Module after parameters were unfrozen
- Return type
torch.nn.Module
-
torchfunc.module.weight_parameters(module: torch.nn.modules.module.Module, prefix: str = '', recurse: bool = True)[source]¶ Iterate only over
module’s parameters considered weights.- Parameters
module (torch.nn.Module) – Module whose weights and biases will be unfrozen.
prefix (str, optional) – Prefix to prepend to all parameter names. Default:
''(no prefix)recurse (bool, optional) – If
True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Default:True
- Yields
torch.nn.Parameter – Module’s parameters being
weight(e.g. their name consistweightstring)