• >
  • torchfunc.module
Shortcuts

torchfunc.module

This module provides functionalities related to torch.nn.Module instances (e.g. freezing parameters).

For performance analysis of torch.nn.Module please see subpackage performance.

class torchfunc.module.Snapshot(*modules: torch.nn.modules.module.Module)[source]

Save module snapshots in memory and/or disk.

Next modules can be added with + or += and their state or whole model saved to disk with appropriate methods.

All added modules are saved unless removed with pop() method.

Additionally, self-explainable methods like len, __iter__ or item access are provided (although there is no __setitem__ as it’s discouraged to mutate contained modules).

Example:

snapshot = torchfunc.module.Snapshot()
snapshot += torch.nn.Sequential(torch.nn.Linear(784, 10), torch.nn.Sigmoid())
snapshot.save("models") # Save all modules to models folder
Parameters

*modules (torch.nn.Module) – Var args of PyTorch modules to be kept.

pop(index: int = - 1)[source]

Remove module at index from memory.

Parameters

index (int, optional) – Index of module to be removed. Default: -1 (last module)

Returns

module – Module removed by this operation

Return type

torch.nn.Module

save(folder: pathlib.Path = None, remove: bool = False, *indices: int) → None[source]

Save module to disk.

Snapshot(s) will be saved using the following naming convention:

module_"index"_"timestamp".pt

See PyTorch’s docs for more information.

Parameters
  • folder (pathlib.Path, optional) – Name of the folder where model will be saved. It has to exist. Defaults to current working directory.

  • remove (bool, optional) – Whether module should be removed from memory after saving. Useful for keeping only best/last model in memory. Default: False

  • *indices (int, optional) – Possibly empty varargs containing indices of modules to be saved. Negative indexing is supported. If empty, save all models.

save_state(folder: pathlib.Path = None, remove: bool = False, *indices: int) → None[source]

Save module’s state to disk.

Snapshot(s) will be saved with using the following naming convention:

module_"index"_"timestamp".pt

See PyTorch’s docs about state_dict for more information.

Parameters
  • folder (pathlib.Path, optional) – Name of the folder where model will be saved. It has to exist. Defaults to current working directory.

  • remove (bool, optional) – Whether module should be removed from memory after saving. Useful for keeping only best/last model in memory. Default: False

  • *indices (int, optional) – Possibly empty varargs containing indices of modules to be saved. Negative indexing is supported. If empty, save all models.

torchfunc.module.bias_parameters(module: torch.nn.modules.module.Module, prefix: str = '', recurse: bool = True)[source]

Iterate only over module’s parameters considered biases.

Parameters
  • module (torch.nn.Module) – Module whose weights and biases will be unfrozen.

  • prefix (str, optional) – Prefix to prepend to all parameter names. Default: '' (no prefix)

  • recurse (bool, optional) – If True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Default: True

Yields

torch.nn.Parameter – Module’s parameters being bias (e.g. their name consist bias string)

torchfunc.module.device(obj)[source]

Return ** `device` **of torch.nn.module or other obj containing device field.

Example:

module = torch.nn.Linear(100, 10)
print(torchfunc.module.device(module)) # "cpu"
Parameters

obj (torch.nn.Module or torch.Tensor) – Object containing device field or containing parameters with device, e.g. module

Returns

Instance of device on which object is currently held. If object is contained on multiple devices, None is returned

Return type

Optional[torch.device]

torchfunc.module.freeze(module: torch.nn.modules.module.Module, weight: bool = True, bias: bool = True) → torch.nn.modules.module.Module[source]

Freeze module’s parameters.

Sets requires_grad to False for specified parameters in module. If bias and weight are specified, ALL parameters will be frozen (even if their names are not matched by weight and bias).

If you want to freeze only those whose names contain bias or weight, call the function twice consecutively (once with bias=True and weight=False and vice versa).

Example:

logistic_regression = torch.nn.Sequential(
    torch.nn.Linear(784, 10),
    torch.nn.Sigmoid(),
)

# Freeze only bias in logistic regression
torchfunc.freeze(logistic_regression, weight = False)
Parameters
  • module (torch.nn.Module) – Module whose weights and biases will be frozen.

  • weight (bool, optional) – Freeze weights. Default: True

  • bias (bool, optional) – Freeze bias. Default: True

Returns

module – Module after parameters were frozen

Return type

torch.nn.Module

torchfunc.module.named_parameters(module: torch.nn.modules.module.Module, name: str, prefix: str = '', recurse: bool = True)[source]

Iterate only over module’s parameters having name as part of their name.

Parameters
  • module (torch.nn.Module) – Module whose weights and biases will be unfrozen.

  • name (str) – Name which parameter needs to be returned

  • prefix (str, optional) – Prefix to prepend to all parameter names. Default: '' (no prefix)

  • recurse (bool, optional) – If True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Default: True

Yields

torch.nn.Parameter – Module’s parameters satisfying name constraint

torchfunc.module.switch_device(obj, target)[source]

Context manager/decorator switching device of torch.nn.module or other obj to the specified target.

After with block ends (or function) specified object is casted back to original device.

Example:

module = torch.nn.Linear(100, 10)
with torchfunc.module.switch_device(module, torch.device("cuda")):
    ... # module is on cuda now

torchfunc.module.device(module) # back on CPU
Parameters
  • obj (torch.nn.Module or torch.Tensor) – Object containing device field or containing parameters with device, e.g. module

  • target (torch.device-like) – PyTorch device or string, compatible with to cast.

torchfunc.module.unfreeze(module: torch.nn.modules.module.Module, weight: bool = True, bias: bool = True) → torch.nn.modules.module.Module[source]

Unfreeze module’s parameters.

Sets requires_grad to True for all parameters in module. Works as complementary function to freeze, see it’s documentation.

Parameters
  • module (torch.nn.Module) – Module whose weights and biases will be unfrozen.

  • weight (bool, optional) – Freeze weights. Default: True

  • bias (bool, optional) – Freeze bias. Default: True

Returns

module – Module after parameters were unfrozen

Return type

torch.nn.Module

torchfunc.module.weight_parameters(module: torch.nn.modules.module.Module, prefix: str = '', recurse: bool = True)[source]

Iterate only over module’s parameters considered weights.

Parameters
  • module (torch.nn.Module) – Module whose weights and biases will be unfrozen.

  • prefix (str, optional) – Prefix to prepend to all parameter names. Default: '' (no prefix)

  • recurse (bool, optional) – If True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module. Default: True

Yields

torch.nn.Parameter – Module’s parameters being weight (e.g. their name consist weight string)