Source code for torchlayers
import inspect
import io
import typing
import warnings
import torch
from . import (_dev_utils, _inferable, activations, convolution, normalization,
pooling, regularization, upsample)
from ._version import __version__
__all__ = ["build", "infer", "Lambda", "Reshape", "Concatenate"]
[docs]def build(module, *args, **kwargs):
"""Build PyTorch layer or module by providing example input.
This method should be used **always** after creating module using `torchlayers`
and shape inference especially.
Works similarly to `build` functionality provided by `keras`.
Provided module will be "compiled" to PyTorch primitives to remove any
overhead.
`torchlayers` also supports `post_build` function to perform some action after
shape was inferred (weight initialization example below)::
import torch
import torchlayers as tl
class _MyModuleImpl(torch.nn.Linear):
def post_build(self):
# You can do anything here really
torch.nn.init.eye_(self.weights)
MyModule = tl.infer(_MyModuleImpl)
`post_build` should have no arguments other than `self` so all necessary
data should be saved in `module` beforehand.
Parameters
----------
module : torch.nn.Module
Instance of module to build
*args
Arguments required by module's `forward`
**kwargs
Keyword arguments required by module's `forward`
"""
def torch_compile(module):
with io.BytesIO() as buffer:
torch.save(module, buffer)
return torch.load(io.BytesIO(buffer.getvalue()))
def run_post(module):
for submodule in module.modules():
function = getattr(submodule, "post_build", None)
if function is not None:
post_build = getattr(submodule, "post_build")
if not callable(post_build):
raise ValueError(
"{}'s post_build is required to be a method.".format(submodule)
)
submodule.post_build()
with torch.no_grad():
module.eval()
module(*args, **kwargs)
module.train()
module = torch_compile(module)
run_post(module)
return module
[docs]def infer(module_class, index: str = 1):
"""Allows custom user modules to infer input shape.
Input shape should be the first argument after `self`.
Usually used as class decorator, e.g.::
import torch
import torchlayers as tl
class _StrangeLinearImpl(torch.nn.Linear):
def __init__(self, in_features, out_features, bias: bool = True):
super().__init__(in_features, out_features, bias)
self.params = torch.nn.Parameter(torch.randn(out_features))
def forward(self, inputs):
super().forward(inputs) + self.params
# Now you can use shape inference of in_features
StrangeLinear = tl.infer(_StrangeLinearImpl)
# in_features can be inferred
layer = StrangeLinear(out_features=64)
Parameters
----------
module_class: torch.nn.Module
Class of module to be updated with shape inference capabilities.
index: int, optional
Index into `tensor.shape` input which should be inferred, e.g. tensor.shape[1].
Default: `1` (`0` being batch dimension)
"""
init_arguments = [
str(argument)
for argument in inspect.signature(module_class.__init__).parameters.values()
]
# Other argument than self
if len(init_arguments) > 1:
name = module_class.__name__
infered_module = type(
name, (torch.nn.Module,), {_dev_utils.infer.MODULE_CLASS: module_class},
)
parsed_arguments, uninferable_arguments = _dev_utils.infer.parse_arguments(
init_arguments, infered_module
)
setattr(
infered_module, "__init__", _dev_utils.infer.create_init(parsed_arguments),
)
setattr(
infered_module,
"forward",
_dev_utils.infer.create_forward(
_dev_utils.infer.MODULE,
_dev_utils.infer.MODULE_CLASS,
parsed_arguments,
index,
),
)
setattr(
infered_module,
"__repr__",
_dev_utils.infer.create_repr(
_dev_utils.infer.MODULE, **uninferable_arguments
),
)
setattr(
infered_module,
"__getattr__",
_dev_utils.infer.create_getattr(_dev_utils.infer.MODULE),
)
setattr(
infered_module,
"__reduce__",
_dev_utils.infer.create_reduce(_dev_utils.infer.MODULE, parsed_arguments),
)
return infered_module
return module_class
[docs]class Lambda(torch.nn.Module):
"""Use any function as `torch.nn.Module`
Simple proxy which allows you to use your own custom in
`torch.nn.Sequential` and other requiring `torch.nn.Module` as input::
import torch
import torchlayers as tl
model = torch.nn.Sequential(tl.Lambda(lambda tensor: tensor ** 2))
model(torch.randn(64 , 20))
Parameters
----------
function : Callable
Any user specified function
Returns
-------
Any
Anything `function` returns
"""
def __init__(self, function: typing.Callable):
super().__init__()
self.function: typing.Callable = function
[docs] def forward(self, *args, **kwargs) -> typing.Any:
return self.function(*args, **kwargs)
[docs]class Concatenate(torch.nn.Module):
"""Concatenate list of tensors.
Mainly useful in `torch.nn.Sequential` when previous layer returns multiple
tensors, e.g.::
import torch
import torchlayers as tl
class Foo(torch.nn.Module):
# Return same tensor three times
# You could explicitly return a list or tuple as well
def forward(tensor):
return tensor, tensor, tensor
model = torch.nn.Sequential(Foo(), tl.Concatenate())
model(torch.randn(64 , 20))
All tensors must have the same shape (except in the concatenating dimension).
Parameters
----------
dim : int
Dimension along which tensors will be concatenated
Returns
-------
torch.Tensor
Concatenated tensor along specified `dim`.
"""
def __init__(self, dim: int):
super().__init__()
self.dim: int = dim
[docs] def forward(self, inputs):
return torch.cat(inputs, dim=self.dim)
[docs]class Reshape(torch.nn.Module):
"""Reshape tensor excluding `batch` dimension
Reshapes input `torch.Tensor` features while preserving batch dimension.
Standard `torch.reshape` values (e.g. `-1`) are supported, e.g.::
import torch
import torchlayers as tl
layer = tl.Reshape(20, -1)
layer(torch.randn(64, 80)) # shape (64, 20, 4)
All tensors must have the same shape (except in the concatenating dimension).
If possible, no copy of `tensor` will be performed.
Parameters
----------
shapes: *int
Variable length list of shapes used in view function
Returns
-------
torch.Tensor
Concatenated tensor
"""
def __init__(self, *shapes: int):
super().__init__()
self.shapes: typing.Tuple[int] = shapes
[docs] def forward(self, tensor):
return torch.reshape(tensor, (tensor.shape[0], *self.shapes))
###############################################################################
#
# MODULE ATTRIBUTE GETTERS
#
###############################################################################
def __dir__():
return (
dir(torch.nn)
+ ["Lambda", "Concatenate", "Reshape"]
+ dir(convolution)
+ dir(normalization)
+ dir(upsample)
+ dir(pooling)
+ dir(regularization)
+ dir(activations)
)
def __getattr__(name: str):
def _getattr(name):
module_class = None
for module in (
convolution,
normalization,
pooling,
regularization,
upsample,
activations,
torch.nn,
):
module_class = getattr(module, name, None)
if module_class is not None:
return module_class
raise AttributeError("module {} has no attribute {}".format(__name__, name))
module_class = _getattr(name)
if name in _inferable.torch.all() + _inferable.custom.all():
return infer(
module_class, _dev_utils.helpers.get_per_module_index(module_class)
)
return module_class