
Source code for torchtraining.accelerators

"""Accelerators enabling distributed (multi-GPU/multi-node) training.

Accelerators should be instantiated only once and used on top-most
module (in the following order):

    * epoch (if exists)
    * iteration (if exists)
    * step

Those are the only objects which can be "piped" into producers, for example::

    tt.accelerators.Horovod(...) ** tt.iterations.Iteration(...)

And should be used in this way (although it's not always necessary).
See `horovod` module for an example.


import importlib

import torch

from .._base import Accelerator
from ..utils import general as utils

if utils.modules_exist("horovod", "horovod.torch"):

    import horovod.torch as hvd

[docs] class Horovod(Accelerator): """Accelerate training using Uber's Horovod framework. See `torchtraining.accelerators.horovod` package for more information. .. note:: **IMPORTANT**: This object needs `horovod` Python package to be visible. You can install it with `pip install -U torchtraining[horovod]`. Also you should export `CUDA_HOME` variable like this: `CUDA_HOME=/opt/cuda pip install -U torchtraining[horovod]` (your path may vary) Parameters ---------- module: torch.nn.Module Module to be broadcasted to all processes. rank: int, optional Root process rank. Default: `0` per_worker_threads: int, optional Number of threads which can be utilized by each process. Default: `pytorch`'s default comm: List, optional List specifying ranks for the communicator, relative to the `MPI_COMM_WORLD` communicator OR the MPI communicator to use. Given communicator will be duplicated. If `None`, Horovod will use MPI_COMM_WORLD Communicator. Default: `None` """ def __init__( self, model, rank: int = 0, per_worker_threads: int = None, comm=None, ): hvd.init(comm) if torch.cuda.is_available(): torch.cuda.set_device(hvd.local_rank()) if per_worker_threads is not None: if per_worker_threads < 1: raise ValueError("Each worker needs at least one thread to run.") torch.set_num_threads(per_worker_threads) hvd.broadcast_parameters(model.state_dict(), root_rank=rank)
from . import horovod