Shortcuts

Source code for torchfunc.hooks.registrators

r"""
**This module allows you for easier hook registration (e.g. based on** `type` **or** `index` **within network).**

Example::

    # Example forward pre hook
    def example_forward_pre(module, inputs):
        return inputs + 1

    # MNIST classifier
    model = torch.nn.Sequential(
        torch.nn.Linear(784, 100),
        torch.nn.ReLU(),
        torch.nn.Linear(100, 50),
        torch.nn.ReLU(),
        torch.nn.Linear(50, 10),
    )
    registrator = torchfunc.hooks.registrators.ForwardPre()
    # Register forwardPreHook for all torch.nn.Linear submodules
    registrator.modules(model, example_forward_pre, types=(torch.nn.Linear))

You could specify indices instead of types (for example all inputs to `torch.nn.Linear` will be registered),
and iterate over `children` instead of `modules`.
"""

import typing

import torch

from .._base import Base
from ._dev_utils import register_condition


class _Registrator(Base):
    r"""**{}**

    Attributes
    ----------
    handles : List[torch.utils.hooks.RemovableHandle]
        Handles for registered hooks, each corresponds to specific submodule.
        Can be used to unregister certain hooks (though discouraged).

    """

    def __init__(self, register_method, hook):
        self._register_method: typing.Callable = register_method
        self.hook: typing.Callable = hook
        self.handles = []

    def _register_hook(
        self,
        network,
        iterating_function: str,
        types: typing.Tuple[typing.Any] = None,
        indices: typing.List[int] = None,
    ):
        for index, module in enumerate(getattr(network, iterating_function)()):
            if register_condition(module, types, index, indices):
                self.handles.append(getattr(module, self._register_method)(self.hook))

    def __iter__(self):
        return iter(self.handles)

    def __len__(self):
        return len(self.handles)

    def remove(self, index) -> None:
        r"""**Remove hook specified by** `index`.

        Parameters
        ----------
        index: int
            Index of subhook (usually registered for layer)

        """
        handle = self.handles.pop(index)
        handle.remove()

    def modules(
        self,
        module: torch.nn.Module,
        types: typing.Tuple[typing.Any] = None,
        indices: typing.List[int] = None,
    ):
        r"""**Register** `hook` **using types and/or indices via** `modules` **hook**.

        This function will use `modules` method of `torch.nn.Module` to iterate over available submodules. If you wish to iterate non-recursively, use `children`.

        **Important:**

        If `types` and `indices` are left with their default values, all modules
        will have `subrecorders` registered.

        Parameters
        ----------
        module : torch.nn.Module
            Module (usually neural network) for which inputs will be collected.
        types : Tuple[typing.Any], optional
            Module types for which data will be recorded. E.g. `(torch.nn.Conv2d, torch.nn.Linear)`
            will register `subrecorders` on every module being instance of either `Conv2d` or `Linear`.
            Default: `None`
        indices : Iterable[int], optional
            Indices of modules whose inputs will be registered.
            Default: `None`

        Returns
        -------
        self
        """

        self._register_hook(module, "modules", types, indices)
        return self

    def children(
        self,
        network,
        types: typing.Tuple[typing.Any] = None,
        indices: typing.List[int] = None,
    ):
        r"""**Register** `subrecorders` **using types and/or indices via** `children` **hook**.

        This function will use `children` method of `torch.nn.Module` to iterate over available submodules. If you wish to iterate recursively, use `modules`.

        **Important:**

        If `types` and `indices` are left with their default values, all modules
        will have `subrecorders` registered.

        Parameters
        ----------
        module : torch.nn.Module
            Module (usually neural network) for which inputs will be collected.
        types : Tuple[typing.Any], optional
            Module types for which data will be recorded. E.g. `(torch.nn.Conv2d, torch.nn.Linear)`
            will register `subrecorders` on every module being instance of either `Conv2d` or `Linear`.
            Default: `None`
        indices : Iterable[int], optional
            Indices of modules whose inputs will be registered.
            Default: `None`

        Returns
        -------
        self
        """

        self._register_hook(network, "children", types, indices)
        return self


[docs]class ForwardPre(_Registrator): __doc__ = _Registrator.__doc__.format( "Register forward pre hook based on module's type or indices." ) def __init__(self, hook: typing.Callable): self.hook = hook super().__init__("register_forward_pre_hook", self.hook)
[docs]class Forward(_Registrator): __doc__ = _Registrator.__doc__.format( "Register forward hook based on module's type or indices." ) def __init__(self, hook: typing.Callable): self.hook = hook super().__init__("register_forward_hook", self.hook)
[docs]class Backward(_Registrator): __doc__ = _Registrator.__doc__.format( "Register backward hook based on module's type or indices." ) def __init__(self, hook: typing.Callable): self.hook = hook super().__init__("register_backward_hook", self.hook)