Source code for torchfunc.performance.layers
r"""
**Check any performance caveats related to PyTorch and it's layers.**
Using functionalities below you can check whether your architecture follows
current good practices related to performance of `torch.nn.Module` concrete layers.
"""
import abc
import collections
import sys
import typing
import torch
from .._base import Base
[docs]class Depthwise(Base):
r"""**Check whether any convolution layer is a so-called depthwise convolution.**
Depthwise convolution is faster for images with input data in format
(batch, height, width, channel) as specialized kernels are available.
Currently PyTorch does not support this functionality, so using those may actually
slow down your neural network.
Depthwise convolution might still be useful in order to save memory, not so
performance-wise.
For easy to follow guidelines, use `tips` method of this class.
Example::
model = torch.nn.Sequential(
torch.nn.Conv1d(64, 64, kernel_size=3, groups=64),
torch.nn.Conv2d(3, 32, kernel_size=3, groups=1),
torch.nn.Conv2d(32, 32, kernel_size=3, groups=32),
)
for index in torchfunc.performance.layers.Depthwise().children(model):
print(index) # Should print 0 and 2
Attributes
----------
checkers : Tuple[Callable], optional
Functions checking whether given module is depthwise convolution.
Should return True in such case, False otherwise.
Default: `Depthwise.default_checker`; if module's groups count is equal
to module's `in_channels` True is returned. Works for PyTorch's `ConvNd` layers.
"""
def __init__(
self, checkers: typing.Tuple[typing.Callable[[torch.nn.Module], bool]] = None
):
self.checkers: typing.Tuple[typing.Callable] = (
Depthwise.default_checker,
) if checkers is None else checkers
[docs] @classmethod
def default_checker(cls, module):
r"""**Default checking method suitable for PyTorch's built-in convolution layers.**
Checks whether count of groups is equal to count of in_channels.
**Important:**
If you want to provide custom checker, you should return `True`
(module being depthwise convolution) or `False` for any
module that is passed to this function.
Parameters
----------
module : torch.nn.Module
Module (or submodule) for which True means it's depthwise.
Returns
----------
List[int]
Submodule's indices where depthwise convolution was located.
"""
if hasattr(module, "groups") and hasattr(module, "in_channels"):
return module.groups == module.in_channels and module.in_channels != 1
return False
def _analyse(self, module, function):
for index, submodule in enumerate(getattr(module, function)()):
for checker in self.checkers:
if checker(submodule):
yield index
[docs] def modules(self, module: torch.nn.Module):
r"""**Look for Depthwise convolution using** `modules()` **method (recursive scanning).**
Parameters
----------
module : torch.nn.Module
Module to be scanned
Yields
------
int
Indices where module is considered depthwise convolution.
"""
yield from self._analyse(module, "modules")
[docs] def children(self, module: torch.nn.Module):
r"""**Look for Depthwise convolution using module's** `children()` **method (shallow scanning).**
Parameters
----------
module : torch.nn.Module
Module to be scanned
Yields
------
int
Indices where module is considered depthwise convolution.
"""
yield from self._analyse(module, "children")
[docs] def tips(self, module: torch.nn.Module) -> str:
r"""**Return** `str` **representation of** `modules()` **method.**
It is advised to use this function to get tips in order to easily fix
performance issues related to depthwise convolution.
Parameters
----------
module : torch.nn.Module
Module to be scanned
Returns
-------
str
String representing tips related to depthwise convolution.
"""
depthwise = self.modules(module)
if depthwise:
return (
"Depthwise convolutions are not currently using specialized kernel and might be slower.\n"
+ "See this issue: https://github.com/pytorch/pytorch/issues/18631 for more information.\n"
+ "Indices of those modules:\n"
+ str(list(depthwise))
+ "\nYou may want to decrease number of groups (like it's done for ResNeXt) for possible speed & accuracy improvements."
)
return ""
[docs]class Inplace(Base):
r"""**Check whether any submodule/child of module is set to inplace mode.**
Inplace operations may interfere with traced module (kernel fusion) and cause slowdowns.
See `this issue <https://github.com/pytorch/pytorch/issues/23655>`__ for more information.
**Example**::
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, kernel_size=3, groups=64),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(64, 64, kernel_size=3),
torch.nn.ReLU6(inplace=True),
torch.nn.Conv2d(64, 128, kernel_size=3, groups=32),
)
for index in torchfunc.performance.layers.Inplace().children(model):
print(index) # Should print 1 and 3
For easy to follow guidelines, use `tips` method of this class.
Attributes
----------
attribute: Tuple[str], optional
Attributes names indicating whether current op is inplace. Do not specify if you are not using
custom modules not following pytorch's conventions. Default: `("inplace",)`.
Existence of all those attributes will be checked in module. If any of them exists
and is `True`, it will be considered as inplace operation.
"""
def __init__(self, inplace: typing.Tuple[str] = ("inplace",)):
self.inplace = inplace
def _analyse(self, module: torch.nn.Module, method: str):
for index, submodule in enumerate(getattr(module, method)()):
for attribute in self.inplace:
if hasattr(submodule, attribute):
if getattr(submodule, attribute):
yield index
[docs] def modules(self, module: torch.nn.Module):
r"""**Look for inplace operation using** `modules()` **method (recursive scanning).**
Yields
------
int
Indices where module is probably `inplace`.
"""
yield from self._analyse(module, "modules")
[docs] def children(self, module: torch.nn.Module):
r"""**Look for inplace operation using** `children()` **method (shallow scanning).**
Yields
------
int
Indices where module is probably `inplace`.
"""
yield from self._analyse(module, "children")
[docs] def tips(self, module: torch.nn.Module) -> str:
r"""**Return** `str` **representation of** `modules()` **method.**
It is advised to use this function to get tips in order to easily fix
performance issues related to inplace operations.
Parameters
----------
module : torch.nn.Module
Module to be scanned
Returns
-------
str
String representing tips related to inplace operations.
"""
inplace = self.modules(module)
if inplace:
return (
"In-place operations might harm kernel fusion. Indices of those modules:\n"
+ str(list(inplace))
+ "\nYou may want to remove inplace flag (see this issue: https://github.com/pytorch/pytorch/issues/23655)"
)
return ""