Shortcuts

Source code for torchtraining.cast

"""Cast tensors in a functional fashion.

Users can use this module to cast `step` outputs to desired type or
to lower precision in order to save memory (though it shouldn't be needed.)

"""

import abc

import torch

from ._base import Operation


def _docstring(klass) -> str:
    klass.__doc__ = """Cast `torch.Tensor` instance to {cast}.

.. note::

    **IMPORTANT**: Only `torch.Tensor` can be passed as `memory_format`
    is specified during casting.


Returns
-------
{cast}
    Casted `data`

""".format(
        cast=klass.__name__
    )
    return klass


def _forward_docstring(function):
    function.__doc__ = """
    Arguments
    ---------
    data: torch.Tensor
        Tensor to be casted
    """
    return function


class _Cast(Operation):
    """Shared base class amongst most casting operations."""

    def __init__(self, memory_format=torch.preserve_format):
        super().__init__()
        self.memory_format = memory_format

    @abc.abstractmethod
    def forward(self, data):
        pass


[docs]@_docstring class BFloat16(_Cast):
[docs] @_forward_docstring def forward(self, data): return data.bfloat16(memory_format=self.memory_format)
[docs]@_docstring class Bool(_Cast):
[docs] @_forward_docstring def forward(self, data): return data.bool(memory_format=self.memory_format)
[docs]@_docstring class Byte(_Cast):
[docs] @_forward_docstring def forward(self, data): return data.byte(memory_format=self.memory_format)
[docs]@_docstring class Char(_Cast):
[docs] @_forward_docstring def forward(self, data): return data.char(memory_format=self.memory_format)
[docs]@_docstring class Double(_Cast):
[docs] @_forward_docstring def forward(self, data): return data.double(memory_format=self.memory_format)
[docs]@_docstring class Float(_Cast):
[docs] @_forward_docstring def forward(self, data): return data.float(memory_format=self.memory_format)
[docs]@_docstring class Half(_Cast):
[docs] @_forward_docstring def forward(self, data): return data.half(memory_format=self.memory_format)
[docs]@_docstring class Int(_Cast):
[docs] @_forward_docstring def forward(self, data): return data.int(memory_format=self.memory_format)
[docs]@_docstring class Long(_Cast):
[docs] @_forward_docstring def forward(self, data): return data.long(memory_format=self.memory_format)
[docs]@_docstring class Short(_Cast):
[docs] @_forward_docstring def forward(self, data): return data.short(memory_format=self.memory_format)
[docs]@_docstring class Item(Operation):
[docs] @_forward_docstring def forward(self, data): return data.item()
[docs]@_docstring class Numpy(Operation):
[docs] @_forward_docstring def forward(self, data): return data.numpy()
[docs]@_docstring class List(Operation):
[docs] @_forward_docstring def forward(self, data): return data.to_list()
[docs]@_docstring class MKLDNN(Operation):
[docs] @_forward_docstring def forward(self, data): return data.to_mkldnn()
[docs]class Sparse(Operation): """Cast `torch.Tensor` to sparse format. Parameters ---------- sparse_dims: int, optional The number of sparse dimensions to include in the new sparse tensor. Default: `None`. """ def __init__(self, sparse_dims=None): super().__init__() self.sparse_dims = sparse_dims
[docs] @_forward_docstring def forward(self, data): return data.to_sparse(sparse_dims=self.sparse_dims)
# As another tensor
[docs]class As(Operation): """Cast `torch.Tensor` to the same type as `other`. Parameters ---------- other: torch.Tensor Tensor according to which incoming tensor will be casted. """ def __init__(self, other): super().__init__() self.other = other
[docs] @_forward_docstring def forward(self, data): return data.type_as(self.other)
############################################################################### # # TYPE ALIASES # ############################################################################### UInt8 = Byte Int8 = Char Int16 = Short Int32 = Int Int64 = Long Float16 = Half Float32 = Float Float64 = Double