
Source code for torchfunc.hooks.registrators

**This module allows you for easier hook registration (e.g. based on** `type` **or** `index` **within network).**


    # Example forward pre hook
    def example_forward_pre(module, inputs):
        return inputs + 1

    # MNIST classifier
    model = torch.nn.Sequential(
        torch.nn.Linear(784, 100),
        torch.nn.Linear(100, 50),
        torch.nn.Linear(50, 10),
    registrator = torchfunc.hooks.registrators.ForwardPre()
    # Register forwardPreHook for all torch.nn.Linear submodules
    registrator.modules(model, example_forward_pre, types=(torch.nn.Linear))

You could specify indices instead of types (for example all inputs to `torch.nn.Linear` will be registered),
and iterate over `children` instead of `modules`.

import typing

import torch

from .._base import Base
from ._dev_utils import register_condition

class _Registrator(Base):

    handles : List[torch.utils.hooks.RemovableHandle]
        Handles for registered hooks, each corresponds to specific submodule.
        Can be used to unregister certain hooks (though discouraged).


    def __init__(self, register_method, hook):
        self._register_method: typing.Callable = register_method
        self.hook: typing.Callable = hook
        self.handles = []

    def _register_hook(
        iterating_function: str,
        types: typing.Tuple[typing.Any] = None,
        indices: typing.List[int] = None,
        for index, module in enumerate(getattr(network, iterating_function)()):
            if register_condition(module, types, index, indices):
                self.handles.append(getattr(module, self._register_method)(self.hook))

    def __iter__(self):
        return iter(self.handles)

    def __len__(self):
        return len(self.handles)

    def remove(self, index) -> None:
        r"""**Remove hook specified by** `index`.

        index: int
            Index of subhook (usually registered for layer)

        handle = self.handles.pop(index)

    def modules(
        module: torch.nn.Module,
        types: typing.Tuple[typing.Any] = None,
        indices: typing.List[int] = None,
        r"""**Register** `hook` **using types and/or indices via** `modules` **hook**.

        This function will use `modules` method of `torch.nn.Module` to iterate over available submodules. If you wish to iterate non-recursively, use `children`.


        If `types` and `indices` are left with their default values, all modules
        will have `subrecorders` registered.

        module : torch.nn.Module
            Module (usually neural network) for which inputs will be collected.
        types : Tuple[typing.Any], optional
            Module types for which data will be recorded. E.g. `(torch.nn.Conv2d, torch.nn.Linear)`
            will register `subrecorders` on every module being instance of either `Conv2d` or `Linear`.
            Default: `None`
        indices : Iterable[int], optional
            Indices of modules whose inputs will be registered.
            Default: `None`


        self._register_hook(module, "modules", types, indices)
        return self

    def children(
        types: typing.Tuple[typing.Any] = None,
        indices: typing.List[int] = None,
        r"""**Register** `subrecorders` **using types and/or indices via** `children` **hook**.

        This function will use `children` method of `torch.nn.Module` to iterate over available submodules. If you wish to iterate recursively, use `modules`.


        If `types` and `indices` are left with their default values, all modules
        will have `subrecorders` registered.

        module : torch.nn.Module
            Module (usually neural network) for which inputs will be collected.
        types : Tuple[typing.Any], optional
            Module types for which data will be recorded. E.g. `(torch.nn.Conv2d, torch.nn.Linear)`
            will register `subrecorders` on every module being instance of either `Conv2d` or `Linear`.
            Default: `None`
        indices : Iterable[int], optional
            Indices of modules whose inputs will be registered.
            Default: `None`


        self._register_hook(network, "children", types, indices)
        return self

[docs]class ForwardPre(_Registrator): __doc__ = _Registrator.__doc__.format( "Register forward pre hook based on module's type or indices." ) def __init__(self, hook: typing.Callable): self.hook = hook super().__init__("register_forward_pre_hook", self.hook)
[docs]class Forward(_Registrator): __doc__ = _Registrator.__doc__.format( "Register forward hook based on module's type or indices." ) def __init__(self, hook: typing.Callable): self.hook = hook super().__init__("register_forward_hook", self.hook)
[docs]class Backward(_Registrator): __doc__ = _Registrator.__doc__.format( "Register backward hook based on module's type or indices." ) def __init__(self, hook: typing.Callable): self.hook = hook super().__init__("register_backward_hook", self.hook)