diff --git a/docs/1.1.0/.buildinfo b/docs/1.1.0/.buildinfo new file mode 100644 index 000000000000..b1656b206005 --- /dev/null +++ b/docs/1.1.0/.buildinfo @@ -0,0 +1,4 @@ +# Sphinx build info version 1 +# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. +config: 6d6a417f86940ceb333a9ffb31bb21d2 +tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/1.1.0/__config__.html b/docs/1.1.0/__config__.html new file mode 100644 index 000000000000..6362b72a0334 --- /dev/null +++ b/docs/1.1.0/__config__.html @@ -0,0 +1,536 @@ + + + + + + +
+ + + + +
+r"""
+The torch package contains data structures for multi-dimensional
+tensors and mathematical operations over these are defined.
+Additionally, it provides many utilities for efficient serializing of
+Tensors and arbitrary types, and other useful utilities.
+
+It has a CUDA counterpart, that enables you to run your tensor computations
+on an NVIDIA GPU with compute capability >= 3.0.
+"""
+
+import os
+import sys
+import platform
+from ._utils import _import_dotted_name
+from ._utils_internal import get_file_path, prepare_multiprocessing_environment
+from .version import __version__ # noqa: F401
+from ._six import string_classes as _string_classes
+
+__all__ = [
+ 'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
+ 'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed',
+ 'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
+ 'no_grad', 'enable_grad', 'rand', 'randn',
+ 'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
+ 'ShortStorage', 'CharStorage', 'ByteStorage',
+ 'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
+ 'ShortTensor', 'CharTensor', 'ByteTensor', 'Tensor',
+]
+
+################################################################################
+# Load the extension module
+################################################################################
+
+# Loading the extension with RTLD_GLOBAL option allows to not link extension
+# modules against the _C shared object. Their missing THP symbols will be
+# automatically filled by the dynamic loader.
+import os as _dl_flags
+
+# if we have numpy, it *must* be imported before the call to setdlopenflags()
+# or there is risk that later c modules will segfault when importing numpy
+try:
+ import numpy as _np # noqa: F401
+except ImportError:
+ pass
+
+if platform.system() == 'Windows':
+ # first get nvToolsExt PATH
+ def get_nvToolsExt_path():
+ NVTOOLEXT_HOME = _dl_flags.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt')
+
+ if _dl_flags.path.exists(NVTOOLEXT_HOME):
+ return _dl_flags.path.join(NVTOOLEXT_HOME, 'bin', 'x64')
+ else:
+ return ''
+
+ py_dll_path = _dl_flags.path.join(_dl_flags.path.dirname(sys.executable), 'Library', 'bin')
+ th_dll_path = _dl_flags.path.join(_dl_flags.path.dirname(__file__), 'lib')
+
+ dll_paths = [th_dll_path, py_dll_path, get_nvToolsExt_path(), _dl_flags.environ['PATH']]
+
+ # then add the path to env
+ _dl_flags.environ['PATH'] = ';'.join(dll_paths)
+
+else:
+ # first check if the os package has the required flags
+ if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'):
+ try:
+ # next try if DLFCN exists
+ import DLFCN as _dl_flags
+ except ImportError:
+ # as a last attempt, use compile-time constants
+ import torch._dl as _dl_flags
+
+ old_flags = sys.getdlopenflags()
+ sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY)
+
+del _dl_flags
+
+from torch._C import *
+
+__all__ += [name for name in dir(_C)
+ if name[0] != '_' and
+ not name.endswith('Base')]
+
+if platform.system() != 'Windows':
+ sys.setdlopenflags(old_flags)
+ del old_flags
+
+################################################################################
+# Define basic utilities
+################################################################################
+
+
+def typename(o):
+ if isinstance(o, torch.Tensor):
+ return o.type()
+
+ module = ''
+ class_name = ''
+ if hasattr(o, '__module__') and o.__module__ != 'builtins' \
+ and o.__module__ != '__builtin__' and o.__module__ is not None:
+ module = o.__module__ + '.'
+
+ if hasattr(o, '__qualname__'):
+ class_name = o.__qualname__
+ elif hasattr(o, '__name__'):
+ class_name = o.__name__
+ else:
+ class_name = o.__class__.__name__
+
+ return module + class_name
+
+
+[docs]def is_tensor(obj):
+ r"""Returns True if `obj` is a PyTorch tensor.
+
+ Args:
+ obj (Object): Object to test
+ """
+ return isinstance(obj, torch.Tensor)
+
+
+[docs]def is_storage(obj):
+ r"""Returns True if `obj` is a PyTorch storage object.
+
+ Args:
+ obj (Object): Object to test
+ """
+ return type(obj) in _storage_classes
+
+
+[docs]def set_default_tensor_type(t):
+ r"""Sets the default ``torch.Tensor`` type to floating point tensor type
+ :attr:`t`. This type will also be used as default floating point type for
+ type inference in :func:`torch.tensor`.
+
+ The default floating point tensor type is initially ``torch.FloatTensor``.
+
+ Args:
+ t (type or string): the floating point tensor type or its name
+
+ Example::
+
+ >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
+ torch.float32
+ >>> torch.set_default_tensor_type(torch.DoubleTensor)
+ >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
+ torch.float64
+
+ """
+ if isinstance(t, _string_classes):
+ t = _import_dotted_name(t)
+ _C._set_default_tensor_type(t)
+
+
+[docs]def set_default_dtype(d):
+ r"""Sets the default floating point dtype to :attr:`d`. This type will be
+ used as default floating point type for type inference in
+ :func:`torch.tensor`.
+
+ The default floating point dtype is initially ``torch.float32``.
+
+ Args:
+ d (:class:`torch.dtype`): the floating point dtype to make the default
+
+ Example::
+
+ >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
+ torch.float32
+ >>> torch.set_default_dtype(torch.float64)
+ >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
+ torch.float64
+
+ """
+ _C._set_default_dtype(d)
+
+# If you edit these imports, please update torch/__init__.py.in as well
+from .random import set_rng_state, get_rng_state, manual_seed, initial_seed
+from .serialization import save, load
+from ._tensor_str import set_printoptions
+
+################################################################################
+# Define Storage and Tensor classes
+################################################################################
+
+from .tensor import Tensor
+from .storage import _StorageBase
+
+
+class DoubleStorage(_C.DoubleStorageBase, _StorageBase):
+ pass
+
+
+
+
+
+class HalfStorage(_C.HalfStorageBase, _StorageBase):
+ pass
+
+
+class LongStorage(_C.LongStorageBase, _StorageBase):
+ pass
+
+
+class IntStorage(_C.IntStorageBase, _StorageBase):
+ pass
+
+
+class ShortStorage(_C.ShortStorageBase, _StorageBase):
+ pass
+
+
+class CharStorage(_C.CharStorageBase, _StorageBase):
+ pass
+
+
+class ByteStorage(_C.ByteStorageBase, _StorageBase):
+ pass
+
+
+class BoolStorage(_C.BoolStorageBase, _StorageBase):
+ pass
+
+_storage_classes = {
+ DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage,
+ CharStorage, ByteStorage, HalfStorage, BoolStorage
+}
+
+# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
+_tensor_classes = set()
+
+
+################################################################################
+# Initialize extension
+################################################################################
+
+def manager_path():
+ if platform.system() == 'Windows':
+ return b""
+ path = get_file_path('torch', 'bin', 'torch_shm_manager')
+ prepare_multiprocessing_environment(get_file_path('torch'))
+ if not os.path.exists(path):
+ raise RuntimeError("Unable to find torch_shm_manager at " + path)
+ return path.encode('utf-8')
+
+
+# Shared memory manager needs to know the exact location of manager executable
+_C._initExtension(manager_path())
+del manager_path
+
+for name in dir(_C._VariableFunctions):
+ if name.startswith('__'):
+ continue
+ globals()[name] = getattr(_C._VariableFunctions, name)
+
+################################################################################
+# Import interface functions defined in Python
+################################################################################
+
+# needs to be after the above ATen bindings so we can overwrite from Python side
+from .functional import *
+
+
+################################################################################
+# Remove unnecessary members
+################################################################################
+
+del DoubleStorageBase
+del FloatStorageBase
+del LongStorageBase
+del IntStorageBase
+del ShortStorageBase
+del CharStorageBase
+del ByteStorageBase
+del BoolStorageBase
+
+################################################################################
+# Import most common subpackages
+################################################################################
+
+import torch.cuda
+import torch.autograd
+from torch.autograd import no_grad, enable_grad, set_grad_enabled # noqa: F401
+import torch.nn
+import torch.optim
+import torch.multiprocessing
+import torch.sparse
+import torch.utils.backcompat
+import torch.onnx
+import torch.jit
+import torch.hub
+import torch.random
+import torch.distributions
+import torch.testing
+import torch.backends.cuda
+import torch.backends.mkl
+import torch.backends.openmp
+import torch.__config__
+
+_C._init_names(list(torch._storage_classes))
+
+# attach docstrings to torch and tensor functions
+from . import _torch_docs, _tensor_docs, _storage_docs
+del _torch_docs, _tensor_docs, _storage_docs
+
+
+[docs]def compiled_with_cxx11_abi():
+ r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""
+ return _C._GLIBCXX_USE_CXX11_ABI
+
+
+# Import the ops "namespace"
+from torch._ops import ops # noqa: F401
+
+# Import the quasi random sampler
+import torch.quasirandom
+
+import torch
+
+
+[docs]def show():
+ """
+ Return a human-readable string with descriptions of the
+ configuration of PyTorch.
+ """
+ return torch._C._show_config()
+
+# TODO: In principle, we could provide more structured version/config
+# information here. We're not for now; considering doing so if someone
+# asks for it.
+
+import math
+import torch
+from torch._six import inf
+
+
+class __PrinterOptions(object):
+ precision = 4
+ threshold = 1000
+ edgeitems = 3
+ linewidth = 80
+ sci_mode = None
+
+
+PRINT_OPTS = __PrinterOptions()
+
+
+# We could use **kwargs, but this will give better docs
+[docs]def set_printoptions(
+ precision=None,
+ threshold=None,
+ edgeitems=None,
+ linewidth=None,
+ profile=None,
+ sci_mode=None
+):
+ r"""Set options for printing. Items shamelessly taken from NumPy
+
+ Args:
+ precision: Number of digits of precision for floating point output
+ (default = 4).
+ threshold: Total number of array elements which trigger summarization
+ rather than full `repr` (default = 1000).
+ edgeitems: Number of array items in summary at beginning and end of
+ each dimension (default = 3).
+ linewidth: The number of characters per line for the purpose of
+ inserting line breaks (default = 80). Thresholded matrices will
+ ignore this parameter.
+ profile: Sane defaults for pretty printing. Can override with any of
+ the above options. (any one of `default`, `short`, `full`)
+ sci_mode: Enable (True) or disable (False) scientific notation. If
+ None (default) is specified, the value is defined by `_Formatter`
+ """
+ if profile is not None:
+ if profile == "default":
+ PRINT_OPTS.precision = 4
+ PRINT_OPTS.threshold = 1000
+ PRINT_OPTS.edgeitems = 3
+ PRINT_OPTS.linewidth = 80
+ elif profile == "short":
+ PRINT_OPTS.precision = 2
+ PRINT_OPTS.threshold = 1000
+ PRINT_OPTS.edgeitems = 2
+ PRINT_OPTS.linewidth = 80
+ elif profile == "full":
+ PRINT_OPTS.precision = 4
+ PRINT_OPTS.threshold = inf
+ PRINT_OPTS.edgeitems = 3
+ PRINT_OPTS.linewidth = 80
+
+ if precision is not None:
+ PRINT_OPTS.precision = precision
+ if threshold is not None:
+ PRINT_OPTS.threshold = threshold
+ if edgeitems is not None:
+ PRINT_OPTS.edgeitems = edgeitems
+ if linewidth is not None:
+ PRINT_OPTS.linewidth = linewidth
+ PRINT_OPTS.sci_mode = sci_mode
+
+
+class _Formatter(object):
+ def __init__(self, tensor):
+ self.floating_dtype = tensor.dtype.is_floating_point
+ self.int_mode = True
+ self.sci_mode = False
+ self.max_width = 1
+
+ with torch.no_grad():
+ tensor_view = tensor.reshape(-1)
+
+ if not self.floating_dtype:
+ for value in tensor_view:
+ value_str = '{}'.format(value)
+ self.max_width = max(self.max_width, len(value_str))
+
+ else:
+ nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0))
+
+ if nonzero_finite_vals.numel() == 0:
+ # no valid number, do nothing
+ return
+
+ # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
+ nonzero_finite_abs = nonzero_finite_vals.abs().double()
+ nonzero_finite_min = nonzero_finite_abs.min().double()
+ nonzero_finite_max = nonzero_finite_abs.max().double()
+
+ for value in nonzero_finite_vals:
+ if value != torch.ceil(value):
+ self.int_mode = False
+ break
+
+ if self.int_mode:
+ # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
+ # to indicate that the tensor is of floating type. add 1 to the len to account for this.
+ if nonzero_finite_max / nonzero_finite_min > 1000. or nonzero_finite_max > 1.e8:
+ self.sci_mode = True
+ for value in nonzero_finite_vals:
+ value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value)
+ self.max_width = max(self.max_width, len(value_str))
+ else:
+ for value in nonzero_finite_vals:
+ value_str = ('{:.0f}').format(value)
+ self.max_width = max(self.max_width, len(value_str) + 1)
+ else:
+ # Check if scientific representation should be used.
+ if nonzero_finite_max / nonzero_finite_min > 1000.\
+ or nonzero_finite_max > 1.e8\
+ or nonzero_finite_min < 1.e-4:
+ self.sci_mode = True
+ for value in nonzero_finite_vals:
+ value_str = ('{{:.{}e}}').format(PRINT_OPTS.precision).format(value)
+ self.max_width = max(self.max_width, len(value_str))
+ else:
+ for value in nonzero_finite_vals:
+ value_str = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value)
+ self.max_width = max(self.max_width, len(value_str))
+
+ if PRINT_OPTS.sci_mode is not None:
+ self.sci_mode = PRINT_OPTS.sci_mode
+
+ def width(self):
+ return self.max_width
+
+ def format(self, value):
+ if self.floating_dtype:
+ if self.sci_mode:
+ ret = ('{{:{}.{}e}}').format(self.max_width, PRINT_OPTS.precision).format(value)
+ elif self.int_mode:
+ ret = '{:.0f}'.format(value)
+ if not (math.isinf(value) or math.isnan(value)):
+ ret += '.'
+ else:
+ ret = ('{{:.{}f}}').format(PRINT_OPTS.precision).format(value)
+ else:
+ ret = '{}'.format(value)
+ return (self.max_width - len(ret)) * ' ' + ret
+
+
+def _scalar_str(self, formatter):
+ return formatter.format(self.item())
+
+
+def _vector_str(self, indent, formatter, summarize):
+ # length includes spaces and comma between elements
+ element_length = formatter.width() + 2
+ elements_per_line = max(1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))))
+ char_per_line = element_length * elements_per_line
+
+ if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
+ data = ([formatter.format(val) for val in self[:PRINT_OPTS.edgeitems].tolist()] +
+ [' ...'] +
+ [formatter.format(val) for val in self[-PRINT_OPTS.edgeitems:].tolist()])
+ else:
+ data = [formatter.format(val) for val in self.tolist()]
+
+ data_lines = [data[i:i + elements_per_line] for i in range(0, len(data), elements_per_line)]
+ lines = [', '.join(line) for line in data_lines]
+ return '[' + (',' + '\n' + ' ' * (indent + 1)).join(lines) + ']'
+
+
+def _tensor_str_with_formatter(self, indent, formatter, summarize):
+ dim = self.dim()
+
+ if dim == 0:
+ return _scalar_str(self, formatter)
+ if dim == 1:
+ return _vector_str(self, indent, formatter, summarize)
+
+ if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
+ slices = ([_tensor_str_with_formatter(self[i], indent + 1, formatter, summarize)
+ for i in range(0, PRINT_OPTS.edgeitems)] +
+ ['...'] +
+ [_tensor_str_with_formatter(self[i], indent + 1, formatter, summarize)
+ for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))])
+ else:
+ slices = [_tensor_str_with_formatter(self[i], indent + 1, formatter, summarize)
+ for i in range(0, self.size(0))]
+
+ tensor_str = (',' + '\n' * (dim - 1) + ' ' * (indent + 1)).join(slices)
+ return '[' + tensor_str + ']'
+
+
+def _tensor_str(self, indent):
+ if self.numel() == 0:
+ return '[]'
+
+ summarize = self.numel() > PRINT_OPTS.threshold
+ if self.dtype is torch.float16:
+ self = self.float()
+ formatter = _Formatter(get_summarized_data(self) if summarize else self)
+ return _tensor_str_with_formatter(self, indent, formatter, summarize)
+
+
+def _add_suffixes(tensor_str, suffixes, indent, force_newline):
+ tensor_strs = [tensor_str]
+ last_line_len = len(tensor_str) - tensor_str.rfind('\n') + 1
+ for suffix in suffixes:
+ suffix_len = len(suffix)
+ if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
+ tensor_strs.append(',\n' + ' ' * indent + suffix)
+ last_line_len = indent + suffix_len
+ force_newline = False
+ else:
+ tensor_strs.append(', ' + suffix)
+ last_line_len += suffix_len + 2
+ tensor_strs.append(')')
+ return ''.join(tensor_strs)
+
+
+def get_summarized_data(self):
+ dim = self.dim()
+ if dim == 0:
+ return self
+ if dim == 1:
+ if self.size(0) > 2 * PRINT_OPTS.edgeitems:
+ return torch.cat((self[:PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems:]))
+ else:
+ return self
+ if self.size(0) > 2 * PRINT_OPTS.edgeitems:
+ start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
+ end = ([self[i]
+ for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))])
+ return torch.stack([get_summarized_data(x) for x in (start + end)])
+ else:
+ return torch.stack([get_summarized_data(x) for x in self])
+
+
+def _str(self):
+ prefix = 'tensor('
+ indent = len(prefix)
+
+ suffixes = []
+ if not torch._C._is_default_type_cuda():
+ if self.device.type == 'cuda':
+ suffixes.append('device=\'' + str(self.device) + '\'')
+ else:
+ if self.device.type == 'cpu' or torch.cuda.current_device() != self.device.index:
+ suffixes.append('device=\'' + str(self.device) + '\'')
+
+ has_default_dtype = self.dtype == torch.get_default_dtype() or self.dtype == torch.int64
+
+ if self.is_sparse:
+ suffixes.append('size=' + str(tuple(self.shape)))
+ suffixes.append('nnz=' + str(self._nnz()))
+ if not has_default_dtype:
+ suffixes.append('dtype=' + str(self.dtype))
+ indices_prefix = 'indices=tensor('
+ indices = self._indices().detach()
+ indices_str = _tensor_str(indices, indent + len(indices_prefix))
+ if indices.numel() == 0:
+ indices_str += ', size=' + str(tuple(indices.shape))
+ values_prefix = 'values=tensor('
+ values = self._values().detach()
+ values_str = _tensor_str(values, indent + len(values_prefix))
+ if values.numel() == 0:
+ values_str += ', size=' + str(tuple(values.shape))
+ tensor_str = indices_prefix + indices_str + '),\n' + ' ' * indent + values_prefix + values_str + ')'
+ else:
+ if self.numel() == 0 and not self.is_sparse:
+ # Explicitly print the shape if it is not (0,), to match NumPy behavior
+ if self.dim() != 1:
+ suffixes.append('size=' + str(tuple(self.shape)))
+
+ # In an empty tensor, there are no elements to infer if the dtype
+ # should be int64, so it must be shown explicitly.
+ if self.dtype != torch.get_default_dtype():
+ suffixes.append('dtype=' + str(self.dtype))
+ tensor_str = '[]'
+ else:
+ if not has_default_dtype:
+ suffixes.append('dtype=' + str(self.dtype))
+ if self.layout != torch.strided:
+ tensor_str = _tensor_str(self.to_dense(), indent)
+ else:
+ tensor_str = _tensor_str(self, indent)
+
+ if self.layout != torch.strided:
+ suffixes.append('layout=' + str(self.layout))
+
+ if self.grad_fn is not None:
+ name = type(self.grad_fn).__name__
+ if name == 'CppFunction':
+ name = self.grad_fn.name().rsplit('::', 1)[-1]
+ suffixes.append('grad_fn=<{}>'.format(name))
+ elif self.requires_grad:
+ suffixes.append('requires_grad=True')
+
+ return _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse)
+
+import torch
+import warnings
+from collections import defaultdict
+
+
+def _type(self, dtype=None, non_blocking=False, **kwargs):
+ """Returns the type if `dtype` is not provided, else casts this object to
+ the specified type.
+
+ If this is already of the correct type, no copy is performed and the
+ original object is returned.
+
+ Args:
+ dtype (type or string): The desired type
+ non_blocking (bool): If ``True``, and the source is in pinned memory
+ and destination is on the GPU or vice versa, the copy is performed
+ asynchronously with respect to the host. Otherwise, the argument
+ has no effect.
+ **kwargs: For compatibility, may contain the key ``async`` in place of
+ the ``non_blocking`` argument. The ``async`` arg is deprecated.
+ """
+ non_blocking = _get_async_or_non_blocking('type', non_blocking, kwargs)
+ if dtype is None:
+ return self.__module__ + '.' + self.__class__.__name__
+
+ if isinstance(dtype, str):
+ dtype = _import_dotted_name(dtype)
+ if dtype == type(self):
+ return self
+ if self.is_sparse:
+ if not dtype.is_sparse:
+ raise RuntimeError("Cannot cast sparse tensor to dense tensor")
+ new_module_name = dtype.__module__.replace('.sparse', '')
+ new_values_type_name = new_module_name + '.' + dtype.__name__
+ new_values = torch._values(self).type(new_values_type_name, non_blocking)
+ new_indices_type_name = new_module_name + '.LongTensor'
+ new_indices = torch._indices(self).type(new_indices_type_name, non_blocking)
+ return dtype(new_indices, new_values, self.size())
+ if dtype.is_sparse:
+ raise RuntimeError("Cannot cast dense tensor to sparse tensor")
+ return dtype(self.size()).copy_(self, non_blocking)
+
+
+def _cuda(self, device=None, non_blocking=False, **kwargs):
+ """Returns a copy of this object in CUDA memory.
+
+ If this object is already in CUDA memory and on the correct device, then
+ no copy is performed and the original object is returned.
+
+ Args:
+ device (int): The destination GPU id. Defaults to the current device.
+ non_blocking (bool): If ``True`` and the source is in pinned memory,
+ the copy will be asynchronous with respect to the host. Otherwise,
+ the argument has no effect.
+ **kwargs: For compatibility, may contain the key ``async`` in place of
+ the ``non_blocking`` argument.
+ """
+ non_blocking = _get_async_or_non_blocking('cuda', non_blocking, kwargs)
+ if self.is_cuda:
+ if device is None:
+ device = torch.cuda.current_device()
+ if self.get_device() == device:
+ return self
+ else:
+ if device is None:
+ device = -1
+ with torch.cuda.device(device):
+ if self.is_sparse:
+ new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
+ indices = torch._indices(self).cuda(device, non_blocking)
+ values = torch._values(self).cuda(device, non_blocking)
+ return new_type(indices, values, self.size())
+ else:
+ new_type = getattr(torch.cuda, self.__class__.__name__)
+ return new_type(self.size()).copy_(self, non_blocking)
+
+
+def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
+ if not kwargs:
+ return non_blocking
+ if len(kwargs) != 1 or 'async' not in kwargs:
+ message = "{}() got an unexpected keyword argument '{}'"
+ argument = list(kwargs.keys()).pop()
+ raise TypeError(message.format(function_name, argument))
+ warnings.warn("'async' is deprecated; use 'non_blocking'")
+ return kwargs['async']
+
+
+# Note [Don't serialize hooks]
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Since time immemorial, we have serialized the backward hooks associated with
+# variables. This kind of half-worked--Python can pickle global functions
+# (but not closures!)--but there were problems.
+#
+# - It's fragile. If you serialize a backward hook into a saved
+# model, and then you rename the function associated with the hook,
+# now your saved model is broken and you can't load it anymore.
+#
+# - It's not actually used. The standard recommendation is to
+# serialize the *state_dict* of a model, not the model itself
+# (since this is more stable to code changes affecting the model
+# serialization), and the state dict saves "data" only, thus
+# stripping the the backward hooks. In some cases, hooks are
+# essential to the well-functioning of a model (e.g., DDP),
+# but DDP already manages readding the hooks!
+#
+# - We didn't serialize them in many cases. Prior to #10220, we
+# were dropping backward hooks in ForkingPickler. We "fixed" this
+# to be convenient with other serialization sites, but lack of
+# serializing backward hooks wasn't actually the root cause of
+# the bug.
+#
+# With these cases in mind, we have decided that a better strategy
+# is to just NOT serialize hooks at all.
+#
+# Since this is a BC-breaking change, we should warn when we previously
+# serialized a hook, but no longer do so. This will be done by adding a special
+# sentinel property to hooks will be used to suppress this warning. If a hook
+# has the property _torch_serialize_ignore, we will not emit a warning if we
+# attempt to serialize a Tensor with this hook attached to it.
+#
+# By the way, when _backward_hooks is skipped, we must give an EMPTY
+# OrderedDict(), if you pass a None you'll run afoul #12219.
+
+
+def _rebuild_tensor(storage, storage_offset, size, stride):
+ # first construct a tensor with the correct dtype/device
+ t = torch.tensor([], dtype=storage.dtype, device=storage.device)
+ return t.set_(storage, storage_offset, size, stride)
+
+
+def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
+ tensor = _rebuild_tensor(storage, storage_offset, size, stride)
+ tensor.requires_grad = requires_grad
+ # NB: This line exists only for backwards compatibility; the
+ # general expectation is that backward_hooks is an empty
+ # OrderedDict. See Note [Don't serialize hooks]
+ tensor._backward_hooks = backward_hooks
+ return tensor
+
+
+def _rebuild_parameter(data, requires_grad, backward_hooks):
+ param = torch.nn.Parameter(data, requires_grad)
+ # NB: This line exists only for backwards compatibility; the
+ # general expectation is that backward_hooks is an empty
+ # OrderedDict. See Note [Don't serialize hooks]
+ param._backward_hooks = backward_hooks
+
+ return param
+
+
+def _import_dotted_name(name):
+ components = name.split('.')
+ obj = __import__(components[0])
+ for component in components[1:]:
+ obj = getattr(obj, component)
+ return obj
+
+
+# Taken from python 3.5 docs
+def _accumulate(iterable, fn=lambda x, y: x + y):
+ 'Return running totals'
+ # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
+ # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
+ it = iter(iterable)
+ try:
+ total = next(it)
+ except StopIteration:
+ return
+ yield total
+ for element in it:
+ total = fn(total, element)
+ yield total
+
+
+def _flatten_dense_tensors(tensors):
+ """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
+ same dense type.
+
+ Since inputs are dense, the resulting tensor will be a concatenated 1D
+ buffer. Element-wise operation on this buffer will be equivalent to
+ operating individually.
+
+ Arguments:
+ tensors (Iterable[Tensor]): dense tensors to flatten.
+
+ Returns:
+ A contiguous 1D buffer containing input tensors.
+ """
+ if len(tensors) == 1:
+ return tensors[0].contiguous().view(-1)
+ flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
+ return flat
+
+
+def _flatten_sparse_tensors(tensors):
+ """Flatten sparse tensors into two contiguous 1D buffers, one of indices and
+ one of values. Assume tensors are of same sparse type.
+
+ Arguments:
+ tensors (Iterable[Tensor]): sparse tensors to flatten.
+
+ Returns:
+ A tuple of two contiguous 1D buffers, one containing input tensors'
+ indices and the other containing the values.
+ """
+ flat_indices = _flatten_dense_tensors([torch._indices(t) for t in tensors])
+ flat_values = _flatten_dense_tensors([torch._values(t) for t in tensors])
+ return flat_indices, flat_values
+
+
+def _unflatten_dense_tensors(flat, tensors):
+ """View a flat buffer using the sizes of tensors. Assume that tensors are of
+ same dense type, and that flat is given by _flatten_dense_tensors.
+
+ Arguments:
+ flat (Tensor): flattened dense tensors to unflatten.
+ tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
+ unflatten flat.
+
+ Returns:
+ Unflattened dense tensors with sizes same as tensors and values from
+ flat.
+ """
+ outputs = []
+ offset = 0
+ for tensor in tensors:
+ numel = tensor.numel()
+ outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
+ offset += numel
+ return tuple(outputs)
+
+
+def _unflatten_sparse_tensors(flat, tensors):
+ """View flat buffer (containing indices and values) using the sizes of
+ tensors. Assume that tensors are of same sparse type, and that flat is given
+ by _flatten_sparse_tensors.
+
+ Arguments:
+ flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
+ tensors to unflatten.
+ tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
+ unflatten flat.
+
+ Returns:
+ Unflattened sparse tensors with sizes same as tensors and values from
+ flat.
+ """
+ flat_indices, flat_values = flat
+ indices = _unflatten_dense_tensors(flat_indices, [torch._indices(t) for t in tensors])
+ values = _unflatten_dense_tensors(flat_values, [torch._values(t) for t in tensors])
+ outputs = []
+ for t, i, v in zip(tensors, indices, values):
+ outputs.append(t.new(i, v, t.size()))
+ return tuple(outputs)
+
+
+def _reorder_tensors_as(tensors, ordered_tensors):
+ """Assume that tensors are of same order as ordered_tensors within their
+ types, e.g., from _take_tensors. Reorder them to be of same order as
+ ordered_tensors.
+
+ Arguments:
+ tensors (Iterable[Tensor]): tensors to be reordered. They should be of
+ the same order as ordered_tensors within their own types.
+ ordered_tensors (Iterable[Tensor]): tensors whose order will be the
+ reference.
+
+ Returns:
+ Ordered tuple of tensors with contents from tensors and order of
+ ordered_tensors.
+ """
+ type_dict = defaultdict(list)
+ for tensor in tensors:
+ type_dict[tensor.type()].append(tensor)
+ type_dict = {t: iter(coll) for t, coll in type_dict.items()}
+ return tuple(next(type_dict[tensor.type()]) for tensor in ordered_tensors)
+
+
+def _take_tensors(tensors, size_limit):
+ """Group tensors into chunks. This generator yields a chunk at each time,
+ each containing tensors of same type up to certain byte limit in total size.
+
+ Args:
+ tensors (Sequence): A sequence of tensors to be separated into chunks.
+ size_limit (int): The limit of each chunk in bytes.
+
+ Yields:
+ Blocks of tensors of same type and within size_limit. The yielded
+ tensors are only ordered as the original sequence within its types.
+ """
+ buf_dict = defaultdict(lambda: [[], 0])
+ for tensor in tensors:
+ t = tensor.type()
+ if tensor.is_sparse:
+ indices = torch._indices(tensor)
+ values = torch._values(tensor)
+ size = indices.numel() * indices.element_size() + values.numel() * values.element_size()
+ else:
+ size = tensor.numel() * tensor.element_size()
+ buf_and_size = buf_dict[t]
+ if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
+ yield buf_and_size[0]
+ buf_and_size = buf_dict[t] = [[], 0]
+ buf_and_size[0].append(tensor)
+ buf_and_size[1] += size
+ for buf, _ in buf_dict.values():
+ if len(buf) > 0:
+ yield buf
+
+
+# annotation decorator to get annotations in a way that is compatible
+# with both Python 2 and 3
+def annotate(ret, **kwargs):
+ def dec(fun):
+ fun.__annotations__ = dict(kwargs)
+ fun.__annotations__['return'] = ret
+ return fun
+ return dec
+
+"""
+``torch.autograd`` provides classes and functions implementing automatic
+differentiation of arbitrary scalar valued functions. It requires minimal
+changes to the existing code - you only need to declare :class:`Tensor` s
+for which gradients should be computed with the ``requires_grad=True`` keyword.
+"""
+import torch
+import warnings
+
+from .variable import Variable
+from .function import Function, NestedIOFunction # noqa: F401
+from .gradcheck import gradcheck, gradgradcheck # noqa: F401
+from .grad_mode import no_grad, enable_grad, set_grad_enabled # noqa: F401
+from .anomaly_mode import detect_anomaly, set_detect_anomaly # noqa: F401
+from . import profiler # noqa: F401
+
+__all__ = ['Variable', 'Function', 'backward', 'grad_mode']
+
+
+def _make_grads(outputs, grads):
+ new_grads = []
+ for out, grad in zip(outputs, grads):
+ if isinstance(grad, torch.Tensor):
+ new_grads.append(grad)
+ elif grad is None:
+ if out.requires_grad:
+ if out.numel() != 1:
+ raise RuntimeError("grad can be implicitly created only for scalar outputs")
+ new_grads.append(torch.ones_like(out))
+ else:
+ new_grads.append(None)
+ else:
+ raise TypeError("gradients can be either Tensors or None, but got " +
+ type(grad).__name__)
+ return tuple(new_grads)
+
+
+[docs]def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
+ r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
+
+ The graph is differentiated using the chain rule. If any of ``tensors``
+ are non-scalar (i.e. their data has more than one element) and require
+ gradient, then the Jacobian-vector product would be computed, in this
+ case the function additionally requires specifying ``grad_tensors``.
+ It should be a sequence of matching length, that contains the "vector"
+ in the Jacobian-vector product, usually the gradient of the differentiated
+ function w.r.t. corresponding tensors (``None`` is an acceptable value for
+ all tensors that don't need gradient tensors).
+
+ This function accumulates gradients in the leaves - you might need to zero
+ them before calling it.
+
+ Arguments:
+ tensors (sequence of Tensor): Tensors of which the derivative will be
+ computed.
+ grad_tensors (sequence of (Tensor or None)): The "vector" in the Jacobian-vector
+ product, usually gradients w.r.t. each element of corresponding tensors.
+ None values can be specified for scalar Tensors or ones that don't require
+ grad. If a None value would be acceptable for all grad_tensors, then this
+ argument is optional.
+ retain_graph (bool, optional): If ``False``, the graph used to compute the grad
+ will be freed. Note that in nearly all cases setting this option to ``True``
+ is not needed and often can be worked around in a much more efficient
+ way. Defaults to the value of ``create_graph``.
+ create_graph (bool, optional): If ``True``, graph of the derivative will
+ be constructed, allowing to compute higher order derivative products.
+ Defaults to ``False``.
+ """
+ if grad_variables is not None:
+ warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
+ if grad_tensors is None:
+ grad_tensors = grad_variables
+ else:
+ raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
+ "arguments both passed to backward(). Please only "
+ "use 'grad_tensors'.")
+
+ tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
+
+ if grad_tensors is None:
+ grad_tensors = [None] * len(tensors)
+ elif isinstance(grad_tensors, torch.Tensor):
+ grad_tensors = [grad_tensors]
+ else:
+ grad_tensors = list(grad_tensors)
+
+ grad_tensors = _make_grads(tensors, grad_tensors)
+ if retain_graph is None:
+ retain_graph = create_graph
+
+ Variable._execution_engine.run_backward(
+ tensors, grad_tensors, retain_graph, create_graph,
+ allow_unreachable=True) # allow_unreachable flag
+
+
+[docs]def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False,
+ only_inputs=True, allow_unused=False):
+ r"""Computes and returns the sum of gradients of outputs w.r.t. the inputs.
+
+ ``grad_outputs`` should be a sequence of length matching ``output``
+ containing the "vector" in Jacobian-vector product, usually the pre-computed
+ gradients w.r.t. each of the outputs. If an output doesn't require_grad,
+ then the gradient can be ``None``).
+
+ If ``only_inputs`` is ``True``, the function will only return a list of gradients
+ w.r.t the specified inputs. If it's ``False``, then gradient w.r.t. all remaining
+ leaves will still be computed, and will be accumulated into their ``.grad``
+ attribute.
+
+ Arguments:
+ outputs (sequence of Tensor): outputs of the differentiated function.
+ inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be
+ returned (and not accumulated into ``.grad``).
+ grad_outputs (sequence of Tensor): The "vector" in the Jacobian-vector product.
+ Usually gradients w.r.t. each output. None values can be specified for scalar
+ Tensors or ones that don't require grad. If a None value would be acceptable
+ for all grad_tensors, then this argument is optional. Default: None.
+ retain_graph (bool, optional): If ``False``, the graph used to compute the grad
+ will be freed. Note that in nearly all cases setting this option to ``True``
+ is not needed and often can be worked around in a much more efficient
+ way. Defaults to the value of ``create_graph``.
+ create_graph (bool, optional): If ``True``, graph of the derivative will
+ be constructed, allowing to compute higher order derivative products.
+ Default: ``False``.
+ allow_unused (bool, optional): If ``False``, specifying inputs that were not
+ used when computing outputs (and therefore their grad is always zero)
+ is an error. Defaults to ``False``.
+ """
+ if not only_inputs:
+ warnings.warn("only_inputs argument is deprecated and is ignored now "
+ "(defaults to True). To accumulate gradient for other "
+ "parts of the graph, please use torch.autograd.backward.")
+
+ outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
+ inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
+ if grad_outputs is None:
+ grad_outputs = [None] * len(outputs)
+ elif isinstance(grad_outputs, torch.Tensor):
+ grad_outputs = [grad_outputs]
+ else:
+ grad_outputs = list(grad_outputs)
+
+ grad_outputs = _make_grads(outputs, grad_outputs)
+ if retain_graph is None:
+ retain_graph = create_graph
+
+ return Variable._execution_engine.run_backward(
+ outputs, grad_outputs, retain_graph, create_graph,
+ inputs, allow_unused)
+
+
+# This function applies in case of gradient checkpointing for memory
+# optimization. Currently, for gradient checkpointing, we only support imperative
+# backwards call i.e. torch.autograd.backward() and the torch.autograd.grad() won't
+# work. The reason being that: torch.autograd.grad() only calculates the grads
+# for the inputs that are passed by user but it doesn't calculate grad for
+# anything else e.g. model parameters like weights, bias etc. However, for
+# torch.autograd.backward(), we would actually compute the grad for the weights as well.
+#
+# This function returns whether the checkpointing is valid i.e. torch.autograd.backward
+# or not i.e. torch.autograd.grad. The implementation works by maintaining a thread
+# local variable in torch/csrc/autograd/engine.cpp which looks at the FunctionTask
+# in the stack and before a FunctionTask is executed in evaluate_function, it
+# checks for whether reentrant backwards is imperative or not.
+# See https://github.com/pytorch/pytorch/pull/4594 for more discussion/context
+def _is_checkpoint_valid():
+ return Variable._execution_engine.is_checkpoint_valid()
+
+
+def variable(*args, **kwargs):
+ warnings.warn("torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead")
+ return torch.tensor(*args, **kwargs)
+
+
+if not torch._C._autograd_init():
+ raise RuntimeError("autograd initialization failed")
+
+import torch
+
+
+[docs]class detect_anomaly(object):
+ r"""Context-manager that enable anomaly detection for the autograd engine.
+
+ This does two things:
+ - Running the forward pass with detection enabled will allow the backward
+ pass to print the traceback of the forward operation that created the failing
+ backward function.
+ - Any backward computation that generate "nan" value will raise an error.
+
+ Example:
+
+ >>> import torch
+ >>> from torch import autograd
+ >>> class MyFunc(autograd.Function):
+ ... @staticmethod
+ ... def forward(ctx, inp):
+ ... return inp.clone()
+ ... @staticmethod
+ ... def backward(ctx, gO):
+ ... # Error during the backward pass
+ ... raise RuntimeError("Some error in backward")
+ ... return gO.clone()
+ >>> def run_fn(a):
+ ... out = MyFunc.apply(a)
+ ... return out.sum()
+ >>> inp = torch.rand(10, 10, requires_grad=True)
+ >>> out = run_fn(inp)
+ >>> out.backward()
+ Traceback (most recent call last):
+ File "<stdin>", line 1, in <module>
+ File "/your/pytorch/install/torch/tensor.py", line 93, in backward
+ torch.autograd.backward(self, gradient, retain_graph, create_graph)
+ File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
+ allow_unreachable=True) # allow_unreachable flag
+ File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
+ return self._forward_cls.backward(self, *args)
+ File "<stdin>", line 8, in backward
+ RuntimeError: Some error in backward
+ >>> with autograd.detect_anomaly():
+ ... inp = torch.rand(10, 10, requires_grad=True)
+ ... out = run_fn(inp)
+ ... out.backward()
+ Traceback of forward call that caused the error:
+ File "tmp.py", line 53, in <module>
+ out = run_fn(inp)
+ File "tmp.py", line 44, in run_fn
+ out = MyFunc.apply(a)
+ Traceback (most recent call last):
+ File "<stdin>", line 4, in <module>
+ File "/your/pytorch/install/torch/tensor.py", line 93, in backward
+ torch.autograd.backward(self, gradient, retain_graph, create_graph)
+ File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
+ allow_unreachable=True) # allow_unreachable flag
+ File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
+ return self._forward_cls.backward(self, *args)
+ File "<stdin>", line 8, in backward
+ RuntimeError: Some error in backward
+
+ """
+
+ def __init__(self):
+ self.prev = torch.is_anomaly_enabled()
+
+ def __enter__(self):
+ torch.set_anomaly_enabled(True)
+
+ def __exit__(self, *args):
+ torch.set_anomaly_enabled(self.prev)
+ return False
+
+
+[docs]class set_detect_anomaly(object):
+ r"""Context-manager that sets the anomaly detection for the autograd engine on or off.
+
+ ``set_detect_anomaly`` will enable or disable the autograd anomaly detection
+ based on its argument :attr:`mode`.
+ It can be used as a context-manager or as a function.
+
+ See ``detect_anomaly`` above for details of the anomaly detection behaviour.
+
+ Arguments:
+ mode (bool): Flag whether to enable anomaly detection (``True``),
+ or disable (``False``).
+
+ """
+
+ def __init__(self, mode):
+ self.prev = torch.is_anomaly_enabled()
+ torch.set_anomaly_enabled(mode)
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, *args):
+ torch.set_anomaly_enabled(self.prev)
+ return False
+
+import torch
+import torch._C as _C
+import torch.utils.hooks as hooks
+from torch._six import with_metaclass
+import functools
+import warnings
+from collections import OrderedDict
+
+
+class _ContextMethodMixin(object):
+
+ def save_for_backward(self, *tensors):
+ r"""Saves given tensors for a future call to :func:`~Function.backward`.
+
+ **This should be called at most once, and only from inside the**
+ :func:`forward` **method.**
+
+ Later, saved tensors can be accessed through the :attr:`saved_tensors`
+ attribute. Before returning them to the user, a check is made to ensure
+ they weren't used in any in-place operation that modified their content.
+
+ Arguments can also be ``None``.
+ """
+ self.to_save = tensors
+
+ def mark_dirty(self, *args):
+ r"""Marks given tensors as modified in an in-place operation.
+
+ **This should be called at most once, only from inside the**
+ :func:`forward` **method, and all arguments should be inputs.**
+
+ Every tensor that's been modified in-place in a call to :func:`forward`
+ should be given to this function, to ensure correctness of our checks.
+ It doesn't matter whether the function is called before or after
+ modification.
+ """
+ self.dirty_tensors = args
+
+ def mark_shared_storage(self, *pairs):
+ warnings.warn(
+ 'mark_shared_storage is deprecated. '
+ 'Tensors with shared storages are automatically tracked. Note '
+ 'that calls to `set_()` are not tracked')
+
+ def mark_non_differentiable(self, *args):
+ r"""Marks outputs as non-differentiable.
+
+ **This should be called at most once, only from inside the**
+ :func:`forward` **method, and all arguments should be outputs.**
+
+ This will mark outputs as not requiring gradients, increasing the
+ efficiency of backward computation. You still need to accept a gradient
+ for each output in :meth:`~Function.backward`, but it's always going to
+ be a zero tensor with the same shape as the shape of a corresponding
+ output.
+
+ This is used e.g. for indices returned from a max :class:`Function`.
+ """
+ self.non_differentiable = args
+
+
+class _HookMixin(object):
+
+ @staticmethod
+ def _register_hook(backward_hooks, hook):
+ if backward_hooks is None:
+ backward_hooks = OrderedDict()
+ handle = hooks.RemovableHandle(backward_hooks)
+ backward_hooks[handle.id] = hook
+ return backward_hooks, handle
+
+
+class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin):
+ _is_legacy = False
+
+ def apply(self, *args):
+ return self._forward_cls.backward(self, *args)
+
+
+class FunctionMeta(type):
+ """Function metaclass.
+
+ This metaclass sets up the following properties:
+ _is_legacy: True if forward is not defined as a static method.
+ _backward_cls: The Function class corresponding to the differentiated
+ version of this function (which is generated on the fly by this
+ metaclass).
+ """
+
+ def __init__(cls, name, bases, attrs):
+ for super_cls in cls.mro():
+ forward = super_cls.__dict__.get('forward')
+ if forward is not None:
+ has_static_forward = isinstance(forward, staticmethod) or isinstance(forward, classmethod)
+ break
+
+ cls._is_legacy = not has_static_forward
+
+ # old-style functions
+ if not has_static_forward:
+ return super(FunctionMeta, cls).__init__(name, bases, attrs)
+
+ backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls})
+ cls._backward_cls = backward_fn
+
+ return super(FunctionMeta, cls).__init__(name, bases, attrs)
+
+
+[docs]class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)):
+ r"""Records operation history and defines formulas for differentiating ops.
+
+ Every operation performed on :class:`Tensor` s creates a new function
+ object, that performs the computation, and records that it happened.
+ The history is retained in the form of a DAG of functions, with edges
+ denoting data dependencies (``input <- output``). Then, when backward is
+ called, the graph is processed in the topological ordering, by calling
+ :func:`backward` methods of each :class:`Function` object, and passing
+ returned gradients on to next :class:`Function` s.
+
+ Normally, the only way users interact with functions is by creating
+ subclasses and defining new operations. This is a recommended way of
+ extending torch.autograd.
+
+ Each function object is meant to be used only once (in the forward pass).
+
+ Examples::
+
+ >>> class Exp(Function):
+ >>>
+ >>> @staticmethod
+ >>> def forward(ctx, i):
+ >>> result = i.exp()
+ >>> ctx.save_for_backward(result)
+ >>> return result
+ >>>
+ >>> @staticmethod
+ >>> def backward(ctx, grad_output):
+ >>> result, = ctx.saved_tensors
+ >>> return grad_output * result
+ """
+
+ # only for backward compatibility
+ __call__ = _C._FunctionBase._do_forward
+
+ # for the tracer
+ is_traceable = False
+
+[docs] @staticmethod
+ def forward(ctx, *args, **kwargs):
+ r"""Performs the operation.
+
+ This function is to be overridden by all subclasses.
+
+ It must accept a context ctx as the first argument, followed by any
+ number of arguments (tensors or other types).
+
+ The context can be used to store tensors that can be then retrieved
+ during the backward pass.
+ """
+ raise NotImplementedError
+
+[docs] @staticmethod
+ def backward(ctx, *grad_outputs):
+ r"""Defines a formula for differentiating the operation.
+
+ This function is to be overridden by all subclasses.
+
+ It must accept a context :attr:`ctx` as the first argument, followed by
+ as many outputs did :func:`forward` return, and it should return as many
+ tensors, as there were inputs to :func:`forward`. Each argument is the
+ gradient w.r.t the given output, and each returned value should be the
+ gradient w.r.t. the corresponding input.
+
+ The context can be used to retrieve tensors saved during the forward
+ pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
+ of booleans representing whether each input needs gradient. E.g.,
+ :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
+ first input to :func:`forward` needs gradient computated w.r.t. the
+ output.
+ """
+ raise NotImplementedError
+
+
+def once_differentiable(fn):
+
+ @functools.wraps(fn)
+ def wrapper(ctx, *args):
+ with torch.no_grad():
+ outputs = fn(ctx, *args)
+
+ if not torch.is_grad_enabled():
+ return outputs
+
+ # If any of the inputs have requires_grad=True, we force the outputs
+ # to have requires_grad=True but point to a grad_fn which throws an
+ # error message during (double) back-propagation.
+ # XXX: this is only an approximation of requires_grad - there's no way
+ # to figure out if fn didn't use ctx.saved_tensors and as a result
+ # some Tensors might require grad, even if no args do.
+ # Unfortunately, this leads to unexpected error messages ("no nodes
+ # require computing gradients"), but I don't have a better idea.
+ # These functions would raise an error in backward anyway.
+ requires_grad = any(isinstance(arg, torch.Tensor) and arg.requires_grad
+ for arg in args)
+ if not requires_grad:
+ return outputs
+
+ if not isinstance(outputs, tuple):
+ outputs = (outputs,)
+
+ err_fn = torch._C._functions.DelayedError(
+ b"trying to differentiate twice a function that was marked"
+ b"with @once_differentiable", len(outputs))
+
+ # Create aliases of each output that has requires_grad=True. We need
+ # at least one of the inputs to err_fn to require grad so that the
+ # output will have a grad_fn.
+ def fake_requires_grad(var):
+ if var is not None:
+ var = var.detach()
+ var.requires_grad = True
+ return var
+
+ return err_fn(*[fake_requires_grad(v) for v in outputs])
+ return wrapper
+
+
+def traceable(fn_cls):
+ r"""Marks Function as traceable for the JIT.
+
+ Traceable functions have additional restrictions - they can't pass any
+ data-dependent values to backward (e.g. Prod passes the output, which makes
+ it non-traceable), and their backward should be implemented entirely in terms
+ of operations on autograd Tensors in all cases.
+
+ DON'T USE THIS DECORATOR. IT IS FOR INTERNAL USE ONLY AND SHOULD BE HANDLED WITH
+ CARE (or can give incorrect results otherwise).
+ """
+ fn_cls.is_traceable = True
+ return fn_cls
+
+
+class InplaceFunction(Function):
+
+ def __init__(self, inplace=False):
+ super(InplaceFunction, self).__init__()
+ self.inplace = inplace
+
+
+def _nested_map(condition, fn, condition_msg=None):
+ def _map(obj):
+ if condition(obj):
+ return fn(obj)
+ elif obj is None:
+ return None
+ elif isinstance(obj, (list, tuple)):
+ return type(obj)(_map(x) for x in obj)
+ elif isinstance(obj, dict):
+ return {x : _map(obj[x]) for x in obj}
+ else:
+ raise ValueError("Auto nesting doesn't know how to process "
+ "an input object of type " + torch.typename(obj) +
+ (". Accepted types: " + condition_msg +
+ ", or lists/tuples of them"
+ if condition_msg else ""))
+
+ return _map
+
+
+def _jit_unwrap_structured(obj):
+ if hasattr(obj, "_jit_unwrap"):
+ return obj._jit_unwrap()
+ return obj
+
+
+def _iter_filter(condition, allow_unknown=False, condition_msg=None,
+ conversion=None):
+ def _iter(obj):
+ if conversion is not None:
+ obj = conversion(obj)
+ if condition(obj):
+ yield obj
+ elif obj is None:
+ return
+ elif isinstance(obj, (list, tuple)):
+ for o in obj:
+ for var in _iter(o):
+ yield var
+ elif isinstance(obj, dict):
+ # We only accept primitive key types, so we needn't inspect them
+ for o in obj.values():
+ for var in _iter(o):
+ yield var
+ elif allow_unknown:
+ yield obj
+ else:
+ raise ValueError("Auto nesting doesn't know how to process "
+ "an input object of type " + torch.typename(obj) +
+ (". Accepted types: " + condition_msg +
+ ", or lists/tuples of them"
+ if condition_msg else ""))
+
+ return _iter
+
+
+def _unflatten(input, proto):
+ # unflatten a list or tuple input into a nested list/tuple structure
+ # specified by proto
+ def unflatten_helper(input, proto):
+ res = []
+ if hasattr(proto, "_jit_wrap"):
+ return proto._jit_wrap(input)
+ if not isinstance(proto, (list, tuple)):
+ return input[0], input[1:]
+ for e in proto:
+ if e is None:
+ res.append(e)
+ else:
+ res_e, input = unflatten_helper(input, e)
+ res.append(res_e)
+ return type(proto)(res), input
+
+ return unflatten_helper(input, proto)[0]
+
+
+_iter_jit_values = _iter_filter(lambda o: o is None or isinstance(o, torch._C.Value),
+ condition_msg="jit's Values or None")
+_iter_tensors = _iter_filter(lambda x: isinstance(x, torch.Tensor), condition_msg="Tensors",
+ conversion=_jit_unwrap_structured)
+_iter_tensors_permissive = _iter_filter(lambda x: isinstance(x, torch.Tensor),
+ allow_unknown=True,
+ condition_msg="Tensors (permissive)")
+_iter_None_tensors = _iter_filter(lambda o: o is None or isinstance(o, torch.Tensor),
+ condition_msg="Tensors or None")
+_map_tensor_data = _nested_map(lambda x: isinstance(x, torch.Tensor), lambda o: o.data,
+ condition_msg="Tensors")
+
+
+class NestedIOFunction(Function):
+
+ def _do_forward(self, *input):
+ self._nested_input = input
+ flat_input = tuple(_iter_tensors(input))
+ flat_output = super(NestedIOFunction, self)._do_forward(*flat_input)
+ nested_output = self._nested_output
+ nested_tensors = _unflatten(flat_output, self._nested_output)
+ return nested_tensors
+
+ def _do_backward(self, gradients, retain_variables):
+ self.retain_variables = retain_variables
+ result = super(NestedIOFunction, self)._do_backward(gradients, retain_variables)
+ if not retain_variables:
+ del self._nested_output
+ del self._to_save_nested
+ return result
+
+ def backward(self, *gradients):
+ nested_gradients = _unflatten(gradients, self._nested_output)
+ result = self.backward_extended(*nested_gradients)
+ return tuple(_iter_None_tensors(result))
+
+ __call__ = _do_forward
+
+ def forward(self, *args):
+ nested_tensors = _map_tensor_data(self._nested_input)
+ result = self.forward_extended(*nested_tensors)
+ del self._nested_input
+ self._nested_output = result
+ return tuple(_iter_tensors(result))
+
+ def save_for_backward(self, *args):
+ self.to_save = tuple(_iter_tensors(args))
+ self._to_save_nested = args
+
+ @property
+ def saved_tensors(self):
+ flat_tensors = super(NestedIOFunction, self).saved_tensors
+ return _unflatten(flat_tensors, self._to_save_nested)
+
+ def mark_dirty(self, *args, **kwargs):
+ self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
+
+ def mark_non_differentiable(self, *args, **kwargs):
+ self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
+
+ def forward_extended(self, *input):
+ raise NotImplementedError
+
+ def backward_extended(self, *grad_output):
+ raise NotImplementedError
+
+import torch
+import functools
+
+
+[docs]class no_grad(object):
+ r"""Context-manager that disabled gradient calculation.
+
+ Disabling gradient calculation is useful for inference, when you are sure
+ that you will not call :meth:`Tensor.backward()`. It will reduce memory
+ consumption for computations that would otherwise have `requires_grad=True`.
+ In this mode, the result of every computation will have
+ `requires_grad=False`, even when the inputs have `requires_grad=True`.
+
+ Also functions as a decorator.
+
+
+ Example::
+
+ >>> x = torch.tensor([1], requires_grad=True)
+ >>> with torch.no_grad():
+ ... y = x * 2
+ >>> y.requires_grad
+ False
+ >>> @torch.no_grad()
+ ... def doubler(x):
+ ... return x * 2
+ >>> z = doubler(x)
+ >>> z.requires_grad
+ False
+ """
+ def __enter__(self):
+ self.prev = torch.is_grad_enabled()
+ torch._C.set_grad_enabled(False)
+
+ def __exit__(self, *args):
+ torch.set_grad_enabled(self.prev)
+ return False
+
+ def __call__(self, func):
+ @functools.wraps(func)
+ def decorate_no_grad(*args, **kwargs):
+ with self:
+ return func(*args, **kwargs)
+ return decorate_no_grad
+
+
+[docs]class enable_grad(object):
+ r"""Context-manager that enables gradient calculation.
+
+ Enables gradient calculation inside a :class:`~no_grad` context. This has
+ no effect outside of :class:`~no_grad`.
+
+ Also functions as a decorator.
+
+
+ Example::
+
+ >>> x = torch.tensor([1], requires_grad=True)
+ >>> with torch.no_grad():
+ ... with torch.enable_grad():
+ ... y = x * 2
+ >>> y.requires_grad
+ True
+ >>> y.backward()
+ >>> x.grad
+ >>> @torch.enable_grad()
+ ... def doubler(x):
+ ... return x * 2
+ >>> with torch.no_grad():
+ ... z = doubler(x)
+ >>> z.requires_grad
+ True
+
+ """
+ def __enter__(self):
+ self.prev = torch.is_grad_enabled()
+ torch._C.set_grad_enabled(True)
+
+ def __exit__(self, *args):
+ torch.set_grad_enabled(self.prev)
+ return False
+
+ def __call__(self, func):
+ @functools.wraps(func)
+ def decorate_enable_grad(*args, **kwargs):
+ with self:
+ return func(*args, **kwargs)
+ return decorate_enable_grad
+
+
+[docs]class set_grad_enabled(object):
+ r"""Context-manager that sets gradient calculation to on or off.
+
+ ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
+ It can be used as a context-manager or as a function.
+
+ Arguments:
+ mode (bool): Flag whether to enable grad (``True``), or disable
+ (``False``). This can be used to conditionally enable
+ gradients.
+
+
+ Example::
+
+ >>> x = torch.tensor([1], requires_grad=True)
+ >>> is_train = False
+ >>> with torch.set_grad_enabled(is_train):
+ ... y = x * 2
+ >>> y.requires_grad
+ False
+ >>> torch.set_grad_enabled(True)
+ >>> y = x * 2
+ >>> y.requires_grad
+ True
+ >>> torch.set_grad_enabled(False)
+ >>> y = x * 2
+ >>> y.requires_grad
+ False
+
+ """
+
+ def __init__(self, mode):
+ self.prev = torch.is_grad_enabled()
+ torch._C.set_grad_enabled(mode)
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, *args):
+ torch.set_grad_enabled(self.prev)
+ return False
+
+import torch
+from torch._six import container_abcs, istuple
+import torch.testing
+from itertools import product
+import warnings
+
+
+def zero_gradients(x):
+ if isinstance(x, torch.Tensor):
+ if x.grad is not None:
+ x.grad.detach_()
+ x.grad.data.zero_()
+ elif isinstance(x, container_abcs.Iterable):
+ for elem in x:
+ zero_gradients(elem)
+
+
+def make_jacobian(input, num_out):
+ if isinstance(input, torch.Tensor):
+ if not input.is_floating_point():
+ return None
+ if not input.requires_grad:
+ return None
+ return torch.zeros(input.nelement(), num_out, dtype=input.dtype)
+ elif isinstance(input, container_abcs.Iterable) and not isinstance(input, str):
+ jacobians = list(filter(
+ lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input)))
+ if not jacobians:
+ return None
+ return type(input)(jacobians)
+ else:
+ return None
+
+
+def iter_tensors(x, only_requiring_grad=False):
+ if isinstance(x, torch.Tensor):
+ if x.requires_grad or not only_requiring_grad:
+ yield x
+ elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
+ for elem in x:
+ for result in iter_tensors(elem, only_requiring_grad):
+ yield result
+
+
+def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
+ """
+ input: input to `fn`
+ target: the Tensors wrt whom Jacobians are calculated (default=`input`)
+
+ Note that `target` may not even be part of `input` to `fn`, so please be
+ **very careful** in this to not clone `target`.
+ """
+ if target is None:
+ target = input
+ output_size = fn(input).numel()
+ jacobian = make_jacobian(target, output_size)
+
+ # It's much easier to iterate over flattened lists of tensors.
+ # These are reference to the same objects in jacobian, so any changes
+ # will be reflected in it as well.
+ x_tensors = [t for t in iter_tensors(target, True)]
+ j_tensors = [t for t in iter_tensors(jacobian)]
+
+ # TODO: compare structure
+ for x_tensor, d_tensor in zip(x_tensors, j_tensors):
+ # need data here to get around the version check because without .data,
+ # the following code updates version but doesn't change content
+ if x_tensor.is_sparse:
+ def get_stride(size):
+ dim = len(size)
+ tmp = 1
+ stride = [0] * dim
+ for i in reversed(range(dim)):
+ stride[i] = tmp
+ tmp *= size[i]
+ return stride
+
+ x_nnz = x_tensor._nnz()
+ x_size = list(x_tensor.size())
+ x_indices = x_tensor._indices().t()
+ x_values = x_tensor._values().data
+ x_stride = get_stride(x_size)
+
+ for i in range(x_nnz):
+ x_value = x_values[i]
+ for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
+ indices = x_indices[i].tolist() + list(x_idx)
+ d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
+ orig = x_value[x_idx].item()
+ x_value[x_idx] = orig - eps
+ outa = fn(input).clone()
+ x_value[x_idx] = orig + eps
+ outb = fn(input).clone()
+ x_value[x_idx] = orig
+ r = (outb - outa) / (2 * eps)
+ d_tensor[d_idx] = r.detach().reshape(-1)
+ elif x_tensor.layout == torch._mkldnn:
+ if len(input) != 1:
+ raise ValueError('gradcheck currently only supports functions with 1 input, but got: ',
+ len(input))
+ x_tensor = x_tensor.data
+ for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
+ # this is really inefficient, but without indexing implemented, there's
+ # not really a better way than converting back and forth
+ x_tensor_dense = x_tensor.to_dense()
+ orig = x_tensor_dense[x_idx].item()
+
+ x_tensor_dense[x_idx] = orig - eps
+ x_tensor_mkl = x_tensor_dense.to_mkldnn()
+ outa = fn([x_tensor_mkl])
+
+ x_tensor_dense[x_idx] = orig + eps
+ x_tensor_mkl = x_tensor_dense.to_mkldnn()
+ outb = fn([x_tensor_mkl])
+
+ r = (outb - outa) / (2 * eps)
+ d_tensor[d_idx] = r.detach().reshape(-1)
+ else:
+ x_tensor = x_tensor.data
+ for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
+ orig = x_tensor[x_idx].item()
+ x_tensor[x_idx] = orig - eps
+ outa = fn(input).clone()
+ x_tensor[x_idx] = orig + eps
+ outb = fn(input).clone()
+ x_tensor[x_idx] = orig
+ r = (outb - outa) / (2 * eps)
+ d_tensor[d_idx] = r.detach().reshape(-1)
+
+ return jacobian
+
+
+def get_analytical_jacobian(input, output):
+ # it is easier to call to_dense() on the sparse output than
+ # to modify analytical jacobian
+ if output.is_sparse:
+ raise ValueError('Sparse output is not supported at gradcheck yet. '
+ 'Please call to_dense() on the output of fn for gradcheck.')
+ if output.layout == torch._mkldnn:
+ raise ValueError('MKLDNN output is not supported at gradcheck yet. '
+ 'Please call to_dense() on the output of fn for gradcheck.')
+ diff_input_list = list(iter_tensors(input, True))
+ jacobian = make_jacobian(input, output.numel())
+ jacobian_reentrant = make_jacobian(input, output.numel())
+ grad_output = torch.zeros_like(output)
+ flat_grad_output = grad_output.view(-1)
+ reentrant = True
+ correct_grad_sizes = True
+
+ for i in range(flat_grad_output.numel()):
+ flat_grad_output.zero_()
+ flat_grad_output[i] = 1
+ for jacobian_c in (jacobian, jacobian_reentrant):
+ grads_input = torch.autograd.grad(output, diff_input_list, grad_output,
+ retain_graph=True, allow_unused=True)
+ for jacobian_x, d_x, x in zip(jacobian_c, grads_input, diff_input_list):
+ if d_x is not None and d_x.size() != x.size():
+ correct_grad_sizes = False
+ elif jacobian_x.numel() != 0:
+ if d_x is None:
+ jacobian_x[:, i].zero_()
+ else:
+ d_x_dense = d_x.to_dense() if not d_x.layout == torch.strided else d_x
+ assert jacobian_x[:, i].numel() == d_x_dense.numel()
+ jacobian_x[:, i] = d_x_dense.contiguous().view(-1)
+
+ for jacobian_x, jacobian_reentrant_x in zip(jacobian, jacobian_reentrant):
+ if jacobian_x.numel() != 0 and (jacobian_x - jacobian_reentrant_x).abs().max() != 0:
+ reentrant = False
+
+ return jacobian, reentrant, correct_grad_sizes
+
+
+def _as_tuple(x):
+ if istuple(x):
+ return x
+ elif isinstance(x, list):
+ return tuple(x)
+ else:
+ return x,
+
+
+def _differentiable_outputs(x):
+ return tuple(o for o in _as_tuple(x) if o.requires_grad)
+
+
+[docs]def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True, check_sparse_nnz=False):
+ r"""Check gradients computed via small finite differences against analytical
+ gradients w.r.t. tensors in :attr:`inputs` that are of floating point type
+ and with ``requires_grad=True``.
+
+ The check between numerical and analytical gradients uses :func:`~torch.allclose`.
+
+ .. note::
+ The default values are designed for :attr:`input` of double precision.
+ This check will likely fail if :attr:`input` is of less precision, e.g.,
+ ``FloatTensor``.
+
+ .. warning::
+ If any checked tensor in :attr:`input` has overlapping memory, i.e.,
+ different indices pointing to the same memory address (e.g., from
+ :func:`torch.expand`), this check will likely fail because the numerical
+ gradients computed by point perturbation at such indices will change
+ values at all other indices that share the same memory address.
+
+ Args:
+ func (function): a Python function that takes Tensor inputs and returns
+ a Tensor or a tuple of Tensors
+ inputs (tuple of Tensor or Tensor): inputs to the function
+ eps (float, optional): perturbation for finite differences
+ atol (float, optional): absolute tolerance
+ rtol (float, optional): relative tolerance
+ raise_exception (bool, optional): indicating whether to raise an exception if
+ the check fails. The exception gives more information about the
+ exact nature of the failure. This is helpful when debugging gradchecks.
+ check_sparse_nnz (bool, optional): if True, gradcheck allows for SparseTensor input,
+ and for any SparseTensor at input, gradcheck will perform check at nnz positions only.
+
+ Returns:
+ True if all differences satisfy allclose condition
+ """
+ def fail_test(msg):
+ if raise_exception:
+ raise RuntimeError(msg)
+ return False
+
+ tupled_inputs = _as_tuple(inputs)
+ if any(t.is_sparse for t in tupled_inputs if isinstance(t, torch.Tensor)) and not check_sparse_nnz:
+ return fail_test('gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False.')
+
+ # Make sure that gradients are saved for all inputs
+ any_input_requiring_grad = False
+ some_input_not_requiring_grad = False
+ for inp in tupled_inputs:
+ if isinstance(inp, torch.Tensor):
+ if inp.requires_grad:
+ if inp.dtype != torch.float64:
+ warnings.warn(
+ 'At least one of the inputs that requires gradient '
+ 'is not of double precision floating point. '
+ 'This check will likely fail if all the inputs are '
+ 'not of double precision floating point. ')
+ any_input_requiring_grad = True
+ else:
+ some_input_not_requiring_grad = True
+ inp.retain_grad()
+ if not any_input_requiring_grad:
+ raise ValueError(
+ 'gradcheck expects at least one input tensor to require gradient, '
+ 'but none of the them have requires_grad=True.')
+ if some_input_not_requiring_grad:
+ raise ValueError(
+ 'gradcheck expects if at least one input tensor is required gradient, '
+ 'then all other inputs should have requires_grad=True.')
+
+ func_out = func(*tupled_inputs)
+ output = _differentiable_outputs(func_out)
+
+ if not output:
+ for i, o in enumerate(func_out):
+ def fn(input):
+ return _as_tuple(func(*input))[i]
+ numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
+ for n in numerical:
+ if len(torch.nonzero(n)) > 0:
+ return fail_test('Numerical gradient for function expected to be zero')
+ return True
+
+ for i, o in enumerate(output):
+ if not o.requires_grad:
+ continue
+
+ def fn(input):
+ return _as_tuple(func(*input))[i]
+
+ analytical, reentrant, correct_grad_sizes = get_analytical_jacobian(tupled_inputs, o)
+ numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
+
+ if not correct_grad_sizes:
+ return fail_test('Analytical gradient has incorrect size')
+
+ for j, (a, n) in enumerate(zip(analytical, numerical)):
+ if a.numel() != 0 or n.numel() != 0:
+ if not torch.allclose(a, n, rtol, atol):
+ return fail_test('Jacobian mismatch for output %d with respect to input %d,\n'
+ 'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
+
+ if not reentrant:
+ return fail_test('Backward is not reentrant, i.e., running backward with same '
+ 'input and grad_output multiple times gives different values, '
+ 'although analytical gradient matches numerical gradient')
+
+ # check if the backward multiplies by grad_output
+ output = _differentiable_outputs(func(*tupled_inputs))
+ if any([o.requires_grad for o in output]):
+ diff_input_list = list(iter_tensors(tupled_inputs, True))
+ if not diff_input_list:
+ raise RuntimeError("no Tensors requiring grad found in input")
+ grads_input = torch.autograd.grad(output, diff_input_list, [torch.zeros_like(o) for o in output],
+ allow_unused=True)
+ for gi, i in zip(grads_input, diff_input_list):
+ if gi is None:
+ continue
+ if isinstance(gi, torch.Tensor) and gi.layout != torch.strided:
+ if gi.layout != i.layout:
+ return fail_test('grad is incorrect layout')
+ if gi.layout == torch.sparse_coo:
+ if gi.sparse_dim() != i.sparse_dim():
+ return fail_test('grad is sparse tensor, but has incorrect sparse_dim')
+ if gi.dense_dim() != i.dense_dim():
+ return fail_test('grad is sparse tensor, but has incorrect dense_dim')
+ gi = gi.to_dense()
+ i = i.to_dense()
+ if not gi.eq(0).all():
+ return fail_test('backward not multiplied by grad_output')
+ if gi.type() != i.type():
+ return fail_test("grad is incorrect type")
+ if gi.size() != i.size():
+ return fail_test('grad is incorrect size')
+
+ return True
+
+
+[docs]def gradgradcheck(func, inputs, grad_outputs=None, eps=1e-6, atol=1e-5, rtol=1e-3,
+ gen_non_contig_grad_outputs=False, raise_exception=True):
+ r"""Check gradients of gradients computed via small finite differences
+ against analytical gradients w.r.t. tensors in :attr:`inputs` and
+ :attr:`grad_outputs` that are of floating point type and with
+ ``requires_grad=True``.
+
+ This function checks that backpropagating through the gradients computed
+ to the given :attr:`grad_outputs` are correct.
+
+ The check between numerical and analytical gradients uses :func:`~torch.allclose`.
+
+ .. note::
+ The default values are designed for :attr:`input` and
+ :attr:`grad_outputs` of double precision. This check will likely fail if
+ they are of less precision, e.g., ``FloatTensor``.
+
+ .. warning::
+ If any checked tensor in :attr:`input` and :attr:`grad_outputs` has
+ overlapping memory, i.e., different indices pointing to the same memory
+ address (e.g., from :func:`torch.expand`), this check will likely fail
+ because the numerical gradients computed by point perturbation at such
+ indices will change values at all other indices that share the same
+ memory address.
+
+ Args:
+ func (function): a Python function that takes Tensor inputs and returns
+ a Tensor or a tuple of Tensors
+ inputs (tuple of Tensor or Tensor): inputs to the function
+ grad_outputs (tuple of Tensor or Tensor, optional): The gradients with
+ respect to the function's outputs.
+ eps (float, optional): perturbation for finite differences
+ atol (float, optional): absolute tolerance
+ rtol (float, optional): relative tolerance
+ gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is
+ ``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the
+ randomly generated gradient outputs are made to be noncontiguous
+ raise_exception (bool, optional): indicating whether to raise an exception if
+ the check fails. The exception gives more information about the
+ exact nature of the failure. This is helpful when debugging gradchecks.
+
+ Returns:
+ True if all differences satisfy allclose condition
+ """
+ tupled_inputs = _as_tuple(inputs)
+
+ if grad_outputs is None:
+ # If grad_outputs is not specified, create random Tensors of the same
+ # shape, type, and device as the outputs
+ def randn_like(x):
+ y = torch.testing.randn_like(x if x.is_floating_point() else x.double())
+ if gen_non_contig_grad_outputs:
+ y = torch.testing.make_non_contiguous(y)
+ return y.requires_grad_()
+ outputs = _as_tuple(func(*tupled_inputs))
+ tupled_grad_outputs = tuple(randn_like(x) for x in outputs)
+ else:
+ tupled_grad_outputs = _as_tuple(grad_outputs)
+
+ num_outputs = len(tupled_grad_outputs)
+
+ def new_func(*args):
+ input_args = args[:-num_outputs]
+ grad_outputs = args[-num_outputs:]
+ outputs = _differentiable_outputs(func(*input_args))
+ input_args = tuple(x for x in input_args if isinstance(x, torch.Tensor) and x.requires_grad)
+ grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs, create_graph=True)
+ return grad_inputs
+
+ return gradcheck(new_func, tupled_inputs + tupled_grad_outputs, eps, atol, rtol, raise_exception)
+
+import itertools
+import torch
+
+from collections import defaultdict, namedtuple
+from operator import attrgetter
+
+
+class EventList(list):
+ """A list of Events (for pretty printing)"""
+ def __init__(self, *args, **kwargs):
+ super(EventList, self).__init__(*args, **kwargs)
+ self._cpu_children_populated = False
+
+ def __str__(self):
+ return self.table()
+
+ def populate_cpu_children(self):
+ """Populates child events into each underlying FunctionEvent object.
+ One event is a child of another if [s1, e1) is inside [s2, e2). Where
+ s1 and e1 would be start and end of the child event's interval. And
+ s2 and e2 start and end of the parent event's interval
+
+ Example: In event list [[0, 10], [1, 3], [3, 4]] would have make [0, 10]
+ be a parent of two other intervals.
+
+ If for any reason two intervals intersect only partialy, this function
+ will not record a parent child relationship between then.
+ """
+ if self.cpu_children_populated:
+ return
+ events = sorted(
+ self,
+ key=attrgetter("thread"),
+ )
+ threads = itertools.groupby(events, key=attrgetter("thread"))
+
+ # For each thread we keep a stack of current nested parents.
+ # We maintain the invariant that each interval is a subset of all other
+ # intervals lower in the stack.
+ #
+ # First we sort the intervals by their start time. Then we iterate over them.
+ # Every time we see a new interval we remove several parents from
+ # the top until we restore the invariant. Then parent child relationship
+ # if recorded if the stack is not empty.
+ # Finally we add new interval to the list
+ #
+ # Algorithm has O(N * log(N)) complexity where N is number of
+ # intervals
+ for thread_id, thread_events in threads:
+ thread_events = sorted(
+ thread_events,
+ key=lambda event: [event.cpu_interval.start, -event.cpu_interval.end],
+ )
+ current_events = []
+ cur_end = 0
+ for event in thread_events:
+ while len(current_events) > 0:
+ parent = current_events[-1]
+ if event.cpu_interval.start >= parent.cpu_interval.end or \
+ event.cpu_interval.end > parent.cpu_interval.end:
+ # this can't be a parent
+ current_events.pop()
+ else:
+ parent.append_cpu_child(event)
+ break
+
+ current_events.append(event)
+
+ self._cpu_children_populated = True
+
+ @property
+ def self_cpu_time_total(self):
+ return sum([event.self_cpu_time_total for event in self])
+
+ @property
+ def cpu_children_populated(self):
+ return self._cpu_children_populated
+
+ def table(self, sort_by=None, row_limit=100):
+ """Prints an EventList as a nicely formatted table.
+
+ Arguments:
+ sort_by (str, optional): Attribute used to sort entries. By default
+ they are printed in the same order as they were registered.
+ Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``,
+ ``cuda_time_total``, ``count``.
+
+ Returns:
+ A string containing the table.
+ """
+ return build_table(self, sort_by=sort_by, row_limit=row_limit)
+
+ def export_chrome_trace(self, path):
+ """Exports an EventList as a Chrome tracing tools file.
+
+ The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL.
+
+ Arguments:
+ path (str): Path where the trace will be written.
+ """
+ import json
+ with open(path, 'w') as f:
+ chrome_events = []
+ next_id = 0
+ for evt in self:
+ chrome_events.append(dict(
+ name=evt.name,
+ ph='X',
+ ts=evt.cpu_interval.start,
+ dur=evt.cpu_interval.elapsed_us(),
+ tid=evt.thread,
+ pid='CPU functions',
+ args={},
+ ))
+ for k in evt.kernels:
+ # 's' and 'f' draw Flow arrows from
+ # the CPU launch to the GPU kernel
+ chrome_events.append(dict(
+ name=evt.name,
+ ph='s',
+ ts=evt.cpu_interval.start,
+ tid=evt.thread,
+ pid='CPU functions',
+ id=next_id,
+ cat='cpu_to_cuda',
+ args={},
+ ))
+ chrome_events.append(dict(
+ name=k.name,
+ ph='f',
+ ts=k.interval.start,
+ tid=k.device,
+ pid='CUDA functions',
+ id=next_id,
+ cat='cpu_to_cuda',
+ args={},
+ ))
+ chrome_events.append(dict(
+ name=k.name,
+ ph='X',
+ ts=k.interval.start,
+ dur=k.interval.elapsed_us(),
+ tid=k.device,
+ pid='CUDA functions',
+ args={},
+ ))
+ next_id += 1
+
+ json.dump(chrome_events, f)
+
+ def key_averages(self):
+ """Averages all function events over their keys.
+
+ Returns:
+ An EventList containing FunctionEventAvg objects.
+ """
+ self.populate_cpu_children()
+ stats = defaultdict(FunctionEventAvg)
+ for evt in self:
+ stats[evt.key] += evt
+ return EventList(stats.values())
+
+ def total_average(self):
+ """Averages all events.
+
+ Returns:
+ A FunctionEventAvg object.
+ """
+ total_stat = FunctionEventAvg()
+ for evt in self:
+ total_stat += evt
+ total_stat.key = None
+ total_stat.key = 'Total'
+ return total_stat
+
+
+[docs]class profile(object):
+ """Context manager that manages autograd profiler state and holds a summary of results.
+
+ Arguments:
+ enabled (bool, optional): Setting this to False makes this context manager a no-op.
+ Default: ``True``.
+
+ use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.
+ Adds approximately 4us of overhead to each tensor operation.
+ Default: ``False``
+
+ .. warning:
+ This context managers should not be called recursively, i.e. at most one
+ instance should be enabled at any given time.
+
+ Example:
+ >>> x = torch.randn((1, 1), requires_grad=True)
+ >>> with torch.autograd.profiler.profile() as prof:
+ ... y = x ** 2
+ ... y.backward()
+ >>> # NOTE: some columns were removed for brevity
+ ... print(prof)
+ ------------------------------------- --------------- ---------------
+ Name CPU time CUDA time
+ ------------------------------------- --------------- ---------------
+ PowConstant 142.036us 0.000us
+ N5torch8autograd9GraphRootE 63.524us 0.000us
+ PowConstantBackward 184.228us 0.000us
+ MulConstant 50.288us 0.000us
+ PowConstant 28.439us 0.000us
+ Mul 20.154us 0.000us
+ N5torch8autograd14AccumulateGradE 13.790us 0.000us
+ N5torch8autograd5CloneE 4.088us 0.000us
+ """
+
+ def __init__(self, enabled=True, use_cuda=False):
+ self.enabled = enabled
+ self.use_cuda = use_cuda
+ self.function_events = None
+ if not self.enabled:
+ return
+ self.entered = False
+
+ def __enter__(self):
+ if not self.enabled:
+ return
+ if self.entered:
+ raise RuntimeError("autograd profiler traces are not reentrant")
+ self.entered = True
+ profiler_kind = torch.autograd.ProfilerState.CUDA if self.use_cuda \
+ else torch.autograd.ProfilerState.CPU
+ torch.autograd._enable_profiler(profiler_kind)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if not self.enabled:
+ return
+ records = torch.autograd._disable_profiler()
+ self.function_events = EventList(parse_cpu_trace(records))
+ return False
+
+ def __repr__(self):
+ if self.function_events is None:
+ return '<unfinished torch.autograd.profile>'
+ return repr(self.function_events)
+
+ def __str__(self):
+ if self.function_events is None:
+ return '<unfinished torch.autograd.profile>'
+ return str(self.function_events)
+
+ def _check_finish(self):
+ if self.function_events is None:
+ raise RuntimeError("can't export a trace that didn't finish running")
+ self.function_events.populate_cpu_children()
+
+[docs] def table(self, sort_by=None, row_limit=100):
+ self._check_finish()
+ return self.function_events.table(sort_by=sort_by, row_limit=row_limit)
+ table.__doc__ = EventList.table.__doc__
+
+[docs] def export_chrome_trace(self, path):
+ self._check_finish()
+ return self.function_events.export_chrome_trace(path)
+ export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
+
+
+ key_averages.__doc__ = EventList.key_averages.__doc__
+
+[docs] def total_average(self):
+ self._check_finish()
+ return self.function_events.total_average()
+ total_average.__doc__ = EventList.total_average.__doc__
+
+ @property
+ def self_cpu_time_total(self):
+ """ Returns total time spent on CPU obtained as a sum of
+ all self times across all the events.
+ """
+ self._check_finish()
+ return self.function_events.self_cpu_time_total
+
+
+[docs]class emit_nvtx(object):
+ """Context manager that makes every autograd operation emit an NVTX range.
+
+ It is useful when running the program under nvprof::
+
+ nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
+
+ Unfortunately, there's no way to force nvprof to flush the data it collected
+ to disk, so for CUDA profiling one has to use this context manager to annotate
+ nvprof traces and wait for the process to exit before inspecting them.
+ Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or
+ :func:`torch.autograd.profiler.load_nvprof` can load the results for inspection
+ e.g. in Python REPL.
+
+ .. warning:
+ This context manager should not be called recursively, i.e. at most one
+ instance should be enabled at any given time.
+
+ Arguments:
+ enabled (bool, optional): Setting this to False makes this context manager a no-op.
+ Default: ``True``.
+
+ Example:
+ >>> with torch.cuda.profiler.profile():
+ ... model(x) # Warmup CUDA memory allocator and profiler
+ ... with torch.autograd.profiler.emit_nvtx():
+ ... model(x)
+
+ **Forward-backward correlation**
+
+ When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler,
+ correlating each backward-pass op with the corresponding forward-pass op can be difficult.
+ To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it
+ generates.
+
+ During the forward pass, each function range is decorated with ``seq=<N>``. ``seq`` is a running
+ counter, incremented each time a new backward Function object is created and stashed for backward.
+ Thus, the `seq=<N>` annotation associated with each forward function range tells you that
+ if a backward Function object is created by this forward function,
+ the backward object will receive sequence number N.
+ During the backward pass, the top-level range wrapping each C++ backward Function's
+ ``apply()`` call is decorated with ``stashed seq=<M>``. ``M`` is the sequence number that
+ the backward object was created with. By comparing ``stashed seq`` numbers in backward with ``seq``
+ numbers in forward, you can track down which forward op created each backward Function.
+
+ Any functions executed during the backward pass are also decorated with ``seq=<N>``. During
+ default backward (with ``create_graph=False``) this information is irrelevant, and in fact,
+ ``N`` may simply be 0 for all such functions. Only the top-level ranges associated with
+ backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function
+ objects with the earlier forward pass.
+
+ **Double-backward**
+
+ If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words,
+ if you are setting up for a double-backward), each function's execution during backward
+ is given a nonzero, useful ``seq=<N>``. Those functions may themselves create Function objects
+ to be executed later during double-backward, just as the original functions in the forward pass did.
+ The relationship between backward and double-backward is conceptually the same as the relationship
+ between forward and backward: The functions still emit current-sequence-number-tagged ranges,
+ the Function objects they create still stash those sequence numbers, and during the eventual
+ double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq``
+ numbers, which can be compared to `seq` numbers from the backward pass.
+
+ .. warning:
+ The sequence number is thread-local, and some forward functions don't create an associated
+ backward Function object (instead delegating that to sub-functions further down the call chain).
+ For these reasons, the correspondence of stashed sequence numbers in
+ backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is
+ not guaranteed to be 1 to 1. The sequence numbers alone may not be enough to fully
+ disambiguate which forward function created which
+ backward Function object. You may need to make a judgment based on analytic knowledge of what
+ the expected correspondence should be.
+ """
+ def __init__(self, enabled=True):
+ self.enabled = enabled
+ self.entered = False
+
+ def __enter__(self):
+ if not self.enabled:
+ return
+ if self.entered:
+ raise RuntimeError("NVTX annotation context manager is not reentrant")
+ self.entered = True
+ torch.cuda.synchronize()
+ torch.autograd._enable_profiler(torch.autograd.ProfilerState.NVTX)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if not self.enabled:
+ return
+ torch.cuda.synchronize()
+ torch.autograd._disable_profiler()
+ return False
+
+
+[docs]def load_nvprof(path):
+ """Opens an nvprof trace file and parses autograd annotations.
+
+ Arguments:
+ path (str): path to nvprof trace
+ """
+ return EventList(parse_nvprof_trace(path))
+
+
+################################################################################
+# FunctionEvent
+
+def format_time(time_us):
+ """Defines how to format time in FunctionEvent"""
+ US_IN_SECOND = 1000.0 * 1000.0
+ US_IN_MS = 1000.0
+ if time_us >= US_IN_SECOND:
+ return '{:.3f}s'.format(time_us / US_IN_SECOND)
+ if time_us >= US_IN_MS:
+ return '{:.3f}ms'.format(time_us / US_IN_MS)
+ return '{:.3f}us'.format(time_us)
+
+
+def format_time_share(time_us, total_time_us):
+ """Defines how to format time in FunctionEvent"""
+ if total_time_us == 0:
+ assert(time_us == 0)
+ return "NaN"
+ return '{:.2f}%'.format(time_us * 100.0 / total_time_us)
+
+
+def attr_formatter(name):
+ return property(lambda self: format_time(getattr(self, name)))
+
+
+class FormattedTimesMixin(object):
+ """Helpers for FunctionEvent and FunctionEventAvg.
+
+ The subclass should define `*_time_total` and `count` attributes.
+ """
+ cpu_time_str = attr_formatter('cpu_time')
+ cuda_time_str = attr_formatter('cuda_time')
+ cpu_time_total_str = attr_formatter('cpu_time_total')
+ cuda_time_total_str = attr_formatter('cuda_time_total')
+ self_cpu_time_total_str = attr_formatter('self_cpu_time_total')
+
+ @property
+ def cpu_time(self):
+ return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count
+
+ @property
+ def cuda_time(self):
+ return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count
+
+
+class Interval(object):
+ def __init__(self, start, end):
+ self.start = start
+ self.end = end
+
+ def elapsed_us(self):
+ return self.end - self.start
+
+
+Kernel = namedtuple('Kernel', ['name', 'device', 'interval'])
+
+
+# TODO: record TID too
+class FunctionEvent(FormattedTimesMixin):
+ """Profiling information about a single function."""
+ def __init__(self, id, name, thread, cpu_start, cpu_end):
+ self.id = id
+ self.name = name
+ self.cpu_interval = Interval(cpu_start, cpu_end)
+ self.thread = thread
+ self.kernels = []
+ self.count = 1
+ self.cpu_children = []
+
+ def append_kernel(self, name, device, start, end):
+ self.kernels.append(Kernel(name, device, Interval(start, end)))
+
+ def append_cpu_child(self, child):
+ """Append a CPU child of type FunctionEvent.
+
+ One is supposed to append only dirrect children to the event to have
+ correct self cpu time being reported.
+ """
+ assert(isinstance(child, FunctionEvent))
+ self.cpu_children.append(child)
+
+ @property
+ def self_cpu_time_total(self):
+ return self.cpu_time_total - sum(
+ [child.cpu_time_total for child in self.cpu_children]
+ )
+
+ @property
+ def cuda_time_total(self):
+ return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels)
+
+ @property
+ def cpu_time_total(self):
+ return self.cpu_interval.elapsed_us()
+
+ @property
+ def key(self):
+ return self.name
+
+ def __repr__(self):
+ return (
+ '<FunctionEvent id={} cpu_time={} cpu_start={} cpu_end={} '
+ 'cpu_children={} cuda_time={} name={} thread={}>'.format(
+ self.id,
+ self.cpu_time_str,
+ self.cpu_interval.start,
+ self.cpu_interval.end,
+ str([child.id for child in self.cpu_children]),
+ self.cuda_time_str,
+ self.name,
+ self.thread
+ )
+ )
+
+
+class FunctionEventAvg(FormattedTimesMixin):
+ """Used to average stats over multiple FunctionEvent objects."""
+ def __init__(self):
+ self.key = None
+ self.count = 0
+ self.cpu_time_total = 0
+ self.cuda_time_total = 0
+ self.self_cpu_time_total = 0
+
+ def __iadd__(self, other):
+ if self.key is None:
+ self.key = other.key
+ assert isinstance(other, FunctionEvent)
+ assert other.key == self.key
+ self.cpu_time_total += other.cpu_time
+ self.cuda_time_total += other.cuda_time
+ self.self_cpu_time_total += other.self_cpu_time_total
+ self.count += 1
+ return self
+
+ def __repr__(self):
+ return '<FunctionEventAvg cpu_time={} cuda_time={} key={}>'.format(
+ self.cpu_time_str, self.cuda_time_str, self.key)
+
+
+################################################################################
+# Utilities
+
+class StringTable(defaultdict):
+ def __missing__(self, key):
+ self[key] = torch._C._demangle(key)
+ return self[key]
+
+
+################################################################################
+# CPU checkpoints
+
+def parse_cpu_trace(thread_records):
+ next_id = 0
+ start_record = None
+ cuda_records = {}
+ functions = []
+ record_stack = []
+ string_table = StringTable()
+
+ # cuda start events and the overall profiler start event don't happen
+ # at exactly the same time because we need to record an event on each device
+ # and each record takes ~4us. So we adjust here by the difference
+ # adding the difference in CPU time between the profiler start event
+ # and the CPU time of the cuda start event for the device
+ def adjusted_time(cuda_record):
+ assert cuda_record.device() != -1
+ cuda_time_0 = cuda_records[cuda_record.device()]
+ return cuda_time_0.cuda_elapsed_us(cuda_record) + start_record.cpu_elapsed_us(cuda_time_0)
+
+ # '__start_profile' is not guarenteed to be first, so we must find it here
+ for record in itertools.chain(*thread_records):
+ if record.name() == '__start_profile':
+ start_record = record
+ elif record.name() == '__cuda_start_event':
+ assert record.device() != -1
+ cuda_records[record.device()] = record
+ assert start_record is not None
+
+ for record in itertools.chain(*thread_records):
+ if record.kind() == 'mark':
+ continue
+ elif record.kind() == 'push':
+ record_stack.append((next_id, record))
+ next_id += 1
+ elif record.kind() == 'pop':
+ function_id, start = record_stack.pop()
+ fe = FunctionEvent(
+ id=function_id,
+ name=string_table[start.name()],
+ thread=start.thread_id(),
+ cpu_start=start_record.cpu_elapsed_us(start),
+ cpu_end=start_record.cpu_elapsed_us(record))
+ if start.has_cuda():
+ cuda_start = adjusted_time(start)
+ cuda_end = adjusted_time(record)
+ fe.append_kernel(start.name(),
+ start.device(),
+ cuda_start,
+ cuda_end)
+ functions.append(fe)
+
+ functions.sort(key=lambda evt: evt.cpu_interval.start)
+ return functions
+
+
+################################################################################
+# CUDA checkpoints
+
+class EnforceUnique(object):
+ """Raises an error if a key is seen more than once."""
+ def __init__(self):
+ self.seen = set()
+
+ def see(self, *key):
+ if key in self.seen:
+ raise RuntimeError('duplicate key: ' + str(key))
+ self.seen.add(key)
+
+
+def parse_nvprof_trace(path):
+ import sqlite3
+ conn = sqlite3.connect(path)
+ conn.row_factory = sqlite3.Row
+
+ # Parse strings table
+ strings = {}
+ for r in conn.execute("SELECT _id_ as id, value FROM StringTable"):
+ strings[r["id"]] = torch._C._demangle(r["value"])
+
+ # First, find all functions and create FunctionEvents for them
+ marker_query = """
+ SELECT
+ start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time
+ FROM
+ CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
+ ON start.id = end.id
+ WHERE
+ start.name != 0 AND end.name = 0
+ """
+ functions = []
+ functions_map = {}
+ unique = EnforceUnique()
+ for row in conn.execute(marker_query):
+ unique.see(row['marker_id'])
+ evt = FunctionEvent(id=row['marker_id'],
+ name=strings[row['name']],
+ cpu_start=row['start_time'],
+ cpu_end=row['end_time'],
+ thread=0) # TODO: find in sqlite database
+ functions.append(evt)
+ functions_map[evt.id] = evt
+
+ # Now, correlate all kernels with FunctionEvents
+ kernel_query = """
+ SELECT
+ start.id AS marker_id, start.name, start.timestamp, end.timestamp,
+ runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end,
+ kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name
+ FROM
+ CUPTI_ACTIVITY_KIND_MARKER AS start
+ INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
+ ON start.id = end.id
+ INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime
+ ON (start.timestamp < runtime.start AND runtime.end < end.timestamp)
+ INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel
+ ON kernel.correlationId = runtime.correlationId
+ """
+ unique = EnforceUnique()
+ for row in conn.execute(kernel_query):
+ unique.see(row['marker_id'], row['runtime_id'])
+ assert row['cbid'] == 13 # 13 == Launch
+ evt = functions_map[row['marker_id']]
+ evt.append_kernel(row['kernel_name'],
+ 0,
+ row['kernel_start'],
+ row['kernel_end'])
+
+ functions.sort(key=lambda evt: evt.cpu_interval.start)
+ return functions
+
+
+################################################################################
+# Pretty printer
+
+def build_table(events, sort_by=None, header=None, row_limit=100):
+ """Prints a summary of events (which can be a list of FunctionEvent or FunctionEventAvg)."""
+ if sort_by is not None:
+ events = EventList(sorted(
+ events, key=lambda evt: getattr(evt, sort_by), reverse=True
+ ))
+
+ name_lengths = [len(evt.key) for evt in events]
+ if len(name_lengths) == 0:
+ return ""
+ max_name_length = max(name_lengths)
+ max_name_length += 4 # Add some nice padding
+ col_width = 15
+ col_format = ' {: >' + str(col_width) + '}'
+ row_format = '{: <' + str(max_name_length) + '}' + col_format * 9
+ header_sep = '-' * max_name_length + (' ' + '-' * col_width) * 9
+
+ # Have to use a list because nonlocal is Py3 only...
+ result = []
+
+ def append(s):
+ result.append(s)
+ result.append('\n') # Yes, newline after the end as well
+
+ self_cpu_time_total = sum([event.self_cpu_time_total for event in events])
+ cuda_time_total = sum([evt.cuda_time_total for evt in events])
+ # Actual printing
+ if header is not None:
+ line_length = max_name_length + (col_width + 2) * 5
+ append('=' * line_length)
+ append(header)
+ append(header_sep)
+ append(row_format.format(
+ 'Name',
+ 'Self CPU total %',
+ 'Self CPU total',
+ 'CPU total %',
+ 'CPU total',
+ 'CPU time avg',
+ 'CUDA total %',
+ 'CUDA total',
+ 'CUDA time avg',
+ 'Number of Calls',
+ ))
+ append(header_sep)
+ for evt in events[:row_limit]:
+ append(row_format.format(
+ evt.key, # Name
+ # Self CPU total %
+ format_time_share(evt.self_cpu_time_total, self_cpu_time_total),
+ evt.self_cpu_time_total_str, # Self CPU total
+ # CPU total %
+ format_time_share(evt.cpu_time_total, self_cpu_time_total),
+ evt.cpu_time_total_str, # CPU total
+ evt.cpu_time_str, # CPU time avg
+ # CUDA time total %
+ format_time_share(evt.cuda_time_total, cuda_time_total),
+ evt.cuda_time_total_str,
+ evt.cuda_time_str, # Cuda time avg
+ evt.count, # Number of calls
+ ))
+ append(header_sep)
+ append("Self CPU time total: {}".format(format_time(self_cpu_time_total)))
+ append("CUDA time total: {}".format(format_time(cuda_time_total)))
+ return ''.join(result)
+
+r"""
+This package adds support for CUDA tensor types, that implement the same
+function as CPU tensors, but they utilize GPUs for computation.
+
+It is lazily initialized, so you can always import it, and use
+:func:`is_available()` to determine if your system supports CUDA.
+
+:ref:`cuda-semantics` has more details about working with CUDA.
+"""
+
+import contextlib
+import platform
+import ctypes
+import os
+import torch
+import traceback
+import warnings
+from torch._six import raise_from
+from subprocess import Popen, PIPE
+from multiprocessing.util import register_after_fork as _register_after_fork
+from ._utils import _get_device_index
+
+_initialized = False
+_queued_calls = [] # don't invoke these until initialization occurs
+_in_bad_fork = False # this global is also used in torch.manual_seed
+_original_pid = False
+_cudart = None
+
+
+def find_cuda_windows_lib():
+ proc = Popen(['where', 'cudart64*.dll'], stdout=PIPE, stderr=PIPE, stdin=PIPE)
+ out, err = proc.communicate()
+ out = out.decode().strip()
+ if len(out) > 0:
+ if out.find('\r\n') != -1:
+ out = out.split('\r\n')[0]
+ cuda_lib_name = os.path.basename(out)
+ cuda_lib = os.path.splitext(cuda_lib_name)[0]
+ cuda_lib = str(cuda_lib)
+ return ctypes.cdll.LoadLibrary(cuda_lib)
+ else:
+ return None
+
+
+[docs]def is_available():
+ r"""Returns a bool indicating if CUDA is currently available."""
+ if (not hasattr(torch._C, '_cuda_isDriverSufficient') or
+ not torch._C._cuda_isDriverSufficient()):
+ return False
+ return torch._C._cuda_getDeviceCount() > 0
+
+
+def _sleep(cycles):
+ torch._C._cuda_sleep(cycles)
+
+
+def _load_cudart():
+ # First check the main program for CUDA symbols
+ if platform.system() == 'Windows':
+ lib = find_cuda_windows_lib()
+ else:
+ lib = ctypes.cdll.LoadLibrary(None)
+ if hasattr(lib, 'cudaGetErrorName'):
+ return lib
+
+ raise RuntimeError(
+ "couldn't find libcudart. Make sure CUDA libraries are installed in a "
+ "default location, or that they're in {}."
+ .format('DYLD_LIBRARY_PATH' if platform.system() == 'Darwin' else
+ 'LD_LIBRARY_PATH'))
+
+
+def _check_driver():
+ if not hasattr(torch._C, '_cuda_isDriverSufficient'):
+ raise AssertionError("Torch not compiled with CUDA enabled")
+ if not torch._C._cuda_isDriverSufficient():
+ if torch._C._cuda_getDriverVersion() == 0:
+ # found no NVIDIA driver on the system
+ raise AssertionError("""
+Found no NVIDIA driver on your system. Please check that you
+have an NVIDIA GPU and installed a driver from
+http://www.nvidia.com/Download/index.aspx""")
+ else:
+ # TODO: directly link to the alternative bin that needs install
+ raise AssertionError("""
+The NVIDIA driver on your system is too old (found version {}).
+Please update your GPU driver by downloading and installing a new
+version from the URL: http://www.nvidia.com/Download/index.aspx
+Alternatively, go to: https://pytorch.org to install
+a PyTorch version that has been compiled with your version
+of the CUDA driver.""".format(str(torch._C._cuda_getDriverVersion())))
+
+
+def _check_capability():
+ incorrect_binary_warn = """
+ Found GPU%d %s which requires CUDA_VERSION >= %d for
+ optimal performance and fast startup time, but your PyTorch was compiled
+ with CUDA_VERSION %d. Please install the correct PyTorch binary
+ using instructions from https://pytorch.org
+ """
+
+ old_gpu_warn = """
+ Found GPU%d %s which is of cuda capability %d.%d.
+ PyTorch no longer supports this GPU because it is too old.
+ The minimum cuda capability that we support is 3.5.
+ """
+
+ CUDA_VERSION = torch._C._cuda_getCompiledVersion()
+ for d in range(device_count()):
+ capability = get_device_capability(d)
+ major = capability[0]
+ name = get_device_name(d)
+ if CUDA_VERSION < 8000 and major >= 6:
+ warnings.warn(incorrect_binary_warn % (d, name, 8000, CUDA_VERSION))
+ elif CUDA_VERSION < 9000 and major >= 7:
+ warnings.warn(incorrect_binary_warn % (d, name, 9000, CUDA_VERSION))
+ elif capability == (3, 0) or major < 3:
+ warnings.warn(old_gpu_warn % (d, name, major, capability[1]))
+
+
+def _lazy_call(callable):
+ if _initialized:
+ callable()
+ else:
+ # Don't store the actual traceback to avoid memory cycle
+ _queued_calls.append((callable, traceback.format_stack()))
+
+_lazy_call(_check_capability)
+
+
+class DeferredCudaCallError(Exception):
+ pass
+
+
+[docs]def init():
+ r"""Initialize PyTorch's CUDA state. You may need to call
+ this explicitly if you are interacting with PyTorch via
+ its C API, as Python bindings for CUDA functionality will not
+ be until this initialization takes place. Ordinary users
+ should not need this, as all of PyTorch's CUDA methods
+ automatically initialize CUDA state on-demand.
+
+ Does nothing if the CUDA state is already initialized.
+ """
+ _lazy_init()
+
+
+def _lazy_init():
+ global _initialized, _cudart, _original_pid, _queued_calls
+ if _initialized:
+ return
+ if _in_bad_fork:
+ from sys import version_info
+ if version_info < (3, 4):
+ msg = ("To use CUDA with multiprocessing, you must use Python "
+ "3.4+ and the 'spawn' start method")
+ else:
+ msg = ("To use CUDA with multiprocessing, you must use the "
+ "'spawn' start method")
+ raise RuntimeError(
+ "Cannot re-initialize CUDA in forked subprocess. " + msg)
+ _check_driver()
+ torch._C._cuda_init()
+ _cudart = _load_cudart()
+ _cudart.cudaGetErrorName.restype = ctypes.c_char_p
+ _cudart.cudaGetErrorString.restype = ctypes.c_char_p
+ _original_pid = os.getpid()
+ _initialized = True
+ # Important to do this after _initialized, since some queued calls
+ # may themselves call _lazy_init()
+ for queued_call, orig_traceback in _queued_calls:
+ try:
+ queued_call()
+ except Exception as e:
+ msg = ("CUDA call failed lazily at initialization with error: {}\n\n"
+ "CUDA call was originally invoked at:\n\n{}").format(str(e), orig_traceback)
+ raise_from(DeferredCudaCallError(msg), e)
+
+
+def _after_fork(arg):
+ global _initialized, _in_bad_fork
+ if _initialized and _original_pid != os.getpid():
+ _initialized = False
+ _in_bad_fork = True
+ _CudaBase.__new__ = _lazy_new
+
+
+_register_after_fork(_after_fork, _after_fork)
+
+
+def cudart():
+ _lazy_init()
+ return _cudart
+
+
+class cudaStatus(object):
+ SUCCESS = 0
+ ERROR_NOT_READY = 34
+
+
+class CudaError(RuntimeError):
+ def __init__(self, code):
+ msg = cudart().cudaGetErrorString(code).decode('utf-8')
+ super(CudaError, self).__init__('{0} ({1})'.format(msg, code))
+
+
+def check_error(res):
+ if res != cudaStatus.SUCCESS:
+ raise CudaError(res)
+
+
+[docs]class device(object):
+ r"""Context-manager that changes the selected device.
+
+ Arguments:
+ device (torch.device or int): device index to select. It's a no-op if
+ this argument is a negative integer or ``None``.
+ """
+
+ def __init__(self, device):
+ self.idx = _get_device_index(device, optional=True)
+ self.prev_idx = -1
+
+ def __enter__(self):
+ if self.idx == -1:
+ return
+ self.prev_idx = torch._C._cuda_getDevice()
+ if self.prev_idx != self.idx:
+ torch._C._cuda_setDevice(self.idx)
+ _lazy_init()
+
+ def __exit__(self, *args):
+ if self.prev_idx != self.idx:
+ torch._C._cuda_setDevice(self.prev_idx)
+ return False
+
+
+[docs]class device_of(device):
+ r"""Context-manager that changes the current device to that of given object.
+
+ You can use both tensors and storages as arguments. If a given object is
+ not allocated on a GPU, this is a no-op.
+
+ Arguments:
+ obj (Tensor or Storage): object allocated on the selected device.
+ """
+
+ def __init__(self, obj):
+ idx = obj.get_device() if obj.is_cuda else -1
+ super(device_of, self).__init__(idx)
+
+
+[docs]def set_device(device):
+ r"""Sets the current device.
+
+ Usage of this function is discouraged in favor of :any:`device`. In most
+ cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable.
+
+ Arguments:
+ device (torch.device or int): selected device. This function is a no-op
+ if this argument is negative.
+ """
+ device = _get_device_index(device)
+ if device >= 0:
+ torch._C._cuda_setDevice(device)
+
+
+[docs]def get_device_name(device=None):
+ r"""Gets the name of a device.
+
+ Arguments:
+ device (torch.device or int, optional): device for which to return the
+ name. This function is a no-op if this argument is a negative
+ integer. It uses the current device, given by :func:`~torch.cuda.current_device`,
+ if :attr:`device` is ``None`` (default).
+ """
+ return get_device_properties(device).name
+
+
+[docs]def get_device_capability(device=None):
+ r"""Gets the cuda capability of a device.
+
+ Arguments:
+ device (torch.device or int, optional): device for which to return the
+ device capability. This function is a no-op if this argument is
+ a negative integer. It uses the current device, given by
+ :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+ (default).
+
+ Returns:
+ tuple(int, int): the major and minor cuda capability of the device
+ """
+ prop = get_device_properties(device)
+ return prop.major, prop.minor
+
+
+def get_device_properties(device):
+ if not _initialized:
+ init() # will define _get_device_properties and _CudaDeviceProperties
+ device = _get_device_index(device, optional=True)
+ if device < 0 or device >= device_count():
+ raise AssertionError("Invalid device id")
+ return _get_device_properties(device)
+
+
+[docs]@contextlib.contextmanager
+def stream(stream):
+ r"""Context-manager that selects a given stream.
+
+ All CUDA kernels queued within its context will be enqueued on a selected
+ stream.
+
+ Arguments:
+ stream (Stream): selected stream. This manager is a no-op if it's
+ ``None``.
+
+ .. note:: Streams are per-device. If the selected stream is not on the
+ current device, this function will also change the current device to
+ match the stream.
+ """
+ if stream is None:
+ yield
+ return
+ src_prev_stream = current_stream()
+
+ if src_prev_stream.device != stream.device:
+ # The given stream is on a different device; have to restore the
+ # current_stream on that device on exit as well
+ with device(stream.device):
+ dst_prev_stream = current_stream()
+
+ torch._C._cuda_setStream(stream._cdata)
+ try:
+ yield
+ finally:
+ if src_prev_stream.device != stream.device:
+ torch._C._cuda_setStream(dst_prev_stream._cdata)
+ torch._C._cuda_setStream(src_prev_stream._cdata)
+
+
+[docs]def device_count():
+ r"""Returns the number of GPUs available."""
+ if is_available():
+ return torch._C._cuda_getDeviceCount()
+ else:
+ return 0
+
+
+[docs]def current_device():
+ r"""Returns the index of a currently selected device."""
+ _lazy_init()
+ return torch._C._cuda_getDevice()
+
+
+[docs]def synchronize(device=None):
+ r"""Waits for all kernels in all streams on a CUDA device to complete.
+
+ Arguments:
+ device (torch.device or int, optional): device for which to synchronize.
+ It uses the current device, given by :func:`~torch.cuda.current_device`,
+ if :attr:`device` is ``None`` (default).
+ """
+ _lazy_init()
+ with torch.cuda.device(device):
+ return torch._C._cuda_synchronize()
+
+
+[docs]def ipc_collect():
+ r"""Force collects GPU memory after it has been released by CUDA IPC.
+
+ .. note::
+ Checks if any sent CUDA tensors could be cleaned from the memory. Force
+ closes shared memory file used for reference counting if there is no
+ active counters. Useful when the producer process stopped actively sending
+ tensors and want to release unused memory.
+ """
+ _lazy_init()
+ return torch._C._cuda_ipc_collect()
+
+
+[docs]def current_stream(device=None):
+ r"""Returns the currently selected :class:`Stream` for a given device.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ the currently selected :class:`Stream` for the current device, given
+ by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+ (default).
+ """
+ _lazy_init()
+ return torch.cuda.Stream(_cdata=torch._C._cuda_getCurrentStream(
+ _get_device_index(device, optional=True)))
+
+
+[docs]def default_stream(device=None):
+ r"""Returns the default :class:`Stream` for a given device.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ the default :class:`Stream` for the current device, given by
+ :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+ (default).
+ """
+ _lazy_init()
+ return torch.cuda.Stream(_cdata=torch._C._cuda_getDefaultStream(
+ _get_device_index(device, optional=True)))
+
+
+[docs]def current_blas_handle():
+ r"""Returns cublasHandle_t pointer to current cuBLAS handle"""
+ _lazy_init()
+ return torch._C._cuda_getCurrentBlasHandle()
+
+
+[docs]def empty_cache():
+ r"""Releases all unoccupied cached memory currently held by the caching
+ allocator so that those can be used in other GPU application and visible in
+ `nvidia-smi`.
+
+ .. note::
+ :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
+ memory available for PyTorch. See :ref:`cuda-memory-management` for
+ more details about GPU memory management.
+ """
+ if _initialized:
+ torch._C._cuda_emptyCache()
+
+
+[docs]def memory_allocated(device=None):
+ r"""Returns the current GPU memory occupied by tensors in bytes for a given
+ device.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
+ if :attr:`device` is ``None`` (default).
+
+ .. note::
+ This is likely less than the amount shown in `nvidia-smi` since some
+ unused memory can be held by the caching allocator and some context
+ needs to be created on GPU. See :ref:`cuda-memory-management` for more
+ details about GPU memory management.
+ """
+ device = _get_device_index(device, optional=True)
+ return torch._C._cuda_memoryAllocated(device)
+
+
+[docs]def max_memory_allocated(device=None):
+ r"""Returns the maximum GPU memory occupied by tensors in bytes for a given
+ device.
+
+ By default, this returns the peak allocated memory since the beginning of
+ this program. :func:`~torch.cuda.reset_max_memory_allocated` can be used to
+ reset the starting point in tracking this metric. For example, these two
+ functions can measure the peak allocated memory usage of each iteration in a
+ training loop.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
+ if :attr:`device` is ``None`` (default).
+
+ .. note::
+ See :ref:`cuda-memory-management` for more details about GPU memory
+ management.
+ """
+ device = _get_device_index(device, optional=True)
+ return torch._C._cuda_maxMemoryAllocated(device)
+
+
+[docs]def reset_max_memory_allocated(device=None):
+ r"""Resets the starting point in tracking maximum GPU memory occupied by
+ tensors for a given device.
+
+ See :func:`~torch.cuda.max_memory_allocated` for details.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
+ if :attr:`device` is ``None`` (default).
+
+ .. note::
+ See :ref:`cuda-memory-management` for more details about GPU memory
+ management.
+ """
+ device = _get_device_index(device, optional=True)
+ return torch._C._cuda_resetMaxMemoryAllocated(device)
+
+
+[docs]def memory_cached(device=None):
+ r"""Returns the current GPU memory managed by the caching allocator in bytes
+ for a given device.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
+ if :attr:`device` is ``None`` (default).
+
+ .. note::
+ See :ref:`cuda-memory-management` for more details about GPU memory
+ management.
+ """
+ device = _get_device_index(device, optional=True)
+ return torch._C._cuda_memoryCached(device)
+
+
+[docs]def max_memory_cached(device=None):
+ r"""Returns the maximum GPU memory managed by the caching allocator in bytes
+ for a given device.
+
+ By default, this returns the peak cached memory since the beginning of this
+ program. :func:`~torch.cuda.reset_max_memory_cached` can be used to reset
+ the starting point in tracking this metric. For example, these two functions
+ can measure the peak cached memory amount of each iteration in a training
+ loop.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
+ if :attr:`device` is ``None`` (default).
+
+ .. note::
+ See :ref:`cuda-memory-management` for more details about GPU memory
+ management.
+ """
+ device = _get_device_index(device, optional=True)
+ return torch._C._cuda_maxMemoryCached(device)
+
+
+[docs]def reset_max_memory_cached(device=None):
+ r"""Resets the starting point in tracking maximum GPU memory managed by the
+ caching allocator for a given device.
+
+ See :func:`~torch.cuda.max_memory_cached` for details.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ statistic for the current device, given by :func:`~torch.cuda.current_device`,
+ if :attr:`device` is ``None`` (default).
+
+ .. note::
+ See :ref:`cuda-memory-management` for more details about GPU memory
+ management.
+ """
+ device = _get_device_index(device, optional=True)
+ return torch._C._cuda_resetMaxMemoryCached(device)
+
+
+def _host_allocator():
+ _lazy_init()
+ return torch._C._cuda_cudaHostAllocator()
+
+
+@contextlib.contextmanager
+def _free_mutex():
+ torch._C._cuda_lock_mutex()
+ try:
+ yield
+ finally:
+ torch._C._cuda_unlock_mutex()
+
+
+from .random import *
+
+################################################################################
+# Define Storage and Tensor classes
+################################################################################
+
+
+from ..storage import _StorageBase
+
+
+def _dummy_type(name):
+ def init_err(self):
+ class_name = self.__class__.__name__
+ raise RuntimeError(
+ "Tried to instantiate dummy base class {}".format(class_name))
+ return type(storage_name, (object,), {"__init__": init_err})
+
+
+if not hasattr(torch._C, 'CudaDoubleStorageBase'):
+ # Define dummy base classes
+ for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half', 'Bool']:
+ storage_name = 'Cuda{0}StorageBase'.format(t)
+ tensor_name = 'Cuda{0}TensorBase'.format(t)
+
+ torch._C.__dict__[storage_name] = _dummy_type(storage_name)
+ torch._C.__dict__[tensor_name] = _dummy_type(tensor_name)
+
+ torch._C.__dict__['_CudaStreamBase'] = _dummy_type('CudaStreamBase')
+ torch._C.__dict__['_CudaEventBase'] = _dummy_type('CudaEventBase')
+
+
+@staticmethod
+def _lazy_new(cls, *args, **kwargs):
+ _lazy_init()
+ # We need this method only for lazy init, so we can remove it
+ del _CudaBase.__new__
+ return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
+
+
+class _CudaBase(object):
+ is_cuda = True
+ is_sparse = False
+
+ def type(self, *args, **kwargs):
+ with device(self.get_device()):
+ return super(_CudaBase, self).type(*args, **kwargs)
+
+ __new__ = _lazy_new
+
+
+class DoubleStorage(_CudaBase, torch._C.CudaDoubleStorageBase, _StorageBase):
+ pass
+
+
+class FloatStorage(_CudaBase, torch._C.CudaFloatStorageBase, _StorageBase):
+ pass
+
+
+class LongStorage(_CudaBase, torch._C.CudaLongStorageBase, _StorageBase):
+ pass
+
+
+class IntStorage(_CudaBase, torch._C.CudaIntStorageBase, _StorageBase):
+ pass
+
+
+class ShortStorage(_CudaBase, torch._C.CudaShortStorageBase, _StorageBase):
+ pass
+
+
+class CharStorage(_CudaBase, torch._C.CudaCharStorageBase, _StorageBase):
+ pass
+
+
+class ByteStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
+ pass
+
+
+class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
+ pass
+
+
+class BoolStorage(_CudaBase, torch._C.CudaBoolStorageBase, _StorageBase):
+ pass
+
+torch._storage_classes.add(DoubleStorage)
+torch._storage_classes.add(FloatStorage)
+torch._storage_classes.add(LongStorage)
+torch._storage_classes.add(IntStorage)
+torch._storage_classes.add(ShortStorage)
+torch._storage_classes.add(CharStorage)
+torch._storage_classes.add(ByteStorage)
+torch._storage_classes.add(HalfStorage)
+torch._storage_classes.add(BoolStorage)
+
+from . import sparse # noqa: F401
+from . import profiler # noqa: F401
+from . import nvtx # noqa: F401
+from .streams import Stream, Event # noqa: F401
+
+import torch
+from . import nccl
+from torch._utils import _take_tensors, _flatten_dense_tensors, \
+ _unflatten_dense_tensors, _reorder_tensors_as
+
+
+[docs]def broadcast(tensor, devices):
+ """Broadcasts a tensor to a number of GPUs.
+
+ Arguments:
+ tensor (Tensor): tensor to broadcast.
+ devices (Iterable): an iterable of devices among which to broadcast.
+ Note that it should be like (src, dst1, dst2, ...), the first element
+ of which is the source device to broadcast from.
+
+ Returns:
+ A tuple containing copies of the ``tensor``, placed on devices
+ corresponding to indices from ``devices``.
+ """
+ return torch._C._broadcast(tensor, devices)
+
+
+[docs]def broadcast_coalesced(tensors, devices, buffer_size=10485760):
+ """Broadcasts a sequence tensors to the specified GPUs.
+ Small tensors are first coalesced into a buffer to reduce the number
+ of synchronizations.
+
+ Arguments:
+ tensors (sequence): tensors to broadcast.
+ devices (Iterable): an iterable of devices among which to broadcast.
+ Note that it should be like (src, dst1, dst2, ...), the first element
+ of which is the source device to broadcast from.
+ buffer_size (int): maximum size of the buffer used for coalescing
+
+ Returns:
+ A tuple containing copies of the ``tensor``, placed on devices
+ corresponding to indices from ``devices``.
+ """
+ return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
+
+
+[docs]def reduce_add(inputs, destination=None):
+ """Sums tensors from multiple GPUs.
+
+ All inputs should have matching shapes.
+
+ Arguments:
+ inputs (Iterable[Tensor]): an iterable of tensors to add.
+ destination (int, optional): a device on which the output will be
+ placed (default: current device).
+
+ Returns:
+ A tensor containing an elementwise sum of all inputs, placed on the
+ ``destination`` device.
+ """
+ # TODO: try to find an input on another gpu, copy it,
+ # and accumulate into the copy
+ if destination is None:
+ destination = torch.cuda.current_device()
+ input_size = inputs[0].size()
+ nccl_root = None
+ for i, inp in enumerate(inputs):
+ assert inp.is_cuda, "reduce_add expects all inputs to be on GPUs"
+ if inp.get_device() == destination:
+ nccl_root = i
+ if inp.size() != input_size:
+ got = 'x'.join(str(x) for x in inp.size())
+ expected = 'x'.join(str(x) for x in input_size)
+ raise ValueError("input {} has invalid size: got {}, but expected "
+ "{}".format(i, got, expected))
+ if nccl_root is None:
+ raise RuntimeError("reduce_add expects destination to be on the same GPU with one of the tensors")
+ result = inp.new(device=destination).resize_as_(inp).zero_()
+
+ if nccl.is_available(inputs) and inputs[0].get_device() == destination:
+ outputs = [result] + [t.new(t.size()) for t in inputs[1:]]
+ nccl.reduce(inputs, outputs, root=nccl_root)
+ return result
+ for inp in inputs:
+ input_correct_gpu = inp.cuda(result.get_device())
+ result.add_(input_correct_gpu)
+ return result
+
+
+def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
+ """Sums tensors from multiple GPUs.
+
+ Small tensors are first coalesced into a buffer to reduce the number
+ of synchronizations.
+
+ Arguments:
+ inputs (Iterable[Iterable[Tensor]]): iterable of iterables that
+ contain tensors from a single device.
+ destination (int, optional): a device on which the output will be
+ placed (default: current device).
+ buffer_size (int): maximum size of the buffer used for coalescing
+
+ Returns:
+ A tuple of tensors containing an elementwise sum of each group of
+ inputs, placed on the ``destination`` device.
+ """
+ # TODO: When `len(inputs) == 1` and all inputs are on `destination`, just
+ # return `inputs`.
+ dense_tensors = [[] for _ in inputs] # shape (num_gpus, num_tensors)
+ output = []
+ ref_order = []
+ # process sparse ones first since they may have different sizes on different gpus
+ for tensor_at_gpus in zip(*inputs):
+ if all(t.is_sparse for t in tensor_at_gpus):
+ result = reduce_add(tensor_at_gpus, destination)
+ output.append(result)
+ ref_order.append(tensor_at_gpus[0])
+ else:
+ for coll, t in zip(dense_tensors, tensor_at_gpus):
+ coll.append(t.to_dense() if t.is_sparse else t)
+ ref_order.append(dense_tensors[0][-1])
+ itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors]
+ # now the dense ones, which have consistent sizes
+ for chunks in zip(*itrs):
+ flat_tensors = [_flatten_dense_tensors(chunk) for chunk in chunks]
+ flat_result = reduce_add(flat_tensors, destination)
+ for t in _unflatten_dense_tensors(flat_result, chunks[0]):
+ # The unflattened tensors do not share storage, and we don't expose
+ # base flat tensor anyways, so give them different version counters.
+ # See NOTE [ Version Counter in comm.*_coalesced ]
+ output.append(t.data)
+ return tuple(_reorder_tensors_as(output, ref_order))
+
+
+[docs]def scatter(tensor, devices, chunk_sizes=None, dim=0, streams=None):
+ """Scatters tensor across multiple GPUs.
+
+ Arguments:
+ tensor (Tensor): tensor to scatter.
+ devices (Iterable[int]): iterable of ints, specifying among which
+ devices the tensor should be scattered.
+ chunk_sizes (Iterable[int], optional): sizes of chunks to be placed on
+ each device. It should match ``devices`` in length and sum to
+ ``tensor.size(dim)``. If not specified, the tensor will be divided
+ into equal chunks.
+ dim (int, optional): A dimension along which to chunk the tensor.
+
+ Returns:
+ A tuple containing chunks of the ``tensor``, spread across given
+ ``devices``.
+ """
+ return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
+
+
+[docs]def gather(tensors, dim=0, destination=None):
+ """Gathers tensors from multiple GPUs.
+
+ Tensor sizes in all dimension different than ``dim`` have to match.
+
+ Arguments:
+ tensors (Iterable[Tensor]): iterable of tensors to gather.
+ dim (int): a dimension along which the tensors will be concatenated.
+ destination (int, optional): output device (-1 means CPU, default:
+ current device)
+
+ Returns:
+ A tensor located on ``destination`` device, that is a result of
+ concatenating ``tensors`` along ``dim``.
+ """
+ return torch._C._gather(tensors, dim, destination)
+
+import os
+import glob
+import ctypes
+import platform
+
+lib = None
+
+__all__ = ['range_push', 'range_pop', 'mark']
+
+
+def windows_nvToolsExt_lib():
+ lib_path = windows_nvToolsExt_path()
+ if len(lib_path) > 0:
+ lib_name = os.path.basename(lib_path)
+ lib = os.path.splitext(lib_name)[0]
+ return ctypes.cdll.LoadLibrary(lib)
+ else:
+ return None
+
+
+def windows_nvToolsExt_path():
+ WINDOWS_HOME = 'C:/Program Files/NVIDIA Corporation/NvToolsExt'
+ NVTOOLEXT_HOME = os.getenv('NVTOOLSEXT_PATH', WINDOWS_HOME)
+ if os.path.exists(NVTOOLEXT_HOME):
+ lib_paths = glob.glob(NVTOOLEXT_HOME + '/bin/x64/nvToolsExt*.dll')
+ if len(lib_paths) > 0:
+ lib_path = lib_paths[0]
+ return lib_path
+ return ''
+
+
+def _libnvToolsExt():
+ global lib
+ if lib is None:
+ if platform.system() != 'Windows':
+ lib = ctypes.cdll.LoadLibrary(None)
+ else:
+ lib = windows_nvToolsExt_lib()
+ lib.nvtxMarkA.restype = None
+ return lib
+
+
+[docs]def range_push(msg):
+ """
+ Pushes a range onto a stack of nested range span. Returns zero-based
+ depth of the range that is started.
+
+ Arguments:
+ msg (string): ASCII message to associate with range
+ """
+ if _libnvToolsExt() is None:
+ raise RuntimeError('Unable to load nvToolsExt library')
+ return lib.nvtxRangePushA(ctypes.c_char_p(msg.encode("ascii")))
+
+
+[docs]def range_pop():
+ """
+ Pops a range off of a stack of nested range spans. Returns the
+ zero-based depth of the range that is ended.
+ """
+ if _libnvToolsExt() is None:
+ raise RuntimeError('Unable to load nvToolsExt library')
+ return lib.nvtxRangePop()
+
+
+[docs]def mark(msg):
+ """
+ Describe an instantaneous event that occurred at some point.
+
+ Arguments:
+ msg (string): ASCII message to associate with the event.
+ """
+ if _libnvToolsExt() is None:
+ raise RuntimeError('Unable to load nvToolsExt library')
+ return lib.nvtxMarkA(ctypes.c_char_p(msg.encode("ascii")))
+
+from torch import _C, device
+from . import _lazy_init, _lazy_call, device_count, device as device_ctx_manager
+
+__all__ = ['get_rng_state', 'get_rng_state_all',
+ 'set_rng_state', 'set_rng_state_all',
+ 'manual_seed', 'manual_seed_all',
+ 'seed', 'seed_all', 'initial_seed']
+
+
+[docs]def get_rng_state(device=device('cuda')):
+ r"""Returns the random number generator state of the current
+ GPU as a ByteTensor.
+
+ Args:
+ device (torch.device or int, optional): The device to return the RNG state of.
+ Default: ``torch.device('cuda')`` (i.e., the current CUDA device).
+
+ .. warning::
+ This function eagerly initializes CUDA.
+ """
+ _lazy_init()
+ with device_ctx_manager(device):
+ return _C._cuda_getRNGState()
+
+
+[docs]def get_rng_state_all():
+ r"""Returns a tuple of ByteTensor representing the random number states of all devices."""
+
+ results = []
+ for i in range(device_count()):
+ with device_ctx_manager(i):
+ results.append(get_rng_state())
+ return results
+
+
+[docs]def set_rng_state(new_state, device=device('cuda')):
+ r"""Sets the random number generator state of the current GPU.
+
+ Args:
+ new_state (torch.ByteTensor): The desired state
+ device (torch.device or int, optional): The device to set the RNG state.
+ Default: ``torch.device('cuda')`` (i.e., the current CUDA device).
+ """
+ new_state_copy = new_state.clone()
+
+ # NB: What if device=-1? You might be afraid that the "current"
+ # device would change by the time we actually get around to invoking
+ # the lazy callback. But actually, this is not possible: changing
+ # the current device involves a CUDA call, which would in turn
+ # initialize the state. So then _lazy_call would execute cb
+ # immediately.
+ def cb():
+ with device_ctx_manager(device):
+ _C._cuda_setRNGState(new_state_copy)
+
+ _lazy_call(cb)
+
+
+[docs]def set_rng_state_all(new_states):
+ r"""Sets the random number generator state of all devices.
+
+ Args:
+ new_state (tuple of torch.ByteTensor): The desired state for each device"""
+ for i, state in enumerate(new_states):
+ set_rng_state(state, i)
+
+
+[docs]def manual_seed(seed):
+ r"""Sets the seed for generating random numbers for the current GPU.
+ It's safe to call this function if CUDA is not available; in that
+ case, it is silently ignored.
+
+ Args:
+ seed (int): The desired seed.
+
+ .. warning::
+ If you are working with a multi-GPU model, this function is insufficient
+ to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
+ """
+ seed = int(seed)
+ _lazy_call(lambda: _C._cuda_manualSeed(seed))
+
+
+[docs]def manual_seed_all(seed):
+ r"""Sets the seed for generating random numbers on all GPUs.
+ It's safe to call this function if CUDA is not available; in that
+ case, it is silently ignored.
+
+ Args:
+ seed (int): The desired seed.
+ """
+ seed = int(seed)
+ _lazy_call(lambda: _C._cuda_manualSeedAll(seed))
+
+
+[docs]def seed():
+ r"""Sets the seed for generating random numbers to a random number for the current GPU.
+ It's safe to call this function if CUDA is not available; in that
+ case, it is silently ignored.
+
+ .. warning::
+ If you are working with a multi-GPU model, this function will only initialize
+ the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
+ """
+ _lazy_call(lambda: _C._cuda_seed())
+
+
+[docs]def seed_all():
+ r"""Sets the seed for generating random numbers to a random number on all GPUs.
+ It's safe to call this function if CUDA is not available; in that
+ case, it is silently ignored.
+ """
+ _lazy_call(lambda: _C._cuda_seedAll())
+
+
+[docs]def initial_seed():
+ r"""Returns the current random seed of the current GPU.
+
+ .. warning::
+ This function eagerly initializes CUDA.
+ """
+ _lazy_init()
+ return _C._cuda_initialSeed()
+
+import ctypes
+import torch
+
+
+[docs]class Stream(torch._C._CudaStreamBase):
+ r"""Wrapper around a CUDA stream.
+
+ A CUDA stream is a linear sequence of execution that belongs to a specific
+ device, independent from other streams. See :ref:`cuda-semantics` for
+ details.
+
+ Arguments:
+ device(torch.device or int, optional): a device on which to allocate
+ the stream. If :attr:`device` is ``None`` (default) or a negative
+ integer, this will use the current device.
+ priority(int, optional): priority of the stream. Lower numbers
+ represent higher priorities.
+ """
+
+ def __new__(cls, device=None, priority=0, **kwargs):
+ with torch.cuda.device(device):
+ return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
+
+[docs] def wait_event(self, event):
+ r"""Makes all future work submitted to the stream wait for an event.
+
+ Arguments:
+ event (Event): an event to wait for.
+
+ .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see `CUDA
+ documentation`_ for more info.
+
+ This function returns without waiting for :attr:`event`: only future
+ operations are affected.
+
+ .. _CUDA documentation:
+ http://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
+ """
+ event.wait(self)
+
+[docs] def wait_stream(self, stream):
+ r"""Synchronizes with another stream.
+
+ All future work submitted to this stream will wait until all kernels
+ submitted to a given stream at the time of call complete.
+
+ Arguments:
+ stream (Stream): a stream to synchronize.
+
+ .. note:: This function returns without waiting for currently enqueued
+ kernels in :attr:`stream`: only future operations are affected.
+ """
+ self.wait_event(stream.record_event())
+
+[docs] def record_event(self, event=None):
+ r"""Records an event.
+
+ Arguments:
+ event (Event, optional): event to record. If not given, a new one
+ will be allocated.
+
+ Returns:
+ Recorded event.
+ """
+ if event is None:
+ event = Event()
+ event.record(self)
+ return event
+
+[docs] def query(self):
+ r"""Checks if all the work submitted has been completed.
+
+ Returns:
+ A boolean indicating if all kernels in this stream are completed."""
+ return super(Stream, self).query()
+
+[docs] def synchronize(self):
+ r"""Wait for all the kernels in this stream to complete.
+
+ .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
+ `CUDA documentation`_ for more info.
+
+ .. _CUDA documentation:
+ http://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
+ """
+ super(Stream, self).synchronize()
+
+ @property
+ def _as_parameter_(self):
+ return ctypes.c_void_p(self.cuda_stream)
+
+ def __eq__(self, o):
+ if isinstance(o, Stream):
+ return super(Stream, self).__eq__(o)
+ return False
+
+ def __hash__(self):
+ return hash((self.cuda_stream, self.device))
+
+ def __repr__(self):
+ return ('<torch.cuda.Stream device={0} cuda_stream={1:#x}>'
+ .format(self.device, self.cuda_stream))
+
+
+[docs]class Event(torch._C._CudaEventBase):
+ r"""Wrapper around a CUDA event.
+
+ CUDA events are synchronization markers that can be used to monitor the
+ device's progress, to accurately measure timing, and to synchronize CUDA
+ streams.
+
+ The underlying CUDA events are lazily initialized when the event is first
+ recorded or exported to another process. After creation, only streams on the
+ same device may record the event. However, streams on any device can wait on
+ the event.
+
+ Arguments:
+ enable_timing (bool, optional): indicates if the event should measure time
+ (default: ``False``)
+ blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
+ interprocess (bool): if ``True``, the event can be shared between processes
+ (default: ``False``)
+
+ .. _CUDA documentation:
+ https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
+ """
+
+ def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
+ return super(Event, cls).__new__(
+ cls,
+ enable_timing=enable_timing, blocking=blocking, interprocess=interprocess)
+
+[docs] @classmethod
+ def from_ipc_handle(cls, device, handle):
+ r"""Reconstruct an event from an IPC handle on the given device."""
+ return super(Event, cls).from_ipc_handle(device, handle)
+
+[docs] def record(self, stream=None):
+ r"""Records the event in a given stream.
+
+ Uses ``torch.cuda.current_stream()`` if no stream is specified. The
+ stream's device must match the event's device."""
+ if stream is None:
+ stream = torch.cuda.current_stream()
+ super(Event, self).record(stream)
+
+[docs] def wait(self, stream=None):
+ r"""Makes all future work submitted to the given stream wait for this
+ event.
+
+ Use ``torch.cuda.current_stream()`` if no stream is specified."""
+ if stream is None:
+ stream = torch.cuda.current_stream()
+ super(Event, self).wait(stream)
+
+[docs] def query(self):
+ r"""Checks if all work currently captured by event has completed.
+
+ Returns:
+ A boolean indicating if all work currently captured by event has
+ completed.
+ """
+ return super(Event, self).query()
+
+[docs] def elapsed_time(self, end_event):
+ r"""Returns the time elapsed in milliseconds after the event was
+ recorded and before the end_event was recorded.
+ """
+ return super(Event, self).elapsed_time(end_event)
+
+[docs] def synchronize(self):
+ r"""Waits for the event to complete.
+
+ Waits until the completion of all work currently captured in this event.
+ This prevents the CPU thread from proceeding until the event completes.
+
+ .. note:: This is a wrapper around ``cudaEventSynchronize()``: see `CUDA
+ documentation`_ for more info.
+
+ .. _CUDA documentation:
+ https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
+ """
+ super(Event, self).synchronize()
+
+[docs] def ipc_handle(self):
+ r"""Returns an IPC handle of this event. If not recorded yet, the event
+ will use the current device. """
+ return super(Event, self).ipc_handle()
+
+ @property
+ def _as_parameter_(self):
+ return ctypes.c_void_p(self.cuda_event)
+
+ def __repr__(self):
+ if self.cuda_event:
+ return '<torch.cuda.Event {0:#x}>'.format(self._as_parameter_.value)
+ else:
+ return '<torch.cuda.Event uninitialized>'
+
+import torch
+
+
+def is_available():
+ return hasattr(torch._C, "_c10d_init")
+
+
+if is_available() and not torch._C._c10d_init():
+ raise RuntimeError("Failed to initialize PyTorch distributed support")
+
+
+if is_available():
+ from .distributed_c10d import * # noqa: F401
+ # Variables prefixed with underscore are not auto imported
+ # See the comment in `distributed_c10d.py` above `_backend` on why we expose
+ # this.
+ from .distributed_c10d import _backend # noqa: F401
+
+import torch
+import warnings
+from torch._six import string_classes
+from datetime import timedelta
+
+# This module is wildcard imported from torch.distributed.
+# TODO: specify __all__
+
+from .rendezvous import rendezvous, register_rendezvous_handler # noqa: F401
+from . import (
+ AllreduceOptions,
+ BroadcastOptions,
+ GatherOptions,
+ ReduceOptions,
+ ReduceScatterOptions,
+ ScatterOptions,
+)
+from . import ReduceOp
+from . import PrefixStore
+
+
+_MPI_AVAILABLE = True
+_NCCL_AVAILABLE = True
+_GLOO_AVAILABLE = True
+
+
+try:
+ from. import ProcessGroupMPI
+except ImportError:
+ _MPI_AVAILABLE = False
+
+try:
+ from. import ProcessGroupNCCL
+except ImportError:
+ _NCCL_AVAILABLE = False
+
+try:
+ from. import ProcessGroupGloo
+except ImportError:
+ _GLOO_AVAILABLE = False
+
+
+[docs]class Backend(object):
+ """
+ An enum-like class of available backends: GLOO, NCCL, and MPI.
+
+ The values of this class are lowercase strings, e.g., ``"gloo"``. They can
+ be accessed as attributes, e.g., ``Backend.NCCL``.
+
+ This class can be directly called to parse the string, e.g.,
+ ``Backend(backend_str)`` will check if ``backend_str`` is valid, and
+ return the parsed lowercase string if so. It also accepts uppercase strings,
+ e.g., ``Backend("GLOO")`` returns ``"gloo"``.
+
+ .. note:: The entry ``Backend.UNDEFINED`` is present but only used as
+ initial value of some fields. Users should neither use it directly
+ nor assume its existence.
+ """
+ UNDEFINED = "undefined"
+ GLOO = "gloo"
+ NCCL = "nccl"
+ MPI = "mpi"
+ TCP = "tcp"
+
+ def __new__(cls, name):
+ if not isinstance(name, string_classes):
+ raise ValueError("Backend name must be a string, but got: {}".format(name))
+ value = getattr(Backend, name.upper(), Backend.UNDEFINED)
+
+ if value == Backend.TCP:
+ raise ValueError("TCP backend has been deprecated. Please use "
+ "Gloo or MPI backend for collective operations "
+ "on CPU tensors.")
+ elif value == Backend.UNDEFINED:
+ raise ValueError("Invalid backend: '{}'".format(name))
+ return value
+
+# `_backend`, `dist_backend`, and `reduce_op` are here to maintain backward
+# compatibility with pre-c10d distributed package.
+# TODO: remove them when users are ready to take a hard dependency on PyTorch 1.
+_backend = Backend.UNDEFINED
+dist_backend = Backend
+
+
+[docs]class reduce_op(object):
+ r"""
+ Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``,
+ ``MIN``, and ``MAX``.
+
+ :class:`~torch.distributed.ReduceOp` is recommended to use instead.
+ """
+
+ def __init__(self):
+ # __members__ is a dict storing key-value pairs for enum classes
+ for k, v in ReduceOp.__members__.items():
+ setattr(self, k, v)
+ self.__members__ = ReduceOp.__members__
+
+ def __getattribute__(self, key):
+ warnings.warn("torch.distributed.reduce_op is deprecated, please use "
+ "torch.distributed.ReduceOp instead")
+ return object.__getattribute__(self, key)
+
+reduce_op = reduce_op()
+
+
+class group(object):
+ WORLD = object()
+
+
+class GroupMember(object):
+ # Alias to group.WORLD for backward compatibility
+ WORLD = group.WORLD
+ NON_GROUP_MEMBER = object()
+
+
+# Cached process groups
+# For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
+# For MPI pg, it is a map from ProcessGroup to (Backend, None)
+_pg_map = {}
+# Process group's names, map from ProcessGroup to str
+_pg_names = {}
+# Process group's global rank to local rank mapping
+_pg_group_ranks = {}
+
+# Default process group state
+_default_pg = None
+_default_pg_init_method = None
+
+# Default process group wide timeout, if applicable.
+# This currently only applies to the gloo backend. To make an attempt at
+# backwards compatibility with THD, we use an extraordinarily high default
+# timeout, given that THD did not have timeouts.
+_default_pg_timeout = timedelta(minutes=30)
+
+# Process group count for default naming
+_group_count = 0
+
+
+def _rank_not_in_group(group):
+ """
+ Helper that checks if the current process's rank is not in a given group
+
+ """
+ if group == GroupMember.WORLD:
+ return False
+ return group == GroupMember.NON_GROUP_MEMBER
+
+
+def _get_group_rank(group, rank):
+ """
+ Helper that gets a given group's local rank in the group from a given global
+ rank
+
+ """
+ if group is GroupMember.WORLD:
+ raise RuntimeError("group.WORLD does not have local rank to global "
+ "rank mapping")
+ if group not in _pg_group_ranks:
+ raise RuntimeError("The given group does not exist")
+ try:
+ group_rank = _pg_group_ranks[group][rank]
+ except KeyError:
+ raise RuntimeError("The global rank is not part of the group")
+ return group_rank
+
+
+def _get_global_rank(group, group_rank):
+ """
+ Helper that gets a given group's global rank from a given local rank in the
+ group
+
+ """
+ if group is GroupMember.WORLD:
+ raise RuntimeError("group.WORLD does not have local rank to global "
+ "rank mapping")
+ group_rank_map = _pg_group_ranks[group]
+ for rank, grp_rank in group_rank_map.items():
+ if grp_rank == group_rank:
+ return rank
+ raise RuntimeError("The group rank is not part of the group")
+
+
+def _check_default_pg():
+ """
+ Helper that checks if the default ProcessGroup has been initializd, with
+ assertion
+
+ """
+ assert _default_pg is not None, \
+ "Default process group is not initialized"
+
+
+def _get_group_size(group):
+ """
+ Helper that gets a given group's world size
+
+ """
+ if group is GroupMember.WORLD:
+ _check_default_pg()
+ return _default_pg.size()
+ if group not in _pg_group_ranks:
+ raise RuntimeError("The given group does not exist")
+ return len(_pg_group_ranks[group])
+
+
+def _check_single_tensor(param, param_name):
+ """
+ Helper that check the parameter: param_name is a single Tensor
+
+ """
+ if not isinstance(param, torch.Tensor):
+ raise RuntimeError("Invalid function argument. Expecting parameter: {} "
+ "to be a torch.Tensor type".format(param_name))
+
+
+def _check_tensor_list(param, param_name):
+ """
+ Helper that check the parameter: param_name is a Tensor list
+
+ """
+ wrong_type = False
+ if isinstance(param, list):
+ for p in param:
+ if not isinstance(p, torch.Tensor):
+ wrong_type = True
+ break
+ else:
+ wrong_type = True
+ if wrong_type:
+ raise RuntimeError("Invalid function argument. Expecting parameter: {} "
+ "to be a List[torch.Tensor] type".format(param_name))
+
+
+[docs]def is_mpi_available():
+ """
+ Checks if the MPI backend is available.
+
+ """
+ return _MPI_AVAILABLE
+
+
+[docs]def is_nccl_available():
+ """
+ Checks if the NCCL backend is available.
+
+ """
+ return _NCCL_AVAILABLE
+
+
+def is_gloo_available():
+ """
+ Checks if the Gloo backend is available.
+
+ """
+ return _GLOO_AVAILABLE
+
+
+[docs]def is_initialized():
+ """
+ Checking if the default process group has been initialized
+
+ """
+ return _default_pg is not None
+
+
+def _get_default_group():
+ """
+ Getting the default process group created by init_process_group
+
+ """
+ if not is_initialized():
+ raise RuntimeError("Default process group has not been initialized, "
+ "please make sure to call init_process_group.")
+ return _default_pg
+
+
+def _get_default_store():
+ """
+ Getting the default store created by init_process_group
+
+ """
+ if not is_initialized():
+ raise RuntimeError("Default process group has not been initialized, "
+ "please make sure to call init_process_group.")
+ _, default_store = _pg_map[_default_pg]
+ return default_store
+
+
+[docs]def get_backend(group=group.WORLD):
+ """
+ Returns the backend of the given process group.
+
+ Arguments:
+ group (ProcessGroup, optional): The process group to work on. The
+ default is the general main process group. If another specific group
+ is specified, the calling process must be part of :attr:`group`.
+
+ Returns:
+ The backend of the given process group as a lower case string.
+
+ """
+ _check_default_pg()
+
+ if group == GroupMember.WORLD:
+ pg = _default_pg
+ else:
+ pg = group
+ if _rank_not_in_group(pg):
+ raise RuntimeError("Invalid process group specified")
+ return _pg_map.get(pg, None)[0]
+
+
+[docs]def init_process_group(backend,
+ init_method=None,
+ timeout=_default_pg_timeout,
+ world_size=-1,
+ rank=-1,
+ store=None,
+ group_name=''):
+ """
+ Initializes the default distributed process group, and this will also
+ initialize the distributed package.
+
+ There are 2 main ways to initialize a process group:
+ 1. Specify ``store``, ``rank``, and ``world_size`` explicitly.
+ 2. Specify ``init_method`` (a URL string) which indicates where/how
+ to discover peers. Optionally specify ``rank`` and ``world_size``,
+ or encode all required parameters in the URL and omit them.
+ If neither is specified, ``init_method`` is assumed to be "env://".
+
+
+ Arguments:
+ backend (str or Backend): The backend to use. Depending on
+ build-time configurations, valid values include ``mpi``, ``gloo``,
+ and ``nccl``. This field should be given as a lowercase string
+ (e.g., ``"gloo"``), which can also be accessed via
+ :class:`Backend` attributes (e.g., ``Backend.GLOO``). If using
+ multiple processes per machine with ``nccl`` backend, each process
+ must have exclusive access to every GPU it uses, as sharing GPUs
+ between processes can result in deadlocks.
+ init_method (str, optional): URL specifying how to initialize the
+ process group. Default is "env://" if no
+ ``init_method`` or ``store`` is specified.
+ Mutually exclusive with ``store``.
+ world_size (int, optional): Number of processes participating in
+ the job. Required if ``store`` is specified.
+ rank (int, optional): Rank of the current process.
+ Required if ``store`` is specified.
+ store(Store, optional): Key/value store accessible to all workers, used
+ to exchange connection/address information.
+ Mutually exclusive with ``init_method``.
+ timeout (timedelta, optional): Timeout for operations executed against
+ the process group. Default value equals 30 minutes.
+ This is only applicable for the ``gloo`` backend.
+ group_name (str, optional, deprecated): Group name.
+
+ To enable ``backend == Backend.MPI``, PyTorch needs to built from source
+ on a system that supports MPI. The same applies to NCCL as well.
+
+ """
+ global _pg_group_ranks
+ global _backend
+ global _default_pg
+ global _default_pg_init_method
+
+ if not isinstance(timeout, timedelta):
+ raise RuntimeError("Expected timeout argument to be of type"
+ "datetime.timedelta")
+
+ if _default_pg is not None:
+ raise RuntimeError("trying to initialize the default process group "
+ "twice!")
+
+ assert (store is None) or (init_method is None), \
+ "Cannot specify both init_method and store."
+
+ if store is not None:
+ assert world_size > 0, 'world_size must be positive if using store'
+ assert rank >= 0, 'rank must be non-negative if using store'
+ elif init_method is None:
+ init_method = "env://"
+
+ backend = Backend(backend)
+
+ if backend == Backend.MPI:
+ _default_pg = _new_process_group_helper(
+ -1,
+ -1,
+ [],
+ Backend.MPI,
+ None,
+ group_name=group_name,
+ timeout=timeout)
+ else:
+ # backward compatible API
+ if store is None:
+ url = init_method
+ if world_size != -1 and rank != -1:
+ url += "?rank={}&world_size={}".format(rank, world_size)
+ elif rank != -1:
+ url += "?rank={}".format(rank)
+ elif world_size != -1:
+ url += "?world_size={}".format(world_size)
+
+ store, rank, world_size = next(rendezvous(url))
+ store.set_timeout(timeout)
+
+ _default_pg = _new_process_group_helper(
+ world_size,
+ rank,
+ [],
+ backend,
+ store,
+ group_name=group_name,
+ timeout=timeout)
+
+ _pg_group_ranks[_default_pg] = {i: i for i in range(_default_pg.size())}
+ _backend = _pg_map[_default_pg][0]
+ _default_pg_init_method = init_method
+
+
+def _new_process_group_helper(world_size,
+ rank,
+ group_ranks,
+ backend,
+ store,
+ group_name=None,
+ timeout=_default_pg_timeout):
+ """
+ Create a new distributed process group.
+
+ This function must be called by ALL processes in the global group, even if
+ the calling process is not part of the newly created group. In that case,
+ this function returns GroupMember.NON_GROUP_MEMBER.
+
+ This function is called with ``group_ranks == []`` for the default group.
+ """
+ global _pg_map
+ global _group_count
+ global _pg_names
+
+ if not group_name:
+ group_name = str(_group_count)
+ _group_count += 1
+
+ if group_name in _pg_names.values():
+ raise RuntimeError("The specified group name has already been "
+ "created, please use a different group name")
+
+ if not isinstance(timeout, timedelta):
+ raise RuntimeError("Expected timeout argument to be of type"
+ "datetime.timedelta")
+
+ # The list of group ranks is empty if we're creating the default group.
+ is_default_group = (len(group_ranks) == 0)
+
+ backend = Backend(backend)
+ if backend == Backend.MPI:
+ if not is_mpi_available():
+ raise RuntimeError("Distributed package doesn't have MPI built in")
+ pg = ProcessGroupMPI.create(group_ranks)
+ if not pg:
+ return GroupMember.NON_GROUP_MEMBER
+ _pg_map[pg] = (Backend.MPI, None)
+ _pg_names[pg] = group_name
+ else:
+ # If this is a subgroup (which means group_ranks is specified),
+ # we check if the current process is a member of the new group.
+ if not is_default_group:
+ global_rank = _default_pg.rank()
+ if global_rank not in group_ranks:
+ return GroupMember.NON_GROUP_MEMBER
+
+ # Use the group name as prefix in the default store, such that
+ # a single store can be reused by multiple groups.
+ prefix_store = PrefixStore(group_name, store)
+
+ if backend == Backend.GLOO:
+ pg = ProcessGroupGloo(
+ prefix_store,
+ rank,
+ world_size,
+ timeout=timeout)
+ _pg_map[pg] = (Backend.GLOO, store)
+ _pg_names[pg] = group_name
+ elif backend == Backend.NCCL:
+ if not is_nccl_available():
+ raise RuntimeError("Distributed package doesn't have NCCL "
+ "built in")
+ pg = ProcessGroupNCCL(
+ prefix_store,
+ rank,
+ world_size,
+ group_name)
+ _pg_map[pg] = (Backend.NCCL, store)
+ _pg_names[pg] = group_name
+ else:
+ raise RuntimeError("Unsupported distributed backend by group")
+
+ return pg
+
+
+def destroy_process_group(group=group.WORLD):
+ """
+ Destroy a given process group, and deinitialize the distributed package
+
+ Arguments:
+ group (ProcessGroup, optional): The process group to be destroyed, if
+ group.WORLD is given, all process
+ groups including the default one will
+ be destroyed.
+ """
+ global _pg_map
+ global _pg_names
+ global _pg_group_ranks
+ global _default_pg
+ global _default_pg_init_method
+
+ if group == GroupMember.NON_GROUP_MEMBER:
+ return
+
+ if group == GroupMember.WORLD:
+ pg = _default_pg
+ else:
+ pg = group
+
+ if _pg_map.get(pg, None) is None:
+ raise RuntimeError("Invalid process group specified")
+
+ if group == GroupMember.WORLD:
+ _default_pg = None
+ _default_pg_init_method = None
+ _pg_map.clear()
+ _pg_names.clear()
+ _pg_group_ranks.clear()
+ else:
+ del _pg_map[pg]
+ del _pg_names[pg]
+ del _pg_group_ranks[pg]
+
+
+[docs]def get_rank(group=group.WORLD):
+ """
+ Returns the rank of current process group
+
+ Rank is a unique identifier assigned to each process within a distributed
+ process group. They are always consecutive integers ranging from 0 to
+ ``world_size``.
+
+ Arguments:
+ group (ProcessGroup, optional): The process group to work on
+
+ Returns:
+ The rank of the process group
+ -1, if not part of the group
+
+ """
+ if _rank_not_in_group(group):
+ return -1
+
+ _check_default_pg()
+ if group == GroupMember.WORLD:
+ return _default_pg.rank()
+
+ return _get_group_rank(group, _default_pg.rank())
+
+
+[docs]def get_world_size(group=group.WORLD):
+ """
+ Returns the number of processes in the current process group
+
+ Arguments:
+ group (ProcessGroup, optional): The process group to work on
+
+ Returns:
+ The world size of the process group
+ -1, if not part of the group
+
+ """
+ if _rank_not_in_group(group):
+ return -1
+
+ return _get_group_size(group)
+
+
+[docs]def isend(tensor,
+ dst,
+ group=group.WORLD,
+ tag=0):
+ """
+ Sends a tensor asynchronously.
+
+ Arguments:
+ tensor (Tensor): Tensor to send.
+ dst (int): Destination rank.
+ group (ProcessGroup, optional): The process group to work on
+ tag (int, optional): Tag to match send with remote recv
+
+ Returns:
+ A distributed request object.
+ None, if not part of the group
+
+ """
+ _check_single_tensor(tensor, "tensor")
+ if _rank_not_in_group(group):
+ return
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ return _default_pg.send([tensor], dst, tag)
+ else:
+ group_dst_rank = _get_group_rank(group, dst)
+ return group.send([tensor], group_dst_rank, tag)
+
+
+[docs]def irecv(tensor,
+ src,
+ group=group.WORLD,
+ tag=0):
+ """
+ Receives a tensor asynchronously.
+
+ Arguments:
+ tensor (Tensor): Tensor to fill with received data.
+ src (int): Source rank.
+ group (ProcessGroup, optional): The process group to work on
+ tag (int, optional): Tag to match recv with remote send
+
+ Returns:
+ A distributed request object.
+ None, if not part of the group
+
+ """
+ _check_single_tensor(tensor, "tensor")
+ if _rank_not_in_group(group):
+ return
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ return _default_pg.recv([tensor], src, tag)
+ else:
+ group_src_rank = _get_group_rank(group, src)
+ return group.recv([tensor], group_src_rank, tag)
+
+
+[docs]def send(tensor,
+ dst,
+ group=group.WORLD,
+ tag=0):
+ """
+ Sends a tensor synchronously.
+
+ Arguments:
+ tensor (Tensor): Tensor to send.
+ dst (int): Destination rank.
+ group (ProcessGroup, optional): The process group to work on
+ tag (int, optional): Tag to match send with remote recv
+
+ """
+ _check_single_tensor(tensor, "tensor")
+ if _rank_not_in_group(group):
+ return
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ _default_pg.send([tensor], dst, tag).wait()
+ else:
+ group_dst_rank = _get_group_rank(group, dst)
+ group.send([tensor], group_dst_rank, tag).wait()
+
+
+[docs]def recv(tensor,
+ src=None,
+ group=group.WORLD,
+ tag=0):
+ """
+ Receives a tensor synchronously.
+
+ Arguments:
+ tensor (Tensor): Tensor to fill with received data.
+ src (int, optional): Source rank. Will receive from any
+ process if unspecified.
+ group (ProcessGroup, optional): The process group to work on
+ tag (int, optional): Tag to match recv with remote send
+
+ Returns:
+ Sender rank
+ -1, if not part of the group
+
+ """
+ _check_single_tensor(tensor, "tensor")
+ if _rank_not_in_group(group):
+ return -1
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ pg = _default_pg
+ else:
+ pg = group
+
+ if src is None:
+ work = pg.recv_anysource([tensor], tag)
+ work.wait()
+ src_rank = work.source_rank()
+ if group == GroupMember.WORLD:
+ return src_rank
+ else:
+ return _get_global_rank(pg, src_rank)
+ else:
+ if group == GroupMember.WORLD:
+ pg.recv([tensor], src, tag).wait()
+ else:
+ group_src_rank = _get_group_rank(pg, src)
+ pg.recv([tensor], group_src_rank, tag).wait()
+ return src
+
+
+[docs]def broadcast_multigpu(tensor_list,
+ src,
+ group=group.WORLD,
+ async_op=False,
+ src_tensor=0):
+ """
+ Broadcasts the tensor to the whole group with multiple GPU tensors
+ per node.
+
+ ``tensor`` must have the same number of elements in all the GPUs from
+ all processes participating in the collective. each tensor in the list must
+ be on a different GPU
+
+ Only nccl and gloo backend are currently supported
+ tensors should only be GPU tensors
+
+ Arguments:
+ tensor_list (List[Tensor]): Tensors that participate in the collective
+ operation. If ``src`` is the rank, then the specified ``src_tensor``
+ element of ``tensor_list`` (``tensor_list[src_tensor]``) will be
+ broadcast to all other tensors (on different GPUs) in the src process
+ and all tensors in ``tensor_list`` of other non-src processes.
+ You also need to make sure that ``len(tensor_list)`` is the same
+ for all the distributed processes calling this function.
+
+ src (int): Source rank.
+ group (ProcessGroup, optional): The process group to work on
+ async_op (bool, optional): Whether this op should be an async op
+ src_tensor (int, optional): Source tensor rank within ``tensor_list``
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group
+
+ """
+ if _rank_not_in_group(group):
+ return
+
+ opts = BroadcastOptions()
+ opts.rootRank = src
+ opts.rootTensor = src_tensor
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.broadcast(tensor_list, opts)
+ else:
+ group_src_rank = _get_group_rank(group, src)
+ opts.rootRank = group_src_rank
+ work = group.broadcast(tensor_list, opts)
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+[docs]def broadcast(tensor,
+ src,
+ group=group.WORLD,
+ async_op=False):
+ """
+ Broadcasts the tensor to the whole group.
+
+ ``tensor`` must have the same number of elements in all processes
+ participating in the collective.
+
+ Arguments:
+ tensor (Tensor): Data to be sent if ``src`` is the rank of current
+ process, and tensor to be used to save received data otherwise.
+ src (int): Source rank.
+ group (ProcessGroup, optional): The process group to work on
+ async_op (bool, optional): Whether this op should be an async op
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group
+
+ """
+ _check_single_tensor(tensor, "tensor")
+ if _rank_not_in_group(group):
+ return
+
+ opts = BroadcastOptions()
+ opts.rootRank = src
+ opts.rootTensor = 0
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.broadcast([tensor], opts)
+ else:
+ group_src_rank = _get_group_rank(group, src)
+ opts.rootRank = group_src_rank
+ work = group.broadcast([tensor], opts)
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+[docs]def all_reduce_multigpu(tensor_list,
+ op=ReduceOp.SUM,
+ group=group.WORLD,
+ async_op=False):
+ r"""
+ Reduces the tensor data across all machines in such a way that all get
+ the final result. This function reduces a number of tensors on every node,
+ while each tensor resides on different GPUs.
+ Therefore, the input tensor in the tensor list needs to be GPU tensors.
+ Also, each tensor in the tensor list needs to reside on a different GPU.
+
+ After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise
+ identical in all processes.
+
+ Only nccl and gloo backend is currently supported
+ tensors should only be GPU tensors
+
+ Arguments:
+ tensor list (List[Tensor]): List of input and output tensors of
+ the collective. The function operates in-place and requires that
+ each tensor to be a GPU tensor on different GPUs.
+ You also need to make sure that ``len(tensor_list)`` is the same for
+ all the distributed processes calling this function.
+ op (optional): One of the values from
+ ``torch.distributed.ReduceOp``
+ enum. Specifies an operation used for element-wise reductions.
+ group (ProcessGroup, optional): The process group to work on
+ async_op (bool, optional): Whether this op should be an async op
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group
+
+ """
+ if _rank_not_in_group(group):
+ return
+
+ opts = AllreduceOptions()
+ opts.reduceOp = op
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.allreduce(tensor_list, opts)
+ else:
+ work = group.allreduce(tensor_list, opts)
+
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+[docs]def all_reduce(tensor,
+ op=ReduceOp.SUM,
+ group=group.WORLD,
+ async_op=False):
+ """
+ Reduces the tensor data across all machines in such a way that all get
+ the final result.
+
+ After the call ``tensor`` is going to be bitwise identical in all processes.
+
+ Arguments:
+ tensor (Tensor): Input and output of the collective. The function
+ operates in-place.
+ op (optional): One of the values from
+ ``torch.distributed.ReduceOp``
+ enum. Specifies an operation used for element-wise reductions.
+ group (ProcessGroup, optional): The process group to work on
+ async_op (bool, optional): Whether this op should be an async op
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group
+
+ """
+ _check_single_tensor(tensor, "tensor")
+ if _rank_not_in_group(group):
+ return
+
+ opts = AllreduceOptions()
+ opts.reduceOp = op
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.allreduce([tensor], opts)
+ else:
+ work = group.allreduce([tensor], opts)
+
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+[docs]def reduce_multigpu(tensor_list,
+ dst,
+ op=ReduceOp.SUM,
+ group=group.WORLD,
+ async_op=False,
+ dst_tensor=0):
+ """
+ Reduces the tensor data on multiple GPUs across all machines. Each tensor
+ in ``tensor_list`` should reside on a separate GPU
+
+ Only the GPU of ``tensor_list[dst_tensor]`` on the process with rank ``dst``
+ is going to receive the final result.
+
+ Only nccl backend is currently supported
+ tensors should only be GPU tensors
+
+ Arguments:
+ tensor_list (List[Tensor]): Input and output GPU tensors of the
+ collective. The function operates in-place.
+ You also need to make sure that ``len(tensor_list)`` is the same for
+ all the distributed processes calling this function.
+ dst (int): Destination rank
+ op (optional): One of the values from
+ ``torch.distributed.ReduceOp``
+ enum. Specifies an operation used for element-wise reductions.
+ group (ProcessGroup, optional): The process group to work on
+ async_op (bool, optional): Whether this op should be an async op
+ dst_tensor (int, optional): Destination tensor rank within
+ ``tensor_list``
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, otherwise
+
+ """
+ if _rank_not_in_group(group):
+ return
+
+ opts = ReduceOptions()
+ opts.reduceOp = op
+ opts.rootRank = dst
+ opts.rootTensor = dst_tensor
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.reduce(tensor_list, opts)
+ else:
+ group_dst_rank = _get_group_rank(group, dst)
+ opts.rootRank = group_dst_rank
+ work = group.reduce(tensor_list, opts)
+
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+[docs]def reduce(tensor,
+ dst,
+ op=ReduceOp.SUM,
+ group=group.WORLD,
+ async_op=False):
+ """
+ Reduces the tensor data across all machines.
+
+ Only the process with rank ``dst`` is going to receive the final result.
+
+ Arguments:
+ tensor (Tensor): Input and output of the collective. The function
+ operates in-place.
+ dst (int): Destination rank
+ op (optional): One of the values from
+ ``torch.distributed.ReduceOp``
+ enum. Specifies an operation used for element-wise reductions.
+ group (ProcessGroup, optional): The process group to work on
+ async_op (bool, optional): Whether this op should be an async op
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group
+
+ """
+ _check_single_tensor(tensor, "tensor")
+ if _rank_not_in_group(group):
+ return
+
+ opts = ReduceOptions()
+ opts.reduceOp = op
+ opts.rootRank = dst
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.reduce([tensor], opts)
+ else:
+ group_dst_rank = _get_group_rank(group, dst)
+ opts.rootRank = group_dst_rank
+ work = group.reduce([tensor], opts)
+
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+[docs]def all_gather_multigpu(output_tensor_lists,
+ input_tensor_list,
+ group=group.WORLD,
+ async_op=False):
+ """
+ Gathers tensors from the whole group in a list.
+ Each tensor in ``tensor_list`` should reside on a separate GPU
+
+ Only nccl backend is currently supported
+ tensors should only be GPU tensors
+
+ Arguments:
+ output_tensor_lists (List[List[Tensor]]): Output lists. It should
+ contain correctly-sized tensors on each GPU to be used for output
+ of the collective, e.g. ``output_tensor_lists[i]`` contains the
+ all_gather result that resides on the GPU of
+ ``input_tensor_list[i]``.
+
+ Note that each element of ``output_tensor_lists`` has the size of
+ ``world_size * len(input_tensor_list)``, since the function all
+ gathers the result from every single GPU in the group. To interpret
+ each element of ``output_tensor_lists[i]``, note that
+ ``input_tensor_list[j]`` of rank k will be appear in
+ ``output_tensor_lists[i][k * world_size + j]``
+
+ Also note that ``len(output_tensor_lists)``, and the size of each
+ element in ``output_tensor_lists`` (each element is a list,
+ therefore ``len(output_tensor_lists[i])``) need to be the same
+ for all the distributed processes calling this function.
+
+ input_tensor_list (List[Tensor]): List of tensors(on different GPUs) to
+ be broadcast from current process.
+ Note that ``len(input_tensor_list)`` needs to be the same for
+ all the distributed processes calling this function.
+
+ group (ProcessGroup, optional): The process group to work on
+ async_op (bool, optional): Whether this op should be an async op
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group
+
+ """
+ if _rank_not_in_group(group):
+ return
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.allgather(output_tensor_lists, input_tensor_list)
+ else:
+ work = group.allgather(output_tensor_lists, input_tensor_list)
+
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+[docs]def all_gather(tensor_list,
+ tensor,
+ group=group.WORLD,
+ async_op=False):
+ """
+ Gathers tensors from the whole group in a list.
+
+ Arguments:
+ tensor_list (list[Tensor]): Output list. It should contain
+ correctly-sized tensors to be used for output of the collective.
+ tensor (Tensor): Tensor to be broadcast from current process.
+ group (ProcessGroup, optional): The process group to work on
+ async_op (bool, optional): Whether this op should be an async op
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group
+
+ """
+ _check_tensor_list(tensor_list, "tensor_list")
+ _check_single_tensor(tensor, "tensor")
+ if _rank_not_in_group(group):
+ return
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.allgather([tensor_list], [tensor])
+ else:
+ work = group.allgather([tensor_list], [tensor])
+
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+[docs]def gather(tensor,
+ gather_list,
+ dst,
+ group=group.WORLD,
+ async_op=False):
+ """
+ Gathers a list of tensors in a single process.
+
+ Arguments:
+ tensor (Tensor): Input tensor.
+ gather_list (list[Tensor]): List of appropriately-sized tensors to
+ use for received data. Required only in the receiving process.
+ dst (int): Destination rank. Required in all processes except the one
+ that is receiveing the data.
+ group (ProcessGroup, optional): The process group to work on
+ async_op (bool, optional): Whether this op should be an async op
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group
+
+ """
+ _check_single_tensor(tensor, "tensor")
+ _check_tensor_list(gather_list, "gather_list")
+ if _rank_not_in_group(group):
+ return
+
+ my_rank = get_rank()
+ if dst == my_rank:
+ if gather_list is None:
+ raise RuntimeError("gather_list is a required argument in gather "
+ "destination")
+ input_tensors = [tensor]
+ output_tensors = [gather_list]
+ else:
+ if gather_list:
+ raise RuntimeError("non-empty gather_list can be given only "
+ "to gather destination")
+ input_tensors = [tensor]
+ output_tensors = []
+
+ opts = GatherOptions()
+ opts.rootRank = dst
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.gather(output_tensors, input_tensors, opts)
+ else:
+ group_dst_rank = _get_group_rank(group, dst)
+ opts.rootRank = group_dst_rank
+ work = group.gather(output_tensors, input_tensors, opts)
+
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+[docs]def scatter(tensor,
+ scatter_list,
+ src,
+ group=group.WORLD,
+ async_op=False):
+ """
+ Scatters a list of tensors to all processes in a group.
+
+ Each process will receive exactly one tensor and store its data in the
+ ``tensor`` argument.
+
+ Arguments:
+ tensor (Tensor): Output tensor.
+ scatter_list (list[Tensor]): List of tensors to scatter. Required only
+ in the process that is sending the data.
+ src (int): Source rank. Required in all processes except the one that
+ is sending the data.
+ group (ProcessGroup, optional): The process group to work on
+ async_op (bool, optional): Whether this op should be an async op
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group
+
+ """
+ _check_single_tensor(tensor, "tensor")
+ _check_tensor_list(scatter_list, "scatter_list")
+ if _rank_not_in_group(group):
+ return
+
+ my_rank = get_rank()
+ if src == my_rank:
+ if scatter_list is None:
+ raise RuntimeError("scatter_list is a required argument in "
+ "scatter source")
+ input_tensors = [scatter_list]
+ output_tensors = [tensor]
+ else:
+ if scatter_list:
+ raise RuntimeError("non-empty can be given only to scatter "
+ "source")
+ input_tensors = []
+ output_tensors = [tensor]
+
+ opts = ScatterOptions()
+ opts.rootRank = src
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.scatter(output_tensors, input_tensors, opts)
+ else:
+ group_src_rank = _get_group_rank(group, src)
+ opts.rootRank = group_src_rank
+ work = group.scatter(output_tensors, input_tensors, opts)
+
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+def reduce_scatter_multigpu(output_tensor_list,
+ input_tensor_lists,
+ op=ReduceOp.SUM,
+ group=group.WORLD,
+ async_op=False):
+ """
+ Reduce and scatter a list of tensors to the whole group. Only nccl backend
+ is currently supported.
+
+ Each tensor in ``output_tensor_list`` should reside on a separate GPU, as
+ should each list of tensors in ``input_tensor_lists``.
+
+ Arguments:
+ output_tensor_list (List[Tensor]): Output tensors (on different GPUs)
+ to receive the result of the operation.
+
+ Note that ``len(output_tensor_list)`` needs to be the same for all
+ the distributed processes calling this function.
+
+ input_tensor_lists (List[List[Tensor]]): Input lists. It should
+ contain correctly-sized tensors on each GPU to be used for input of
+ the collective, e.g. ``input_tensor_lists[i]`` contains the
+ reduce_scatter input that resides on the GPU of
+ ``output_tensor_list[i]``.
+
+ Note that each element of ``input_tensor_lists`` has the size of
+ ``world_size * len(output_tensor_list)``, since the function
+ scatters the result from every single GPU in the group. To
+ interpret each element of ``input_tensor_lists[i]``, note that
+ ``output_tensor_list[j]`` of rank k receives the reduce-scattered
+ result from ``input_tensor_lists[i][k * world_size + j]``
+
+ Also note that ``len(input_tensor_lists)``, and the size of each
+ element in ``input_tensor_lists`` (each element is a list,
+ therefore ``len(input_tensor_lists[i])``) need to be the same for
+ all the distributed processes calling this function.
+
+ group (ProcessGroup, optional): The process group to work on.
+ async_op (bool, optional): Whether this op should be an async op.
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group.
+
+ """
+ if _rank_not_in_group(group):
+ return
+
+ opts = ReduceScatterOptions()
+ opts.reduceOp = op
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.reduce_scatter(
+ output_tensor_list,
+ input_tensor_lists,
+ opts
+ )
+ else:
+ work = group.reduce_scatter(
+ output_tensor_list,
+ input_tensor_lists,
+ opts
+ )
+
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+def reduce_scatter(output,
+ input_list,
+ op=ReduceOp.SUM,
+ group=group.WORLD,
+ async_op=False):
+ """
+ Reduces, then scatters a list of tensors to all processes in a group.
+
+ Arguments:
+ output (Tensor): Output tensor.
+ input_list (list[Tensor]): List of tensors to reduce and scatter.
+ group (ProcessGroup, optional): The process group to work on.
+ async_op (bool, optional): Whether this op should be an async op.
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group.
+
+ """
+ _check_single_tensor(output, "output")
+ _check_tensor_list(input_list, "input_list")
+ if _rank_not_in_group(group):
+ return
+
+ opts = ReduceScatterOptions()
+ opts.reduceOp = op
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.reduce_scatter([output], [input_list], opts)
+ else:
+ work = group.reduce_scatter([output], [input_list], opts)
+
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+[docs]def barrier(group=group.WORLD,
+ async_op=False):
+ """
+ Synchronizes all processes.
+
+ This collective blocks processes until the whole group enters this function,
+ if async_op is False, or if async work handle is called on wait().
+
+ Arguments:
+ group (ProcessGroup, optional): The process group to work on
+ async_op (bool, optional): Whether this op should be an async op
+
+ Returns:
+ Async work handle, if async_op is set to True.
+ None, if not async_op or if not part of the group
+ """
+ if _rank_not_in_group(group):
+ return
+
+ if group == GroupMember.WORLD:
+ _check_default_pg()
+ work = _default_pg.barrier()
+ else:
+ work = group.barrier()
+
+ if async_op:
+ return work
+ else:
+ work.wait()
+
+
+[docs]def new_group(ranks=None, timeout=_default_pg_timeout, backend=None):
+ """
+ Creates a new distributed group.
+
+ This function requires that all processes in the main group (i.e. all
+ processes that are part of the distributed job) enter this function, even
+ if they are not going to be members of the group. Additionally, groups
+ should be created in the same order in all processes.
+
+ Arguments:
+ ranks (list[int]): List of ranks of group members.
+ timeout (timedelta, optional): Timeout for operations executed against
+ the process group. Default value equals 30 minutes.
+ This is only applicable for the ``gloo`` backend.
+ backend (str or Backend, optional): The backend to use. Depending on
+ build-time configurations, valid values are ``gloo`` and ``nccl``.
+ By default uses the same backend as the global group. This field
+ should be given as a lowercase string (e.g., ``"gloo"``), which can
+ also be accessed via :class:`Backend` attributes (e.g.,
+ ``Backend.GLOO``).
+
+ Returns:
+ A handle of distributed group that can be given to collective calls.
+ """
+
+ _check_default_pg()
+
+ global _pg_group_ranks
+
+ default_backend, default_store = _pg_map[_default_pg]
+ global_rank = _default_pg.rank()
+ global_world_size = _default_pg.size()
+
+ # Default to the same backend as the global process group
+ # if the backend is not specified.
+ if not backend:
+ backend = default_backend
+
+ # checks the input ranks
+ if ranks is not None:
+ ranks = sorted(ranks)
+ group_world_size = len(ranks)
+ if group_world_size > global_world_size:
+ raise RuntimeError("the new group's world size should be less or "
+ "equal to the world size set by "
+ "init_process_group")
+ # check ranks' sanity
+ for rank in ranks:
+ if rank < 0 or rank >= global_world_size:
+ raise RuntimeError("The new group's rank should be within the "
+ "the world_size set by init_process_group")
+ if global_rank in ranks:
+ group_rank = ranks.index(global_rank)
+ else:
+ group_rank = None
+ else:
+ ranks = list(range(global_world_size))
+ group_world_size = global_world_size
+ group_rank = global_rank
+
+ backend = Backend(backend)
+ pg = _new_process_group_helper(group_world_size,
+ group_rank,
+ ranks,
+ backend,
+ default_store,
+ timeout=timeout)
+
+ # Create the global rank to group rank mapping
+ _pg_group_ranks[pg] = {
+ global_rank: group_rank
+ for group_rank, global_rank in enumerate(ranks)
+ }
+
+ return pg
+
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property
+from torch.nn.functional import binary_cross_entropy_with_logits
+
+
+[docs]class Bernoulli(ExponentialFamily):
+ r"""
+ Creates a Bernoulli distribution parameterized by :attr:`probs`
+ or :attr:`logits` (but not both).
+
+ Samples are binary (0 or 1). They take the value `1` with probability `p`
+ and `0` with probability `1 - p`.
+
+ Example::
+
+ >>> m = Bernoulli(torch.tensor([0.3]))
+ >>> m.sample() # 30% chance 1; 70% chance 0
+ tensor([ 0.])
+
+ Args:
+ probs (Number, Tensor): the probability of sampling `1`
+ logits (Number, Tensor): the log-odds of sampling `1`
+ """
+ arg_constraints = {'probs': constraints.unit_interval,
+ 'logits': constraints.real}
+ support = constraints.boolean
+ has_enumerate_support = True
+ _mean_carrier_measure = 0
+
+ def __init__(self, probs=None, logits=None, validate_args=None):
+ if (probs is None) == (logits is None):
+ raise ValueError("Either `probs` or `logits` must be specified, but not both.")
+ if probs is not None:
+ is_scalar = isinstance(probs, Number)
+ self.probs, = broadcast_all(probs)
+ else:
+ is_scalar = isinstance(logits, Number)
+ self.logits, = broadcast_all(logits)
+ self._param = self.probs if probs is not None else self.logits
+ if is_scalar:
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self._param.size()
+ super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Bernoulli, _instance)
+ batch_shape = torch.Size(batch_shape)
+ if 'probs' in self.__dict__:
+ new.probs = self.probs.expand(batch_shape)
+ new._param = new.probs
+ if 'logits' in self.__dict__:
+ new.logits = self.logits.expand(batch_shape)
+ new._param = new.logits
+ super(Bernoulli, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ def _new(self, *args, **kwargs):
+ return self._param.new(*args, **kwargs)
+
+ @property
+ def mean(self):
+ return self.probs
+
+ @property
+ def variance(self):
+ return self.probs * (1 - self.probs)
+
+
+
+
+
+ @property
+ def param_shape(self):
+ return self._param.size()
+
+[docs] def sample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ with torch.no_grad():
+ return torch.bernoulli(self.probs.expand(shape))
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ logits, value = broadcast_all(self.logits, value)
+ return -binary_cross_entropy_with_logits(logits, value, reduction='none')
+
+[docs] def entropy(self):
+ return binary_cross_entropy_with_logits(self.logits, self.probs, reduction='none')
+
+[docs] def enumerate_support(self, expand=True):
+ values = torch.arange(2, dtype=self._param.dtype, device=self._param.device)
+ values = values.view((-1,) + (1,) * len(self._batch_shape))
+ if expand:
+ values = values.expand((-1,) + self._batch_shape)
+ return values
+
+ @property
+ def _natural_params(self):
+ return (torch.log(self.probs / (1 - self.probs)), )
+
+ def _log_normalizer(self, x):
+ return torch.log(1 + torch.exp(x))
+
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.dirichlet import Dirichlet
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import broadcast_all
+
+
+[docs]class Beta(ExponentialFamily):
+ r"""
+ Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
+
+ Example::
+
+ >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
+ >>> m.sample() # Beta distributed with concentration concentration1 and concentration0
+ tensor([ 0.1046])
+
+ Args:
+ concentration1 (float or Tensor): 1st concentration parameter of the distribution
+ (often referred to as alpha)
+ concentration0 (float or Tensor): 2nd concentration parameter of the distribution
+ (often referred to as beta)
+ """
+ arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive}
+ support = constraints.unit_interval
+ has_rsample = True
+
+ def __init__(self, concentration1, concentration0, validate_args=None):
+ if isinstance(concentration1, Number) and isinstance(concentration0, Number):
+ concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
+ else:
+ concentration1, concentration0 = broadcast_all(concentration1, concentration0)
+ concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
+ self._dirichlet = Dirichlet(concentration1_concentration0)
+ super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Beta, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new._dirichlet = self._dirichlet.expand(batch_shape)
+ super(Beta, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ @property
+ def mean(self):
+ return self.concentration1 / (self.concentration1 + self.concentration0)
+
+ @property
+ def variance(self):
+ total = self.concentration1 + self.concentration0
+ return (self.concentration1 * self.concentration0 /
+ (total.pow(2) * (total + 1)))
+
+[docs] def rsample(self, sample_shape=()):
+ value = self._dirichlet.rsample(sample_shape).select(-1, 0)
+ if isinstance(value, Number):
+ value = self._dirichlet.concentration.new_tensor(value)
+ return value
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ heads_tails = torch.stack([value, 1.0 - value], -1)
+ return self._dirichlet.log_prob(heads_tails)
+
+
+
+ @property
+ def concentration1(self):
+ result = self._dirichlet.concentration[..., 0]
+ if isinstance(result, Number):
+ return torch.tensor([result])
+ else:
+ return result
+
+ @property
+ def concentration0(self):
+ result = self._dirichlet.concentration[..., 1]
+ if isinstance(result, Number):
+ return torch.tensor([result])
+ else:
+ return result
+
+ @property
+ def _natural_params(self):
+ return (self.concentration1, self.concentration0)
+
+ def _log_normalizer(self, x, y):
+ return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
+
+from numbers import Number
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs
+
+
+[docs]class Binomial(Distribution):
+ r"""
+ Creates a Binomial distribution parameterized by :attr:`total_count` and
+ either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be
+ broadcastable with :attr:`probs`/:attr:`logits`.
+
+ Example::
+
+ >>> m = Binomial(100, torch.tensor([0 , .2, .8, 1]))
+ >>> x = m.sample()
+ tensor([ 0., 22., 71., 100.])
+
+ >>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8]))
+ >>> x = m.sample()
+ tensor([[ 4., 5.],
+ [ 7., 6.]])
+
+ Args:
+ total_count (int or Tensor): number of Bernoulli trials
+ probs (Tensor): Event probabilities
+ logits (Tensor): Event log-odds
+ """
+ arg_constraints = {'total_count': constraints.nonnegative_integer,
+ 'probs': constraints.unit_interval,
+ 'logits': constraints.real}
+ has_enumerate_support = True
+
+ def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
+ if (probs is None) == (logits is None):
+ raise ValueError("Either `probs` or `logits` must be specified, but not both.")
+ if probs is not None:
+ self.total_count, self.probs, = broadcast_all(total_count, probs)
+ self.total_count = self.total_count.type_as(self.logits)
+ is_scalar = isinstance(self.probs, Number)
+ else:
+ self.total_count, self.logits, = broadcast_all(total_count, logits)
+ self.total_count = self.total_count.type_as(self.logits)
+ is_scalar = isinstance(self.logits, Number)
+
+ self._param = self.probs if probs is not None else self.logits
+ if is_scalar:
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self._param.size()
+ super(Binomial, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Binomial, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.total_count = self.total_count.expand(batch_shape)
+ if 'probs' in self.__dict__:
+ new.probs = self.probs.expand(batch_shape)
+ new._param = new.probs
+ if 'logits' in self.__dict__:
+ new.logits = self.logits.expand(batch_shape)
+ new._param = new.logits
+ super(Binomial, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ def _new(self, *args, **kwargs):
+ return self._param.new(*args, **kwargs)
+
+ @constraints.dependent_property
+ def support(self):
+ return constraints.integer_interval(0, self.total_count)
+
+ @property
+ def mean(self):
+ return self.total_count * self.probs
+
+ @property
+ def variance(self):
+ return self.total_count * self.probs * (1 - self.probs)
+
+
+
+
+
+ @property
+ def param_shape(self):
+ return self._param.size()
+
+[docs] def sample(self, sample_shape=torch.Size()):
+ with torch.no_grad():
+ max_count = max(int(self.total_count.max()), 1)
+ shape = self._extended_shape(sample_shape) + (max_count,)
+ bernoullis = torch.bernoulli(self.probs.unsqueeze(-1).expand(shape))
+ if self.total_count.min() != max_count:
+ arange = torch.arange(max_count, dtype=self._param.dtype, device=self._param.device)
+ mask = arange >= self.total_count.unsqueeze(-1)
+ if torch._C._get_tracing_state():
+ # [JIT WORKAROUND] lack of support for .masked_fill_()
+ bernoullis[mask.expand(shape)] = 0.
+ else:
+ bernoullis.masked_fill_(mask, 0.)
+ return bernoullis.sum(dim=-1)
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ log_factorial_n = torch.lgamma(self.total_count + 1)
+ log_factorial_k = torch.lgamma(value + 1)
+ log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
+ # Note that: torch.log1p(-self.probs)) = - torch.log1p(self.logits.exp()))
+ return (log_factorial_n - log_factorial_k - log_factorial_nmk +
+ value * self.logits - self.total_count * torch.log1p(self.logits.exp()))
+
+[docs] def enumerate_support(self, expand=True):
+ total_count = int(self.total_count.max())
+ if not self.total_count.min() == total_count:
+ raise NotImplementedError("Inhomogeneous total count not supported by `enumerate_support`.")
+ values = torch.arange(1 + total_count, dtype=self._param.dtype, device=self._param.device)
+ values = values.view((-1,) + (1,) * len(self._batch_shape))
+ if expand:
+ values = values.expand((-1,) + self._batch_shape)
+ return values
+
+import torch
+from torch._six import nan
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import probs_to_logits, logits_to_probs, lazy_property
+
+
+[docs]class Categorical(Distribution):
+ r"""
+ Creates a categorical distribution parameterized by either :attr:`probs` or
+ :attr:`logits` (but not both).
+
+ .. note::
+ It is equivalent to the distribution that :func:`torch.multinomial`
+ samples from.
+
+ Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``.
+
+ If :attr:`probs` is 1D with length-`K`, each element is the relative
+ probability of sampling the class at that index.
+
+ If :attr:`probs` is 2D, it is treated as a batch of relative probability
+ vectors.
+
+ .. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
+ and it will be normalized to sum to 1.
+
+ See also: :func:`torch.multinomial`
+
+ Example::
+
+ >>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
+ >>> m.sample() # equal probability of 0, 1, 2, 3
+ tensor(3)
+
+ Args:
+ probs (Tensor): event probabilities
+ logits (Tensor): event log probabilities
+ """
+ arg_constraints = {'probs': constraints.simplex,
+ 'logits': constraints.real}
+ has_enumerate_support = True
+
+ def __init__(self, probs=None, logits=None, validate_args=None):
+ if (probs is None) == (logits is None):
+ raise ValueError("Either `probs` or `logits` must be specified, but not both.")
+ if probs is not None:
+ if probs.dim() < 1:
+ raise ValueError("`probs` parameter must be at least one-dimensional.")
+ self.probs = probs / probs.sum(-1, keepdim=True)
+ else:
+ if logits.dim() < 1:
+ raise ValueError("`logits` parameter must be at least one-dimensional.")
+ self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
+ self._param = self.probs if probs is not None else self.logits
+ self._num_events = self._param.size()[-1]
+ batch_shape = self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size()
+ super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Categorical, _instance)
+ batch_shape = torch.Size(batch_shape)
+ param_shape = batch_shape + torch.Size((self._num_events,))
+ if 'probs' in self.__dict__:
+ new.probs = self.probs.expand(param_shape)
+ new._param = new.probs
+ if 'logits' in self.__dict__:
+ new.logits = self.logits.expand(param_shape)
+ new._param = new.logits
+ new._num_events = self._num_events
+ super(Categorical, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ def _new(self, *args, **kwargs):
+ return self._param.new(*args, **kwargs)
+
+ @constraints.dependent_property
+ def support(self):
+ return constraints.integer_interval(0, self._num_events - 1)
+
+
+
+
+
+ @property
+ def param_shape(self):
+ return self._param.size()
+
+ @property
+ def mean(self):
+ return self.probs.new_tensor(nan).expand(self._extended_shape())
+
+ @property
+ def variance(self):
+ return self.probs.new_tensor(nan).expand(self._extended_shape())
+
+[docs] def sample(self, sample_shape=torch.Size()):
+ sample_shape = self._extended_shape(sample_shape)
+ param_shape = sample_shape + torch.Size((self._num_events,))
+ probs = self.probs.expand(param_shape)
+ if self.probs.dim() == 1 or self.probs.size(0) == 1:
+ probs_2d = probs.view(-1, self._num_events)
+ else:
+ probs_2d = probs.contiguous().view(-1, self._num_events)
+ sample_2d = torch.multinomial(probs_2d, 1, True)
+ return sample_2d.contiguous().view(sample_shape)
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ value = value.long().unsqueeze(-1)
+ value, log_pmf = torch.broadcast_tensors(value, self.logits)
+ value = value[..., :1]
+ return log_pmf.gather(-1, value).squeeze(-1)
+
+
+
+[docs] def enumerate_support(self, expand=True):
+ num_events = self._num_events
+ values = torch.arange(num_events, dtype=torch.long, device=self._param.device)
+ values = values.view((-1,) + (1,) * len(self._batch_shape))
+ if expand:
+ values = values.expand((-1,) + self._batch_shape)
+ return values
+
+import math
+from torch._six import inf, nan
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all
+
+
+[docs]class Cauchy(Distribution):
+ r"""
+ Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of
+ independent normally distributed random variables with means `0` follows a
+ Cauchy distribution.
+
+ Example::
+
+ >>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))
+ >>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1
+ tensor([ 2.3214])
+
+ Args:
+ loc (float or Tensor): mode or median of the distribution.
+ scale (float or Tensor): half width at half maximum.
+ """
+ arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
+ support = constraints.real
+ has_rsample = True
+
+ def __init__(self, loc, scale, validate_args=None):
+ self.loc, self.scale = broadcast_all(loc, scale)
+ if isinstance(loc, Number) and isinstance(scale, Number):
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self.loc.size()
+ super(Cauchy, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Cauchy, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.loc = self.loc.expand(batch_shape)
+ new.scale = self.scale.expand(batch_shape)
+ super(Cauchy, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ @property
+ def mean(self):
+ return self.loc.new_tensor(nan).expand(self._extended_shape())
+
+ @property
+ def variance(self):
+ return self.loc.new_tensor(inf).expand(self._extended_shape())
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ eps = self.loc.new(shape).cauchy_()
+ return self.loc + eps * self.scale
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return -math.log(math.pi) - self.scale.log() - (1 + ((value - self.loc) / self.scale)**2).log()
+
+[docs] def cdf(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5
+
+[docs] def icdf(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc
+
+
+
+from torch.distributions import constraints
+from torch.distributions.gamma import Gamma
+
+
+[docs]class Chi2(Gamma):
+ r"""
+ Creates a Chi2 distribution parameterized by shape parameter :attr:`df`.
+ This is exactly equivalent to ``Gamma(alpha=0.5*df, beta=0.5)``
+
+ Example::
+
+ >>> m = Chi2(torch.tensor([1.0]))
+ >>> m.sample() # Chi2 distributed with shape df=1
+ tensor([ 0.1046])
+
+ Args:
+ df (float or Tensor): shape parameter of the distribution
+ """
+ arg_constraints = {'df': constraints.positive}
+
+ def __init__(self, df, validate_args=None):
+ super(Chi2, self).__init__(0.5 * df, 0.5, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Chi2, _instance)
+ return super(Chi2, self).expand(batch_shape, new)
+
+ @property
+ def df(self):
+ return self.concentration * 2
+
+r"""
+PyTorch provides two global :class:`ConstraintRegistry` objects that link
+:class:`~torch.distributions.constraints.Constraint` objects to
+:class:`~torch.distributions.transforms.Transform` objects. These objects both
+input constraints and return transforms, but they have different guarantees on
+bijectivity.
+
+1. ``biject_to(constraint)`` looks up a bijective
+ :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
+ to the given ``constraint``. The returned transform is guaranteed to have
+ ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
+2. ``transform_to(constraint)`` looks up a not-necessarily bijective
+ :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
+ to the given ``constraint``. The returned transform is not guaranteed to
+ implement ``.log_abs_det_jacobian()``.
+
+The ``transform_to()`` registry is useful for performing unconstrained
+optimization on constrained parameters of probability distributions, which are
+indicated by each distribution's ``.arg_constraints`` dict. These transforms often
+overparameterize a space in order to avoid rotation; they are thus more
+suitable for coordinate-wise optimization algorithms like Adam::
+
+ loc = torch.zeros(100, requires_grad=True)
+ unconstrained = torch.zeros(100, requires_grad=True)
+ scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
+ loss = -Normal(loc, scale).log_prob(data).sum()
+
+The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
+samples from a probability distribution with constrained ``.support`` are
+propagated in an unconstrained space, and algorithms are typically rotation
+invariant.::
+
+ dist = Exponential(rate)
+ unconstrained = torch.zeros(100, requires_grad=True)
+ sample = biject_to(dist.support)(unconstrained)
+ potential_energy = -dist.log_prob(sample).sum()
+
+.. note::
+
+ An example where ``transform_to`` and ``biject_to`` differ is
+ ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
+ :class:`~torch.distributions.transforms.SoftmaxTransform` that simply
+ exponentiates and normalizes its inputs; this is a cheap and mostly
+ coordinate-wise operation appropriate for algorithms like SVI. In
+ contrast, ``biject_to(constraints.simplex)`` returns a
+ :class:`~torch.distributions.transforms.StickBreakingTransform` that
+ bijects its input down to a one-fewer-dimensional space; this a more
+ expensive less numerically stable transform but is needed for algorithms
+ like HMC.
+
+The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
+constraints and transforms using their ``.register()`` method either as a
+function on singleton constraints::
+
+ transform_to.register(my_constraint, my_transform)
+
+or as a decorator on parameterized constraints::
+
+ @transform_to.register(MyConstraintClass)
+ def my_factory(constraint):
+ assert isinstance(constraint, MyConstraintClass)
+ return MyTransform(constraint.param1, constraint.param2)
+
+You can create your own registry by creating a new :class:`ConstraintRegistry`
+object.
+"""
+
+import numbers
+
+from torch.distributions import constraints, transforms
+
+__all__ = [
+ 'ConstraintRegistry',
+ 'biject_to',
+ 'transform_to',
+]
+
+
+[docs]class ConstraintRegistry(object):
+ """
+ Registry to link constraints to transforms.
+ """
+ def __init__(self):
+ self._registry = {}
+ super(ConstraintRegistry, self).__init__()
+
+[docs] def register(self, constraint, factory=None):
+ """
+ Registers a :class:`~torch.distributions.constraints.Constraint`
+ subclass in this registry. Usage::
+
+ @my_registry.register(MyConstraintClass)
+ def construct_transform(constraint):
+ assert isinstance(constraint, MyConstraint)
+ return MyTransform(constraint.arg_constraints)
+
+ Args:
+ constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
+ A subclass of :class:`~torch.distributions.constraints.Constraint`, or
+ a singleton object of the desired class.
+ factory (callable): A callable that inputs a constraint object and returns
+ a :class:`~torch.distributions.transforms.Transform` object.
+ """
+ # Support use as decorator.
+ if factory is None:
+ return lambda factory: self.register(constraint, factory)
+
+ # Support calling on singleton instances.
+ if isinstance(constraint, constraints.Constraint):
+ constraint = type(constraint)
+
+ if not isinstance(constraint, type) or not issubclass(constraint, constraints.Constraint):
+ raise TypeError('Expected constraint to be either a Constraint subclass or instance, '
+ 'but got {}'.format(constraint))
+
+ self._registry[constraint] = factory
+ return factory
+
+ def __call__(self, constraint):
+ """
+ Looks up a transform to constrained space, given a constraint object.
+ Usage::
+
+ constraint = Normal.arg_constraints['scale']
+ scale = transform_to(constraint)(torch.zeros(1)) # constrained
+ u = transform_to(constraint).inv(scale) # unconstrained
+
+ Args:
+ constraint (:class:`~torch.distributions.constraints.Constraint`):
+ A constraint object.
+
+ Returns:
+ A :class:`~torch.distributions.transforms.Transform` object.
+
+ Raises:
+ `NotImplementedError` if no transform has been registered.
+ """
+ # Look up by Constraint subclass.
+ try:
+ factory = self._registry[type(constraint)]
+ except KeyError:
+ raise NotImplementedError(
+ 'Cannot transform {} constraints'.format(type(constraint).__name__))
+ return factory(constraint)
+
+
+biject_to = ConstraintRegistry()
+transform_to = ConstraintRegistry()
+
+
+################################################################################
+# Registration Table
+################################################################################
+
+@biject_to.register(constraints.real)
+@biject_to.register(constraints.real_vector)
+@transform_to.register(constraints.real)
+@transform_to.register(constraints.real_vector)
+def _transform_to_real(constraint):
+ return transforms.identity_transform
+
+
+@biject_to.register(constraints.positive)
+@transform_to.register(constraints.positive)
+def _transform_to_positive(constraint):
+ return transforms.ExpTransform()
+
+
+@biject_to.register(constraints.greater_than)
+@biject_to.register(constraints.greater_than_eq)
+@transform_to.register(constraints.greater_than)
+@transform_to.register(constraints.greater_than_eq)
+def _transform_to_greater_than(constraint):
+ return transforms.ComposeTransform([transforms.ExpTransform(),
+ transforms.AffineTransform(constraint.lower_bound, 1)])
+
+
+@biject_to.register(constraints.less_than)
+@transform_to.register(constraints.less_than)
+def _transform_to_less_than(constraint):
+ return transforms.ComposeTransform([transforms.ExpTransform(),
+ transforms.AffineTransform(constraint.upper_bound, -1)])
+
+
+@biject_to.register(constraints.interval)
+@biject_to.register(constraints.half_open_interval)
+@transform_to.register(constraints.interval)
+@transform_to.register(constraints.half_open_interval)
+def _transform_to_interval(constraint):
+ # Handle the special case of the unit interval.
+ lower_is_0 = isinstance(constraint.lower_bound, numbers.Number) and constraint.lower_bound == 0
+ upper_is_1 = isinstance(constraint.upper_bound, numbers.Number) and constraint.upper_bound == 1
+ if lower_is_0 and upper_is_1:
+ return transforms.SigmoidTransform()
+
+ loc = constraint.lower_bound
+ scale = constraint.upper_bound - constraint.lower_bound
+ return transforms.ComposeTransform([transforms.SigmoidTransform(),
+ transforms.AffineTransform(loc, scale)])
+
+
+@biject_to.register(constraints.simplex)
+def _biject_to_simplex(constraint):
+ return transforms.StickBreakingTransform()
+
+
+@transform_to.register(constraints.simplex)
+def _transform_to_simplex(constraint):
+ return transforms.SoftmaxTransform()
+
+
+# TODO define a bijection for LowerCholeskyTransform
+@transform_to.register(constraints.lower_cholesky)
+def _transform_to_lower_cholesky(constraint):
+ return transforms.LowerCholeskyTransform()
+
+r"""
+The following constraints are implemented:
+
+- ``constraints.boolean``
+- ``constraints.dependent``
+- ``constraints.greater_than(lower_bound)``
+- ``constraints.integer_interval(lower_bound, upper_bound)``
+- ``constraints.interval(lower_bound, upper_bound)``
+- ``constraints.lower_cholesky``
+- ``constraints.lower_triangular``
+- ``constraints.nonnegative_integer``
+- ``constraints.positive``
+- ``constraints.positive_definite``
+- ``constraints.positive_integer``
+- ``constraints.real``
+- ``constraints.real_vector``
+- ``constraints.simplex``
+- ``constraints.unit_interval``
+"""
+
+import torch
+
+__all__ = [
+ 'Constraint',
+ 'boolean',
+ 'dependent',
+ 'dependent_property',
+ 'greater_than',
+ 'greater_than_eq',
+ 'integer_interval',
+ 'interval',
+ 'half_open_interval',
+ 'is_dependent',
+ 'less_than',
+ 'lower_cholesky',
+ 'lower_triangular',
+ 'nonnegative_integer',
+ 'positive',
+ 'positive_definite',
+ 'positive_integer',
+ 'real',
+ 'real_vector',
+ 'simplex',
+ 'unit_interval',
+]
+
+
+[docs]class Constraint(object):
+ """
+ Abstract base class for constraints.
+
+ A constraint object represents a region over which a variable is valid,
+ e.g. within which a variable can be optimized.
+ """
+[docs] def check(self, value):
+ """
+ Returns a byte tensor of `sample_shape + batch_shape` indicating
+ whether each event in value satisfies this constraint.
+ """
+ raise NotImplementedError
+
+ def __repr__(self):
+ return self.__class__.__name__[1:] + '()'
+
+
+class _Dependent(Constraint):
+ """
+ Placeholder for variables whose support depends on other variables.
+ These variables obey no simple coordinate-wise constraints.
+ """
+ def check(self, x):
+ raise ValueError('Cannot determine validity of dependent constraint')
+
+
+def is_dependent(constraint):
+ return isinstance(constraint, _Dependent)
+
+
+class _DependentProperty(property, _Dependent):
+ """
+ Decorator that extends @property to act like a `Dependent` constraint when
+ called on a class and act like a property when called on an object.
+
+ Example::
+
+ class Uniform(Distribution):
+ def __init__(self, low, high):
+ self.low = low
+ self.high = high
+ @constraints.dependent_property
+ def support(self):
+ return constraints.interval(self.low, self.high)
+ """
+ pass
+
+
+class _Boolean(Constraint):
+ """
+ Constrain to the two values `{0, 1}`.
+ """
+ def check(self, value):
+ return (value == 0) | (value == 1)
+
+
+class _IntegerInterval(Constraint):
+ """
+ Constrain to an integer interval `[lower_bound, upper_bound]`.
+ """
+ def __init__(self, lower_bound, upper_bound):
+ self.lower_bound = lower_bound
+ self.upper_bound = upper_bound
+
+ def check(self, value):
+ return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
+
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
+ return fmt_string
+
+
+class _IntegerLessThan(Constraint):
+ """
+ Constrain to an integer interval `(-inf, upper_bound]`.
+ """
+ def __init__(self, upper_bound):
+ self.upper_bound = upper_bound
+
+ def check(self, value):
+ return (value % 1 == 0) & (value <= self.upper_bound)
+
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(upper_bound={})'.format(self.upper_bound)
+ return fmt_string
+
+
+class _IntegerGreaterThan(Constraint):
+ """
+ Constrain to an integer interval `[lower_bound, inf)`.
+ """
+ def __init__(self, lower_bound):
+ self.lower_bound = lower_bound
+
+ def check(self, value):
+ return (value % 1 == 0) & (value >= self.lower_bound)
+
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={})'.format(self.lower_bound)
+ return fmt_string
+
+
+class _Real(Constraint):
+ """
+ Trivially constrain to the extended real line `[-inf, inf]`.
+ """
+ def check(self, value):
+ return value == value # False for NANs.
+
+
+class _GreaterThan(Constraint):
+ """
+ Constrain to a real half line `(lower_bound, inf]`.
+ """
+ def __init__(self, lower_bound):
+ self.lower_bound = lower_bound
+
+ def check(self, value):
+ return self.lower_bound < value
+
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={})'.format(self.lower_bound)
+ return fmt_string
+
+
+class _GreaterThanEq(Constraint):
+ """
+ Constrain to a real half line `[lower_bound, inf)`.
+ """
+ def __init__(self, lower_bound):
+ self.lower_bound = lower_bound
+
+ def check(self, value):
+ return self.lower_bound <= value
+
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={})'.format(self.lower_bound)
+ return fmt_string
+
+
+class _LessThan(Constraint):
+ """
+ Constrain to a real half line `[-inf, upper_bound)`.
+ """
+ def __init__(self, upper_bound):
+ self.upper_bound = upper_bound
+
+ def check(self, value):
+ return value < self.upper_bound
+
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(upper_bound={})'.format(self.upper_bound)
+ return fmt_string
+
+
+class _Interval(Constraint):
+ """
+ Constrain to a real interval `[lower_bound, upper_bound]`.
+ """
+ def __init__(self, lower_bound, upper_bound):
+ self.lower_bound = lower_bound
+ self.upper_bound = upper_bound
+
+ def check(self, value):
+ return (self.lower_bound <= value) & (value <= self.upper_bound)
+
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
+ return fmt_string
+
+
+class _HalfOpenInterval(Constraint):
+ """
+ Constrain to a real interval `[lower_bound, upper_bound)`.
+ """
+ def __init__(self, lower_bound, upper_bound):
+ self.lower_bound = lower_bound
+ self.upper_bound = upper_bound
+
+ def check(self, value):
+ return (self.lower_bound <= value) & (value < self.upper_bound)
+
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
+ return fmt_string
+
+
+class _Simplex(Constraint):
+ """
+ Constrain to the unit simplex in the innermost (rightmost) dimension.
+ Specifically: `x >= 0` and `x.sum(-1) == 1`.
+ """
+ def check(self, value):
+ return (value >= 0).all() & ((value.sum(-1, True) - 1).abs() < 1e-6).all()
+
+
+class _LowerTriangular(Constraint):
+ """
+ Constrain to lower-triangular square matrices.
+ """
+ def check(self, value):
+ value_tril = value.tril()
+ return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
+
+
+class _LowerCholesky(Constraint):
+ """
+ Constrain to lower-triangular square matrices with positive diagonals.
+ """
+ def check(self, value):
+ value_tril = value.tril()
+ lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
+
+ positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
+ return lower_triangular & positive_diagonal
+
+
+class _PositiveDefinite(Constraint):
+ """
+ Constrain to positive-definite matrices.
+ """
+ def check(self, value):
+ matrix_shape = value.shape[-2:]
+ batch_shape = value.unsqueeze(0).shape[:-2]
+ # TODO: replace with batched linear algebra routine when one becomes available
+ # note that `symeig()` returns eigenvalues in ascending order
+ flattened_value = value.reshape((-1,) + matrix_shape)
+ return torch.stack([v.symeig(eigenvectors=False)[0][:1] > 0.0
+ for v in flattened_value]).view(batch_shape)
+
+
+class _RealVector(Constraint):
+ """
+ Constrain to real-valued vectors. This is the same as `constraints.real`,
+ but additionally reduces across the `event_shape` dimension.
+ """
+ def check(self, value):
+ return (value == value).all() # False for NANs.
+
+
+# Public interface.
+dependent = _Dependent()
+dependent_property = _DependentProperty
+boolean = _Boolean()
+nonnegative_integer = _IntegerGreaterThan(0)
+positive_integer = _IntegerGreaterThan(1)
+integer_interval = _IntegerInterval
+real = _Real()
+real_vector = _RealVector()
+positive = _GreaterThan(0.)
+greater_than = _GreaterThan
+greater_than_eq = _GreaterThanEq
+less_than = _LessThan
+unit_interval = _Interval(0., 1.)
+interval = _Interval
+half_open_interval = _HalfOpenInterval
+simplex = _Simplex()
+lower_triangular = _LowerTriangular()
+lower_cholesky = _LowerCholesky()
+positive_definite = _PositiveDefinite()
+
+import torch
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+
+
+# This helper is exposed for testing.
+def _Dirichlet_backward(x, concentration, grad_output):
+ total = concentration.sum(-1, True).expand_as(concentration)
+ grad = torch._dirichlet_grad(x, concentration, total)
+ return grad * (grad_output - (x * grad_output).sum(-1, True))
+
+
+class _Dirichlet(Function):
+ @staticmethod
+ def forward(ctx, concentration):
+ x = torch._sample_dirichlet(concentration)
+ ctx.save_for_backward(x, concentration)
+ return x
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ x, concentration = ctx.saved_tensors
+ return _Dirichlet_backward(x, concentration, grad_output)
+
+
+[docs]class Dirichlet(ExponentialFamily):
+ r"""
+ Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`.
+
+ Example::
+
+ >>> m = Dirichlet(torch.tensor([0.5, 0.5]))
+ >>> m.sample() # Dirichlet distributed with concentrarion concentration
+ tensor([ 0.1046, 0.8954])
+
+ Args:
+ concentration (Tensor): concentration parameter of the distribution
+ (often referred to as alpha)
+ """
+ arg_constraints = {'concentration': constraints.positive}
+ support = constraints.simplex
+ has_rsample = True
+
+ def __init__(self, concentration, validate_args=None):
+ if concentration.dim() < 1:
+ raise ValueError("`concentration` parameter must be at least one-dimensional.")
+ self.concentration = concentration
+ batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
+ super(Dirichlet, self).__init__(batch_shape, event_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Dirichlet, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.concentration = self.concentration.expand(batch_shape + self.event_shape)
+ super(Dirichlet, new).__init__(batch_shape, self.event_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+[docs] def rsample(self, sample_shape=()):
+ shape = self._extended_shape(sample_shape)
+ concentration = self.concentration.expand(shape)
+ return _Dirichlet.apply(concentration)
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return ((torch.log(value) * (self.concentration - 1.0)).sum(-1) +
+ torch.lgamma(self.concentration.sum(-1)) -
+ torch.lgamma(self.concentration).sum(-1))
+
+ @property
+ def mean(self):
+ return self.concentration / self.concentration.sum(-1, True)
+
+ @property
+ def variance(self):
+ con0 = self.concentration.sum(-1, True)
+ return self.concentration * (con0 - self.concentration) / (con0.pow(2) * (con0 + 1))
+
+[docs] def entropy(self):
+ k = self.concentration.size(-1)
+ a0 = self.concentration.sum(-1)
+ return (torch.lgamma(self.concentration).sum(-1) - torch.lgamma(a0) -
+ (k - a0) * torch.digamma(a0) -
+ ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1))
+
+ @property
+ def _natural_params(self):
+ return (self.concentration, )
+
+ def _log_normalizer(self, x):
+ return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
+
+import torch
+import warnings
+from torch.distributions import constraints
+from torch.distributions.utils import lazy_property
+
+
+[docs]class Distribution(object):
+ r"""
+ Distribution is the abstract base class for probability distributions.
+ """
+
+ has_rsample = False
+ has_enumerate_support = False
+ _validate_args = False
+ support = None
+ arg_constraints = {}
+
+ @staticmethod
+ def set_default_validate_args(value):
+ if value not in [True, False]:
+ raise ValueError
+ Distribution._validate_args = value
+
+ def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_args=None):
+ self._batch_shape = batch_shape
+ self._event_shape = event_shape
+ if validate_args is not None:
+ self._validate_args = validate_args
+ if self._validate_args:
+ for param, constraint in self.arg_constraints.items():
+ if constraints.is_dependent(constraint):
+ continue # skip constraints that cannot be checked
+ if param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property):
+ continue # skip checking lazily-constructed args
+ if not constraint.check(getattr(self, param)).all():
+ raise ValueError("The parameter {} has invalid values".format(param))
+ super(Distribution, self).__init__()
+
+[docs] def expand(self, batch_shape, _instance=None):
+ """
+ Returns a new distribution instance (or populates an existing instance
+ provided by a derived class) with batch dimensions expanded to
+ `batch_shape`. This method calls :class:`~torch.Tensor.expand` on
+ the distribution's parameters. As such, this does not allocate new
+ memory for the expanded distribution instance. Additionally,
+ this does not repeat any args checking or parameter broadcasting in
+ `__init__.py`, when an instance is first created.
+
+ Args:
+ batch_shape (torch.Size): the desired expanded size.
+ _instance: new instance provided by subclasses that
+ need to override `.expand`.
+
+ Returns:
+ New distribution instance with batch dimensions expanded to
+ `batch_size`.
+ """
+ raise NotImplementedError
+
+ @property
+ def batch_shape(self):
+ """
+ Returns the shape over which parameters are batched.
+ """
+ return self._batch_shape
+
+ @property
+ def event_shape(self):
+ """
+ Returns the shape of a single sample (without batching).
+ """
+ return self._event_shape
+
+ @property
+ def arg_constraints(self):
+ """
+ Returns a dictionary from argument names to
+ :class:`~torch.distributions.constraints.Constraint` objects that
+ should be satisfied by each argument of this distribution. Args that
+ are not tensors need not appear in this dict.
+ """
+ raise NotImplementedError
+
+ @property
+ def support(self):
+ """
+ Returns a :class:`~torch.distributions.constraints.Constraint` object
+ representing this distribution's support.
+ """
+ raise NotImplementedError
+
+ @property
+ def mean(self):
+ """
+ Returns the mean of the distribution.
+ """
+ raise NotImplementedError
+
+ @property
+ def variance(self):
+ """
+ Returns the variance of the distribution.
+ """
+ raise NotImplementedError
+
+ @property
+ def stddev(self):
+ """
+ Returns the standard deviation of the distribution.
+ """
+ return self.variance.sqrt()
+
+[docs] def sample(self, sample_shape=torch.Size()):
+ """
+ Generates a sample_shape shaped sample or sample_shape shaped batch of
+ samples if the distribution parameters are batched.
+ """
+ with torch.no_grad():
+ return self.rsample(sample_shape)
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ """
+ Generates a sample_shape shaped reparameterized sample or sample_shape
+ shaped batch of reparameterized samples if the distribution parameters
+ are batched.
+ """
+ raise NotImplementedError
+
+[docs] def sample_n(self, n):
+ """
+ Generates n samples or n batches of samples if the distribution
+ parameters are batched.
+ """
+ warnings.warn('sample_n will be deprecated. Use .sample((n,)) instead', UserWarning)
+ return self.sample(torch.Size((n,)))
+
+[docs] def log_prob(self, value):
+ """
+ Returns the log of the probability density/mass function evaluated at
+ `value`.
+
+ Args:
+ value (Tensor):
+ """
+ raise NotImplementedError
+
+[docs] def cdf(self, value):
+ """
+ Returns the cumulative density/mass function evaluated at
+ `value`.
+
+ Args:
+ value (Tensor):
+ """
+ raise NotImplementedError
+
+[docs] def icdf(self, value):
+ """
+ Returns the inverse cumulative density/mass function evaluated at
+ `value`.
+
+ Args:
+ value (Tensor):
+ """
+ raise NotImplementedError
+
+[docs] def enumerate_support(self, expand=True):
+ """
+ Returns tensor containing all values supported by a discrete
+ distribution. The result will enumerate over dimension 0, so the shape
+ of the result will be `(cardinality,) + batch_shape + event_shape`
+ (where `event_shape = ()` for univariate distributions).
+
+ Note that this enumerates over all batched tensors in lock-step
+ `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens
+ along dim 0, but with the remaining batch dimensions being
+ singleton dimensions, `[[0], [1], ..`.
+
+ To iterate over the full Cartesian product use
+ `itertools.product(m.enumerate_support())`.
+
+ Args:
+ expand (bool): whether to expand the support over the
+ batch dims to match the distribution's `batch_shape`.
+
+ Returns:
+ Tensor iterating over dimension 0.
+ """
+ raise NotImplementedError
+
+[docs] def entropy(self):
+ """
+ Returns entropy of distribution, batched over batch_shape.
+
+ Returns:
+ Tensor of shape batch_shape.
+ """
+ raise NotImplementedError
+
+[docs] def perplexity(self):
+ """
+ Returns perplexity of distribution, batched over batch_shape.
+
+ Returns:
+ Tensor of shape batch_shape.
+ """
+ return torch.exp(self.entropy())
+
+ def _extended_shape(self, sample_shape=torch.Size()):
+ """
+ Returns the size of the sample returned by the distribution, given
+ a `sample_shape`. Note, that the batch and event shapes of a distribution
+ instance are fixed at the time of construction. If this is empty, the
+ returned shape is upcast to (1,).
+
+ Args:
+ sample_shape (torch.Size): the size of the sample to be drawn.
+ """
+ if not isinstance(sample_shape, torch.Size):
+ sample_shape = torch.Size(sample_shape)
+ return sample_shape + self._batch_shape + self._event_shape
+
+ def _validate_sample(self, value):
+ """
+ Argument validation for distribution methods such as `log_prob`,
+ `cdf` and `icdf`. The rightmost dimensions of a value to be
+ scored via these methods must agree with the distribution's batch
+ and event shapes.
+
+ Args:
+ value (Tensor): the tensor whose log probability is to be
+ computed by the `log_prob` method.
+ Raises
+ ValueError: when the rightmost dimensions of `value` do not match the
+ distribution's batch and event shapes.
+ """
+ if not isinstance(value, torch.Tensor):
+ raise ValueError('The value argument to log_prob must be a Tensor')
+
+ event_dim_start = len(value.size()) - len(self._event_shape)
+ if value.size()[event_dim_start:] != self._event_shape:
+ raise ValueError('The right-most size of value must match event_shape: {} vs {}.'.
+ format(value.size(), self._event_shape))
+
+ actual_shape = value.size()
+ expected_shape = self._batch_shape + self._event_shape
+ for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
+ if i != 1 and j != 1 and i != j:
+ raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
+ format(actual_shape, expected_shape))
+
+ if not self.support.check(value).all():
+ raise ValueError('The value argument must be within the support')
+
+ def _get_checked_instance(self, cls, _instance=None):
+ if _instance is None and type(self).__init__ != cls.__init__:
+ raise NotImplementedError("Subclass {} of {} that defines a custom __init__ method "
+ "must also define a custom .expand() method.".
+ format(self.__class__.__name__, cls.__name__))
+ return self.__new__(type(self)) if _instance is None else _instance
+
+ def __repr__(self):
+ param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
+ args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p]
+ if self.__dict__[p].numel() == 1
+ else self.__dict__[p].size()) for p in param_names])
+ return self.__class__.__name__ + '(' + args_string + ')'
+
+import torch
+from torch.distributions.distribution import Distribution
+
+
+[docs]class ExponentialFamily(Distribution):
+ r"""
+ ExponentialFamily is the abstract base class for probability distributions belonging to an
+ exponential family, whose probability mass/density function has the form is defined below
+
+ .. math::
+
+ p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
+
+ where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic,
+ :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier
+ measure.
+
+ Note:
+ This class is an intermediary between the `Distribution` class and distributions which belong
+ to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL
+ divergence methods. We use this class to compute the entropy and KL divergence using the AD
+ framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
+ Cross-entropies of Exponential Families).
+ """
+
+ @property
+ def _natural_params(self):
+ """
+ Abstract method for natural parameters. Returns a tuple of Tensors based
+ on the distribution
+ """
+ raise NotImplementedError
+
+ def _log_normalizer(self, *natural_params):
+ """
+ Abstract method for log normalizer function. Returns a log normalizer based on
+ the distribution and input
+ """
+ raise NotImplementedError
+
+ @property
+ def _mean_carrier_measure(self):
+ """
+ Abstract method for expected carrier measure, which is required for computing
+ entropy.
+ """
+ raise NotImplementedError
+
+[docs] def entropy(self):
+ """
+ Method to compute the entropy using Bregman divergence of the log normalizer.
+ """
+ result = -self._mean_carrier_measure
+ nparams = [p.detach().requires_grad_() for p in self._natural_params]
+ lg_normal = self._log_normalizer(*nparams)
+ gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
+ result += lg_normal.clone()
+ for np, g in zip(nparams, gradients):
+ result -= np * g
+ return result
+
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import broadcast_all
+
+
+[docs]class Exponential(ExponentialFamily):
+ r"""
+ Creates a Exponential distribution parameterized by :attr:`rate`.
+
+ Example::
+
+ >>> m = Exponential(torch.tensor([1.0]))
+ >>> m.sample() # Exponential distributed with rate=1
+ tensor([ 0.1046])
+
+ Args:
+ rate (float or Tensor): rate = 1 / scale of the distribution
+ """
+ arg_constraints = {'rate': constraints.positive}
+ support = constraints.positive
+ has_rsample = True
+ _mean_carrier_measure = 0
+
+ @property
+ def mean(self):
+ return self.rate.reciprocal()
+
+ @property
+ def stddev(self):
+ return self.rate.reciprocal()
+
+ @property
+ def variance(self):
+ return self.rate.pow(-2)
+
+ def __init__(self, rate, validate_args=None):
+ self.rate, = broadcast_all(rate)
+ batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size()
+ super(Exponential, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Exponential, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.rate = self.rate.expand(batch_shape)
+ super(Exponential, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ if torch._C._get_tracing_state():
+ # [JIT WORKAROUND] lack of support for ._exponential()
+ u = torch.rand(shape, dtype=self.rate.dtype, device=self.rate.device)
+ return -(-u).log1p() / self.rate
+ return self.rate.new(shape).exponential_() / self.rate
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return self.rate.log() - self.rate * value
+
+[docs] def cdf(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return 1 - torch.exp(-self.rate * value)
+
+[docs] def icdf(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return -torch.log(1 - value) / self.rate
+
+
+
+ @property
+ def _natural_params(self):
+ return (-self.rate, )
+
+ def _log_normalizer(self, x):
+ return -torch.log(-x)
+
+from numbers import Number
+import torch
+from torch._six import nan
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.gamma import Gamma
+from torch.distributions.utils import broadcast_all
+
+
+[docs]class FisherSnedecor(Distribution):
+ r"""
+ Creates a Fisher-Snedecor distribution parameterized by :attr:`df1` and :attr:`df2`.
+
+ Example::
+
+ >>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0]))
+ >>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2
+ tensor([ 0.2453])
+
+ Args:
+ df1 (float or Tensor): degrees of freedom parameter 1
+ df2 (float or Tensor): degrees of freedom parameter 2
+ """
+ arg_constraints = {'df1': constraints.positive, 'df2': constraints.positive}
+ support = constraints.positive
+ has_rsample = True
+
+ def __init__(self, df1, df2, validate_args=None):
+ self.df1, self.df2 = broadcast_all(df1, df2)
+ self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
+ self._gamma2 = Gamma(self.df2 * 0.5, self.df2)
+
+ if isinstance(df1, Number) and isinstance(df2, Number):
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self.df1.size()
+ super(FisherSnedecor, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(FisherSnedecor, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.df1 = self.df1.expand(batch_shape)
+ new.df2 = self.df2.expand(batch_shape)
+ new._gamma1 = self._gamma1.expand(batch_shape)
+ new._gamma2 = self._gamma2.expand(batch_shape)
+ super(FisherSnedecor, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ @property
+ def mean(self):
+ df2 = self.df2.clone()
+ df2[df2 <= 2] = nan
+ return df2 / (df2 - 2)
+
+ @property
+ def variance(self):
+ df2 = self.df2.clone()
+ df2[df2 <= 4] = nan
+ return 2 * df2.pow(2) * (self.df1 + df2 - 2) / (self.df1 * (df2 - 2).pow(2) * (df2 - 4))
+
+[docs] def rsample(self, sample_shape=torch.Size(())):
+ shape = self._extended_shape(sample_shape)
+ # X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2)
+ # Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2)
+ X1 = self._gamma1.rsample(sample_shape).view(shape)
+ X2 = self._gamma2.rsample(sample_shape).view(shape)
+ tiny = torch.finfo(X2.dtype).tiny
+ X2.clamp_(min=tiny)
+ Y = X1 / X2
+ Y.clamp_(min=tiny)
+ return Y
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ ct1 = self.df1 * 0.5
+ ct2 = self.df2 * 0.5
+ ct3 = self.df1 / self.df2
+ t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma()
+ t2 = ct1 * ct3.log() + (ct1 - 1) * torch.log(value)
+ t3 = (ct1 + ct2) * torch.log1p(ct3 * value)
+ return t1 + t2 - t3
+
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import broadcast_all
+
+
+def _standard_gamma(concentration):
+ return torch._standard_gamma(concentration)
+
+
+[docs]class Gamma(ExponentialFamily):
+ r"""
+ Creates a Gamma distribution parameterized by shape :attr:`concentration` and :attr:`rate`.
+
+ Example::
+
+ >>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0]))
+ >>> m.sample() # Gamma distributed with concentration=1 and rate=1
+ tensor([ 0.1046])
+
+ Args:
+ concentration (float or Tensor): shape parameter of the distribution
+ (often referred to as alpha)
+ rate (float or Tensor): rate = 1 / scale of the distribution
+ (often referred to as beta)
+ """
+ arg_constraints = {'concentration': constraints.positive, 'rate': constraints.positive}
+ support = constraints.positive
+ has_rsample = True
+ _mean_carrier_measure = 0
+
+ @property
+ def mean(self):
+ return self.concentration / self.rate
+
+ @property
+ def variance(self):
+ return self.concentration / self.rate.pow(2)
+
+ def __init__(self, concentration, rate, validate_args=None):
+ self.concentration, self.rate = broadcast_all(concentration, rate)
+ if isinstance(concentration, Number) and isinstance(rate, Number):
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self.concentration.size()
+ super(Gamma, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Gamma, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.concentration = self.concentration.expand(batch_shape)
+ new.rate = self.rate.expand(batch_shape)
+ super(Gamma, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
+ value.detach().clamp_(min=torch.finfo(value.dtype).tiny) # do not record in autograd graph
+ return value
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return (self.concentration * torch.log(self.rate) +
+ (self.concentration - 1) * torch.log(value) -
+ self.rate * value - torch.lgamma(self.concentration))
+
+[docs] def entropy(self):
+ return (self.concentration - torch.log(self.rate) + torch.lgamma(self.concentration) +
+ (1.0 - self.concentration) * torch.digamma(self.concentration))
+
+ @property
+ def _natural_params(self):
+ return (self.concentration - 1, -self.rate)
+
+ def _log_normalizer(self, x, y):
+ return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
+
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property
+from torch.nn.functional import binary_cross_entropy_with_logits
+
+
+[docs]class Geometric(Distribution):
+ r"""
+ Creates a Geometric distribution parameterized by :attr:`probs`,
+ where :attr:`probs` is the probability of success of Bernoulli trials.
+ It represents the probability that in :math:`k + 1` Bernoulli trials, the
+ first :math:`k` trials failed, before seeing a success.
+
+ Samples are non-negative integers [0, :math:`\inf`).
+
+ Example::
+
+ >>> m = Geometric(torch.tensor([0.3]))
+ >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0
+ tensor([ 2.])
+
+ Args:
+ probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1]
+ logits (Number, Tensor): the log-odds of sampling `1`.
+ """
+ arg_constraints = {'probs': constraints.unit_interval,
+ 'logits': constraints.real}
+ support = constraints.nonnegative_integer
+
+ def __init__(self, probs=None, logits=None, validate_args=None):
+ if (probs is None) == (logits is None):
+ raise ValueError("Either `probs` or `logits` must be specified, but not both.")
+ if probs is not None:
+ self.probs, = broadcast_all(probs)
+ if not self.probs.gt(0).all():
+ raise ValueError('All elements of probs must be greater than 0')
+ else:
+ self.logits, = broadcast_all(logits)
+ probs_or_logits = probs if probs is not None else logits
+ if isinstance(probs_or_logits, Number):
+ batch_shape = torch.Size()
+ else:
+ batch_shape = probs_or_logits.size()
+ super(Geometric, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Geometric, _instance)
+ batch_shape = torch.Size(batch_shape)
+ if 'probs' in self.__dict__:
+ new.probs = self.probs.expand(batch_shape)
+ if 'logits' in self.__dict__:
+ new.logits = self.logits.expand(batch_shape)
+ super(Geometric, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ @property
+ def mean(self):
+ return 1. / self.probs - 1.
+
+ @property
+ def variance(self):
+ return (1. / self.probs - 1.) / self.probs
+
+
+
+
+
+[docs] def sample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ tiny = torch.finfo(self.probs.dtype).tiny
+ with torch.no_grad():
+ if torch._C._get_tracing_state():
+ # [JIT WORKAROUND] lack of support for .uniform_()
+ u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
+ u = u.clamp(min=tiny)
+ else:
+ u = self.probs.new(shape).uniform_(tiny, 1)
+ return (u.log() / (-self.probs).log1p()).floor()
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ value, probs = broadcast_all(value, self.probs.clone())
+ probs[(probs == 1) & (value == 0)] = 0
+ return value * (-probs).log1p() + self.probs.log()
+
+[docs] def entropy(self):
+ return binary_cross_entropy_with_logits(self.logits, self.probs, reduction='none') / self.probs
+
+from numbers import Number
+import math
+import torch
+from torch.distributions import constraints
+from torch.distributions.uniform import Uniform
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import AffineTransform, ExpTransform
+from torch.distributions.utils import broadcast_all
+
+euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
+
+
+[docs]class Gumbel(TransformedDistribution):
+ r"""
+ Samples from a Gumbel Distribution.
+
+ Examples::
+
+ >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
+ >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2
+ tensor([ 1.0124])
+
+ Args:
+ loc (float or Tensor): Location parameter of the distribution
+ scale (float or Tensor): Scale parameter of the distribution
+ """
+ arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
+ support = constraints.real
+
+ def __init__(self, loc, scale, validate_args=None):
+ self.loc, self.scale = broadcast_all(loc, scale)
+ finfo = torch.finfo(self.loc.dtype)
+ if isinstance(loc, Number) and isinstance(scale, Number):
+ base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
+ else:
+ base_dist = Uniform(torch.full_like(self.loc, finfo.tiny),
+ torch.full_like(self.loc, 1 - finfo.eps))
+ transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
+ ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
+ super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Gumbel, _instance)
+ new.loc = self.loc.expand(batch_shape)
+ new.scale = self.scale.expand(batch_shape)
+ return super(Gumbel, self).expand(batch_shape, _instance=new)
+
+ # Explicitly defining the log probability function for Gumbel due to precision issues
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ y = (self.loc - value) / self.scale
+ return (y - y.exp()) - self.scale.log()
+
+ @property
+ def mean(self):
+ return self.loc + self.scale * euler_constant
+
+ @property
+ def stddev(self):
+ return (math.pi / math.sqrt(6)) * self.scale
+
+ @property
+ def variance(self):
+ return self.stddev.pow(2)
+
+
+
+import math
+
+from torch._six import inf
+from torch.distributions import constraints
+from torch.distributions.transforms import AbsTransform
+from torch.distributions.cauchy import Cauchy
+from torch.distributions.transformed_distribution import TransformedDistribution
+
+
+[docs]class HalfCauchy(TransformedDistribution):
+ r"""
+ Creates a half-normal distribution parameterized by `scale` where::
+
+ X ~ Cauchy(0, scale)
+ Y = |X| ~ HalfCauchy(scale)
+
+ Example::
+
+ >>> m = HalfCauchy(torch.tensor([1.0]))
+ >>> m.sample() # half-cauchy distributed with scale=1
+ tensor([ 2.3214])
+
+ Args:
+ scale (float or Tensor): scale of the full Cauchy distribution
+ """
+ arg_constraints = {'scale': constraints.positive}
+ support = constraints.positive
+ has_rsample = True
+
+ def __init__(self, scale, validate_args=None):
+ base_dist = Cauchy(0, scale)
+ super(HalfCauchy, self).__init__(base_dist, AbsTransform(),
+ validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(HalfCauchy, _instance)
+ return super(HalfCauchy, self).expand(batch_shape, _instance=new)
+
+ @property
+ def scale(self):
+ return self.base_dist.scale
+
+ @property
+ def mean(self):
+ return self.base_dist.mean
+
+ @property
+ def variance(self):
+ return self.base_dist.variance
+
+[docs] def log_prob(self, value):
+ log_prob = self.base_dist.log_prob(value) + math.log(2)
+ log_prob[value.expand(log_prob.shape) < 0] = -inf
+ return log_prob
+
+
+
+
+
+
+
+import math
+
+from torch._six import inf
+from torch.distributions import constraints
+from torch.distributions.transforms import AbsTransform
+from torch.distributions.normal import Normal
+from torch.distributions.transformed_distribution import TransformedDistribution
+
+
+[docs]class HalfNormal(TransformedDistribution):
+ r"""
+ Creates a half-normal distribution parameterized by `scale` where::
+
+ X ~ Normal(0, scale)
+ Y = |X| ~ HalfNormal(scale)
+
+ Example::
+
+ >>> m = HalfNormal(torch.tensor([1.0]))
+ >>> m.sample() # half-normal distributed with scale=1
+ tensor([ 0.1046])
+
+ Args:
+ scale (float or Tensor): scale of the full Normal distribution
+ """
+ arg_constraints = {'scale': constraints.positive}
+ support = constraints.positive
+ has_rsample = True
+
+ def __init__(self, scale, validate_args=None):
+ base_dist = Normal(0, scale)
+ super(HalfNormal, self).__init__(base_dist, AbsTransform(),
+ validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(HalfNormal, _instance)
+ return super(HalfNormal, self).expand(batch_shape, _instance=new)
+
+ @property
+ def scale(self):
+ return self.base_dist.scale
+
+ @property
+ def mean(self):
+ return self.scale * math.sqrt(2 / math.pi)
+
+ @property
+ def variance(self):
+ return self.scale.pow(2) * (1 - 2 / math.pi)
+
+[docs] def log_prob(self, value):
+ log_prob = self.base_dist.log_prob(value) + math.log(2)
+ log_prob[value.expand(log_prob.shape) < 0] = -inf
+ return log_prob
+
+
+
+
+
+
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import _sum_rightmost
+
+
+[docs]class Independent(Distribution):
+ r"""
+ Reinterprets some of the batch dims of a distribution as event dims.
+
+ This is mainly useful for changing the shape of the result of
+ :meth:`log_prob`. For example to create a diagonal Normal distribution with
+ the same shape as a Multivariate Normal distribution (so they are
+ interchangeable), you can::
+
+ >>> loc = torch.zeros(3)
+ >>> scale = torch.ones(3)
+ >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
+ >>> [mvn.batch_shape, mvn.event_shape]
+ [torch.Size(()), torch.Size((3,))]
+ >>> normal = Normal(loc, scale)
+ >>> [normal.batch_shape, normal.event_shape]
+ [torch.Size((3,)), torch.Size(())]
+ >>> diagn = Independent(normal, 1)
+ >>> [diagn.batch_shape, diagn.event_shape]
+ [torch.Size(()), torch.Size((3,))]
+
+ Args:
+ base_distribution (torch.distributions.distribution.Distribution): a
+ base distribution
+ reinterpreted_batch_ndims (int): the number of batch dims to
+ reinterpret as event dims
+ """
+ arg_constraints = {}
+
+ def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None):
+ if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
+ raise ValueError("Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
+ "actual {} vs {}".format(reinterpreted_batch_ndims,
+ len(base_distribution.batch_shape)))
+ shape = base_distribution.batch_shape + base_distribution.event_shape
+ event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
+ batch_shape = shape[:len(shape) - event_dim]
+ event_shape = shape[len(shape) - event_dim:]
+ self.base_dist = base_distribution
+ self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
+ super(Independent, self).__init__(batch_shape, event_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Independent, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.base_dist = self.base_dist.expand(batch_shape +
+ self.event_shape[:self.reinterpreted_batch_ndims])
+ new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
+ super(Independent, new).__init__(batch_shape, self.event_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ @property
+ def has_rsample(self):
+ return self.base_dist.has_rsample
+
+ @property
+ def has_enumerate_support(self):
+ if self.reinterpreted_batch_ndims > 0:
+ return False
+ return self.base_dist.has_enumerate_support
+
+ @constraints.dependent_property
+ def support(self):
+ return self.base_dist.support
+
+ @property
+ def mean(self):
+ return self.base_dist.mean
+
+ @property
+ def variance(self):
+ return self.base_dist.variance
+
+
+
+
+
+[docs] def log_prob(self, value):
+ log_prob = self.base_dist.log_prob(value)
+ return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
+
+[docs] def entropy(self):
+ entropy = self.base_dist.entropy()
+ return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
+
+[docs] def enumerate_support(self, expand=True):
+ if self.reinterpreted_batch_ndims > 0:
+ raise NotImplementedError("Enumeration over cartesian product is not implemented")
+ return self.base_dist.enumerate_support(expand=expand)
+
+import math
+import warnings
+from functools import total_ordering
+
+import torch
+from torch._six import inf
+
+from .bernoulli import Bernoulli
+from .beta import Beta
+from .binomial import Binomial
+from .categorical import Categorical
+from .dirichlet import Dirichlet
+from .distribution import Distribution
+from .exponential import Exponential
+from .exp_family import ExponentialFamily
+from .gamma import Gamma
+from .geometric import Geometric
+from .gumbel import Gumbel
+from .half_normal import HalfNormal
+from .independent import Independent
+from .laplace import Laplace
+from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet,
+ _batch_lowrank_mahalanobis)
+from .multivariate_normal import (MultivariateNormal, _batch_mahalanobis)
+from .normal import Normal
+from .one_hot_categorical import OneHotCategorical
+from .pareto import Pareto
+from .poisson import Poisson
+from .transformed_distribution import TransformedDistribution
+from .uniform import Uniform
+from .utils import _sum_rightmost
+
+_KL_REGISTRY = {} # Source of truth mapping a few general (type, type) pairs to functions.
+_KL_MEMOIZE = {} # Memoized version mapping many specific (type, type) pairs to functions.
+
+
+[docs]def register_kl(type_p, type_q):
+ """
+ Decorator to register a pairwise function with :meth:`kl_divergence`.
+ Usage::
+
+ @register_kl(Normal, Normal)
+ def kl_normal_normal(p, q):
+ # insert implementation here
+
+ Lookup returns the most specific (type,type) match ordered by subclass. If
+ the match is ambiguous, a `RuntimeWarning` is raised. For example to
+ resolve the ambiguous situation::
+
+ @register_kl(BaseP, DerivedQ)
+ def kl_version1(p, q): ...
+ @register_kl(DerivedP, BaseQ)
+ def kl_version2(p, q): ...
+
+ you should register a third most-specific implementation, e.g.::
+
+ register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie.
+
+ Args:
+ type_p (type): A subclass of :class:`~torch.distributions.Distribution`.
+ type_q (type): A subclass of :class:`~torch.distributions.Distribution`.
+ """
+ if not isinstance(type_p, type) and issubclass(type_p, Distribution):
+ raise TypeError('Expected type_p to be a Distribution subclass but got {}'.format(type_p))
+ if not isinstance(type_q, type) and issubclass(type_q, Distribution):
+ raise TypeError('Expected type_q to be a Distribution subclass but got {}'.format(type_q))
+
+ def decorator(fun):
+ _KL_REGISTRY[type_p, type_q] = fun
+ _KL_MEMOIZE.clear() # reset since lookup order may have changed
+ return fun
+
+ return decorator
+
+
+@total_ordering
+class _Match(object):
+ __slots__ = ['types']
+
+ def __init__(self, *types):
+ self.types = types
+
+ def __eq__(self, other):
+ return self.types == other.types
+
+ def __le__(self, other):
+ for x, y in zip(self.types, other.types):
+ if not issubclass(x, y):
+ return False
+ if x is not y:
+ break
+ return True
+
+
+def _dispatch_kl(type_p, type_q):
+ """
+ Find the most specific approximate match, assuming single inheritance.
+ """
+ matches = [(super_p, super_q) for super_p, super_q in _KL_REGISTRY
+ if issubclass(type_p, super_p) and issubclass(type_q, super_q)]
+ if not matches:
+ return NotImplemented
+ # Check that the left- and right- lexicographic orders agree.
+ left_p, left_q = min(_Match(*m) for m in matches).types
+ right_q, right_p = min(_Match(*reversed(m)) for m in matches).types
+ left_fun = _KL_REGISTRY[left_p, left_q]
+ right_fun = _KL_REGISTRY[right_p, right_q]
+ if left_fun is not right_fun:
+ warnings.warn('Ambiguous kl_divergence({}, {}). Please register_kl({}, {})'.format(
+ type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__),
+ RuntimeWarning)
+ return left_fun
+
+
+def _infinite_like(tensor):
+ """
+ Helper function for obtaining infinite KL Divergence throughout
+ """
+ return tensor.new_tensor(inf).expand_as(tensor)
+
+
+def _x_log_x(tensor):
+ """
+ Utility function for calculating x log x
+ """
+ return tensor * tensor.log()
+
+
+def _batch_trace_XXT(bmat):
+ """
+ Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions
+ """
+ n = bmat.size(-1)
+ m = bmat.size(-2)
+ flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1)
+ return flat_trace.reshape(bmat.shape[:-2])
+
+
+[docs]def kl_divergence(p, q):
+ r"""
+ Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
+
+ .. math::
+
+ KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
+
+ Args:
+ p (Distribution): A :class:`~torch.distributions.Distribution` object.
+ q (Distribution): A :class:`~torch.distributions.Distribution` object.
+
+ Returns:
+ Tensor: A batch of KL divergences of shape `batch_shape`.
+
+ Raises:
+ NotImplementedError: If the distribution types have not been registered via
+ :meth:`register_kl`.
+ """
+ try:
+ fun = _KL_MEMOIZE[type(p), type(q)]
+ except KeyError:
+ fun = _dispatch_kl(type(p), type(q))
+ _KL_MEMOIZE[type(p), type(q)] = fun
+ if fun is NotImplemented:
+ raise NotImplementedError
+ return fun(p, q)
+
+
+################################################################################
+# KL Divergence Implementations
+################################################################################
+
+_euler_gamma = 0.57721566490153286060
+
+# Same distributions
+
+
+@register_kl(Bernoulli, Bernoulli)
+def _kl_bernoulli_bernoulli(p, q):
+ t1 = p.probs * (p.probs / q.probs).log()
+ t1[q.probs == 0] = inf
+ t1[p.probs == 0] = 0
+ t2 = (1 - p.probs) * ((1 - p.probs) / (1 - q.probs)).log()
+ t2[q.probs == 1] = inf
+ t2[p.probs == 1] = 0
+ return t1 + t2
+
+
+@register_kl(Beta, Beta)
+def _kl_beta_beta(p, q):
+ sum_params_p = p.concentration1 + p.concentration0
+ sum_params_q = q.concentration1 + q.concentration0
+ t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
+ t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
+ t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
+ t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
+ t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
+ return t1 - t2 + t3 + t4 + t5
+
+
+@register_kl(Binomial, Binomial)
+def _kl_binomial_binomial(p, q):
+ # from https://math.stackexchange.com/questions/2214993/
+ # kullback-leibler-divergence-for-binomial-distributions-p-and-q
+ if (p.total_count < q.total_count).any():
+ raise NotImplementedError('KL between Binomials where q.total_count > p.total_count is not implemented')
+ kl = p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p())
+ inf_idxs = p.total_count > q.total_count
+ kl[inf_idxs] = _infinite_like(kl[inf_idxs])
+ return kl
+
+
+@register_kl(Categorical, Categorical)
+def _kl_categorical_categorical(p, q):
+ t = p.probs * (p.logits - q.logits)
+ t[(q.probs == 0).expand_as(t)] = inf
+ t[(p.probs == 0).expand_as(t)] = 0
+ return t.sum(-1)
+
+
+@register_kl(Dirichlet, Dirichlet)
+def _kl_dirichlet_dirichlet(p, q):
+ # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
+ sum_p_concentration = p.concentration.sum(-1)
+ sum_q_concentration = q.concentration.sum(-1)
+ t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma()
+ t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)
+ t3 = p.concentration - q.concentration
+ t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1)
+ return t1 - t2 + (t3 * t4).sum(-1)
+
+
+@register_kl(Exponential, Exponential)
+def _kl_exponential_exponential(p, q):
+ rate_ratio = q.rate / p.rate
+ t1 = -rate_ratio.log()
+ return t1 + rate_ratio - 1
+
+
+@register_kl(ExponentialFamily, ExponentialFamily)
+def _kl_expfamily_expfamily(p, q):
+ if not type(p) == type(q):
+ raise NotImplementedError("The cross KL-divergence between different exponential families cannot \
+ be computed using Bregman divergences")
+ p_nparams = [np.detach().requires_grad_() for np in p._natural_params]
+ q_nparams = q._natural_params
+ lg_normal = p._log_normalizer(*p_nparams)
+ gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True)
+ result = q._log_normalizer(*q_nparams) - lg_normal.clone()
+ for pnp, qnp, g in zip(p_nparams, q_nparams, gradients):
+ term = (qnp - pnp) * g
+ result -= _sum_rightmost(term, len(q.event_shape))
+ return result
+
+
+@register_kl(Gamma, Gamma)
+def _kl_gamma_gamma(p, q):
+ t1 = q.concentration * (p.rate / q.rate).log()
+ t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration)
+ t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration)
+ t4 = (q.rate - p.rate) * (p.concentration / p.rate)
+ return t1 + t2 + t3 + t4
+
+
+@register_kl(Gumbel, Gumbel)
+def _kl_gumbel_gumbel(p, q):
+ ct1 = p.scale / q.scale
+ ct2 = q.loc / q.scale
+ ct3 = p.loc / q.scale
+ t1 = -ct1.log() - ct2 + ct3
+ t2 = ct1 * _euler_gamma
+ t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3)
+ return t1 + t2 + t3 - (1 + _euler_gamma)
+
+
+@register_kl(Geometric, Geometric)
+def _kl_geometric_geometric(p, q):
+ return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits
+
+
+@register_kl(HalfNormal, HalfNormal)
+def _kl_halfnormal_halfnormal(p, q):
+ return _kl_normal_normal(p.base_dist, q.base_dist)
+
+
+@register_kl(Laplace, Laplace)
+def _kl_laplace_laplace(p, q):
+ # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
+ scale_ratio = p.scale / q.scale
+ loc_abs_diff = (p.loc - q.loc).abs()
+ t1 = -scale_ratio.log()
+ t2 = loc_abs_diff / q.scale
+ t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale)
+ return t1 + t2 + t3 - 1
+
+
+@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal)
+def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
+ if p.event_shape != q.event_shape:
+ raise ValueError("KL-divergence between two Low Rank Multivariate Normals with\
+ different event shapes cannot be computed")
+
+ term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+ q._capacitance_tril) -
+ _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
+ p._capacitance_tril))
+ term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+ q.loc - p.loc,
+ q._capacitance_tril)
+ # Expands term2 according to
+ # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)
+ # = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
+ qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
+ q._unbroadcasted_cov_diag.unsqueeze(-2))
+ A = torch.triangular_solve(qWt_qDinv, q._capacitance_tril, upper=False)[0]
+ term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
+ term22 = _batch_trace_XXT(p._unbroadcasted_cov_factor *
+ q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
+ term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
+ term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
+ term2 = term21 + term22 - term23 - term24
+ return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
+
+
+@register_kl(MultivariateNormal, LowRankMultivariateNormal)
+def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
+ if p.event_shape != q.event_shape:
+ raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\
+ different event shapes cannot be computed")
+
+ term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+ q._capacitance_tril) -
+ 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
+ term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+ q.loc - p.loc,
+ q._capacitance_tril)
+ # Expands term2 according to
+ # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
+ # = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
+ qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
+ q._unbroadcasted_cov_diag.unsqueeze(-2))
+ A = torch.triangular_solve(qWt_qDinv, q._capacitance_tril, upper=False)[0]
+ term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril *
+ q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
+ term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
+ term2 = term21 - term22
+ return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
+
+
+@register_kl(LowRankMultivariateNormal, MultivariateNormal)
+def _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
+ if p.event_shape != q.event_shape:
+ raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\
+ different event shapes cannot be computed")
+
+ term1 = (2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) -
+ _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
+ p._capacitance_tril))
+ term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
+ # Expands term2 according to
+ # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)
+ combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2],
+ p._unbroadcasted_cov_factor.shape[:-2])
+ n = p.event_shape[0]
+ q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
+ p_cov_factor = p._unbroadcasted_cov_factor.expand(combined_batch_shape +
+ (n, p.cov_factor.size(-1)))
+ p_cov_diag = (torch.diag_embed(p._unbroadcasted_cov_diag.sqrt())
+ .expand(combined_batch_shape + (n, n)))
+ term21 = _batch_trace_XXT(torch.triangular_solve(p_cov_factor, q_scale_tril, upper=False)[0])
+ term22 = _batch_trace_XXT(torch.triangular_solve(p_cov_diag, q_scale_tril, upper=False)[0])
+ term2 = term21 + term22
+ return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
+
+
+@register_kl(MultivariateNormal, MultivariateNormal)
+def _kl_multivariatenormal_multivariatenormal(p, q):
+ # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence
+ if p.event_shape != q.event_shape:
+ raise ValueError("KL-divergence between two Multivariate Normals with\
+ different event shapes cannot be computed")
+
+ half_term1 = (q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) -
+ p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
+ combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2],
+ p._unbroadcasted_scale_tril.shape[:-2])
+ n = p.event_shape[0]
+ q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
+ p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
+ term2 = _batch_trace_XXT(torch.triangular_solve(p_scale_tril, q_scale_tril, upper=False)[0])
+ term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
+ return half_term1 + 0.5 * (term2 + term3 - n)
+
+
+@register_kl(Normal, Normal)
+def _kl_normal_normal(p, q):
+ var_ratio = (p.scale / q.scale).pow(2)
+ t1 = ((p.loc - q.loc) / q.scale).pow(2)
+ return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
+
+
+@register_kl(OneHotCategorical, OneHotCategorical)
+def _kl_onehotcategorical_onehotcategorical(p, q):
+ return _kl_categorical_categorical(p._categorical, q._categorical)
+
+
+@register_kl(Pareto, Pareto)
+def _kl_pareto_pareto(p, q):
+ # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
+ scale_ratio = p.scale / q.scale
+ alpha_ratio = q.alpha / p.alpha
+ t1 = q.alpha * scale_ratio.log()
+ t2 = -alpha_ratio.log()
+ result = t1 + t2 + alpha_ratio - 1
+ result[p.support.lower_bound < q.support.lower_bound] = inf
+ return result
+
+
+@register_kl(Poisson, Poisson)
+def _kl_poisson_poisson(p, q):
+ return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)
+
+
+@register_kl(TransformedDistribution, TransformedDistribution)
+def _kl_transformed_transformed(p, q):
+ if p.transforms != q.transforms:
+ raise NotImplementedError
+ if p.event_shape != q.event_shape:
+ raise NotImplementedError
+ # extra_event_dim = len(p.event_shape) - len(p.base_dist.event_shape)
+ extra_event_dim = len(p.event_shape)
+ base_kl_divergence = kl_divergence(p.base_dist, q.base_dist)
+ return _sum_rightmost(base_kl_divergence, extra_event_dim)
+
+
+@register_kl(Uniform, Uniform)
+def _kl_uniform_uniform(p, q):
+ result = ((q.high - q.low) / (p.high - p.low)).log()
+ result[(q.low > p.low) | (q.high < p.high)] = inf
+ return result
+
+
+# Different distributions
+@register_kl(Bernoulli, Poisson)
+def _kl_bernoulli_poisson(p, q):
+ return -p.entropy() - (p.probs * q.rate.log() - q.rate)
+
+
+@register_kl(Beta, Pareto)
+def _kl_beta_infinity(p, q):
+ return _infinite_like(p.concentration1)
+
+
+@register_kl(Beta, Exponential)
+def _kl_beta_exponential(p, q):
+ return -p.entropy() - q.rate.log() + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0))
+
+
+@register_kl(Beta, Gamma)
+def _kl_beta_gamma(p, q):
+ t1 = -p.entropy()
+ t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
+ t3 = (q.concentration - 1) * (p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma())
+ t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0)
+ return t1 + t2 - t3 + t4
+
+# TODO: Add Beta-Laplace KL Divergence
+
+
+@register_kl(Beta, Normal)
+def _kl_beta_normal(p, q):
+ E_beta = p.concentration1 / (p.concentration1 + p.concentration0)
+ var_normal = q.scale.pow(2)
+ t1 = -p.entropy()
+ t2 = 0.5 * (var_normal * 2 * math.pi).log()
+ t3 = (E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1) + E_beta.pow(2)) * 0.5
+ t4 = q.loc * E_beta
+ t5 = q.loc.pow(2) * 0.5
+ return t1 + t2 + (t3 - t4 + t5) / var_normal
+
+
+@register_kl(Beta, Uniform)
+def _kl_beta_uniform(p, q):
+ result = -p.entropy() + (q.high - q.low).log()
+ result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
+ return result
+
+
+@register_kl(Exponential, Beta)
+@register_kl(Exponential, Pareto)
+@register_kl(Exponential, Uniform)
+def _kl_exponential_infinity(p, q):
+ return _infinite_like(p.rate)
+
+
+@register_kl(Exponential, Gamma)
+def _kl_exponential_gamma(p, q):
+ ratio = q.rate / p.rate
+ t1 = -q.concentration * torch.log(ratio)
+ return t1 + ratio + q.concentration.lgamma() + q.concentration * _euler_gamma - (1 + _euler_gamma)
+
+
+@register_kl(Exponential, Gumbel)
+def _kl_exponential_gumbel(p, q):
+ scale_rate_prod = p.rate * q.scale
+ loc_scale_ratio = q.loc / q.scale
+ t1 = scale_rate_prod.log() - 1
+ t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1)
+ t3 = scale_rate_prod.reciprocal()
+ return t1 - loc_scale_ratio + t2 + t3
+
+# TODO: Add Exponential-Laplace KL Divergence
+
+
+@register_kl(Exponential, Normal)
+def _kl_exponential_normal(p, q):
+ var_normal = q.scale.pow(2)
+ rate_sqr = p.rate.pow(2)
+ t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi)
+ t2 = rate_sqr.reciprocal()
+ t3 = q.loc / p.rate
+ t4 = q.loc.pow(2) * 0.5
+ return t1 - 1 + (t2 - t3 + t4) / var_normal
+
+
+@register_kl(Gamma, Beta)
+@register_kl(Gamma, Pareto)
+@register_kl(Gamma, Uniform)
+def _kl_gamma_infinity(p, q):
+ return _infinite_like(p.concentration)
+
+
+@register_kl(Gamma, Exponential)
+def _kl_gamma_exponential(p, q):
+ return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate
+
+
+@register_kl(Gamma, Gumbel)
+def _kl_gamma_gumbel(p, q):
+ beta_scale_prod = p.rate * q.scale
+ loc_scale_ratio = q.loc / q.scale
+ t1 = (p.concentration - 1) * p.concentration.digamma() - p.concentration.lgamma() - p.concentration
+ t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod
+ t3 = torch.exp(loc_scale_ratio) * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration) - loc_scale_ratio
+ return t1 + t2 + t3
+
+# TODO: Add Gamma-Laplace KL Divergence
+
+
+@register_kl(Gamma, Normal)
+def _kl_gamma_normal(p, q):
+ var_normal = q.scale.pow(2)
+ beta_sqr = p.rate.pow(2)
+ t1 = 0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi) - p.concentration - p.concentration.lgamma()
+ t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr
+ t3 = q.loc * p.concentration / p.rate
+ t4 = 0.5 * q.loc.pow(2)
+ return t1 + (p.concentration - 1) * p.concentration.digamma() + (t2 - t3 + t4) / var_normal
+
+
+@register_kl(Gumbel, Beta)
+@register_kl(Gumbel, Exponential)
+@register_kl(Gumbel, Gamma)
+@register_kl(Gumbel, Pareto)
+@register_kl(Gumbel, Uniform)
+def _kl_gumbel_infinity(p, q):
+ return _infinite_like(p.loc)
+
+# TODO: Add Gumbel-Laplace KL Divergence
+
+
+@register_kl(Gumbel, Normal)
+def _kl_gumbel_normal(p, q):
+ param_ratio = p.scale / q.scale
+ t1 = (param_ratio / math.sqrt(2 * math.pi)).log()
+ t2 = (math.pi * param_ratio * 0.5).pow(2) / 3
+ t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5
+ return -t1 + t2 + t3 - (_euler_gamma + 1)
+
+
+@register_kl(Laplace, Beta)
+@register_kl(Laplace, Exponential)
+@register_kl(Laplace, Gamma)
+@register_kl(Laplace, Pareto)
+@register_kl(Laplace, Uniform)
+def _kl_laplace_infinity(p, q):
+ return _infinite_like(p.loc)
+
+
+@register_kl(Laplace, Normal)
+def _kl_laplace_normal(p, q):
+ var_normal = q.scale.pow(2)
+ scale_sqr_var_ratio = p.scale.pow(2) / var_normal
+ t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi)
+ t2 = 0.5 * p.loc.pow(2)
+ t3 = p.loc * q.loc
+ t4 = 0.5 * q.loc.pow(2)
+ return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1
+
+
+@register_kl(Normal, Beta)
+@register_kl(Normal, Exponential)
+@register_kl(Normal, Gamma)
+@register_kl(Normal, Pareto)
+@register_kl(Normal, Uniform)
+def _kl_normal_infinity(p, q):
+ return _infinite_like(p.loc)
+
+
+@register_kl(Normal, Gumbel)
+def _kl_normal_gumbel(p, q):
+ mean_scale_ratio = p.loc / q.scale
+ var_scale_sqr_ratio = (p.scale / q.scale).pow(2)
+ loc_scale_ratio = q.loc / q.scale
+ t1 = var_scale_sqr_ratio.log() * 0.5
+ t2 = mean_scale_ratio - loc_scale_ratio
+ t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio)
+ return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi)))
+
+# TODO: Add Normal-Laplace KL Divergence
+
+
+@register_kl(Pareto, Beta)
+@register_kl(Pareto, Uniform)
+def _kl_pareto_infinity(p, q):
+ return _infinite_like(p.scale)
+
+
+@register_kl(Pareto, Exponential)
+def _kl_pareto_exponential(p, q):
+ scale_rate_prod = p.scale * q.rate
+ t1 = (p.alpha / scale_rate_prod).log()
+ t2 = p.alpha.reciprocal()
+ t3 = p.alpha * scale_rate_prod / (p.alpha - 1)
+ result = t1 - t2 + t3 - 1
+ result[p.alpha <= 1] = inf
+ return result
+
+
+@register_kl(Pareto, Gamma)
+def _kl_pareto_gamma(p, q):
+ common_term = p.scale.log() + p.alpha.reciprocal()
+ t1 = p.alpha.log() - common_term
+ t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
+ t3 = (1 - q.concentration) * common_term
+ t4 = q.rate * p.alpha * p.scale / (p.alpha - 1)
+ result = t1 + t2 + t3 + t4 - 1
+ result[p.alpha <= 1] = inf
+ return result
+
+# TODO: Add Pareto-Laplace KL Divergence
+
+
+@register_kl(Pareto, Normal)
+def _kl_pareto_normal(p, q):
+ var_normal = 2 * q.scale.pow(2)
+ common_term = p.scale / (p.alpha - 1)
+ t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log()
+ t2 = p.alpha.reciprocal()
+ t3 = p.alpha * common_term.pow(2) / (p.alpha - 2)
+ t4 = (p.alpha * common_term - q.loc).pow(2)
+ result = t1 - t2 + (t3 + t4) / var_normal - 1
+ result[p.alpha <= 2] = inf
+ return result
+
+
+@register_kl(Poisson, Bernoulli)
+@register_kl(Poisson, Binomial)
+def _kl_poisson_infinity(p, q):
+ return _infinite_like(p.rate)
+
+
+@register_kl(Uniform, Beta)
+def _kl_uniform_beta(p, q):
+ common_term = p.high - p.low
+ t1 = torch.log(common_term)
+ t2 = (q.concentration1 - 1) * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) / common_term
+ t3 = (q.concentration0 - 1) * (_x_log_x((1 - p.high)) - _x_log_x((1 - p.low)) + common_term) / common_term
+ t4 = q.concentration1.lgamma() + q.concentration0.lgamma() - (q.concentration1 + q.concentration0).lgamma()
+ result = t3 + t4 - t1 - t2
+ result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf
+ return result
+
+
+@register_kl(Uniform, Exponential)
+def _kl_uniform_exponetial(p, q):
+ result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()
+ result[p.low < q.support.lower_bound] = inf
+ return result
+
+
+@register_kl(Uniform, Gamma)
+def _kl_uniform_gamma(p, q):
+ common_term = p.high - p.low
+ t1 = common_term.log()
+ t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
+ t3 = (1 - q.concentration) * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) / common_term
+ t4 = q.rate * (p.high + p.low) / 2
+ result = -t1 + t2 + t3 + t4
+ result[p.low < q.support.lower_bound] = inf
+ return result
+
+
+@register_kl(Uniform, Gumbel)
+def _kl_uniform_gumbel(p, q):
+ common_term = q.scale / (p.high - p.low)
+ high_loc_diff = (p.high - q.loc) / q.scale
+ low_loc_diff = (p.low - q.loc) / q.scale
+ t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff)
+ t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff))
+ return t1 - t2
+
+# TODO: Uniform-Laplace KL Divergence
+
+
+@register_kl(Uniform, Normal)
+def _kl_uniform_normal(p, q):
+ common_term = p.high - p.low
+ t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log()
+ t2 = (common_term).pow(2) / 12
+ t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2)
+ return t1 + 0.5 * (t2 + t3) / q.scale.pow(2)
+
+
+@register_kl(Uniform, Pareto)
+def _kl_uniform_pareto(p, q):
+ support_uniform = p.high - p.low
+ t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log()
+ t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform
+ result = t2 * (q.alpha + 1) - t1
+ result[p.low < q.support.lower_bound] = inf
+ return result
+
+
+@register_kl(Independent, Independent)
+def _kl_independent_independent(p, q):
+ if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
+ raise NotImplementedError
+ result = kl_divergence(p.base_dist, q.base_dist)
+ return _sum_rightmost(result, p.reinterpreted_batch_ndims)
+
+from numbers import Number
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all
+
+
+[docs]class Laplace(Distribution):
+ r"""
+ Creates a Laplace distribution parameterized by :attr:`loc` and :attr:'scale'.
+
+ Example::
+
+ >>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0]))
+ >>> m.sample() # Laplace distributed with loc=0, scale=1
+ tensor([ 0.1046])
+
+ Args:
+ loc (float or Tensor): mean of the distribution
+ scale (float or Tensor): scale of the distribution
+ """
+ arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
+ support = constraints.real
+ has_rsample = True
+
+ @property
+ def mean(self):
+ return self.loc
+
+ @property
+ def variance(self):
+ return 2 * self.scale.pow(2)
+
+ @property
+ def stddev(self):
+ return (2 ** 0.5) * self.scale
+
+ def __init__(self, loc, scale, validate_args=None):
+ self.loc, self.scale = broadcast_all(loc, scale)
+ if isinstance(loc, Number) and isinstance(scale, Number):
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self.loc.size()
+ super(Laplace, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Laplace, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.loc = self.loc.expand(batch_shape)
+ new.scale = self.scale.expand(batch_shape)
+ super(Laplace, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ finfo = torch.finfo(self.loc.dtype)
+ if torch._C._get_tracing_state():
+ # [JIT WORKAROUND] lack of support for .uniform_()
+ u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1
+ return self.loc - self.scale * u.sign() * torch.log1p(-u.abs().clamp(min=finfo.tiny))
+ u = self.loc.new(shape).uniform_(finfo.eps - 1, 1)
+ # TODO: If we ever implement tensor.nextafter, below is what we want ideally.
+ # u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5)
+ return self.loc - self.scale * u.sign() * torch.log1p(-u.abs())
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale
+
+[docs] def cdf(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1(-(value - self.loc).abs() / self.scale)
+
+[docs] def icdf(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ term = value - 0.5
+ return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs())
+
+
+
+from torch.distributions import constraints
+from torch.distributions.transforms import ExpTransform
+from torch.distributions.normal import Normal
+from torch.distributions.transformed_distribution import TransformedDistribution
+
+
+[docs]class LogNormal(TransformedDistribution):
+ r"""
+ Creates a log-normal distribution parameterized by
+ :attr:`loc` and :attr:`scale` where::
+
+ X ~ Normal(loc, scale)
+ Y = exp(X) ~ LogNormal(loc, scale)
+
+ Example::
+
+ >>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0]))
+ >>> m.sample() # log-normal distributed with mean=0 and stddev=1
+ tensor([ 0.1046])
+
+ Args:
+ loc (float or Tensor): mean of log of distribution
+ scale (float or Tensor): standard deviation of log of the distribution
+ """
+ arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
+ support = constraints.positive
+ has_rsample = True
+
+ def __init__(self, loc, scale, validate_args=None):
+ base_dist = Normal(loc, scale)
+ super(LogNormal, self).__init__(base_dist, ExpTransform(), validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(LogNormal, _instance)
+ return super(LogNormal, self).expand(batch_shape, _instance=new)
+
+ @property
+ def loc(self):
+ return self.base_dist.loc
+
+ @property
+ def scale(self):
+ return self.base_dist.scale
+
+ @property
+ def mean(self):
+ return (self.loc + self.scale.pow(2) / 2).exp()
+
+ @property
+ def variance(self):
+ return (self.scale.pow(2).exp() - 1) * (2 * self.loc + self.scale.pow(2)).exp()
+
+
+
+import math
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
+from torch.distributions.utils import _standard_normal, lazy_property
+
+
+def _batch_capacitance_tril(W, D):
+ r"""
+ Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
+ and a batch of vectors :math:`D`.
+ """
+ m = W.size(-1)
+ Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2)
+ K = torch.matmul(Wt_Dinv, W).contiguous()
+ K.view(-1, m * m)[:, ::m + 1] += 1 # add identity matrix to K
+ return torch.cholesky(K)
+
+
+def _batch_lowrank_logdet(W, D, capacitance_tril):
+ r"""
+ Uses "matrix determinant lemma"::
+ log|W @ W.T + D| = log|C| + log|D|,
+ where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
+ the log determinant.
+ """
+ return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(-1)
+
+
+def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
+ r"""
+ Uses "Woodbury matrix identity"::
+ inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
+ where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
+ Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
+ """
+ Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2)
+ Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
+ mahalanobis_term1 = (x.pow(2) / D).sum(-1)
+ mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
+ return mahalanobis_term1 - mahalanobis_term2
+
+
+[docs]class LowRankMultivariateNormal(Distribution):
+ r"""
+ Creates a multivariate normal distribution with covariance matrix having a low-rank form
+ parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
+ covariance_matrix = cov_factor @ cov_factor.T + cov_diag
+
+ Example:
+
+ >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1]))
+ >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[1,0]`, cov_diag=`[1,1]`
+ tensor([-0.2102, -0.5429])
+
+ Args:
+ loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
+ cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
+ `batch_shape + event_shape + (rank,)`
+ cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
+ `batch_shape + event_shape`
+
+ Note:
+ The computation for determinant and inverse of covariance matrix is avoided when
+ `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
+ <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and
+ `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_.
+ Thanks to these formulas, we just need to compute the determinant and inverse of
+ the small size "capacitance" matrix::
+ capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
+ """
+ arg_constraints = {"loc": constraints.real,
+ "cov_factor": constraints.real,
+ "cov_diag": constraints.positive}
+ support = constraints.real
+ has_rsample = True
+
+ def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
+ if loc.dim() < 1:
+ raise ValueError("loc must be at least one-dimensional.")
+ event_shape = loc.shape[-1:]
+ if cov_factor.dim() < 2:
+ raise ValueError("cov_factor must be at least two-dimensional, "
+ "with optional leading batch dimensions")
+ if cov_factor.shape[-2:-1] != event_shape:
+ raise ValueError("cov_factor must be a batch of matrices with shape {} x m"
+ .format(event_shape[0]))
+ if cov_diag.shape[-1:] != event_shape:
+ raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape))
+
+ loc_ = loc.unsqueeze(-1)
+ cov_diag_ = cov_diag.unsqueeze(-1)
+ try:
+ loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_)
+ except RuntimeError:
+ raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}"
+ .format(loc.shape, cov_factor.shape, cov_diag.shape))
+ self.loc = loc_[..., 0]
+ self.cov_diag = cov_diag_[..., 0]
+ batch_shape = self.loc.shape[:-1]
+
+ self._unbroadcasted_cov_factor = cov_factor
+ self._unbroadcasted_cov_diag = cov_diag
+ self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
+ super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape,
+ validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
+ batch_shape = torch.Size(batch_shape)
+ loc_shape = batch_shape + self.event_shape
+ new.loc = self.loc.expand(loc_shape)
+ new.cov_diag = self.cov_diag.expand(loc_shape)
+ new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
+ new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
+ new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
+ new._capacitance_tril = self._capacitance_tril
+ super(LowRankMultivariateNormal, new).__init__(batch_shape,
+ self.event_shape,
+ validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ @property
+ def mean(self):
+ return self.loc
+
+[docs] @lazy_property
+ def variance(self):
+ return (self._unbroadcasted_cov_factor.pow(2).sum(-1)
+ + self._unbroadcasted_cov_diag).expand(self._batch_shape + self._event_shape)
+
+[docs] @lazy_property
+ def scale_tril(self):
+ # The following identity is used to increase the numerically computation stability
+ # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
+ # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
+ # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
+ # hence it is well-conditioned and safe to take Cholesky decomposition.
+ n = self._event_shape[0]
+ cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
+ Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
+ K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous()
+ K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K
+ scale_tril = cov_diag_sqrt_unsqueeze * torch.cholesky(K)
+ return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape)
+
+[docs] @lazy_property
+ def covariance_matrix(self):
+ covariance_matrix = (torch.matmul(self._unbroadcasted_cov_factor,
+ self._unbroadcasted_cov_factor.transpose(-1, -2))
+ + torch.diag_embed(self._unbroadcasted_cov_diag))
+ return covariance_matrix.expand(self._batch_shape + self._event_shape +
+ self._event_shape)
+
+[docs] @lazy_property
+ def precision_matrix(self):
+ # We use "Woodbury matrix identity" to take advantage of low rank form::
+ # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
+ # where :math:`C` is the capacitance matrix.
+ Wt_Dinv = (self._unbroadcasted_cov_factor.transpose(-1, -2)
+ / self._unbroadcasted_cov_diag.unsqueeze(-2))
+ A = torch.triangular_solve(Wt_Dinv, self._capacitance_tril, upper=False)[0]
+ precision_matrix = (torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal())
+ - torch.matmul(A.transpose(-1, -2), A))
+ return precision_matrix.expand(self._batch_shape + self._event_shape +
+ self._event_shape)
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ W_shape = shape[:-1] + self.cov_factor.shape[-1:]
+ eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
+ eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
+ return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
+ + self._unbroadcasted_cov_diag.sqrt() * eps_D)
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ diff = value - self.loc
+ M = _batch_lowrank_mahalanobis(self._unbroadcasted_cov_factor,
+ self._unbroadcasted_cov_diag,
+ diff,
+ self._capacitance_tril)
+ log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
+ self._unbroadcasted_cov_diag,
+ self._capacitance_tril)
+ return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
+
+[docs] def entropy(self):
+ log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
+ self._unbroadcasted_cov_diag,
+ self._capacitance_tril)
+ H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
+ if len(self._batch_shape) == 0:
+ return H
+ else:
+ return H.expand(self._batch_shape)
+
+import torch
+from torch._six import inf
+from torch.distributions.distribution import Distribution
+from torch.distributions import Categorical
+from numbers import Number
+from torch.distributions import constraints
+from torch.distributions.utils import broadcast_all
+
+
+[docs]class Multinomial(Distribution):
+ r"""
+ Creates a Multinomial distribution parameterized by :attr:`total_count` and
+ either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
+ :attr:`probs` indexes over categories. All other dimensions index over batches.
+
+ Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
+ called (see example below)
+
+ .. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
+ and it will be normalized to sum to 1.
+
+ - :meth:`sample` requires a single shared `total_count` for all
+ parameters and samples.
+ - :meth:`log_prob` allows different `total_count` for each parameter and
+ sample.
+
+ Example::
+
+ >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
+ >>> x = m.sample() # equal probability of 0, 1, 2, 3
+ tensor([ 21., 24., 30., 25.])
+
+ >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
+ tensor([-4.1338])
+
+ Args:
+ total_count (int): number of trials
+ probs (Tensor): event probabilities
+ logits (Tensor): event log probabilities
+ """
+ arg_constraints = {'probs': constraints.simplex,
+ 'logits': constraints.real}
+
+ @property
+ def mean(self):
+ return self.probs * self.total_count
+
+ @property
+ def variance(self):
+ return self.total_count * self.probs * (1 - self.probs)
+
+ def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
+ if not isinstance(total_count, Number):
+ raise NotImplementedError('inhomogeneous total_count is not supported')
+ self.total_count = total_count
+ self._categorical = Categorical(probs=probs, logits=logits)
+ batch_shape = self._categorical.batch_shape
+ event_shape = self._categorical.param_shape[-1:]
+ super(Multinomial, self).__init__(batch_shape, event_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Multinomial, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.total_count = self.total_count
+ new._categorical = self._categorical.expand(batch_shape)
+ super(Multinomial, new).__init__(batch_shape, self.event_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ def _new(self, *args, **kwargs):
+ return self._categorical._new(*args, **kwargs)
+
+ @constraints.dependent_property
+ def support(self):
+ return constraints.integer_interval(0, self.total_count)
+
+ @property
+ def logits(self):
+ return self._categorical.logits
+
+ @property
+ def probs(self):
+ return self._categorical.probs
+
+ @property
+ def param_shape(self):
+ return self._categorical.param_shape
+
+[docs] def sample(self, sample_shape=torch.Size()):
+ sample_shape = torch.Size(sample_shape)
+ samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape)
+ # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
+ # (sample_shape, batch_shape, total_count)
+ shifted_idx = list(range(samples.dim()))
+ shifted_idx.append(shifted_idx.pop(0))
+ samples = samples.permute(*shifted_idx)
+ counts = samples.new(self._extended_shape(sample_shape)).zero_()
+ counts.scatter_add_(-1, samples, torch.ones_like(samples))
+ return counts.type_as(self.probs)
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ logits, value = broadcast_all(self.logits.clone(), value)
+ log_factorial_n = torch.lgamma(value.sum(-1) + 1)
+ log_factorial_xs = torch.lgamma(value + 1).sum(-1)
+ logits[(value == 0) & (logits == -inf)] = 0
+ log_powers = (logits * value).sum(-1)
+ return log_factorial_n - log_factorial_xs + log_powers
+
+import math
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import _standard_normal, lazy_property
+
+
+def _batch_mv(bmat, bvec):
+ r"""
+ Performs a batched matrix-vector product, with compatible but different batch shapes.
+
+ This function takes as input `bmat`, containing :math:`n \times n` matrices, and
+ `bvec`, containing length :math:`n` vectors.
+
+ Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
+ to a batch shape. They are not necessarily assumed to have the same batch shape,
+ just ones which can be broadcasted.
+ """
+ return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
+
+
+def _batch_mahalanobis(bL, bx):
+ r"""
+ Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
+ for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
+
+ Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
+ shape, but `bL` one should be able to broadcasted to `bx` one.
+ """
+ n = bx.size(-1)
+ bx_batch_shape = bx.shape[:-1]
+
+ # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
+ # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve
+ bx_batch_dims = len(bx_batch_shape)
+ bL_batch_dims = bL.dim() - 2
+ outer_batch_dims = bx_batch_dims - bL_batch_dims
+ old_batch_dims = outer_batch_dims + bL_batch_dims
+ new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
+ # Reshape bx with the shape (..., 1, i, j, 1, n)
+ bx_new_shape = bx.shape[:outer_batch_dims]
+ for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
+ bx_new_shape += (sx // sL, sL)
+ bx_new_shape += (n,)
+ bx = bx.reshape(bx_new_shape)
+ # Permute bx to make it have shape (..., 1, j, i, 1, n)
+ permute_dims = (list(range(outer_batch_dims)) +
+ list(range(outer_batch_dims, new_batch_dims, 2)) +
+ list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
+ [new_batch_dims])
+ bx = bx.permute(permute_dims)
+
+ flat_L = bL.reshape(-1, n, n) # shape = b x n x n
+ flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
+ flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
+ M_swap = torch.triangular_solve(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2) # shape = b x c
+ M = M_swap.t() # shape = c x b
+
+ # Now we revert the above reshape and permute operators.
+ permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)
+ permute_inv_dims = list(range(outer_batch_dims))
+ for i in range(bL_batch_dims):
+ permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
+ reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)
+ return reshaped_M.reshape(bx_batch_shape)
+
+
+[docs]class MultivariateNormal(Distribution):
+ r"""
+ Creates a multivariate normal (also called Gaussian) distribution
+ parameterized by a mean vector and a covariance matrix.
+
+ The multivariate normal distribution can be parameterized either
+ in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
+ or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
+ or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
+ diagonal entries, such that
+ :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
+ can be obtained via e.g. Cholesky decomposition of the covariance.
+
+ Example:
+
+ >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
+ >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
+ tensor([-0.2102, -0.5429])
+
+ Args:
+ loc (Tensor): mean of the distribution
+ covariance_matrix (Tensor): positive-definite covariance matrix
+ precision_matrix (Tensor): positive-definite precision matrix
+ scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
+
+ Note:
+ Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
+ :attr:`scale_tril` can be specified.
+
+ Using :attr:`scale_tril` will be more efficient: all computations internally
+ are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
+ :attr:`precision_matrix` is passed instead, it is only used to compute
+ the corresponding lower triangular matrices using a Cholesky decomposition.
+ """
+ arg_constraints = {'loc': constraints.real_vector,
+ 'covariance_matrix': constraints.positive_definite,
+ 'precision_matrix': constraints.positive_definite,
+ 'scale_tril': constraints.lower_cholesky}
+ support = constraints.real
+ has_rsample = True
+
+ def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None):
+ if loc.dim() < 1:
+ raise ValueError("loc must be at least one-dimensional.")
+ if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1:
+ raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")
+
+ loc_ = loc.unsqueeze(-1) # temporarily add dim on right
+ if scale_tril is not None:
+ if scale_tril.dim() < 2:
+ raise ValueError("scale_tril matrix must be at least two-dimensional, "
+ "with optional leading batch dimensions")
+ self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_)
+ elif covariance_matrix is not None:
+ if covariance_matrix.dim() < 2:
+ raise ValueError("covariance_matrix must be at least two-dimensional, "
+ "with optional leading batch dimensions")
+ self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_)
+ else:
+ if precision_matrix.dim() < 2:
+ raise ValueError("precision_matrix must be at least two-dimensional, "
+ "with optional leading batch dimensions")
+ self.precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_)
+ self.loc = loc_[..., 0] # drop rightmost dim
+
+ batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]
+ super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
+
+ if scale_tril is not None:
+ self._unbroadcasted_scale_tril = scale_tril
+ else:
+ if precision_matrix is not None:
+ self.covariance_matrix = torch.inverse(precision_matrix).expand_as(loc_)
+ self._unbroadcasted_scale_tril = torch.cholesky(self.covariance_matrix)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(MultivariateNormal, _instance)
+ batch_shape = torch.Size(batch_shape)
+ loc_shape = batch_shape + self.event_shape
+ cov_shape = batch_shape + self.event_shape + self.event_shape
+ new.loc = self.loc.expand(loc_shape)
+ new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
+ if 'covariance_matrix' in self.__dict__:
+ new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
+ if 'scale_tril' in self.__dict__:
+ new.scale_tril = self.scale_tril.expand(cov_shape)
+ if 'precision_matrix' in self.__dict__:
+ new.precision_matrix = self.precision_matrix.expand(cov_shape)
+ super(MultivariateNormal, new).__init__(batch_shape,
+ self.event_shape,
+ validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+[docs] @lazy_property
+ def scale_tril(self):
+ return self._unbroadcasted_scale_tril.expand(
+ self._batch_shape + self._event_shape + self._event_shape)
+
+[docs] @lazy_property
+ def covariance_matrix(self):
+ return (torch.matmul(self._unbroadcasted_scale_tril,
+ self._unbroadcasted_scale_tril.transpose(-1, -2))
+ .expand(self._batch_shape + self._event_shape + self._event_shape))
+
+[docs] @lazy_property
+ def precision_matrix(self):
+ # TODO: use `torch.potri` on `scale_tril` once a backwards pass is implemented.
+ scale_tril_inv = torch.inverse(self._unbroadcasted_scale_tril)
+ return torch.matmul(scale_tril_inv.transpose(-1, -2), scale_tril_inv).expand(
+ self._batch_shape + self._event_shape + self._event_shape)
+
+ @property
+ def mean(self):
+ return self.loc
+
+ @property
+ def variance(self):
+ return self._unbroadcasted_scale_tril.pow(2).sum(-1).expand(
+ self._batch_shape + self._event_shape)
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
+ return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ diff = value - self.loc
+ M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
+ half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+ return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
+
+[docs] def entropy(self):
+ half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+ H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
+ if len(self._batch_shape) == 0:
+ return H
+ else:
+ return H.expand(self._batch_shape)
+
+import torch
+import torch.nn.functional as F
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs
+
+
+[docs]class NegativeBinomial(Distribution):
+ r"""
+ Creates a Negative Binomial distribution, i.e. distribution
+ of the number of independent identical Bernoulli trials
+ needed before :attr:`total_count` failures are achieved. The probability
+ of success of each Bernoulli trial is :attr:`probs`.
+
+ Args:
+ total_count (float or Tensor): non-negative number of negative Bernoulli
+ trials to stop, although the distribution is still valid for real
+ valued count
+ probs (Tensor): Event probabilities of success in the half open interval [0, 1)
+ logits (Tensor): Event log-odds for probabilities of success
+ """
+ arg_constraints = {'total_count': constraints.greater_than_eq(0),
+ 'probs': constraints.half_open_interval(0., 1.),
+ 'logits': constraints.real}
+ support = constraints.nonnegative_integer
+
+ def __init__(self, total_count, probs=None, logits=None, validate_args=None):
+ if (probs is None) == (logits is None):
+ raise ValueError("Either `probs` or `logits` must be specified, but not both.")
+ if probs is not None:
+ self.total_count, self.probs, = broadcast_all(total_count, probs)
+ self.total_count = self.total_count.type_as(self.probs)
+ else:
+ self.total_count, self.logits, = broadcast_all(total_count, logits)
+ self.total_count = self.total_count.type_as(self.logits)
+
+ self._param = self.probs if probs is not None else self.logits
+ batch_shape = self._param.size()
+ super(NegativeBinomial, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(NegativeBinomial, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.total_count = self.total_count.expand(batch_shape)
+ if 'probs' in self.__dict__:
+ new.probs = self.probs.expand(batch_shape)
+ new._param = new.probs
+ if 'logits' in self.__dict__:
+ new.logits = self.logits.expand(batch_shape)
+ new._param = new.logits
+ super(NegativeBinomial, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ def _new(self, *args, **kwargs):
+ return self._param.new(*args, **kwargs)
+
+ @property
+ def mean(self):
+ return self.total_count * torch.exp(self.logits)
+
+ @property
+ def variance(self):
+ return self.mean / torch.sigmoid(-self.logits)
+
+
+
+
+
+ @property
+ def param_shape(self):
+ return self._param.size()
+
+ @lazy_property
+ def _gamma(self):
+ return torch.distributions.Gamma(concentration=self.total_count,
+ rate=torch.exp(-self.logits))
+
+[docs] def sample(self, sample_shape=torch.Size()):
+ with torch.no_grad():
+ rate = self._gamma.sample(sample_shape=sample_shape)
+ return torch.poisson(rate)
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+
+ log_unnormalized_prob = (self.total_count * F.logsigmoid(-self.logits) +
+ value * F.logsigmoid(self.logits))
+
+ log_normalization = (-torch.lgamma(self.total_count + value) + torch.lgamma(1. + value) +
+ torch.lgamma(self.total_count))
+
+ return log_unnormalized_prob - log_normalization
+
+import math
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import _standard_normal, broadcast_all
+
+
+[docs]class Normal(ExponentialFamily):
+ r"""
+ Creates a normal (also called Gaussian) distribution parameterized by
+ :attr:`loc` and :attr:`scale`.
+
+ Example::
+
+ >>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
+ >>> m.sample() # normally distributed with loc=0 and scale=1
+ tensor([ 0.1046])
+
+ Args:
+ loc (float or Tensor): mean of the distribution (often referred to as mu)
+ scale (float or Tensor): standard deviation of the distribution
+ (often referred to as sigma)
+ """
+ arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
+ support = constraints.real
+ has_rsample = True
+ _mean_carrier_measure = 0
+
+ @property
+ def mean(self):
+ return self.loc
+
+ @property
+ def stddev(self):
+ return self.scale
+
+ @property
+ def variance(self):
+ return self.stddev.pow(2)
+
+ def __init__(self, loc, scale, validate_args=None):
+ self.loc, self.scale = broadcast_all(loc, scale)
+ if isinstance(loc, Number) and isinstance(scale, Number):
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self.loc.size()
+ super(Normal, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Normal, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.loc = self.loc.expand(batch_shape)
+ new.scale = self.scale.expand(batch_shape)
+ super(Normal, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+[docs] def sample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ with torch.no_grad():
+ return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
+ return self.loc + eps * self.scale
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ # compute the variance
+ var = (self.scale ** 2)
+ log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
+ return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
+
+[docs] def cdf(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)))
+
+[docs] def icdf(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)
+
+
+
+ @property
+ def _natural_params(self):
+ return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
+
+ def _log_normalizer(self, x, y):
+ return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.categorical import Categorical
+from torch.distributions.distribution import Distribution
+
+
+[docs]class OneHotCategorical(Distribution):
+ r"""
+ Creates a one-hot categorical distribution parameterized by :attr:`probs` or
+ :attr:`logits`.
+
+ Samples are one-hot coded vectors of size ``probs.size(-1)``.
+
+ .. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
+ and it will be normalized to sum to 1.
+
+ See also: :func:`torch.distributions.Categorical` for specifications of
+ :attr:`probs` and :attr:`logits`.
+
+ Example::
+
+ >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
+ >>> m.sample() # equal probability of 0, 1, 2, 3
+ tensor([ 0., 0., 0., 1.])
+
+ Args:
+ probs (Tensor): event probabilities
+ logits (Tensor): event log probabilities
+ """
+ arg_constraints = {'probs': constraints.simplex,
+ 'logits': constraints.real}
+ support = constraints.simplex
+ has_enumerate_support = True
+
+ def __init__(self, probs=None, logits=None, validate_args=None):
+ self._categorical = Categorical(probs, logits)
+ batch_shape = self._categorical.batch_shape
+ event_shape = self._categorical.param_shape[-1:]
+ super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(OneHotCategorical, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new._categorical = self._categorical.expand(batch_shape)
+ super(OneHotCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ def _new(self, *args, **kwargs):
+ return self._categorical._new(*args, **kwargs)
+
+ @property
+ def _param(self):
+ return self._categorical._param
+
+ @property
+ def probs(self):
+ return self._categorical.probs
+
+ @property
+ def logits(self):
+ return self._categorical.logits
+
+ @property
+ def mean(self):
+ return self._categorical.probs
+
+ @property
+ def variance(self):
+ return self._categorical.probs * (1 - self._categorical.probs)
+
+ @property
+ def param_shape(self):
+ return self._categorical.param_shape
+
+[docs] def sample(self, sample_shape=torch.Size()):
+ sample_shape = torch.Size(sample_shape)
+ probs = self._categorical.probs
+ num_events = self._categorical._num_events
+ indices = self._categorical.sample(sample_shape)
+ return torch.nn.functional.one_hot(indices, num_events).to(probs)
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ indices = value.max(-1)[1]
+ return self._categorical.log_prob(indices)
+
+
+
+[docs] def enumerate_support(self, expand=True):
+ n = self.event_shape[0]
+ values = torch.eye(n, dtype=self._param.dtype, device=self._param.device)
+ values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
+ if expand:
+ values = values.expand((n,) + self.batch_shape + (n,))
+ return values
+
+from torch.distributions import constraints
+from torch.distributions.exponential import Exponential
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import AffineTransform, ExpTransform
+from torch.distributions.utils import broadcast_all
+
+
+[docs]class Pareto(TransformedDistribution):
+ r"""
+ Samples from a Pareto Type 1 distribution.
+
+ Example::
+
+ >>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0]))
+ >>> m.sample() # sample from a Pareto distribution with scale=1 and alpha=1
+ tensor([ 1.5623])
+
+ Args:
+ scale (float or Tensor): Scale parameter of the distribution
+ alpha (float or Tensor): Shape parameter of the distribution
+ """
+ arg_constraints = {'alpha': constraints.positive, 'scale': constraints.positive}
+
+ def __init__(self, scale, alpha, validate_args=None):
+ self.scale, self.alpha = broadcast_all(scale, alpha)
+ base_dist = Exponential(self.alpha)
+ transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
+ super(Pareto, self).__init__(base_dist, transforms, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Pareto, _instance)
+ new.scale = self.scale.expand(batch_shape)
+ new.alpha = self.alpha.expand(batch_shape)
+ return super(Pareto, self).expand(batch_shape, _instance=new)
+
+ @property
+ def mean(self):
+ # mean is inf for alpha <= 1
+ a = self.alpha.clone().clamp(min=1)
+ return a * self.scale / (a - 1)
+
+ @property
+ def variance(self):
+ # var is inf for alpha <= 2
+ a = self.alpha.clone().clamp(min=2)
+ return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2))
+
+ @constraints.dependent_property
+ def support(self):
+ return constraints.greater_than(self.scale)
+
+[docs] def entropy(self):
+ return ((self.scale / self.alpha).log() + (1 + self.alpha.reciprocal()))
+
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import broadcast_all
+
+
+[docs]class Poisson(ExponentialFamily):
+ r"""
+ Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
+
+ Samples are nonnegative integers, with a pmf given by
+
+ .. math::
+ \mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
+
+ Example::
+
+ >>> m = Poisson(torch.tensor([4]))
+ >>> m.sample()
+ tensor([ 3.])
+
+ Args:
+ rate (Number, Tensor): the rate parameter
+ """
+ arg_constraints = {'rate': constraints.positive}
+ support = constraints.nonnegative_integer
+
+ @property
+ def mean(self):
+ return self.rate
+
+ @property
+ def variance(self):
+ return self.rate
+
+ def __init__(self, rate, validate_args=None):
+ self.rate, = broadcast_all(rate)
+ if isinstance(rate, Number):
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self.rate.size()
+ super(Poisson, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Poisson, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.rate = self.rate.expand(batch_shape)
+ super(Poisson, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+[docs] def sample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ with torch.no_grad():
+ return torch.poisson(self.rate.expand(shape))
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ rate, value = broadcast_all(self.rate, value)
+ return (rate.log() * value) - rate - (value + 1).lgamma()
+
+ @property
+ def _natural_params(self):
+ return (torch.log(self.rate), )
+
+ def _log_normalizer(self, x):
+ return torch.exp(x)
+
+import torch
+from numbers import Number
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import SigmoidTransform
+from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property, clamp_probs
+
+
+[docs]class LogitRelaxedBernoulli(Distribution):
+ r"""
+ Creates a LogitRelaxedBernoulli distribution parameterized by :attr:`probs`
+ or :attr:`logits` (but not both), which is the logit of a RelaxedBernoulli
+ distribution.
+
+ Samples are logits of values in (0, 1). See [1] for more details.
+
+ Args:
+ temperature (Tensor): relaxation temperature
+ probs (Number, Tensor): the probability of sampling `1`
+ logits (Number, Tensor): the log-odds of sampling `1`
+
+ [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random
+ Variables (Maddison et al, 2017)
+
+ [2] Categorical Reparametrization with Gumbel-Softmax
+ (Jang et al, 2017)
+ """
+ arg_constraints = {'probs': constraints.unit_interval,
+ 'logits': constraints.real}
+ support = constraints.real
+
+ def __init__(self, temperature, probs=None, logits=None, validate_args=None):
+ self.temperature = temperature
+ if (probs is None) == (logits is None):
+ raise ValueError("Either `probs` or `logits` must be specified, but not both.")
+ if probs is not None:
+ is_scalar = isinstance(probs, Number)
+ self.probs, = broadcast_all(probs)
+ else:
+ is_scalar = isinstance(logits, Number)
+ self.logits, = broadcast_all(logits)
+ self._param = self.probs if probs is not None else self.logits
+ if is_scalar:
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self._param.size()
+ super(LogitRelaxedBernoulli, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(LogitRelaxedBernoulli, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.temperature = self.temperature
+ if 'probs' in self.__dict__:
+ new.probs = self.probs.expand(batch_shape)
+ new._param = new.probs
+ if 'logits' in self.__dict__:
+ new.logits = self.logits.expand(batch_shape)
+ new._param = new.logits
+ super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ def _new(self, *args, **kwargs):
+ return self._param.new(*args, **kwargs)
+
+
+
+
+
+ @property
+ def param_shape(self):
+ return self._param.size()
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ probs = clamp_probs(self.probs.expand(shape))
+ uniforms = clamp_probs(torch.rand(shape, dtype=probs.dtype, device=probs.device))
+ return (uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p()) / self.temperature
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ logits, value = broadcast_all(self.logits, value)
+ diff = logits - value.mul(self.temperature)
+ return self.temperature.log() + diff - 2 * diff.exp().log1p()
+
+
+[docs]class RelaxedBernoulli(TransformedDistribution):
+ r"""
+ Creates a RelaxedBernoulli distribution, parametrized by
+ :attr:`temperature`, and either :attr:`probs` or :attr:`logits`
+ (but not both). This is a relaxed version of the `Bernoulli` distribution,
+ so the values are in (0, 1), and has reparametrizable samples.
+
+ Example::
+
+ >>> m = RelaxedBernoulli(torch.tensor([2.2]),
+ torch.tensor([0.1, 0.2, 0.3, 0.99]))
+ >>> m.sample()
+ tensor([ 0.2951, 0.3442, 0.8918, 0.9021])
+
+ Args:
+ temperature (Tensor): relaxation temperature
+ probs (Number, Tensor): the probability of sampling `1`
+ logits (Number, Tensor): the log-odds of sampling `1`
+ """
+ arg_constraints = {'probs': constraints.unit_interval,
+ 'logits': constraints.real}
+ support = constraints.unit_interval
+ has_rsample = True
+
+ def __init__(self, temperature, probs=None, logits=None, validate_args=None):
+ base_dist = LogitRelaxedBernoulli(temperature, probs, logits)
+ super(RelaxedBernoulli, self).__init__(base_dist,
+ SigmoidTransform(),
+ validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(RelaxedBernoulli, _instance)
+ return super(RelaxedBernoulli, self).expand(batch_shape, _instance=new)
+
+ @property
+ def temperature(self):
+ return self.base_dist.temperature
+
+ @property
+ def logits(self):
+ return self.base_dist.logits
+
+ @property
+ def probs(self):
+ return self.base_dist.probs
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.categorical import Categorical
+from torch.distributions.utils import clamp_probs, broadcast_all
+from torch.distributions.distribution import Distribution
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import ExpTransform
+
+
+class ExpRelaxedCategorical(Distribution):
+ r"""
+ Creates a ExpRelaxedCategorical parameterized by
+ :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both).
+ Returns the log of a point in the simplex. Based on the interface to
+ :class:`OneHotCategorical`.
+
+ Implementation based on [1].
+
+ See also: :func:`torch.distributions.OneHotCategorical`
+
+ Args:
+ temperature (Tensor): relaxation temperature
+ probs (Tensor): event probabilities
+ logits (Tensor): the log probability of each event.
+
+ [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
+ (Maddison et al, 2017)
+
+ [2] Categorical Reparametrization with Gumbel-Softmax
+ (Jang et al, 2017)
+ """
+ arg_constraints = {'probs': constraints.simplex,
+ 'logits': constraints.real}
+ support = constraints.real
+ has_rsample = True
+
+ def __init__(self, temperature, probs=None, logits=None, validate_args=None):
+ self._categorical = Categorical(probs, logits)
+ self.temperature = temperature
+ batch_shape = self._categorical.batch_shape
+ event_shape = self._categorical.param_shape[-1:]
+ super(ExpRelaxedCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
+
+ def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.temperature = self.temperature
+ new._categorical = self._categorical.expand(batch_shape)
+ super(ExpRelaxedCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ def _new(self, *args, **kwargs):
+ return self._categorical._new(*args, **kwargs)
+
+ @property
+ def param_shape(self):
+ return self._categorical.param_shape
+
+ @property
+ def logits(self):
+ return self._categorical.logits
+
+ @property
+ def probs(self):
+ return self._categorical.probs
+
+ def rsample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ uniforms = clamp_probs(torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device))
+ gumbels = -((-(uniforms.log())).log())
+ scores = (self.logits + gumbels) / self.temperature
+ return scores - scores.logsumexp(dim=-1, keepdim=True)
+
+ def log_prob(self, value):
+ K = self._categorical._num_events
+ if self._validate_args:
+ self._validate_sample(value)
+ logits, value = broadcast_all(self.logits, value)
+ log_scale = (self.temperature.new_tensor(float(K)).lgamma() -
+ self.temperature.log().mul(-(K - 1)))
+ score = logits - value.mul(self.temperature)
+ score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
+ return score + log_scale
+
+
+[docs]class RelaxedOneHotCategorical(TransformedDistribution):
+ r"""
+ Creates a RelaxedOneHotCategorical distribution parametrized by
+ :attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
+ This is a relaxed version of the :class:`OneHotCategorical` distribution, so
+ its samples are on simplex, and are reparametrizable.
+
+ Example::
+
+ >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
+ torch.tensor([0.1, 0.2, 0.3, 0.4]))
+ >>> m.sample()
+ tensor([ 0.1294, 0.2324, 0.3859, 0.2523])
+
+ Args:
+ temperature (Tensor): relaxation temperature
+ probs (Tensor): event probabilities
+ logits (Tensor): the log probability of each event.
+ """
+ arg_constraints = {'probs': constraints.simplex,
+ 'logits': constraints.real}
+ support = constraints.simplex
+ has_rsample = True
+
+ def __init__(self, temperature, probs=None, logits=None, validate_args=None):
+ base_dist = ExpRelaxedCategorical(temperature, probs, logits)
+ super(RelaxedOneHotCategorical, self).__init__(base_dist,
+ ExpTransform(),
+ validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
+ return super(RelaxedOneHotCategorical, self).expand(batch_shape, _instance=new)
+
+ @property
+ def temperature(self):
+ return self.base_dist.temperature
+
+ @property
+ def logits(self):
+ return self.base_dist.logits
+
+ @property
+ def probs(self):
+ return self.base_dist.probs
+
+import math
+
+import torch
+from torch._six import inf, nan
+from torch.distributions import Chi2, constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import _standard_normal, broadcast_all
+
+
+[docs]class StudentT(Distribution):
+ r"""
+ Creates a Student's t-distribution parameterized by degree of
+ freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`.
+
+ Example::
+
+ >>> m = StudentT(torch.tensor([2.0]))
+ >>> m.sample() # Student's t-distributed with degrees of freedom=2
+ tensor([ 0.1046])
+
+ Args:
+ df (float or Tensor): degrees of freedom
+ loc (float or Tensor): mean of the distribution
+ scale (float or Tensor): scale of the distribution
+ """
+ arg_constraints = {'df': constraints.positive, 'loc': constraints.real, 'scale': constraints.positive}
+ support = constraints.real
+ has_rsample = True
+
+ @property
+ def mean(self):
+ m = self.loc.clone()
+ m[self.df <= 1] = nan
+ return m
+
+ @property
+ def variance(self):
+ m = self.df.clone()
+ m[self.df > 2] = self.scale[self.df > 2].pow(2) * self.df[self.df > 2] / (self.df[self.df > 2] - 2)
+ m[(self.df <= 2) & (self.df > 1)] = inf
+ m[self.df <= 1] = nan
+ return m
+
+ def __init__(self, df, loc=0., scale=1., validate_args=None):
+ self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
+ self._chi2 = Chi2(self.df)
+ batch_shape = self.df.size()
+ super(StudentT, self).__init__(batch_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(StudentT, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.df = self.df.expand(batch_shape)
+ new.loc = self.loc.expand(batch_shape)
+ new.scale = self.scale.expand(batch_shape)
+ new._chi2 = self._chi2.expand(batch_shape)
+ super(StudentT, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ # NOTE: This does not agree with scipy implementation as much as other distributions.
+ # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
+ # parameters seems to help.
+
+ # X ~ Normal(0, 1)
+ # Z ~ Chi2(df)
+ # Y = X / sqrt(Z / df) ~ StudentT(df)
+ shape = self._extended_shape(sample_shape)
+ X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
+ Z = self._chi2.rsample(sample_shape)
+ Y = X * torch.rsqrt(Z / self.df)
+ return self.loc + self.scale * Y
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ y = (value - self.loc) / self.scale
+ Z = (self.scale.log() +
+ 0.5 * self.df.log() +
+ 0.5 * math.log(math.pi) +
+ torch.lgamma(0.5 * self.df) -
+ torch.lgamma(0.5 * (self.df + 1.)))
+ return -0.5 * (self.df + 1.) * torch.log1p(y**2. / self.df) - Z
+
+[docs] def entropy(self):
+ lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(0.5 * (self.df + 1))
+ return (self.scale.log() +
+ 0.5 * (self.df + 1) *
+ (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) +
+ 0.5 * self.df.log() + lbeta)
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.transforms import Transform
+from torch.distributions.utils import _sum_rightmost
+
+
+[docs]class TransformedDistribution(Distribution):
+ r"""
+ Extension of the Distribution class, which applies a sequence of Transforms
+ to a base distribution. Let f be the composition of transforms applied::
+
+ X ~ BaseDistribution
+ Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
+ log p(Y) = log p(X) + log |det (dX/dY)|
+
+ Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
+ maximum shape of its base distribution and its transforms, since transforms
+ can introduce correlations among events.
+
+ An example for the usage of :class:`TransformedDistribution` would be::
+
+ # Building a Logistic Distribution
+ # X ~ Uniform(0, 1)
+ # f = a + b * logit(X)
+ # Y ~ f(X) ~ Logistic(a, b)
+ base_distribution = Uniform(0, 1)
+ transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
+ logistic = TransformedDistribution(base_distribution, transforms)
+
+ For more examples, please look at the implementations of
+ :class:`~torch.distributions.gumbel.Gumbel`,
+ :class:`~torch.distributions.half_cauchy.HalfCauchy`,
+ :class:`~torch.distributions.half_normal.HalfNormal`,
+ :class:`~torch.distributions.log_normal.LogNormal`,
+ :class:`~torch.distributions.pareto.Pareto`,
+ :class:`~torch.distributions.weibull.Weibull`,
+ :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
+ :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
+ """
+ arg_constraints = {}
+
+ def __init__(self, base_distribution, transforms, validate_args=None):
+ self.base_dist = base_distribution
+ if isinstance(transforms, Transform):
+ self.transforms = [transforms, ]
+ elif isinstance(transforms, list):
+ if not all(isinstance(t, Transform) for t in transforms):
+ raise ValueError("transforms must be a Transform or a list of Transforms")
+ self.transforms = transforms
+ else:
+ raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))
+ shape = self.base_dist.batch_shape + self.base_dist.event_shape
+ event_dim = max([len(self.base_dist.event_shape)] + [t.event_dim for t in self.transforms])
+ batch_shape = shape[:len(shape) - event_dim]
+ event_shape = shape[len(shape) - event_dim:]
+ super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(TransformedDistribution, _instance)
+ batch_shape = torch.Size(batch_shape)
+ base_dist_batch_shape = batch_shape + self.base_dist.batch_shape[len(self.batch_shape):]
+ new.base_dist = self.base_dist.expand(base_dist_batch_shape)
+ new.transforms = self.transforms
+ super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ @constraints.dependent_property
+ def support(self):
+ return self.transforms[-1].codomain if self.transforms else self.base_dist.support
+
+ @property
+ def has_rsample(self):
+ return self.base_dist.has_rsample
+
+[docs] def sample(self, sample_shape=torch.Size()):
+ """
+ Generates a sample_shape shaped sample or sample_shape shaped batch of
+ samples if the distribution parameters are batched. Samples first from
+ base distribution and applies `transform()` for every transform in the
+ list.
+ """
+ with torch.no_grad():
+ x = self.base_dist.sample(sample_shape)
+ for transform in self.transforms:
+ x = transform(x)
+ return x
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ """
+ Generates a sample_shape shaped reparameterized sample or sample_shape
+ shaped batch of reparameterized samples if the distribution parameters
+ are batched. Samples first from base distribution and applies
+ `transform()` for every transform in the list.
+ """
+ x = self.base_dist.rsample(sample_shape)
+ for transform in self.transforms:
+ x = transform(x)
+ return x
+
+[docs] def log_prob(self, value):
+ """
+ Scores the sample by inverting the transform(s) and computing the score
+ using the score of the base distribution and the log abs det jacobian.
+ """
+ event_dim = len(self.event_shape)
+ log_prob = 0.0
+ y = value
+ for transform in reversed(self.transforms):
+ x = transform.inv(y)
+ log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
+ event_dim - transform.event_dim)
+ y = x
+
+ log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
+ event_dim - len(self.base_dist.event_shape))
+ return log_prob
+
+ def _monotonize_cdf(self, value):
+ """
+ This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
+ monotone increasing.
+ """
+ sign = 1
+ for transform in self.transforms:
+ sign = sign * transform.sign
+ if isinstance(sign, int) and sign == 1:
+ return value
+ return sign * (value - 0.5) + 0.5
+
+[docs] def cdf(self, value):
+ """
+ Computes the cumulative distribution function by inverting the
+ transform(s) and computing the score of the base distribution.
+ """
+ for transform in self.transforms[::-1]:
+ value = transform.inv(value)
+ if self._validate_args:
+ self.base_dist._validate_sample(value)
+ value = self.base_dist.cdf(value)
+ value = self._monotonize_cdf(value)
+ return value
+
+[docs] def icdf(self, value):
+ """
+ Computes the inverse cumulative distribution function using
+ transform(s) and computing the score of the base distribution.
+ """
+ value = self._monotonize_cdf(value)
+ if self._validate_args:
+ self.base_dist._validate_sample(value)
+ value = self.base_dist.icdf(value)
+ for transform in self.transforms:
+ value = transform(value)
+ return value
+
+import math
+import numbers
+import weakref
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.utils import (_sum_rightmost, broadcast_all,
+ lazy_property)
+from torch.nn.functional import pad
+
+__all__ = [
+ 'AbsTransform',
+ 'AffineTransform',
+ 'ComposeTransform',
+ 'ExpTransform',
+ 'LowerCholeskyTransform',
+ 'PowerTransform',
+ 'SigmoidTransform',
+ 'SoftmaxTransform',
+ 'StickBreakingTransform',
+ 'Transform',
+ 'identity_transform',
+]
+
+
+[docs]class Transform(object):
+ """
+ Abstract class for invertable transformations with computable log
+ det jacobians. They are primarily used in
+ :class:`torch.distributions.TransformedDistribution`.
+
+ Caching is useful for tranforms whose inverses are either expensive or
+ numerically unstable. Note that care must be taken with memoized values
+ since the autograd graph may be reversed. For example while the following
+ works with or without caching::
+
+ y = t(x)
+ t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
+
+ However the following will error when caching due to dependency reversal::
+
+ y = t(x)
+ z = t.inv(y)
+ grad(z.sum(), [y]) # error because z is x
+
+ Derived classes should implement one or both of :meth:`_call` or
+ :meth:`_inverse`. Derived classes that set `bijective=True` should also
+ implement :meth:`log_abs_det_jacobian`.
+
+ Args:
+ cache_size (int): Size of cache. If zero, no caching is done. If one,
+ the latest single value is cached. Only 0 and 1 are supported.
+
+ Attributes:
+ domain (:class:`~torch.distributions.constraints.Constraint`):
+ The constraint representing valid inputs to this transform.
+ codomain (:class:`~torch.distributions.constraints.Constraint`):
+ The constraint representing valid outputs to this transform
+ which are inputs to the inverse transform.
+ bijective (bool): Whether this transform is bijective. A transform
+ ``t`` is bijective iff ``t.inv(t(x)) == x`` and
+ ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
+ the codomain. Transforms that are not bijective should at least
+ maintain the weaker pseudoinverse properties
+ ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
+ sign (int or Tensor): For bijective univariate transforms, this
+ should be +1 or -1 depending on whether transform is monotone
+ increasing or decreasing.
+ event_dim (int): Number of dimensions that are correlated together in
+ the transform ``event_shape``. This should be 0 for pointwise
+ transforms, 1 for transforms that act jointly on vectors, 2 for
+ transforms that act jointly on matrices, etc.
+ """
+ bijective = False
+ event_dim = 0
+
+ def __init__(self, cache_size=0):
+ self._cache_size = cache_size
+ self._inv = None
+ if cache_size == 0:
+ pass # default behavior
+ elif cache_size == 1:
+ self._cached_x_y = None, None
+ else:
+ raise ValueError('cache_size must be 0 or 1')
+ super(Transform, self).__init__()
+
+ @property
+ def inv(self):
+ """
+ Returns the inverse :class:`Transform` of this transform.
+ This should satisfy ``t.inv.inv is t``.
+ """
+ inv = None
+ if self._inv is not None:
+ inv = self._inv()
+ if inv is None:
+ inv = _InverseTransform(self)
+ self._inv = weakref.ref(inv)
+ return inv
+
+ @property
+ def sign(self):
+ """
+ Returns the sign of the determinant of the Jacobian, if applicable.
+ In general this only makes sense for bijective transforms.
+ """
+ raise NotImplementedError
+
+ def __eq__(self, other):
+ return self is other
+
+ def __ne__(self, other):
+ # Necessary for Python2
+ return not self.__eq__(other)
+
+ def __call__(self, x):
+ """
+ Computes the transform `x => y`.
+ """
+ if self._cache_size == 0:
+ return self._call(x)
+ x_old, y_old = self._cached_x_y
+ if x is x_old:
+ return y_old
+ y = self._call(x)
+ self._cached_x_y = x, y
+ return y
+
+ def _inv_call(self, y):
+ """
+ Inverts the transform `y => x`.
+ """
+ if self._cache_size == 0:
+ return self._inverse(y)
+ x_old, y_old = self._cached_x_y
+ if y is y_old:
+ return x_old
+ x = self._inverse(y)
+ self._cached_x_y = x, y
+ return x
+
+ def _call(self, x):
+ """
+ Abstract method to compute forward transformation.
+ """
+ raise NotImplementedError
+
+ def _inverse(self, y):
+ """
+ Abstract method to compute inverse transformation.
+ """
+ raise NotImplementedError
+
+[docs] def log_abs_det_jacobian(self, x, y):
+ """
+ Computes the log det jacobian `log |dy/dx|` given input and output.
+ """
+ raise NotImplementedError
+
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
+
+
+class _InverseTransform(Transform):
+ """
+ Inverts a single :class:`Transform`.
+ This class is private; please instead use the ``Transform.inv`` property.
+ """
+ def __init__(self, transform):
+ super(_InverseTransform, self).__init__()
+ self._inv = transform
+
+ @constraints.dependent_property
+ def domain(self):
+ return self._inv.codomain
+
+ @constraints.dependent_property
+ def codomain(self):
+ return self._inv.domain
+
+ @property
+ def bijective(self):
+ return self._inv.bijective
+
+ @property
+ def sign(self):
+ return self._inv.sign
+
+ @property
+ def event_dim(self):
+ return self._inv.event_dim
+
+ @property
+ def inv(self):
+ return self._inv
+
+ def __eq__(self, other):
+ if not isinstance(other, _InverseTransform):
+ return False
+ return self._inv == other._inv
+
+ def __call__(self, x):
+ return self._inv._inv_call(x)
+
+ def log_abs_det_jacobian(self, x, y):
+ return -self._inv.log_abs_det_jacobian(y, x)
+
+
+[docs]class ComposeTransform(Transform):
+ """
+ Composes multiple transforms in a chain.
+ The transforms being composed are responsible for caching.
+
+ Args:
+ parts (list of :class:`Transform`): A list of transforms to compose.
+ """
+ def __init__(self, parts):
+ super(ComposeTransform, self).__init__()
+ self.parts = parts
+
+ def __eq__(self, other):
+ if not isinstance(other, ComposeTransform):
+ return False
+ return self.parts == other.parts
+
+ @constraints.dependent_property
+ def domain(self):
+ if not self.parts:
+ return constraints.real
+ return self.parts[0].domain
+
+ @constraints.dependent_property
+ def codomain(self):
+ if not self.parts:
+ return constraints.real
+ return self.parts[-1].codomain
+
+ @lazy_property
+ def bijective(self):
+ return all(p.bijective for p in self.parts)
+
+ @lazy_property
+ def sign(self):
+ sign = 1
+ for p in self.parts:
+ sign = sign * p.sign
+ return sign
+
+ @lazy_property
+ def event_dim(self):
+ return max(p.event_dim for p in self.parts) if self.parts else 0
+
+ @property
+ def inv(self):
+ inv = None
+ if self._inv is not None:
+ inv = self._inv()
+ if inv is None:
+ inv = ComposeTransform([p.inv for p in reversed(self.parts)])
+ self._inv = weakref.ref(inv)
+ inv._inv = weakref.ref(self)
+ return inv
+
+ def __call__(self, x):
+ for part in self.parts:
+ x = part(x)
+ return x
+
+ def log_abs_det_jacobian(self, x, y):
+ if not self.parts:
+ return torch.zeros_like(x)
+ result = 0
+ for part in self.parts:
+ y = part(x)
+ result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
+ self.event_dim - part.event_dim)
+ x = y
+ return result
+
+ def __repr__(self):
+ fmt_string = self.__class__.__name__ + '(\n '
+ fmt_string += ',\n '.join([p.__repr__() for p in self.parts])
+ fmt_string += '\n)'
+ return fmt_string
+
+
+identity_transform = ComposeTransform([])
+
+
+[docs]class ExpTransform(Transform):
+ r"""
+ Transform via the mapping :math:`y = \exp(x)`.
+ """
+ domain = constraints.real
+ codomain = constraints.positive
+ bijective = True
+ sign = +1
+
+ def __eq__(self, other):
+ return isinstance(other, ExpTransform)
+
+ def _call(self, x):
+ return x.exp()
+
+ def _inverse(self, y):
+ return y.log()
+
+ def log_abs_det_jacobian(self, x, y):
+ return x
+
+
+[docs]class PowerTransform(Transform):
+ r"""
+ Transform via the mapping :math:`y = x^{\text{exponent}}`.
+ """
+ domain = constraints.positive
+ codomain = constraints.positive
+ bijective = True
+ sign = +1
+
+ def __init__(self, exponent, cache_size=0):
+ super(PowerTransform, self).__init__(cache_size=cache_size)
+ self.exponent, = broadcast_all(exponent)
+
+ def __eq__(self, other):
+ if not isinstance(other, PowerTransform):
+ return False
+ return self.exponent.eq(other.exponent).all().item()
+
+ def _call(self, x):
+ return x.pow(self.exponent)
+
+ def _inverse(self, y):
+ return y.pow(1 / self.exponent)
+
+ def log_abs_det_jacobian(self, x, y):
+ return (self.exponent * y / x).abs().log()
+
+
+[docs]class SigmoidTransform(Transform):
+ r"""
+ Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
+ """
+ domain = constraints.real
+ codomain = constraints.unit_interval
+ bijective = True
+ sign = +1
+
+ def __eq__(self, other):
+ return isinstance(other, SigmoidTransform)
+
+ def _call(self, x):
+ return torch.sigmoid(x)
+
+ def _inverse(self, y):
+ return y.log() - (-y).log1p()
+
+ def log_abs_det_jacobian(self, x, y):
+ return -(y.reciprocal() + (1 - y).reciprocal()).log()
+
+
+[docs]class AbsTransform(Transform):
+ r"""
+ Transform via the mapping :math:`y = |x|`.
+ """
+ domain = constraints.real
+ codomain = constraints.positive
+
+ def __eq__(self, other):
+ return isinstance(other, AbsTransform)
+
+ def _call(self, x):
+ return x.abs()
+
+ def _inverse(self, y):
+ return y
+
+
+[docs]class AffineTransform(Transform):
+ r"""
+ Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`.
+
+ Args:
+ loc (Tensor or float): Location parameter.
+ scale (Tensor or float): Scale parameter.
+ event_dim (int): Optional size of `event_shape`. This should be zero
+ for univariate random variables, 1 for distributions over vectors,
+ 2 for distributions over matrices, etc.
+ """
+ domain = constraints.real
+ codomain = constraints.real
+ bijective = True
+
+ def __init__(self, loc, scale, event_dim=0, cache_size=0):
+ super(AffineTransform, self).__init__(cache_size=cache_size)
+ self.loc = loc
+ self.scale = scale
+ self.event_dim = event_dim
+
+ def __eq__(self, other):
+ if not isinstance(other, AffineTransform):
+ return False
+
+ if isinstance(self.loc, numbers.Number) and isinstance(other.loc, numbers.Number):
+ if self.loc != other.loc:
+ return False
+ else:
+ if not (self.loc == other.loc).all().item():
+ return False
+
+ if isinstance(self.scale, numbers.Number) and isinstance(other.scale, numbers.Number):
+ if self.scale != other.scale:
+ return False
+ else:
+ if not (self.scale == other.scale).all().item():
+ return False
+
+ return True
+
+ @property
+ def sign(self):
+ if isinstance(self.scale, numbers.Number):
+ return 1 if self.scale > 0 else -1 if self.scale < 0 else 0
+ return self.scale.sign()
+
+ def _call(self, x):
+ return self.loc + self.scale * x
+
+ def _inverse(self, y):
+ return (y - self.loc) / self.scale
+
+ def log_abs_det_jacobian(self, x, y):
+ shape = x.shape
+ scale = self.scale
+ if isinstance(scale, numbers.Number):
+ result = x.new_empty(shape).fill_(math.log(abs(scale)))
+ else:
+ result = torch.abs(scale).log()
+ if self.event_dim:
+ result_size = result.size()[:-self.event_dim] + (-1,)
+ result = result.view(result_size).sum(-1)
+ shape = shape[:-self.event_dim]
+ return result.expand(shape)
+
+
+[docs]class SoftmaxTransform(Transform):
+ r"""
+ Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then
+ normalizing.
+
+ This is not bijective and cannot be used for HMC. However this acts mostly
+ coordinate-wise (except for the final normalization), and thus is
+ appropriate for coordinate-wise optimization algorithms.
+ """
+ domain = constraints.real
+ codomain = constraints.simplex
+ event_dim = 1
+
+ def __eq__(self, other):
+ return isinstance(other, SoftmaxTransform)
+
+ def _call(self, x):
+ logprobs = x
+ probs = (logprobs - logprobs.max(-1, True)[0]).exp()
+ return probs / probs.sum(-1, True)
+
+ def _inverse(self, y):
+ probs = y
+ return probs.log()
+
+
+[docs]class StickBreakingTransform(Transform):
+ """
+ Transform from unconstrained space to the simplex of one additional
+ dimension via a stick-breaking process.
+
+ This transform arises as an iterated sigmoid transform in a stick-breaking
+ construction of the `Dirichlet` distribution: the first logit is
+ transformed via sigmoid to the first probability and the probability of
+ everything else, and then the process recurses.
+
+ This is bijective and appropriate for use in HMC; however it mixes
+ coordinates together and is less appropriate for optimization.
+ """
+ domain = constraints.real
+ codomain = constraints.simplex
+ bijective = True
+ event_dim = 1
+
+ def __eq__(self, other):
+ return isinstance(other, StickBreakingTransform)
+
+ def _call(self, x):
+ offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
+ z = torch.sigmoid(x - offset.log())
+ z_cumprod = (1 - z).cumprod(-1)
+ y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1)
+ return y
+
+ def _inverse(self, y):
+ shape = y.shape[:-1] + (y.shape[-1] - 1,)
+ offset = (shape[-1] + 1) - y.new([1]).expand(shape).cumsum(-1)
+ sf = (1 - y.cumsum(-1))[..., :-1]
+ x = y[..., :-1].log() - sf.log() + offset.log()
+ return x
+
+ def log_abs_det_jacobian(self, x, y):
+ offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
+ z = torch.sigmoid(x - offset.log())
+ detJ = ((1 - z).log() + y[..., :-1].log()).sum(-1)
+ return detJ
+
+
+[docs]class LowerCholeskyTransform(Transform):
+ """
+ Transform from unconstrained matrices to lower-triangular matrices with
+ nonnegative diagonal entries.
+
+ This is useful for parameterizing positive definite matrices in terms of
+ their Cholesky factorization.
+ """
+ domain = constraints.real
+ codomain = constraints.lower_cholesky
+ event_dim = 2
+
+ def __eq__(self, other):
+ return isinstance(other, LowerCholeskyTransform)
+
+ def _call_on_event(self, x):
+ return x.tril(-1) + x.diag().exp().diag()
+
+ def _inverse_on_event(self, y):
+ return y.tril(-1) + y.diag().log().diag()
+
+ def _call(self, x):
+ flat_x = x.contiguous().view((-1,) + x.shape[-2:])
+ return torch.stack([self._call_on_event(flat_x[i]) for i in range(flat_x.size(0))]).view(x.shape)
+
+ def _inverse(self, y):
+ flat_y = y.contiguous().view((-1,) + y.shape[-2:])
+ return torch.stack([self._inverse_on_event(flat_y[i]) for i in range(flat_y.size(0))]).view(y.shape)
+
+from numbers import Number
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.distribution import Distribution
+from torch.distributions.utils import broadcast_all
+
+
+[docs]class Uniform(Distribution):
+ r"""
+ Generates uniformly distributed random samples from the half-open interval
+ ``[low, high)``.
+
+ Example::
+
+ >>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
+ >>> m.sample() # uniformly distributed in the range [0.0, 5.0)
+ tensor([ 2.3418])
+
+ Args:
+ low (float or Tensor): lower range (inclusive).
+ high (float or Tensor): upper range (exclusive).
+ """
+ # TODO allow (loc,scale) parameterization to allow independent constraints.
+ arg_constraints = {'low': constraints.dependent, 'high': constraints.dependent}
+ has_rsample = True
+
+ @property
+ def mean(self):
+ return (self.high + self.low) / 2
+
+ @property
+ def stddev(self):
+ return (self.high - self.low) / 12**0.5
+
+ @property
+ def variance(self):
+ return (self.high - self.low).pow(2) / 12
+
+ def __init__(self, low, high, validate_args=None):
+ self.low, self.high = broadcast_all(low, high)
+
+ if isinstance(low, Number) and isinstance(high, Number):
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self.low.size()
+ super(Uniform, self).__init__(batch_shape, validate_args=validate_args)
+
+ if self._validate_args and not torch.lt(self.low, self.high).all():
+ raise ValueError("Uniform is not defined when low>= high")
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Uniform, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.low = self.low.expand(batch_shape)
+ new.high = self.high.expand(batch_shape)
+ super(Uniform, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ @constraints.dependent_property
+ def support(self):
+ return constraints.interval(self.low, self.high)
+
+[docs] def rsample(self, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ rand = torch.rand(shape, dtype=self.low.dtype, device=self.low.device)
+ return self.low + rand * (self.high - self.low)
+
+[docs] def log_prob(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ lb = value.ge(self.low).type_as(self.low)
+ ub = value.lt(self.high).type_as(self.low)
+ return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)
+
+[docs] def cdf(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ result = (value - self.low) / (self.high - self.low)
+ return result.clamp(min=0, max=1)
+
+[docs] def icdf(self, value):
+ if self._validate_args:
+ self._validate_sample(value)
+ result = value * (self.high - self.low) + self.low
+ return result
+
+
+
+import torch
+from torch.distributions import constraints
+from torch.distributions.exponential import Exponential
+from torch.distributions.transformed_distribution import TransformedDistribution
+from torch.distributions.transforms import AffineTransform, PowerTransform
+from torch.distributions.utils import broadcast_all
+from torch.distributions.gumbel import euler_constant
+
+
+[docs]class Weibull(TransformedDistribution):
+ r"""
+ Samples from a two-parameter Weibull distribution.
+
+ Example:
+
+ >>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0]))
+ >>> m.sample() # sample from a Weibull distribution with scale=1, concentration=1
+ tensor([ 0.4784])
+
+ Args:
+ scale (float or Tensor): Scale parameter of distribution (lambda).
+ concentration (float or Tensor): Concentration parameter of distribution (k/shape).
+ """
+ arg_constraints = {'scale': constraints.positive, 'concentration': constraints.positive}
+ support = constraints.positive
+
+ def __init__(self, scale, concentration, validate_args=None):
+ self.scale, self.concentration = broadcast_all(scale, concentration)
+ self.concentration_reciprocal = self.concentration.reciprocal()
+ base_dist = Exponential(torch.ones_like(self.scale))
+ transforms = [PowerTransform(exponent=self.concentration_reciprocal),
+ AffineTransform(loc=0, scale=self.scale)]
+ super(Weibull, self).__init__(base_dist,
+ transforms,
+ validate_args=validate_args)
+
+[docs] def expand(self, batch_shape, _instance=None):
+ new = self._get_checked_instance(Weibull, _instance)
+ new.scale = self.scale.expand(batch_shape)
+ new.concentration = self.concentration.expand(batch_shape)
+ new.concentration_reciprocal = new.concentration.reciprocal()
+ base_dist = self.base_dist.expand(batch_shape)
+ transforms = [PowerTransform(exponent=new.concentration_reciprocal),
+ AffineTransform(loc=0, scale=new.scale)]
+ super(Weibull, new).__init__(base_dist,
+ transforms,
+ validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ @property
+ def mean(self):
+ return self.scale * torch.exp(torch.lgamma(1 + self.concentration_reciprocal))
+
+ @property
+ def variance(self):
+ return self.scale.pow(2) * (torch.exp(torch.lgamma(1 + 2 * self.concentration_reciprocal)) -
+ torch.exp(2 * torch.lgamma(1 + self.concentration_reciprocal)))
+
+[docs] def entropy(self):
+ return euler_constant * (1 - self.concentration_reciprocal) + \
+ torch.log(self.scale * self.concentration_reciprocal) + 1
+
+import torch
+import torch.nn.functional as F
+from torch._six import inf
+from itertools import product
+import warnings
+
+__all__ = [
+ 'broadcast_tensors',
+ 'btrifact',
+ 'btrifact_with_info',
+ 'btrisolve',
+ 'btriunpack',
+ 'cartesian_prod',
+ 'chain_matmul',
+ 'einsum',
+ 'gesv',
+ 'isfinite',
+ 'isinf',
+ 'lu',
+ 'lu_unpack',
+ 'norm',
+ 'meshgrid',
+ 'pstrf',
+ 'potrf',
+ 'potri',
+ 'potrs',
+ 'split',
+ 'stft',
+ 'tensordot',
+ 'trtrs',
+ 'unique',
+ 'unique_consecutive',
+]
+
+
+[docs]def broadcast_tensors(*tensors):
+ r"""broadcast_tensors(*tensors) -> List of Tensors
+
+ Broadcasts the given tensors according to :ref:`broadcasting-semantics`.
+
+ Args:
+ *tensors: any number of tensors of the same type
+
+ .. warning::
+
+ More than one element of a broadcasted tensor may refer to a single
+ memory location. As a result, in-place operations (especially ones that
+ are vectorized) may result in incorrect behavior. If you need to write
+ to the tensors, please clone them first.
+
+ Example::
+
+ >>> x = torch.arange(3).view(1, 3)
+ >>> y = torch.arange(2).view(2, 1)
+ >>> a, b = torch.broadcast_tensors(x, y)
+ >>> a.size()
+ torch.Size([2, 3])
+ >>> a
+ tensor([[0, 1, 2],
+ [0, 1, 2]])
+ """
+ return torch._C._VariableFunctions.broadcast_tensors(tensors)
+
+
+[docs]def split(tensor, split_size_or_sections, dim=0):
+ r"""Splits the tensor into chunks.
+
+ If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
+ be split into equally sized chunks (if possible). Last chunk will be smaller if
+ the tensor size along the given dimension :attr:`dim` is not divisible by
+ :attr:`split_size`.
+
+ If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
+ into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
+ to :attr:`split_size_or_sections`.
+
+ Arguments:
+ tensor (Tensor): tensor to split.
+ split_size_or_sections (int) or (list(int)): size of a single chunk or
+ list of sizes for each chunk
+ dim (int): dimension along which to split the tensor.
+ """
+ # Overwriting reason:
+ # This dispatches to two ATen functions depending on the type of
+ # split_size_or_sections. The branching code is in tensor.py, which we
+ # call here.
+ return tensor.split(split_size_or_sections, dim)
+
+
+[docs]def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
+ r"""Unpacks the data and pivots from a LU factorization of a tensor.
+
+ Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.
+
+ Arguments:
+ LU_data (Tensor): the packed LU factorization data
+ LU_pivots (Tensor): the packed LU factorization pivots
+ unpack_data (bool): flag indicating if the data should be unpacked
+ unpack_pivots (bool): flag indicating if the pivots should be unpacked
+
+ Example::
+
+ >>> A = torch.randn(2, 3, 3)
+ >>> A_LU, pivots = A.lu()
+ >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
+ >>>
+ >>> # can recover A from factorization
+ >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
+ """
+
+ sz = LU_data.size(-1)
+
+ if unpack_data:
+ U = LU_data.triu()
+ L = LU_data.tril()
+ L.diagonal(dim1=-2, dim2=-1).fill_(1)
+ else:
+ L = U = None
+
+ if unpack_pivots:
+ LU_pivots_zero_idx = LU_pivots - 1
+ if LU_data.dim() > 2:
+ P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone()
+ for idx in product(*map(lambda x: list(range(x)), LU_data.shape[:-2])):
+ final_order = list(range(sz))
+ for k, j in enumerate(LU_pivots_zero_idx[idx]):
+ final_order[k], final_order[j] = final_order[j], final_order[k]
+ P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
+ else:
+ P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype)
+ final_order = list(range(sz))
+ for k, j, in enumerate(LU_pivots_zero_idx):
+ final_order[k], final_order[j] = final_order[j], final_order[k]
+ P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
+ else:
+ P = None
+
+ return P, L, U
+
+
+[docs]def einsum(equation, *operands):
+ r"""einsum(equation, *operands) -> Tensor
+
+This function provides a way of computing multilinear expressions (i.e. sums of products) using the
+Einstein summation convention.
+
+Args:
+ equation (string): The equation is given in terms of lower case letters (indices) to be associated
+ with each dimension of the operands and result. The left hand side lists the operands
+ dimensions, separated by commas. There should be one index letter per tensor dimension.
+ The right hand side follows after `->` and gives the indices for the output.
+ If the `->` and right hand side are omitted, it implicitly defined as the alphabetically
+ sorted list of all indices appearing exactly once in the left hand side.
+ The indices not apprearing in the output are summed over after multiplying the operands
+ entries.
+ If an index appears several times for the same operand, a diagonal is taken.
+ Ellipses `...` represent a fixed number of dimensions. If the right hand side is inferred,
+ the ellipsis dimensions are at the beginning of the output.
+ operands (list of Tensors): The operands to compute the Einstein sum of.
+
+Examples::
+
+ >>> x = torch.randn(5)
+ >>> y = torch.randn(4)
+ >>> torch.einsum('i,j->ij', x, y) # outer product
+ tensor([[-0.0570, -0.0286, -0.0231, 0.0197],
+ [ 1.2616, 0.6335, 0.5113, -0.4351],
+ [ 1.4452, 0.7257, 0.5857, -0.4984],
+ [-0.4647, -0.2333, -0.1883, 0.1603],
+ [-1.1130, -0.5588, -0.4510, 0.3838]])
+
+
+ >>> A = torch.randn(3,5,4)
+ >>> l = torch.randn(2,5)
+ >>> r = torch.randn(2,4)
+ >>> torch.einsum('bn,anm,bm->ba', l, A, r) # compare torch.nn.functional.bilinear
+ tensor([[-0.3430, -5.2405, 0.4494],
+ [ 0.3311, 5.5201, -3.0356]])
+
+
+ >>> As = torch.randn(3,2,5)
+ >>> Bs = torch.randn(3,5,4)
+ >>> torch.einsum('bij,bjk->bik', As, Bs) # batch matrix multiplication
+ tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
+ [-1.6706, -0.8097, -0.8025, -2.1183]],
+
+ [[ 4.2239, 0.3107, -0.5756, -0.2354],
+ [-1.4558, -0.3460, 1.5087, -0.8530]],
+
+ [[ 2.8153, 1.8787, -4.3839, -1.2112],
+ [ 0.3728, -2.1131, 0.0921, 0.8305]]])
+
+ >>> A = torch.randn(3, 3)
+ >>> torch.einsum('ii->i', A) # diagonal
+ tensor([-0.7825, 0.8291, -0.1936])
+
+ >>> A = torch.randn(4, 3, 3)
+ >>> torch.einsum('...ii->...i', A) # batch diagonal
+ tensor([[-1.0864, 0.7292, 0.0569],
+ [-0.9725, -1.0270, 0.6493],
+ [ 0.5832, -1.1716, -1.5084],
+ [ 0.4041, -1.1690, 0.8570]])
+
+ >>> A = torch.randn(2, 3, 4, 5)
+ >>> torch.einsum('...ij->...ji', A).shape # batch permute
+ torch.Size([2, 3, 5, 4])
+"""
+ if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
+ # the old interface of passing the operands as one list argument
+ operands = operands[0]
+ return torch._C._VariableFunctions.einsum(equation, operands)
+
+
+[docs]def isfinite(tensor):
+ r"""Returns a new tensor with boolean elements representing if each element is `Finite` or not.
+
+ Arguments:
+ tensor (Tensor): A tensor to check
+
+ Returns:
+ Tensor: A ``torch.ByteTensor`` containing a 1 at each location of finite elements and 0 otherwise
+
+ Example::
+
+ >>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
+ tensor([ 1, 0, 1, 0, 0], dtype=torch.uint8)
+ """
+ if not isinstance(tensor, torch.Tensor):
+ raise ValueError("The argument is not a tensor", str(tensor))
+
+ # Support int input, nan and inf are concepts in floating point numbers.
+ # Numpy uses type 'Object' when the int overflows long, but we don't
+ # have a similar concept. It's safe to assume any created LongTensor doesn't
+ # overflow and it's finite.
+ if not tensor.is_floating_point():
+ return torch.ones_like(tensor, dtype=torch.uint8)
+ return (tensor == tensor) & (tensor.abs() != inf)
+
+
+[docs]def isinf(tensor):
+ r"""Returns a new tensor with boolean elements representing if each element is `+/-INF` or not.
+
+ Arguments:
+ tensor (Tensor): A tensor to check
+
+ Returns:
+ Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `+/-INF` elements and 0 otherwise
+
+ Example::
+
+ >>> torch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
+ tensor([ 0, 1, 0, 1, 0], dtype=torch.uint8)
+ """
+ if not isinstance(tensor, torch.Tensor):
+ raise ValueError("The argument is not a tensor", str(tensor))
+ if tensor.dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
+ return torch.zeros_like(tensor, dtype=torch.uint8)
+ return tensor.abs() == inf
+
+
+[docs]def meshgrid(*tensors, **kwargs):
+ r"""Take :math:`N` tensors, each of which can be either scalar or 1-dimensional
+vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by
+expanding the :math:`i` :sup:`th` input over dimensions defined by other inputs.
+
+
+ Args:
+ tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
+ treated as tensors of size :math:`(1,)` automatically
+
+ Returns:
+ seq (sequence of Tensors): If the input has :math:`k` tensors of size
+ :math:`(N_1,), (N_2,), \ldots , (N_k,)`, then the output would also has :math:`k` tensors,
+ where all tensors are of size :math:`(N_1, N_2, \ldots , N_k)`.
+
+ Example::
+
+ >>> x = torch.tensor([1, 2, 3])
+ >>> y = torch.tensor([4, 5, 6])
+ >>> grid_x, grid_y = torch.meshgrid(x, y)
+ >>> grid_x
+ tensor([[1, 1, 1],
+ [2, 2, 2],
+ [3, 3, 3]])
+ >>> grid_y
+ tensor([[4, 5, 6],
+ [4, 5, 6],
+ [4, 5, 6]])
+ """
+ if kwargs:
+ raise TypeError("meshgrid() got an unexpected keyword argument '%s'" % (list(kwargs)[0],))
+ if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
+ # the old interface of passing the operands as one list argument
+ tensors = tensors[0]
+ return torch._C._VariableFunctions.meshgrid(tensors)
+
+
+[docs]def stft(input, n_fft, hop_length=None, win_length=None, window=None,
+ center=True, pad_mode='reflect', normalized=False, onesided=True):
+ r"""Short-time Fourier transform (STFT).
+
+ Ignoring the optional batch dimension, this method computes the following
+ expression:
+
+ .. math::
+ X[m, \omega] = \sum_{k = 0}^{\text{win\_length-1}}%
+ \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
+ \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right),
+
+ where :math:`m` is the index of the sliding window, and :math:`\omega` is
+ the frequency that :math:`0 \leq \omega < \text{n\_fft}`. When
+ :attr:`onesided` is the default value ``True``,
+
+ * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
+ sequences.
+
+ * If :attr:`hop_length` is ``None`` (default), it is treated as equal to
+ ``floor(n_fft / 4)``.
+
+ * If :attr:`win_length` is ``None`` (default), it is treated as equal to
+ :attr:`n_fft`.
+
+ * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
+ :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
+ treated as if having :math:`1` everywhere in the window. If
+ :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
+ both sides to length :attr:`n_fft` before being applied.
+
+ * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
+ both sides so that the :math:`t`-th frame is centered at time
+ :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
+ begins at time :math:`t \times \text{hop\_length}`.
+
+ * :attr:`pad_mode` determines the padding method used on :attr:`input` when
+ :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
+ all available options. Default is ``"reflect"``.
+
+ * If :attr:`onesided` is ``True`` (default), only values for :math:`\omega`
+ in :math:`\left[0, 1, 2, \dots, \left\lfloor \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]`
+ are returned because the real-to-complex Fourier transform satisfies the
+ conjugate symmetry, i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
+
+ * If :attr:`normalized` is ``True`` (default is ``False``), the function
+ returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
+
+ Returns the real and the imaginary parts together as one tensor of size
+ :math:`(* \times N \times T \times 2)`, where :math:`*` is the optional
+ batch size of :attr:`input`, :math:`N` is the number of frequencies where
+ STFT is applied, :math:`T` is the total number of frames used, and each pair
+ in the last dimension represents a complex number as the real part and the
+ imaginary part.
+
+ .. warning::
+ This function changed signature at version 0.4.1. Calling with the
+ previous signature may cause error or return incorrect result.
+
+ Arguments:
+ input (Tensor): the input tensor
+ n_fft (int): size of Fourier transform
+ hop_length (int, optional): the distance between neighboring sliding window
+ frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
+ win_length (int, optional): the size of window frame and STFT filter.
+ Default: ``None`` (treated as equal to :attr:`n_fft`)
+ window (Tensor, optional): the optional window function.
+ Default: ``None`` (treated as window of all :math:`1` s)
+ center (bool, optional): whether to pad :attr:`input` on both sides so
+ that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
+ Default: ``True``
+ pad_mode (string, optional): controls the padding method used when
+ :attr:`center` is ``True``. Default: ``"reflect"``
+ normalized (bool, optional): controls whether to return the normalized STFT results
+ Default: ``False``
+ onesided (bool, optional): controls whether to return half of results to
+ avoid redundancy Default: ``True``
+
+ Returns:
+ Tensor: A tensor containing the STFT result with shape described above
+
+ """
+ # TODO: after having proper ways to map Python strings to ATen Enum, move
+ # this and F.pad to ATen.
+ if center:
+ signal_dim = input.dim()
+ extended_shape = [1] * (3 - signal_dim) + list(input.size())
+ pad = int(n_fft // 2)
+ input = F.pad(input.view(extended_shape), (pad, pad), pad_mode)
+ input = input.view(input.shape[-signal_dim:])
+ return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
+
+
+del torch.unique_dim
+
+
+[docs]def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
+ r"""Returns the unique elements of the input tensor.
+
+ Arguments:
+ input (Tensor): the input tensor
+ sorted (bool): Whether to sort the unique elements in ascending order
+ before returning as output.
+ return_inverse (bool): Whether to also return the indices for where
+ elements in the original input ended up in the returned unique list.
+ return_counts (bool): Whether to also return the counts for each unique
+ element.
+ dim (int): the dimension to apply unique. If ``None``, the unique of the
+ flattened input is returned. default: ``None``
+
+ Returns:
+ (Tensor, Tensor (optional) Tensor (optional))::
+ A tensor or a tuple of tensors containing
+
+ - **output** (*Tensor*): the output list of unique scalar elements.
+ - **inverse_indices** (*Tensor*): (optional) if
+ :attr:`return_inverse` is True, there will be an additional
+ returned tensor (same shape as input) representing the indices
+ for where elements in the original input map to in the output;
+ otherwise, this function will only return a single tensor.
+ - **counts** (*Tensor*): (optional) if
+ :attr:`return_counts` is True, there will be an additional
+ returned tensor (same shape as output or output.size(dim),
+ if dim was specified) representing the number of occurrences
+ for each unique value or tensor.
+
+ Example::
+
+ >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))
+ >>> output
+ tensor([ 2, 3, 1])
+
+ >>> output, inverse_indices = torch.unique(
+ torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)
+ >>> output
+ tensor([ 1, 2, 3])
+ >>> inverse_indices
+ tensor([ 0, 2, 1, 2])
+
+ >>> output, inverse_indices = torch.unique(
+ torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)
+ >>> output
+ tensor([ 1, 2, 3])
+ >>> inverse_indices
+ tensor([[ 0, 2],
+ [ 1, 2]])
+
+ """
+ if dim is not None:
+ output, inverse_indices, counts = torch._C._VariableFunctions.unique_dim(
+ input,
+ dim,
+ sorted=sorted,
+ return_inverse=return_inverse,
+ return_counts=return_counts,
+ )
+ else:
+ output, inverse_indices, counts = torch._unique2(
+ input,
+ sorted=sorted,
+ return_inverse=return_inverse,
+ return_counts=return_counts,
+ )
+ if return_inverse and return_counts:
+ return output, inverse_indices, counts
+ elif return_inverse:
+ return output, inverse_indices
+ elif return_counts:
+ return output, counts
+ else:
+ return output
+
+
+[docs]def unique_consecutive(input, return_inverse=False, return_counts=False, dim=None):
+ r"""Eliminates all but the first element from every consecutive group of equivalent elements.
+
+ .. note:: This function is different from :func:`torch.unique` in the sense that this function
+ only eliminates consecutive duplicate values. This semantics is similar to `std::unique`
+ in C++.
+
+ Arguments:
+ input (Tensor): the input tensor
+ return_inverse (bool): Whether to also return the indices for where
+ elements in the original input ended up in the returned unique list.
+ return_counts (bool): Whether to also return the counts for each unique
+ element.
+ dim (int): the dimension to apply unique. If ``None``, the unique of the
+ flattened input is returned. default: ``None``
+
+ Returns:
+ (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
+
+ - **output** (*Tensor*): the output list of unique scalar elements.
+ - **inverse_indices** (*Tensor*): (optional) if
+ :attr:`return_inverse` is True, there will be an additional
+ returned tensor (same shape as input) representing the indices
+ for where elements in the original input map to in the output;
+ otherwise, this function will only return a single tensor.
+ - **counts** (*Tensor*): (optional) if
+ :attr:`return_counts` is True, there will be an additional
+ returned tensor (same shape as output or output.size(dim),
+ if dim was specified) representing the number of occurrences
+ for each unique value or tensor.
+
+ Example::
+
+ >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])
+ >>> output = torch.unique_consecutive(x)
+ >>> output
+ tensor([1, 2, 3, 1, 2])
+
+ >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)
+ >>> output
+ tensor([1, 2, 3, 1, 2])
+ >>> inverse_indices
+ tensor([0, 0, 1, 1, 2, 3, 3, 4])
+
+ >>> output, counts = torch.unique_consecutive(x, return_counts=True)
+ >>> output
+ tensor([1, 2, 3, 1, 2])
+ >>> counts
+ tensor([2, 2, 1, 2, 1])
+ """
+ output, inverse_indices, counts = torch._C._VariableFunctions.unique_consecutive(
+ input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
+ if return_inverse and return_counts:
+ return output, inverse_indices, counts
+ if return_inverse:
+ return output, inverse_indices
+ if return_counts:
+ return output, counts
+ return output
+
+
+[docs]def tensordot(a, b, dims=2):
+ r"""Returns a contraction of a and b over multiple dimensions.
+
+ :attr:`tensordot` implements a generalizes the matrix product.
+
+ Args:
+ a (Tensor): Left tensor to contract
+ b (Tensor): Right tensor to contract
+ dims (int or tuple of two lists of integers): number of dimensions to
+ contract or explicit lists of dimensions for :attr:`a` and
+ :attr:`b` respectively
+
+ When called with an integer argument :attr:`dims` = :math:`d`, and the number of
+ dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, respectively,
+ it computes
+
+ .. math::
+ r_{i_0,...,i_{m-d}, i_d,...,i_n}
+ = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
+
+ When called with :attr:`dims` of the list form, the given dimensions will be contracted
+ in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
+ in these dimensions must match, but :attr:`tensordot` will deal with broadcasted
+ dimensions.
+
+ Examples::
+
+ >>> a = torch.arange(60.).reshape(3, 4, 5)
+ >>> b = torch.arange(24.).reshape(4, 3, 2)
+ >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
+ tensor([[4400., 4730.],
+ [4532., 4874.],
+ [4664., 5018.],
+ [4796., 5162.],
+ [4928., 5306.]])
+
+ >>> a = torch.randn(3, 4, 5, device='cuda')
+ >>> b = torch.randn(4, 5, 6, device='cuda')
+ >>> c = torch.tensordot(a, b, dims=2).cpu()
+ tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741],
+ [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744],
+ [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]])
+
+ """
+ if isinstance(dims, (list, tuple)) or \
+ (isinstance(dims, torch.Tensor) and dims.numel() > 1):
+ dims_a, dims_b = dims
+ else:
+ if isinstance(dims, torch.Tensor):
+ dims = dims.item()
+ dims_a = list(range(-dims, 0))
+ dims_b = list(range(dims))
+ return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b)
+
+
+[docs]def cartesian_prod(*tensors):
+ """Do cartesian product of the given sequence of tensors. The behavior is similar to
+ python's `itertools.product`.
+
+ Arguments:
+ *tensors: any number of 1 dimensional tensors.
+
+ Returns:
+ Tensor: A tensor equivalent to converting all the input tensors into lists,
+ do `itertools.product` on these lists, and finally convert the resulting list
+ into tensor.
+
+ Example::
+
+ >>> a = [1, 2, 3]
+ >>> b = [4, 5]
+ >>> list(itertools.product(a, b))
+ [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
+ >>> tensor_a = torch.tensor(a)
+ >>> tensor_b = torch.tensor(b)
+ >>> torch.cartesian_prod(tensor_a, tensor_b)
+ tensor([[1, 4],
+ [1, 5],
+ [2, 4],
+ [2, 5],
+ [3, 4],
+ [3, 5]])
+ """
+ return torch._C._VariableFunctions.cartesian_prod(tensors)
+
+
+[docs]def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
+ r"""Returns the matrix norm or vector norm of a given tensor.
+
+ Args:
+ input (Tensor): the input tensor
+ p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``
+ The following norms can be calculated:
+
+ ===== ============================ ==========================
+ ord matrix norm vector norm
+ ===== ============================ ==========================
+ None Frobenius norm 2-norm
+ 'fro' Frobenius norm --
+ 'nuc' nuclear norm --
+ Other as vec norm when dim is None sum(abs(x)**ord)**(1./ord)
+ ===== ============================ ==========================
+
+ dim (int, 2-tuple of ints, 2-list of ints, optional): If it is an int,
+ vector norm will be calculated, if it is 2-tuple of ints, matrix norm
+ will be calculated. If the value is None, matrix norm will be calculated
+ when the input tensor only has two dimensions, vector norm will be
+ calculated when the input tensor only has one dimension. If the input
+ tensor has more than two dimensions, the vector norm will be applied to
+ last dimension.
+ keepdim (bool, optional): whether the output tensors have :attr:`dim`
+ retained or not. Ignored if :attr:`dim` = ``None`` and
+ :attr:`out` = ``None``. Default: ``False``
+ out (Tensor, optional): the output tensor. Ignored if
+ :attr:`dim` = ``None`` and :attr:`out` = ``None``.
+ dtype (:class:`torch.dtype`, optional): the desired data type of
+ returned tensor. If specified, the input tensor is casted to
+ :attr:'dtype' while performing the operation. Default: None.
+
+
+ Example::
+
+ >>> import torch
+ >>> a = torch.arange(9, dtype= torch.float) - 4
+ >>> b = a.reshape((3, 3))
+ >>> torch.norm(a)
+ tensor(7.7460)
+ >>> torch.norm(b)
+ tensor(7.7460)
+ >>> torch.norm(a, float('inf'))
+ tensor(4.)
+ >>> torch.norm(b, float('inf'))
+ tensor(4.)
+ >>> c = torch.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= torch.float)
+ >>> torch.norm(c, dim=0)
+ tensor([1.4142, 2.2361, 5.0000])
+ >>> torch.norm(c, dim=1)
+ tensor([3.7417, 4.2426])
+ >>> torch.norm(c, p=1, dim=1)
+ tensor([6., 6.])
+ >>> d = torch.arange(8, dtype= torch.float).reshape(2,2,2)
+ >>> torch.norm(d, dim=(1,2))
+ tensor([ 3.7417, 11.2250])
+ >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
+ (tensor(3.7417), tensor(11.2250))
+ """
+ ndim = input.dim()
+
+ # catch default case
+ if dim is None and out is None and dtype is None:
+ if p == "fro":
+ return torch._C._VariableFunctions.frobenius_norm(input)
+ elif p != "nuc":
+ return torch._C._VariableFunctions.norm(input, p)
+
+ if p == "fro":
+ if dtype is not None:
+ raise ValueError("dtype argument is not supported in frobenius norm")
+ if dim is None:
+ dim = tuple(range(ndim))
+ if out is None:
+ return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim)
+ return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim, out=out)
+ elif p == "nuc":
+ if dtype is not None:
+ raise ValueError("dtype argument is not supported in nuclear norm")
+ if out is None:
+ torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim)
+ return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out)
+ else:
+ if dim is None:
+ dim = tuple(range(ndim))
+ if out is None and dtype is None:
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim)
+ elif out is None:
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype)
+ elif dtype is None:
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out)
+ return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype, out=out)
+
+
+[docs]def chain_matmul(*matrices):
+ r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
+ using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
+ of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
+ needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
+ If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.
+
+
+ Args:
+ matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
+
+
+ Returns:
+ Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product
+ would be of dimensions :math:`p_{1} \times p_{N + 1}`.
+
+ Example::
+
+ >>> a = torch.randn(3, 4)
+ >>> b = torch.randn(4, 5)
+ >>> c = torch.randn(5, 6)
+ >>> d = torch.randn(6, 7)
+ >>> torch.chain_matmul(a, b, c, d)
+ tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614],
+ [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163],
+ [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]])
+
+ .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
+ """
+ return torch._C._VariableFunctions.chain_matmul(matrices)
+
+
+[docs]def pstrf(a, upper=True, out=None):
+ r"""Computes the pivoted Cholesky decomposition of a symmetric positive-definite
+ matrix :attr:`a`. returns a namedtuple (u, pivot) of matrice.
+
+ If :attr:`upper` is ``True`` or not provided, `u` is upper triangular
+ such that :math:`a = p^T u^T u p`, with `p` the permutation given by `pivot`.
+
+ If :attr:`upper` is ``False``, `u` is lower triangular such that
+ :math:`a = p^T u u^T p`.
+
+ .. warning::
+ :func:`torch.pstrf` is deprecated in favour of :func:`torch.cholesky` and will
+ be removed in the next release.
+
+ Args:
+ a (Tensor): the input 2-D tensor
+ upper (bool, optional): whether to return a upper (default) or lower triangular matrix
+ out (tuple, optional): namedtuple of `u` and `pivot` tensors
+
+ Example::
+
+ >>> a = torch.randn(3, 3)
+ >>> a = torch.mm(a, a.t()) # make symmetric positive definite
+ >>> a
+ tensor([[ 3.5405, -0.4577, 0.8342],
+ [-0.4577, 1.8244, -0.1996],
+ [ 0.8342, -0.1996, 3.7493]])
+ >>> u,piv = torch.pstrf(a)
+ >>> u
+ tensor([[ 1.9363, 0.4308, -0.1031],
+ [ 0.0000, 1.8316, -0.2256],
+ [ 0.0000, 0.0000, 1.3277]])
+ >>> piv
+ tensor([ 2, 0, 1], dtype=torch.int32)
+ >>> p = torch.eye(3).index_select(0,piv.long()).index_select(0,piv.long()).t() # make pivot permutation
+ >>> torch.mm(torch.mm(p.t(),torch.mm(u.t(),u)),p) # reconstruct
+ tensor([[ 3.5405, -0.4577, 0.8342],
+ [-0.4577, 1.8244, -0.1996],
+ [ 0.8342, -0.1996, 3.7493]])
+ """
+ warnings.warn("torch.pstrf is deprecated in favour of torch.cholesky and will be removed "
+ "in the next release.", stacklevel=2)
+ return torch._C._VariableFunctions.pstrf(a, upper=upper, out=out)
+
+
+[docs]def potrf(a, upper=True, out=None):
+ r"""Computes the Cholesky decomposition of a symmetric positive-definite
+ matrix :math:`A`.
+
+ For more information regarding :func:`torch.potrf`, please check :func:`torch.cholesky`.
+
+ .. warning::
+ :func:`torch.potrf` is deprecated in favour of :func:`torch.cholesky` and will be removed
+ in the next release. Please use :func:`torch.cholesky` instead and note that the :attr:`upper`
+ argument in :func:`torch.cholesky` defaults to ``False``.
+ """
+ warnings.warn("torch.potrf is deprecated in favour of torch.cholesky and will be removed in the next "
+ "release. Please use torch.cholesky instead and note that the :attr:`upper` argument in"
+ " torch.cholesky defaults to ``False``.", stacklevel=2)
+ return torch.cholesky(a, upper=upper, out=out)
+
+
+[docs]def potri(a, upper=True, out=None):
+ r"""Computes the inverse of a symmetric positive-definite matrix :math:`A` using its
+ Cholesky factor.
+
+ For more information regarding :func:`torch.potri`, please check :func:`torch.cholesky_inverse`.
+
+ .. warning::
+ :func:`torch.potri` is deprecated in favour of :func:`torch.cholesky_inverse` and will be removed
+ in the next release. Please use :func:`torch.cholesky_inverse` instead and note that the :attr:`upper`
+ argument in :func:`torch.cholesky_inverse` defaults to ``False``.
+ """
+ warnings.warn("torch.potri is deprecated in favour of torch.cholesky_inverse and will be removed in "
+ "the next release. Please use torch.cholesky_inverse instead and note that the :attr:`upper` "
+ "argument in torch.cholesky_inverse defaults to ``False``.", stacklevel=2)
+ return torch.cholesky_inverse(a, upper=upper, out=out)
+
+
+[docs]def potrs(b, u, upper=True, out=None):
+ r"""Solves a linear system of equations with a positive semidefinite
+ matrix to be inverted given its Cholesky factor matrix :attr:`u`.
+
+ For more information regarding :func:`torch.potrs`, please check :func:`torch.cholesky_solve`.
+
+ .. warning::
+ :func:`torch.potrs` is deprecated in favour of :func:`torch.cholesky_solve` and will be
+ removed in the next release. Please use :func:`torch.cholesky_solve` instead and note that
+ the :attr:`upper` argument in :func:`torch.cholesky_solve` defaults to ``False``.
+ """
+ warnings.warn("torch.potrs is deprecated in favour of torch.cholesky_solve and will be removed "
+ "in the next release. Please use torch.cholesky instead and note that the "
+ ":attr:`upper` argument in torch.cholesky_solve defaults to ``False``.", stacklevel=2)
+ return torch.cholesky_solve(b, u, upper=upper, out=out)
+
+
+[docs]def gesv(b, A, out=None):
+ r"""This function returns the solution to the system of linear equations represented
+ by :math:`AX = B` and the LU factorization of A, in order as a tuple `X, LU`.
+
+ For more information regarding :func:`torch.gesv`, please check :func:`torch.solve`.
+
+ .. warning::
+ :func:`torch.gesv` is deprecated in favour of :func:`torch.solve` and will be removed in the
+ next release. Please use :func:`torch.solve` instead.
+ """
+ warnings.warn("torch.gesv is deprecated in favour of torch.solve and will be removed in the "
+ "next release. Please use torch.solve instead.", stacklevel=2)
+ return torch.solve(b, A, out=out)
+
+
+[docs]def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None):
+ r"""Solves a system of equations with a triangular coefficient matrix :math:`A`
+ and multiple right-hand sides :attr:`b`.
+
+ In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
+ with the default keyword arguments.
+
+ For more information regarding :func:`torch.trtrs`, please check :func:`torch.triangular_solve`.
+
+ .. warning::
+ :func:`torch.trtrs` is deprecated in favour of :func:`torch.triangular_solve` and will be
+ removed in the next release. Please use :func:`torch.triangular_solve` instead.
+ """
+ warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
+ "removed in the next release. Please use torch.triangular_solve instead.", stacklevel=2)
+ return torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular, out=out)
+
+
+[docs]def btrifact(A, pivot=True, out=None):
+ r"""Returns a tuple containing the LU factorization and pivots of :attr:`A`.
+ Pivoting is done if :attr:`pivot` is set.
+
+ For more information regarding :func:`torch.btrifact`, please check :func:`torch.lu`.
+
+ .. warning::
+ :func:`torch.btrifact` is deprecated in favour of :func:`torch.lu` and will be
+ removed in the next release. Please use :func:`torch.lu` instead.
+ """
+ warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be "
+ "removed in the next release. Please use torch.lu instead.", stacklevel=2)
+ return lu(A, pivot=pivot, get_infos=False, out=out)
+
+
+[docs]def btrifact_with_info(A, pivot=True, out=None):
+ r"""Performs LU factorization and returns additional status information along with the LU
+ factorization and pivots.
+
+ For more information regarding :func:`torch.btrifact_with_info`, please check :func:`torch.lu`.
+
+ .. warning::
+ :func:`torch.btrifact_with_info` is deprecated in favour of :func:`torch.lu` and will
+ be removed in the next release. Please use :func:`torch.lu` with the :attr:`get_infos`
+ argument set to ``True`` instead.
+ """
+ warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu and will be "
+ "removed in the next release. Please use torch.lu with the get_infos argument "
+ "set to True instead.",
+ stacklevel=2)
+ return lu(A, pivot=pivot, get_infos=True, out=out)
+
+
+[docs]def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
+ r"""Unpacks the data and pivots from a LU factorization of a tensor.
+
+ For more information regarding :func:`torch.btriunpack`, please check :func:`torch.lu_unpack`.
+
+ .. warning::
+ :func:`torch.btriunpack` is deprecated in favour of :func:`torch.lu_unpack` and will be
+ removed in the next release. Please use :func:`torch.lu_unpack` instead.
+ """
+ warnings.warn("torch.btriunpack is deprecated in favour of torch.lu_unpack and will be "
+ "removed in the next release. Please use torch.lu_unpack instead.", stacklevel=2)
+ return lu_unpack(LU_data=LU_data, LU_pivots=LU_pivots,
+ unpack_data=unpack_data, unpack_pivots=unpack_pivots)
+
+
+[docs]def btrisolve(b, LU_data, LU_pivots, out=None):
+ r"""Solves the system of equations :math:`Ax = b` using the partially pivoted LU
+ factorization of :math:`A` given by :attr:`LU_data` and :attr:`LU_pivots`.
+
+ For more information regarding :func:`torch.btrisolve`, please check
+ :func:`torch.lu_solve`.
+
+ .. warning::
+ :func:`torch.btrisolve` is deprecated in favour of :func:`torch.lu_solve` and will be
+ removed in the next release. Please use :func:`torch.lu_solve` instead.
+ """
+ warnings.warn("torch.btrisolve is deprecated in favour of torch.lu_solve and will be "
+ "removed in the next release. Please use torch.lu_solve instead.", stacklevel=2)
+ return torch.lu_solve(b, LU_data=LU_data, LU_pivots=LU_pivots, out=out)
+
+
+[docs]def lu(A, pivot=True, get_infos=False, out=None):
+ r"""Computes the LU factorization of a square matrix or batches of square matrices
+ :attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`.
+ Pivoting is done if :attr:`pivot` is set to ``True``.
+
+ .. note::
+ The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``,
+ then the returned pivots is a tensor filled with zeros of the appropriate size.
+
+ .. note::
+ LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting
+ to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is
+ available for CUDA.
+
+ .. note::
+ This function does not check if the factorization was successful or not if
+ :attr:`get_infos` is ``True`` since the status of the factorization is present in the
+ third element of the return tuple.
+
+ Arguments:
+ A (Tensor): the tensor to factor of size :math:`(*, m, m)`
+ pivot (bool, optional): controls whether pivoting is done. Default: ``True``
+ get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
+ Default: ``False``
+ out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
+ then the elements in the tuple are Tensor, IntTensor,
+ and IntTensor. If :attr:`get_infos` is ``False``, then the
+ elements in the tuple are Tensor, IntTensor. Default: ``None``
+
+ Returns:
+ (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
+
+ - **factorization** (*Tensor*): the factorization of size :math:`(*, m, m)`
+
+ - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)`
+
+ - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
+ size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
+ each minibatch has succeeded or failed
+
+ Example::
+
+ >>> A = torch.randn(2, 3, 3)
+ >>> A_LU, pivots = torch.lu(A)
+ >>> A_LU
+ tensor([[[ 1.3506, 2.5558, -0.0816],
+ [ 0.1684, 1.1551, 0.1940],
+ [ 0.1193, 0.6189, -0.5497]],
+
+ [[ 0.4526, 1.2526, -0.3285],
+ [-0.7988, 0.7175, -0.9701],
+ [ 0.2634, -0.9255, -0.3459]]])
+ >>> pivots
+ tensor([[ 3, 3, 3],
+ [ 3, 3, 3]], dtype=torch.int32)
+ >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
+ >>> if info.nonzero().size(0) == 0:
+ ... print('LU factorization succeeded for all samples!')
+ LU factorization succeeded for all samples!
+ """
+ # If get_infos is True, then we don't need to check for errors and vice versa
+ result = torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
+ if out is not None:
+ if not isinstance(out, (tuple, list)):
+ raise TypeError("argument 'out' must be tuple of Tensors, not {}"
+ .format(type(out).__name__))
+ if len(out) - int(get_infos) != 2:
+ raise TypeError("expected tuple of {} elements but got {}"
+ .format(2 + int(get_infos), len(out)))
+ return (out[i].resize_as_(result[i]).copy_(result[i]) for i in range(len(out)))
+ if get_infos:
+ return result # A_LU, pivots, infos
+ else:
+ return result[0], result[1] # A_LU, pivots
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+import errno
+import hashlib
+import os
+import re
+import shutil
+import sys
+import tempfile
+import torch
+import warnings
+import zipfile
+
+if sys.version_info[0] == 2:
+ from urlparse import urlparse
+ from urllib2 import urlopen # noqa f811
+else:
+ from urllib.request import urlopen
+ from urllib.parse import urlparse # noqa: F401
+
+try:
+ from tqdm import tqdm
+except ImportError:
+ # fake tqdm if it's not installed
+ class tqdm(object):
+
+ def __init__(self, total=None, disable=False):
+ self.total = total
+ self.disable = disable
+ self.n = 0
+
+ def update(self, n):
+ if self.disable:
+ return
+
+ self.n += n
+ if self.total is None:
+ sys.stderr.write("\r{0:.1f} bytes".format(self.n))
+ else:
+ sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
+ sys.stderr.flush()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.disable:
+ return
+
+ sys.stderr.write('\n')
+
+# matches bfd8deac from resnet18-bfd8deac.pth
+HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
+
+MASTER_BRANCH = 'master'
+ENV_TORCH_HOME = 'TORCH_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+VAR_DEPENDENCY = 'dependencies'
+MODULE_HUBCONF = 'hubconf.py'
+READ_DATA_CHUNK = 8192
+hub_dir = None
+
+
+# Copied from tools/shared/module_loader to be included in torch package
+def import_module(name, path):
+ if sys.version_info >= (3, 5):
+ import importlib.util
+ spec = importlib.util.spec_from_file_location(name, path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
+ elif sys.version_info >= (3, 0):
+ from importlib.machinery import SourceFileLoader
+ return SourceFileLoader(name, path).load_module()
+ else:
+ import imp
+ return imp.load_source(name, path)
+
+
+def _remove_if_exists(path):
+ if os.path.exists(path):
+ if os.path.isfile(path):
+ os.remove(path)
+ else:
+ shutil.rmtree(path)
+
+
+def _git_archive_link(repo_owner, repo_name, branch):
+ return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch)
+
+
+def _download_archive_zip(url, filename):
+ sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, filename))
+ response = urlopen(url)
+ with open(filename, 'wb') as f:
+ while True:
+ data = response.read(READ_DATA_CHUNK)
+ if len(data) == 0:
+ break
+ f.write(data)
+
+
+def _load_attr_from_module(module, func_name):
+ # Check if callable is defined in the module
+ if func_name not in dir(module):
+ return None
+ return getattr(module, func_name)
+
+
+def _get_torch_home():
+ torch_home = os.path.expanduser(
+ os.getenv(ENV_TORCH_HOME,
+ os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch')))
+ return torch_home
+
+
+def _setup_hubdir():
+ global hub_dir
+ # Issue warning to move data if old env is set
+ if os.getenv('TORCH_HUB'):
+ warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
+
+ if hub_dir is None:
+ torch_home = _get_torch_home()
+ hub_dir = os.path.join(torch_home, 'hub')
+
+ if not os.path.exists(hub_dir):
+ os.makedirs(hub_dir)
+
+
+def _parse_repo_info(github):
+ branch = MASTER_BRANCH
+ if ':' in github:
+ repo_info, branch = github.split(':')
+ else:
+ repo_info = github
+ repo_owner, repo_name = repo_info.split('/')
+ return repo_owner, repo_name, branch
+
+
+def _get_cache_or_reload(github, force_reload):
+ # Parse github repo information
+ repo_owner, repo_name, branch = _parse_repo_info(github)
+
+ # Github renames folder repo-v1.x.x to repo-1.x.x
+ # We don't know the repo name before downloading the zip file
+ # and inspect name from it.
+ # To check if cached repo exists, we need to normalize folder names.
+ repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, branch]))
+
+ use_cache = (not force_reload) and os.path.exists(repo_dir)
+
+ if use_cache:
+ sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
+ else:
+ cached_file = os.path.join(hub_dir, branch + '.zip')
+ _remove_if_exists(cached_file)
+
+ url = _git_archive_link(repo_owner, repo_name, branch)
+ _download_archive_zip(url, cached_file)
+
+ cached_zipfile = zipfile.ZipFile(cached_file)
+ extraced_repo_name = cached_zipfile.infolist()[0].filename
+ extracted_repo = os.path.join(hub_dir, extraced_repo_name)
+ _remove_if_exists(extracted_repo)
+ # Unzip the code and rename the base folder
+ cached_zipfile.extractall(hub_dir)
+
+ _remove_if_exists(cached_file)
+ _remove_if_exists(repo_dir)
+ shutil.move(extracted_repo, repo_dir) # rename the repo
+
+ return repo_dir
+
+
+def _check_module_exists(name):
+ if sys.version_info >= (3, 4):
+ import importlib.util
+ return importlib.util.find_spec(name) is not None
+ elif sys.version_info >= (3, 3):
+ # Special case for python3.3
+ import importlib.find_loader
+ return importlib.find_loader(name) is not None
+ else:
+ # NB: imp doesn't handle hierarchical module names (names contains dots).
+ try:
+ import imp
+ imp.find_module(name)
+ except Exception:
+ return False
+ return True
+
+
+def _check_dependencies(m):
+ dependencies = _load_attr_from_module(m, VAR_DEPENDENCY)
+
+ if dependencies is not None:
+ missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
+ if len(missing_deps):
+ raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps)))
+
+
+def _load_entry_from_hubconf(m, model):
+ if not isinstance(model, str):
+ raise ValueError('Invalid input: model should be a string of function name')
+
+ # Note that if a missing dependency is imported at top level of hubconf, it will
+ # throw before this function. It's a chicken and egg situation where we have to
+ # load hubconf to know what're the dependencies, but to import hubconf it requires
+ # a missing package. This is fine, Python will throw proper error message for users.
+ _check_dependencies(m)
+
+ func = _load_attr_from_module(m, model)
+
+ if func is None or not callable(func):
+ raise RuntimeError('Cannot find callable {} in hubconf'.format(model))
+
+ return func
+
+
+[docs]def set_dir(d):
+ r"""
+ Optionally set hub_dir to a local dir to save downloaded models & weights.
+
+ If ``set_dir`` is not called, default path is ``$TORCH_HOME/hub`` where
+ environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
+ ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
+ filesytem layout, with a default value ``~/.cache`` if the environment
+ variable is not set.
+
+
+ Args:
+ d: path to a local folder to save downloaded models & weights.
+ """
+ global hub_dir
+ hub_dir = d
+
+
+[docs]def list(github, force_reload=False):
+ r"""
+ List all entrypoints available in `github` hubconf.
+
+ Args:
+ github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
+ tag/branch. The default branch is `master` if not specified.
+ Example: 'pytorch/vision[:hub]'
+ force_reload: Optional, whether to discard the existing cache and force a fresh download.
+ Default is `False`.
+ Returns:
+ entrypoints: a list of available entrypoint names
+
+ Example:
+ >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
+ """
+ # Setup hub_dir to save downloaded files
+ _setup_hubdir()
+
+ repo_dir = _get_cache_or_reload(github, force_reload)
+
+ sys.path.insert(0, repo_dir)
+
+ hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
+
+ sys.path.remove(repo_dir)
+
+ # We take functions starts with '_' as internal helper functions
+ entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
+
+ return entrypoints
+
+
+[docs]def help(github, model, force_reload=False):
+ r"""
+ Show the docstring of entrypoint `model`.
+
+ Args:
+ github: Required, a string with format <repo_owner/repo_name[:tag_name]> with an optional
+ tag/branch. The default branch is `master` if not specified.
+ Example: 'pytorch/vision[:hub]'
+ model: Required, a string of entrypoint name defined in repo's hubconf.py
+ force_reload: Optional, whether to discard the existing cache and force a fresh download.
+ Default is `False`.
+ Example:
+ >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
+ """
+ # Setup hub_dir to save downloaded files
+ _setup_hubdir()
+
+ repo_dir = _get_cache_or_reload(github, force_reload)
+
+ sys.path.insert(0, repo_dir)
+
+ hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
+
+ sys.path.remove(repo_dir)
+
+ entry = _load_entry_from_hubconf(hub_module, model)
+
+ return entry.__doc__
+
+
+# Ideally this should be `def load(github, model, *args, forece_reload=False, **kwargs):`,
+# but Python2 complains syntax error for it. We have to skip force_reload in function
+# signature here but detect it in kwargs instead.
+# TODO: fix it after Python2 EOL
+[docs]def load(github, model, *args, **kwargs):
+ r"""
+ Load a model from a github repo, with pretrained weights.
+
+ Args:
+ github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
+ tag/branch. The default branch is `master` if not specified.
+ Example: 'pytorch/vision[:hub]'
+ model: Required, a string of entrypoint name defined in repo's hubconf.py
+ *args: Optional, the corresponding args for callable `model`.
+ force_reload: Optional, whether to force a fresh download of github repo unconditionally.
+ Default is `False`.
+ **kwargs: Optional, the corresponding kwargs for callable `model`.
+
+ Returns:
+ a single model with corresponding pretrained weights.
+
+ Example:
+ >>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
+ """
+ # Setup hub_dir to save downloaded files
+ _setup_hubdir()
+
+ force_reload = kwargs.get('force_reload', False)
+ kwargs.pop('force_reload', None)
+
+ repo_dir = _get_cache_or_reload(github, force_reload)
+
+ sys.path.insert(0, repo_dir)
+
+ hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
+
+ entry = _load_entry_from_hubconf(hub_module, model)
+
+ model = entry(*args, **kwargs)
+
+ sys.path.remove(repo_dir)
+
+ return model
+
+
+def _download_url_to_file(url, dst, hash_prefix, progress):
+ file_size = None
+ u = urlopen(url)
+ meta = u.info()
+ if hasattr(meta, 'getheaders'):
+ content_length = meta.getheaders("Content-Length")
+ else:
+ content_length = meta.get_all("Content-Length")
+ if content_length is not None and len(content_length) > 0:
+ file_size = int(content_length[0])
+
+ f = tempfile.NamedTemporaryFile(delete=False)
+ try:
+ if hash_prefix is not None:
+ sha256 = hashlib.sha256()
+ with tqdm(total=file_size, disable=not progress) as pbar:
+ while True:
+ buffer = u.read(8192)
+ if len(buffer) == 0:
+ break
+ f.write(buffer)
+ if hash_prefix is not None:
+ sha256.update(buffer)
+ pbar.update(len(buffer))
+
+ f.close()
+ if hash_prefix is not None:
+ digest = sha256.hexdigest()
+ if digest[:len(hash_prefix)] != hash_prefix:
+ raise RuntimeError('invalid hash value (expected "{}", got "{}")'
+ .format(hash_prefix, digest))
+ shutil.move(f.name, dst)
+ finally:
+ f.close()
+ if os.path.exists(f.name):
+ os.remove(f.name)
+
+
+def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True):
+ r"""Loads the Torch serialized object at the given URL.
+
+ If the object is already present in `model_dir`, it's deserialized and
+ returned. The filename part of the URL should follow the naming convention
+ ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
+ digits of the SHA256 hash of the contents of the file. The hash is used to
+ ensure unique names and to verify the contents of the file.
+
+ The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
+ environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
+ ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
+ filesytem layout, with a default value ``~/.cache`` if not set.
+
+ Args:
+ url (string): URL of the object to download
+ model_dir (string, optional): directory in which to save the object
+ map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
+ progress (bool, optional): whether or not to display a progress bar to stderr
+
+ Example:
+ >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
+
+ """
+ # Issue warning to move data if old env is set
+ if os.getenv('TORCH_MODEL_ZOO'):
+ warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
+
+ if model_dir is None:
+ torch_home = _get_torch_home()
+ model_dir = os.path.join(torch_home, 'checkpoints')
+
+ try:
+ os.makedirs(model_dir)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ # Directory already exists, ignore.
+ pass
+ else:
+ # Unexpected OSError, re-raise.
+ raise
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ cached_file = os.path.join(model_dir, filename)
+ if not os.path.exists(cached_file):
+ sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
+ hash_prefix = HASH_REGEX.search(filename).group(1)
+ _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
+ return torch.load(cached_file, map_location=map_location)
+
+import torch._C
+from torch.autograd import Variable, function
+from torch.serialization import validate_cuda_device
+from torch.nn import Module, ModuleList, Parameter, Sequential
+from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args
+import torch.backends.cudnn as cudnn
+import torch.jit.annotations
+import torch._jit_internal as _jit_internal
+from torch._six import with_metaclass, get_function_from_type, \
+ string_classes
+from torch._jit_internal import ignore # noqa: F401
+from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
+ _list_with_default
+import torch.testing
+
+import math
+from collections import OrderedDict, namedtuple
+import textwrap
+import sys
+import warnings
+import weakref
+import types
+import contextlib
+import os
+import functools
+import copy
+import collections
+import inspect
+import pickle
+if sys.version_info[0] > 2:
+ import pathlib
+
+
+def _parse_env(name, default, true_message, false_message):
+ value = os.environ.get(name)
+ if value is None:
+ return default
+ if value.lower() in {'1', 'true', 'yes'}:
+ return True
+ elif value.lower() in {'0', 'false', 'no'}:
+ return False
+ if value == '1v':
+ print(true_message)
+ return True
+ elif value == '0v':
+ print(false_message)
+ return False
+ raise ValueError('Unknown setting of {}. Try using 0 or 1.'.format(name))
+
+
+_enabled = _parse_env('PYTORCH_JIT', True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED")
+_flatten = torch._C._jit_flatten
+_unflatten = torch._C._jit_unflatten
+_jit_script_class_compile = torch._C._jit_script_class_compile
+
+Future = torch._C.Future
+_fork = torch._C.fork
+_wait = torch._C.wait
+
+
+@contextlib.contextmanager
+def scope(scope_name):
+ tracing_state = torch._C._get_tracing_state()
+ if tracing_state:
+ tracing_state.push_scope(scope_name)
+ try:
+ yield
+ finally:
+ if tracing_state:
+ tracing_state.pop_scope()
+
+
+DEFAULT_EXTRA_FILES_MAP = torch._C.ExtraFilesMap()
+
+
+[docs]def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP):
+ r"""
+ Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.save>`
+
+ All previously saved modules, no matter their device, are first loaded onto CPU,
+ and then are moved to the devices they were saved from. If this fails (e.g. because
+ the run time system doesn't have certain devices), an exception is raised.
+ However, storages can be dynamically remapped to an alternative set of devices
+ using the `map_location` argument. Comparing to :func:`torch.load`, `map_location`
+ in this function is simplified, which only accepts a string (e.g., 'cpu', 'cuda:0'),
+ or torch.device (e.g., torch.device('cpu'))
+
+ Arguments:
+ f: a file-like object (has to implement read, readline, tell, and seek),
+ or a string containing a file name
+ map_location: can a string (e.g., 'cpu', 'cuda:0'), a device (e.g.,
+ torch.device('cpu'))
+ _extra_files: map from filename to content. The extra
+ filenames given in the map would be loaded and their content
+ would be stored in the provided map.
+
+
+ Returns:
+ A ``ScriptModule`` object.
+
+ Example: ::
+
+ torch.jit.load('scriptmodule.pt')
+
+ # Load ScriptModule from io.BytesIO object
+ with open('scriptmodule.pt', 'rb') as f:
+ buffer = io.BytesIO(f.read())
+
+ # Load all tensors to the original device
+ torch.jit.load(buffer)
+
+ # Load all tensors onto CPU, using a device
+ torch.jit.load(buffer, map_location=torch.device('cpu'))
+
+ # Load all tensors onto CPU, using a string
+ torch.jit.load(buffer, map_location='cpu')
+
+ # Load with extra files.
+ files = {'metadata.json' : ''}
+ torch.jit.load('scriptmodule.pt', _extra_files = files)
+ print (files['metadata.json'])
+ """
+ m = ScriptModule()
+
+ def module_lookup(names):
+ curr = m
+ for name in names:
+ if not hasattr(curr, name):
+ setattr(curr, name, ScriptModule())
+ curr = getattr(curr, name)
+ return curr._c
+ if isinstance(f, string_classes):
+ if not os.path.exists(f):
+ raise ValueError("The provided filename {} does not exist".format(f))
+ if isinstance(map_location, string_classes):
+ map_location = torch.device(map_location)
+ elif not (map_location is None or
+ isinstance(map_location, torch.device)):
+ raise ValueError("map_location should be either None, string or torch.device, "
+ "but got type: " + str(type(map_location)))
+ if (str(map_location).startswith('cuda')):
+ validate_cuda_device(map_location)
+
+ if isinstance(f, str) or \
+ (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
+ (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
+ torch._C.import_ir_module(module_lookup, f, map_location, _extra_files)
+ else:
+ torch._C.import_ir_module_from_buffer(module_lookup, f.read(), map_location, _extra_files)
+
+ return m
+
+
+[docs]def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP):
+ """
+ Save an offline version of this module for use in a separate process. The saved
+ module serializes all of the methods, submodules, parameters, and attributes of this
+ module. It can be loaded into the C++ API using ``torch::jit::load(filename)`` or into the Python
+ API with ``torch.jit.load(filename)``.
+
+ To be able to save a module, it must not make any calls to native Python functions.
+ This means that all submodules must be subclasses of ``torch.jit.ScriptModule`` as well.
+
+ .. DANGER::
+ All modules, no matter their device, are always loaded onto the CPU during loading.
+ This is different from :func:`torch.load`'s semantics and may change in the future.
+
+ Arguments:
+ m: a ScriptModule to save
+ f: a file-like object (has to implement write and flush) or a string
+ containing a file name
+ _extra_files: Map from filename to contents which will be stored as part of 'f'
+
+ .. warning::
+ If you are using Python 2, ``torch.save`` does NOT support ``StringIO.StringIO``
+ as a valid file-like object. This is because the write method should return
+ the number of bytes written; ``StringIO.write()`` does not do this.
+
+ Please use something like ``io.BytesIO`` instead.
+
+ Example: ::
+
+ m = torch.jit.ScriptModule()
+
+ # Save to file
+ torch.jit.save(m, 'scriptmodule.pt')
+
+ # Save to io.BytesIO buffer
+ buffer = io.BytesIO()
+ torch.jit.save(m, buffer)
+
+ # Save with extra files
+ extra_files = torch._C.ExtraFilesMap()
+ extra_files['foo.txt'] = 'bar'
+ torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
+ """
+ if isinstance(f, str) or \
+ (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
+ (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
+ m.save(f, _extra_files=_extra_files)
+ else:
+ ret = m.save_to_buffer(_extra_files=_extra_files)
+ f.write(ret)
+
+
+def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False, return_inputs=False):
+ """
+ Trace a function or model, returning a tuple consisting of the both the
+ *trace* of an execution, as well as the original return value. If return_inputs,
+ also returns the trace inputs as part of the tuple
+
+ Tracing is guaranteed not to change the semantics of the function/module
+ that is traced.
+
+ Arguments:
+ f (torch.nn.Module or function): the function or module
+ to be traced.
+ args (tuple or Tensor): the positional arguments to pass to the
+ function/module to be traced. A non-tuple is assumed to
+ be a single positional argument to be passed to the model.
+ kwargs (dict): the keyword arguments to pass to the function/module
+ to be traced.
+
+ Example: Trace a cell.
+
+ >>> trace, out = jit.trace(nn.LSTMCell(), (input, hidden))
+ >>> print(trace)
+ """
+ if kwargs is None:
+ kwargs = {}
+ if not isinstance(args, tuple):
+ args = (args,)
+ return LegacyTracedModule(f, _force_outplace, return_inputs)(*args, **kwargs)
+
+
+def _unique_state_dict(module, keep_vars=False):
+ # since Parameter.data always creates a new torch.Tensor instance,
+ # id(v) doesn't work with it. So we always get the Parameter or Buffer
+ # as values, and deduplicate the params using Parameters and Buffers
+ state_dict = module.state_dict(keep_vars=True)
+ filtered_dict = type(state_dict)()
+ seen_ids = set()
+ for k, v in state_dict.items():
+ if id(v) in seen_ids:
+ continue
+ seen_ids.add(id(v))
+ if keep_vars:
+ filtered_dict[k] = v
+ else:
+ filtered_dict[k] = v.data
+ return filtered_dict
+
+
+def _create_interpreter_name_lookup_fn(frames_up=1):
+ def _get_interpreter_name_for_var(var):
+ frame = inspect.currentframe()
+ i = 0
+ while i < frames_up + 1:
+ frame = frame.f_back
+ i += 1
+
+ f_locals = frame.f_locals
+ f_globals = frame.f_globals
+
+ for k, v in f_locals.items():
+ if isinstance(v, torch.Tensor) and var is v:
+ return k if k != 'self' else ''
+ for k, v in f_globals.items():
+ if isinstance(v, torch.Tensor) and var is v:
+ return k if k != 'self' else ''
+ return ''
+ return _get_interpreter_name_for_var
+
+
+class LegacyTracedModule(Module):
+ def __init__(self, inner, force_outplace=False, return_inputs=False):
+ super(LegacyTracedModule, self).__init__()
+ # inner may be a Module, or it may be an arbitrary callable
+ # If it's a Module, we get its parameters automatically, which lets
+ # us avoid a special casing functions versus modules.
+ self.inner = inner
+ self._force_outplace = force_outplace
+ self._return_inputs = return_inputs
+
+ def forward(self, *args):
+ in_vars, in_desc = _flatten(args)
+ # NOTE: use full state, because we need it for BatchNorm export
+ # This differs from the compiler path, which doesn't support it at the moment.
+ module_state = list(_unique_state_dict(self, keep_vars=True).values())
+ trace, all_trace_inputs = torch._C._tracer_enter(*(in_vars + module_state))
+ ret_inputs = tuple(x.clone() for x in all_trace_inputs)
+ torch._C._tracer_set_force_outplace(self._force_outplace)
+ torch._C._tracer_set_get_unique_name_fn(_create_interpreter_name_lookup_fn())
+ try:
+ trace_inputs = _unflatten(all_trace_inputs[:len(in_vars)], in_desc)
+ out = self.inner(*trace_inputs)
+ out_vars, _ = _flatten(out)
+ torch._C._tracer_exit(tuple(out_vars))
+ except Exception:
+ torch._C._tracer_abandon()
+ raise
+ if self._return_inputs:
+ return trace, out, ret_inputs
+ else:
+ return trace, out
+
+
+def _clone_inputs(args):
+ def clone_input(a):
+ if a is None:
+ return None
+ elif isinstance(a, torch.Tensor):
+ # TODO: figure out one liner to .clone() and set requires_grad
+ v = Variable(a.data.clone(), requires_grad=a.requires_grad)
+ if a.grad is not None:
+ v.grad = clone_input(v.grad)
+ return v
+ else:
+ return a.clone()
+ return function._nested_map(lambda x: isinstance(x, torch.Tensor),
+ clone_input, condition_msg="tensors")(args)
+
+
+# This is purely for developer debugging. We are not going to advertise it.
+_JIT_DUMP = os.environ.get('PYTORCH_JIT_DUMP', False)
+_JIT_TIME = os.environ.get('PYTORCH_JIT_TIME', False) # CUDA-only timing
+_JIT_DISABLE = os.environ.get('PYTORCH_JIT_DISABLE', False)
+_JIT_STATS = os.environ.get('PYTORCH_JIT_STATS', False)
+
+
+def _dump_trace(trace_name, pass_name, input_key, trace):
+ if not _JIT_DUMP:
+ return
+
+ import torch.contrib._graph_vis as graph_vis
+
+ filename = "{}_{}".format(trace_name, pass_name)
+ # TODO: Also paste out the backtrace when the trace was compiled
+ # (and maybe also when it was run?)
+ with open(filename + ".ir", "w") as f:
+ f.write("Input key: {}\n\n{}".format(input_key, str(trace)))
+ graph_vis.write(trace.graph(), filename + ".html")
+
+
+@contextlib.contextmanager
+def _time(trace_name, name, time=True):
+ if (not _JIT_TIME and not time) or not torch.cuda.is_available():
+ yield
+ return
+ stream = torch.cuda.current_stream()
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ stream.record_event(start)
+ try:
+ yield
+ finally:
+ stream.record_event(end)
+ end.synchronize()
+ print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))
+
+
+def verify(model, args, loss_fn=torch.sum, devices=None):
+ """
+ Verify that a JIT compiled model has the same behavior as its uncompiled
+ version along with its backwards pass. If your model returns multiple
+ outputs, you must also specify a `loss_fn` to produce a loss for which
+ the backwards will be computed.
+
+ This function has side-effects (e.g., it executes your model / saves and loads
+ parameters), so don't expect the model to come out exactly the same as what
+ you passed in.
+
+ Arguments:
+ model (compiled torch.nn.Module or function): the module/function to be
+ verified. The module/function definition MUST have been decorated with
+ `@torch.jit.compile`.
+ args (tuple or Tensor): the positional arguments to pass to the
+ compiled function/module to be verified. A non-tuple is assumed to
+ be a single positional argument to be passed to the model.
+ loss_fn (function, optional): the loss function to be applied to
+ the output of the model, before backwards is invoked. By default,
+ we assume that a model returns a single result, and we :func:`torch.sum`
+ before calling backwards; if this is inappropriate, you can pass your
+ own loss function. Note that if a model returns a tuple of results,
+ these are passed as separate positional arguments to `loss_fn`.
+ devices (iterable of device IDs, optional): the GPU devices which the
+ compiled module will be run on. This determines the RNG state we
+ must save when running both compiled and uncompiled versions of the model.
+ """
+ # TODO: In principle, we track device information in our trace, so it
+ # should be possible to check if our execution actually obeyed the 'devices'
+ # the user provided.
+
+ # TODO: Consider adding a utility function to torch.jit to test
+ # for this case
+ if not isinstance(model, torch._C.CompiledFunction):
+ raise TypeError("Cannot verify an uncompiled module. Add @torch.jit.compile to compile it")
+ is_module = isinstance(model, Module)
+
+ if not isinstance(args, tuple):
+ args = (args,)
+
+ saved_args = _clone_inputs(args)
+ if is_module:
+ saved_state = copy.deepcopy(model.state_dict())
+
+ def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
+ params = list(model.parameters()) if is_module else []
+ in_vars, _ = _flatten((args, params))
+ # We use a special API to reset the trace and compile it from scratch.
+ compiled_fn = model
+ if force_trace:
+ compiled_fn.clear_cache()
+ if assert_compiled:
+ hits = compiled_fn.hits
+ out = model(*args)
+ if assert_compiled and compiled_fn.hits == hits:
+ raise RuntimeError("failed to use the compiled function")
+ if not isinstance(out, tuple):
+ out = (out, )
+ if loss_fn == torch.sum and len(out) != 1:
+ raise ValueError(("Model returns {} outputs, but default loss function "
+ "(torch.sum) can only handle a single output").format(len(out)))
+ out_vars, _ = _flatten(out)
+ saved_outs = [v.data.clone() for v in out_vars]
+ loss = loss_fn(*out)
+ grads = torch.autograd.grad([loss], in_vars)
+ # TODO: I'm not sure if the clone here is necessary but it is safer
+ saved_grads = [v.data.clone() for v in grads]
+ return (saved_outs, saved_grads)
+
+ with torch.random.fork_rng(devices, _caller="torch.jit.verify"):
+ uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True)
+ assert model.has_trace_for(*args)
+
+ if is_module:
+ model.load_state_dict(saved_state)
+ compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
+
+ _verify_equal(uncompiled_outs, compiled_outs)
+ _verify_equal(uncompiled_grads, compiled_grads)
+
+
+def _verify_equal(xs, ys):
+ for x, y in zip(xs, ys):
+ if x.sub(y).abs().max() > 1e-6:
+ raise RuntimeError("JIT and real computation mismatch")
+
+
+def indent(s):
+ return '\n'.join(['\t' + line for line in s.splitlines()])
+
+
+class TracingCheckError(Exception):
+ def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None):
+ self.message = 'Tracing failed sanity checks!\n'
+ if extra_msg is not None:
+ self.message += extra_msg + '\n'
+ if graph_diff_error is not None:
+ self.message += 'ERROR: Graphs differed across invocations!\n'
+ self.message += indent(graph_diff_error) + '\n'
+ if tensor_compare_error is not None:
+ self.message += 'ERROR: Tensor-valued Constant nodes differed in value ' \
+ 'across invocations. This often indicates that the tracer has' \
+ ' encountered untraceable code.\n'
+ self.message += indent(tensor_compare_error) + '\n'
+ super(TracingCheckError, self).__init__(self.message)
+
+
+# Check the traced module against a set of user-provided validation inputs
+@torch.no_grad()
+def _check_trace(check_inputs, func, executor_options, module, check_tolerance, force_outplace):
+ # Note: tracing is independent of optimizations, which consume the trace
+ executor_options['optimize'] = False
+ for inputs in check_inputs:
+ if isinstance(inputs, torch.Tensor):
+ inputs = (inputs,)
+ check_mod = torch.jit.trace(
+ func,
+ _clone_inputs(inputs),
+ check_trace=False,
+ _force_outplace=force_outplace,
+ **executor_options)
+
+ def graph_diagnostic_info():
+ mod_canonicalized = torch._C._jit_pass_canonicalize(module.graph)
+ torch._C._jit_pass_erase_shape_information(mod_canonicalized)
+ check_canonicalized = torch._C._jit_pass_canonicalize(check_mod.graph)
+ torch._C._jit_pass_erase_shape_information(check_canonicalized)
+
+ graph_diff_errors = None
+ if str(mod_canonicalized) != str(check_canonicalized):
+ import difflib
+ graph_diff = difflib.ndiff(str(mod_canonicalized).splitlines(True),
+ str(check_canonicalized).splitlines(True))
+ graph_diff_errors = 'Graph diff:\n' + indent(''.join(graph_diff)) + '\n'
+
+ for n_mod, n_check in zip(mod_canonicalized.nodes(), check_canonicalized.nodes()):
+ if str(n_mod) != str(n_check):
+ graph_diff_errors += 'First diverging operator:\n'
+ node_diff = difflib.ndiff(str(n_mod).splitlines(True),
+ str(n_check).splitlines(True))
+ source_printout = 'Node diff:\n' + indent(''.join(node_diff)) + '\n'
+ mod_stack = n_mod.getSourceLocation()
+ if mod_stack:
+ source_printout += 'Trace source location:\n' + indent(mod_stack) + '\n'
+ check_stack = n_check.getSourceLocation()
+ if check_stack:
+ source_printout += 'Check source location:\n' + indent(check_stack) + '\n'
+ graph_diff_errors += source_printout
+
+ break # For now, only print out the first pair of nodes that diverges
+
+ tensor_compare_errors = None
+ # Check Tensor-valued constant nodes
+ for n_mod, n_check in zip(mod_canonicalized.nodes(), check_canonicalized.nodes()):
+ if n_mod.kind() != n_check.kind():
+ break # Graphs have already diverged
+
+ if n_mod.kind() == 'prim::Constant' and not (n_mod.mustBeNone() or n_check.mustBeNone()):
+ if n_mod.kindOf('value') != 't' or n_check.kindOf('value') != 't':
+ continue
+
+ mod_tensor_val = n_mod.t('value')
+ check_tensor_val = n_check.t('value')
+
+ try:
+ torch.testing.assert_allclose(mod_tensor_val, check_tensor_val)
+ except (RuntimeError, AssertionError) as e:
+ if tensor_compare_errors is None:
+ tensor_compare_errors = ''
+ tensor_compare_errors += 'Node:\n' + indent(str(n_mod)) + '\n'
+ compare_stack = n_mod.getSourceLocation()
+ if compare_stack:
+ tensor_compare_errors += 'Source Location:\n' + indent(compare_stack) + '\n'
+ tensor_compare_errors += 'Comparison exception: ' + indent(str(e))
+
+ break # For now, only print the first diverging pair
+
+ return graph_diff_errors, tensor_compare_errors
+
+ def wrap_retval(x):
+ return x if isinstance(x, tuple) else (x,)
+
+ def run_mod_and_filter_tensor_outputs(mod, inputs, running_what):
+ try:
+ outs = wrap_retval(mod(*_clone_inputs(inputs)))
+ outs = [out for out in outs if isinstance(out, torch.Tensor)]
+ return outs
+ except Exception as e:
+ raise TracingCheckError(*graph_diagnostic_info(),
+ extra_msg='Encountered an exception while running the ' + running_what +
+ ' with test inputs.\nException:\n' + indent(str(e)))
+
+ has_warned = [False]
+
+ def maybe_warn_nondeterministic():
+ if has_warned[0]:
+ return
+ has_warned[0] = True
+ nondeterm_ops = [op for op in module.graph.nodes() if op.isNondeterministic()]
+ if len(nondeterm_ops) > 0:
+ nondeterministic_ops_warning = "Trace had nondeterministic nodes. "
+ nondeterministic_ops_warning += "Did you forget call .eval() on your model? Nodes:\n"
+ nondeterministic_ops_warning += "\n".join([indent(str(op)) for op in nondeterm_ops][:20])
+ nondeterministic_ops_warning += "\nThis may cause errors in trace checking. To disable trace checking,"\
+ " pass check_trace=False to torch.jit.trace()"
+ warnings.warn(nondeterministic_ops_warning, category=TracerWarning, stacklevel=5)
+
+ def compare_outputs(original, reference, match_what):
+ all_ok = True
+ for i, (orig, ref) in enumerate(zip(original, reference)):
+ try:
+ torch.testing.assert_allclose(orig.double(), ref.double(), rtol=check_tolerance,
+ atol=torch.testing._get_default_tolerance(orig, ref)[1])
+ except AssertionError as e:
+ maybe_warn_nondeterministic()
+ warnings.warn('Output nr ' + str(i + 1) + '. of the traced function does not match '
+ 'the corresponding output of the ' + match_what + '. Detailed error:\n' + str(e),
+ category=TracerWarning, stacklevel=4)
+ all_ok = False
+
+ return all_ok
+
+ traced_outs = run_mod_and_filter_tensor_outputs(module, inputs, 'trace')
+ fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, 'Python function')
+ if compare_outputs(traced_outs, fn_outs, 'Python function'):
+ check_outs = run_mod_and_filter_tensor_outputs(check_mod, inputs, 'repeated trace')
+ compare_outputs(traced_outs, check_outs, 'repeated trace')
+
+ diag_info = graph_diagnostic_info()
+ if any(info is not None for info in diag_info):
+ raise TracingCheckError(*diag_info)
+
+
+class TracerWarning(Warning):
+ @staticmethod
+ def ignore_lib_warnings():
+ # We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace
+ warnings.filterwarnings('ignore', category=TracerWarning, module='torch.(?!jit)')
+
+
+# We ignore the tracer warnings coming form inside the library, because all our shape
+# checks in nn will trigger them.
+TracerWarning.ignore_lib_warnings()
+torch._C._tracer_warn_use_python()
+
+
+[docs]def trace(func,
+ example_inputs,
+ optimize=True,
+ check_trace=True,
+ check_inputs=None,
+ check_tolerance=1e-5,
+ _force_outplace=False,
+ _module_class=None):
+ """
+ Trace a function and return an executable ``ScriptModule`` that will be optimized
+ using just-in-time compilation.
+
+ .. warning::
+
+ Tracing only correctly records functions and modules which are not data
+ dependent (e.g., do not have conditionals on data in tensors) and do not have
+ any untracked external dependencies (e.g., perform input/output or
+ access global variables). If you trace such models, you may silently get
+ incorrect results on subsequent invocations of the model. The tracer
+ will try to emit warnings when doing something that may cause an
+ incorrect trace to be produced.
+
+ Arguments:
+ func (callable or torch.nn.Module): a Python function or ``torch.nn.Module``
+ that will be run with ``example_inputs``.
+ arguments and returns to ``func`` must be tensors
+ or (possibly nested) tuples that
+ contain tensors.
+ example_inputs (tuple): a tuple of example inputs that will be passed to the function
+ while tracing. The resulting trace can be run with
+ inputs of different types and shapes assuming the traced operations
+ support those types and shapes. ``example_inputs`` may also be a single
+ Tensor in which case it is automatically wrapped in a tuple
+
+ Keyword arguments:
+ optimize (bool, optional): whether or not to apply optimizations. Default: ``True``.
+ check_trace (bool, optional): check if the same inputs run through
+ traced code produce the same outputs. Default: ``True``. You might want
+ to disable this if, for example, your network contains non-
+ deterministic ops or if you are sure that the network is correct despite
+ a checker failure.
+
+ check_inputs (list of tuples, optional): A list of tuples of input arguments that should be used
+ to check the trace against what is expected. Each tuple
+ is equivalent to a set of input arguments that would
+ be specified in ``example_inputs``. For best results, pass in a
+ set of checking inputs representative of the space of
+ shapes and types of inputs you expect the network to see.
+ If not specified, the original ``example_inputs`` are used for checking
+ check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure.
+ This can be used to relax the checker strictness in the event that
+ results diverge numerically for a known reason, such as operator fusion.
+
+ Returns:
+ A ``ScriptModule`` object with a single ``forward()`` method containing the traced code.
+ When ``func`` is a ``torch.nn.Module``, the returned ``ScriptModule`` will have the same set of
+ sub-modules and parameters as ``func``.
+
+ Example::
+
+ def f(x):
+ return x * 2
+ traced_f = torch.jit.trace(f, torch.rand(1))
+
+ """
+ if not _enabled:
+ return func
+ executor_options = {'optimize': bool(optimize)}
+ # Special case for common case of passing a single Tensor
+ if isinstance(example_inputs, (torch.Tensor, dict)):
+ example_inputs = (example_inputs,)
+ # done primarily so that weird iterables fail here and not pybind11 code
+ elif not isinstance(example_inputs, tuple):
+ example_inputs = tuple(example_inputs)
+ var_lookup_fn = _create_interpreter_name_lookup_fn(0)
+
+ if isinstance(func, torch.nn.Module):
+ if _module_class is None:
+ _module_class = TopLevelTracedModule
+ traced = _module_class(func, **executor_options)
+ traced._c._create_method_from_trace('forward', func, example_inputs,
+ var_lookup_fn, _force_outplace)
+ else:
+ name = getattr(func, '__name__', 'forward')
+ if name == '<lambda>':
+ name = '_lambda' # make name a valid identifier
+ traced = torch._C._create_function_from_trace(name, func, example_inputs,
+ var_lookup_fn,
+ _force_outplace)
+
+ # Check the trace against new traces created from user-specified inputs
+ if check_trace:
+ if check_inputs is not None:
+ _check_trace(check_inputs, func, executor_options, traced, check_tolerance, _force_outplace)
+ else:
+ _check_trace([example_inputs], func, executor_options, traced, check_tolerance, _force_outplace)
+
+ return traced
+
+
+class CompilationUnit(object):
+ def __init__(self, lang=None, optimize=True, _frames_up=0):
+ self._c = torch._C.CompilationUnit()
+ self._c.set_optimized(optimize)
+ if lang is not None:
+ self.define(lang, _frames_up=_frames_up + 1)
+
+ def define(self, lang, rcb=None, _frames_up=0):
+ if not rcb:
+ rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
+ self._c.define(lang, rcb)
+
+ def __getattr__(self, attr):
+ r = self._c.find_function(attr)
+ if r is None:
+ raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
+ return r
+
+ def _import(self, src, constants):
+ """ test import logic for single function, use only for testing """
+ src = "op_version_set = 0\n{}".format(src)
+ torch._C._jit_import_functions(self._c, src, constants, None)
+ return self
+
+
+def _try_get_dispatched_fn(fn):
+ if not callable(fn):
+ return None
+ return _jit_internal.boolean_dispatched.get(fn)
+
+
+def _try_get_overloaded_fn(mod, field):
+ return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
+
+
+def _try_compile_weak_script(fn):
+ entry = _jit_internal.compiled_weak_fns.get(fn)
+ if entry is None:
+ return None
+ if entry["status"] == _jit_internal.COMPILATION_PENDING:
+ compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
+ del entry["rcb"]
+ _jit_internal.compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
+ entry["status"] = _jit_internal.COMPILED
+ return compiled_fn
+ else:
+ return entry["compiled_fn"]
+
+
+# ScriptClasses must be new-style classes because we construct them using their
+# __new__ method.
+def _is_new_style_class(cls):
+ if hasattr(cls, '__class__'):
+ return ('__dict__' in dir(cls) or hasattr(cls, '__slots__'))
+
+
+def whichmodule(obj):
+ """Find the module an object belong to."""
+ module_name = getattr(obj, '__module__', None)
+ # Protect the iteration by using a list copy of sys.modules against dynamic
+ # modules that trigger imports of other modules upon calls to getattr.
+ for name, module in list(sys.modules.items()):
+ if name == '__main__' or module is None:
+ continue
+ try:
+ if _getattribute(module, name)[0] is obj:
+ return module_name
+ except AttributeError:
+ pass
+ return '__main__'
+
+
+# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
+def _qualified_name(obj):
+ name = obj.__name__
+ module_name = obj.__module__
+
+ # The Python docs are very clear that `__module__` can be None, but I can't
+ # figure out when it actually would be.
+ if module_name is None:
+ raise RuntimeError("Could not get qualified name for class '{}': "
+ "__module__ can't be None.".format(name))
+
+ # if getattr(sys.modules[module_name], name) is not obj:
+ # raise RuntimeError("Could not get qualified name for class '{}': "
+ # "the attr {} on module {} is not the the class".format(name, name, module_name))
+
+ # __main__ is a builtin module, so rewrite it to "__torch__".
+ if module_name == "__main__":
+ module_name = "__torch__"
+ else:
+ # Everything else gets a "__torch__" prefix to avoid name collisions
+ # with the names of user values.
+ module_name = "__torch__." + module_name
+
+ if "." in name:
+ raise RuntimeError("Could not get qualified name for class '{}': "
+ "'{}' is not a valid identifier".format(name, name))
+
+ return module_name + "." + name
+
+
+def script(obj, optimize=True, _frames_up=0, _rcb=None):
+ if not _enabled:
+ return obj
+ if _rcb is None:
+ _rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
+ if inspect.isclass(obj):
+ if not _is_new_style_class(obj):
+ raise RuntimeError("TorchScript classes must be new-style classes. Please inherit from 'object'")
+ name = _qualified_name(obj)
+ ast = get_jit_class_def(obj, name)
+ _jit_script_class_compile(ast, _rcb)
+ _add_script_class(obj, name)
+ return obj
+ else:
+ ast = get_jit_def(obj)
+ fn = torch._C._jit_script_compile(ast, _rcb, get_default_args(obj))
+ # Forward docstrings
+ fn.__doc__ = obj.__doc__
+ return fn
+
+
+ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
+
+
+def script_method(fn, _rcb=None):
+ if not _enabled:
+ return fn
+ # NOTE: we need to traverse two frames here because the meta-class frame
+ # for ScriptModule will be present, as opposed to invoking @script on a
+ # a function or invoking define() on a CompilationUnit.
+ # The stack will look like:
+ #
+ # 0. createResolutionCallback()
+ # 1. script_method()
+ # 2. ScriptModule metaclass frame
+ # 3. Surrounding scope
+ #
+ # createResolutionCallback internally adds 1 to get us to the scope of this
+ # function (the calling function). Adding 2 gets us to the proper surrounding scope.
+ if _rcb is None:
+ _rcb = _jit_internal.createResolutionCallback(frames_up=2)
+ ast = get_jit_def(fn, self_name="ScriptModule")
+ return ScriptMethodStub(_rcb, ast, fn)
+
+
+def _try_get_weak_module(mod):
+ """
+ Get the WeakScriptModuleProxy corresponding to mod if it exists
+ """
+ if not isinstance(mod, Module):
+ return None
+ return _jit_internal.weak_modules.get(mod)
+
+
+def _try_get_ignored_op(fn):
+ if not callable(fn):
+ return False
+ if hasattr(fn, '__func__'):
+ fn = fn.__func__
+ return fn in _jit_internal.ignored_fns
+
+
+def _is_weak_type(cls):
+ """
+ Check if a type has been annotated with `weak_module`
+ """
+ return cls in _jit_internal.weak_types
+
+
+# These OrderedDictWrapper classes replace the actual OrderedDicts in
+# module with versions that get/set properties inside of script::Module.
+# This allows us to reuse most of nn.Module while still storing the
+# data in C++.
+# Each OrderedDict needs to support:
+# x not in view
+# x in view
+# view[name] = ...
+# view.values()
+# del view[name]
+# view.items()
+# view.keys()
+# len(view)
+
+class OrderedDictWrapper(object):
+ def __init__(self, module):
+ self.module = module
+
+ def keys(self):
+ return [k for k, v in self.items()]
+
+ def values(self):
+ return [v for k, v in self.items()]
+
+ def __delitem__(self, k):
+ raise RuntimeError("cannot delete methods or parameters of a script module")
+
+ def items(self):
+ raise NotImplementedError
+
+ def __contains__(self, k):
+ raise NotImplementedError
+
+ def __getitem__(self, k):
+ raise NotImplementedError
+
+ def __setitem__(self, k, v):
+ raise NotImplementedError
+
+
+class OrderedModuleDict(OrderedDictWrapper):
+ def __init__(self, module):
+ super(OrderedModuleDict, self).__init__(module)
+ # contains _both_ script modules and non-script python-only modules
+
+ # because script modules are subclassed in python and the
+ # C++ script::Module class will not hold references to them,
+ # to ensure that you always get the same python value here
+ # we store it in the python dict as well
+ self._python_modules = OrderedDict()
+
+ def items(self):
+ r = self._python_modules.items()
+ return r
+
+ def __contains__(self, k):
+ return k in self._python_modules
+
+ def __setitem__(self, k, v):
+ if k in self._python_modules:
+ raise RuntimeError("cannot re-assign modules in a ScriptModule")
+ if isinstance(v, ScriptModule):
+ self.module._register_module(k, v._c)
+
+ self._python_modules[k] = v
+
+ def __getitem__(self, k):
+ return self._python_modules[k]
+
+
+class OrderedParameterDict(OrderedDictWrapper):
+ def __init__(self, module):
+ super(OrderedParameterDict, self).__init__(module)
+
+ def items(self):
+ return [(name, param) for name, param in self.module._get_parameters()]
+
+ def __setitem__(self, k, v):
+ self.module._register_parameter(k, v, False)
+
+ def __contains__(self, k):
+ return self.module._has_parameter(k)
+
+ def __getitem__(self, k):
+ if k not in self:
+ raise KeyError(k)
+ return self.module._get_parameter(k)
+
+
+class OrderedBufferDict(OrderedDictWrapper):
+ def __init__(self, module):
+ super(OrderedBufferDict, self).__init__(module)
+
+ def items(self):
+ return [(name, param) for name, _, param in
+ self.module._get_attributes() if isinstance(param, torch.Tensor)]
+
+ def __setitem__(self, k, v):
+ self.module._register_buffer(k, v)
+
+ def __contains__(self, k):
+ return self.module._has_buffer(k)
+
+ def __getitem__(self, k):
+ if k not in self:
+ raise KeyError(k)
+ return self.module._get_buffer(k)
+
+# base types that can be constants
+# in addition, tuples and lists of these base types are also considered constants
+# If you edit this list, then you also need to edit the handlers in
+# ConstantValue in jit/script/init.cpp
+_constant_types = (bool, float, int, str, type(None), types.FunctionType, torch.device, torch.layout, torch.dtype)
+
+
+def _get_valid_constant(attr, v):
+ if isinstance(v, _constant_types):
+ return v
+ elif isinstance(v, tuple) or isinstance(v, list):
+ return tuple(_get_valid_constant(attr, x) for x in v)
+ constants = ", ".join(typ.__name__ for typ in _constant_types)
+ raise TypeError(textwrap.dedent("""
+ '{}' object for attribute '{}' is not a valid constant.
+ Valid constants are:
+ 1. a nn.ModuleList
+ 2. a value of type {{{}}}
+ 3. a list or tuple of (2)
+ """.format(type(v).__name__, attr, constants)))
+
+
+def _create_methods_from_stubs(self, stubs):
+ defs = [m.def_ for m in stubs]
+ rcbs = [m.resolution_callback for m in stubs]
+ defaults = [get_default_args(m.original_method) for m in stubs]
+ self._c._create_methods(self, defs, rcbs, defaults)
+
+# For each user-defined class that subclasses ScriptModule this meta-class,
+# (1) finds all the methods annotated with @script_method
+# in a ScriptModule and removes them from the class attributes, and
+# (2) puts a wrapper around the class's __init__ method to register
+# all of the script_methods with the module after the original __init__
+# has run. This has to occur after the user-defined __init__ so that
+# submodules and parameters are initialized _before_ the script compiler
+# resolve references to `self.param` or `self.module`.
+
+
+class ScriptMeta(type):
+ # this has to inherit from pybind11's metaclass otherwise we get
+ # issues because ScriptModule inherits from torch._C.ScriptModule,
+ # a pybind11 type
+ def __init__(cls, name, bases, attrs):
+ # find all the script methods
+ cls._original_methods = {}
+ methods = []
+ for k, v in sorted(attrs.items()):
+ if isinstance(v, ScriptMethodStub):
+ delattr(cls, k)
+ methods.append(v)
+ cls._original_methods[v.original_method.__name__] = v.original_method
+ # after the user's __init__ register all the script methods
+ # with the module
+ original_init = getattr(cls, '__init__', lambda self: None)
+ super_constants = getattr(super(cls), '_constants_set', set())
+ cls._constants_set = set(getattr(cls, '__constants__', ())).union(super_constants)
+ cls._overloads = dict(getattr(cls, '__overloads__', {}))
+
+ @functools.wraps(original_init)
+ def init_then_register(self, *args, **kwargs):
+ original_init(self, *args, **kwargs)
+ _create_methods_from_stubs(self, methods)
+
+ cls.__init__ = init_then_register
+ return super(ScriptMeta, cls).__init__(name, bases, attrs)
+
+
+if _enabled:
+
+ # this is a Python 'non-data descriptor' that causes the first access
+ # to ScriptModule's forward to lookup the forward method and stash
+ # it in the objects dict. Due to the standard rules for attribute lookup
+ # subsequent lookups will just directly return the previously looked up method.
+ # This is necessary because nn.Module defines forward as a method. If we
+ # did nothing __getattr__ would not be called. Instead we'd get nn.Module.forward
+ # which always throws an exception.
+ class _CachedForward(object):
+ def __get__(self, obj, cls):
+ return self.__getattr__('forward')
+
+ class ScriptModule(with_metaclass(ScriptMeta, Module)):
+ r"""
+ The core data structure in TorchScript is the ``ScriptModule``. It is an
+ analogue of torch's ``nn.Module`` and represents an entire model as a tree of
+ submodules. Like normal modules, each individual module in a ``ScriptModule`` can
+ have submodules, parameters, and methods. In ``nn.Module``\s methods are implemented
+ as Python functions, but in ``ScriptModule``\s methods are implemented as
+ TorchScript functions, a statically-typed subset of Python that contains all
+ of PyTorch's built-in Tensor operations. This difference allows your
+ ScriptModules code to run without the need for a Python interpreter.
+
+ ``ScriptModule``\s be created in two ways:
+
+ **Tracing:**
+
+ Using ``torch.jit.trace``, you can turn an existing module or Python
+ function into a TorchScript program. You must provide example inputs,
+ and we run the function, recording the operations performed on all the tensors. We turn the resulting recording
+ into a TorchScript method that is installed as the ``forward`` method of a
+ ``ScriptModule``. This module also contains any parameters that the original
+ module had as well.
+
+ Example (tracing a function)::
+
+ import torch
+ def foo(x, y):
+ return 2 * x + y
+ traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
+
+ .. note::
+ Tracing a function will construct a ``ScriptModule`` with a single
+ ``forward`` method that implements the function. The resulting
+ ``ScriptModule`` has no parameters or attributes.
+
+ Example (tracing an existing module)::
+
+ import torch
+ import torchvision
+ traced_net = torch.jit.trace(torchvision.models.resnet18(),
+ torch.rand(1, 3, 224, 224))
+
+ .. note::
+
+ Tracing only records operations done when the given function is run on the given
+ tensors. Therefore, the returned ``ScriptModule`` will always run the same traced
+ graph on any input. This has some important implications when your module is
+ expected to run different sets of operations, depending on the input and/or the
+ module state. For example,
+
+ + Tracing will not record any control-flow like if-statements or loops. When
+ this control-flow is constant across your module, this is fine and it often
+ inlines the control-flow decisions. But sometimes the control-flow is
+ actually part of the model itself. For instance, a recurrent network is
+ a loop over the (possibly dynamic) length of an input sequence.
+
+ + In the returned ``ScriptModule``, operations that have different behaviors
+ in ``training`` and ``eval`` modes will always behave as if it is in the
+ mode it was in during tracing, no matter which mode the ``ScriptModule``
+ is in.
+
+ In cases like these, tracing would not be appropriate and scripting is a better
+ choice.
+
+ **Scripting:**
+
+ You can write TorchScript code directly using Python syntax. You do this
+ using the ``@torch.jit.script`` decorator (for functions) or
+ ``@torch.jit.script_method`` decorator (for methods) on subclasses of
+ ``ScriptModule``. With this decorator the body of the annotated function is
+ directly translated into TorchScript. TorchScript itself is a subset of
+ the Python language, so not all features in Python work, but we provide
+ enough functionality to compute on tensors and do control-dependent
+ operations.
+
+ Example (scripting a function)::
+
+ import torch
+ @torch.jit.script
+ def foo(x, y):
+ if x.max() > y.max():
+ r = x
+ else:
+ r = y
+ return r
+
+ .. note::
+ A ``@torch.jit.script`` decorator will construct a ``ScriptModule`` with a single
+ ``forward`` method that implements the function. The resulting
+ ``ScriptModule`` has no parameters or attributes.
+
+ Example (scripting a simple module with a Parameter)::
+
+ import torch
+ class MyModule(torch.jit.ScriptModule):
+ def __init__(self, N, M):
+ super(MyModule, self).__init__()
+ self.weight = torch.nn.Parameter(torch.rand(N, M))
+
+ @torch.jit.script_method
+ def forward(self, input):
+ return self.weight.mv(input)
+
+ Example (scripting a module with traced submodules)::
+
+ import torch
+ import torch.nn as nn
+ import torch.nn.functional as F
+
+ class MyScriptModule(torch.jit.ScriptModule):
+ def __init__(self):
+ super(MyScriptModule, self).__init__()
+ # torch.jit.trace produces a ScriptModule's conv1 and conv2
+ self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
+ self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
+
+ @torch.jit.script_method
+ def forward(self, input):
+ input = F.relu(self.conv1(input))
+ input = F.relu(self.conv2(input))
+ return input
+ """
+
+ def __init__(self, optimize=True):
+ self.__dict__['_c'] = torch._C.ScriptModule()
+ Module.__init__(self)
+ self._c._set_optimized(optimize)
+ self._parameters = OrderedParameterDict(self._c)
+ self._buffers = OrderedBufferDict(self._c)
+ self._modules = OrderedModuleDict(self._c)
+
+ @property
+ def graph(self):
+ return self.forward.graph
+
+ @property
+ def code(self):
+ return self.forward.code
+
+ def save(self, *args, **kwargs):
+ return self._c.save(*args, **kwargs)
+
+ def save_to_buffer(self, *args, **kwargs):
+ return self._c.save_to_buffer(*args, **kwargs)
+
+ def get_debug_state(self, *args, **kwargs):
+ return self._c.get_debug_state()
+
+ forward = _CachedForward()
+
+ def __getattr__(self, attr):
+ if '_c' not in self.__dict__:
+ raise RuntimeError("ScriptModule has not been initialized, did you forget to call super's init?")
+ if self._c._has_method(attr):
+ if attr in self.__class__._original_methods:
+ original_method = self.__class__._original_methods[attr]
+ script_method = self._c._get_method(attr)
+ script_method = functools.wraps(original_method)(script_method)
+ else:
+ script_method = self._c._get_method(attr)
+ # cache method so future calls do not go through __getattr__
+ # to improve invocation performance
+ self.__dict__[attr] = script_method
+ return script_method
+
+ if self._c._has_attribute(attr):
+ return self._c._get_attribute(attr)
+ return Module.__getattr__(self, attr)
+
+ def __setattr__(self, attr, value):
+ if attr not in self._constants_set:
+ if isinstance(value, Module) and _is_weak_type(type(value)):
+ # Compile weak script module
+ value = _make_strong(value)
+ if attr == 'training':
+ if self._c._has_buffer('training'):
+ self.__dict__['training'] = value
+ self._c._get_buffer('training').fill_(int(value))
+ return
+ if isinstance(value, Attribute):
+ the_type = torch.jit.annotations.ann_to_type(value.type)
+ try:
+ self._c._register_attribute(attr, the_type, value.value)
+ except RuntimeError:
+ raise RuntimeError("Could not register attribute '{}' of type '{}' for a value of type '{}'"
+ .format(attr, value.type, type(value.value)))
+ return
+ return super(ScriptModule, self).__setattr__(attr, value)
+
+ if hasattr(self, attr):
+ raise RuntimeError("attempting to re-assign constant '{}'".format(attr))
+
+ def conv_module_to_const(module_value):
+ if not isinstance(module_value, (ModuleList, Sequential)):
+ return module_value
+ for i in range(len(module_value)):
+ module_value[i] = conv_module_to_const(module_value[i])
+ if isinstance(module_value, Sequential):
+ return _ConstSequential(module_value)
+ else:
+ return _ConstModuleList(module_value)
+
+ if isinstance(value, (ModuleList, Sequential)):
+ # special case for list of modules. Modules need to be registered with their
+ # parent module. To do this, we create a ConstModuleList, which is itself a module, that
+ # contains each of these modules as submodules. The ConstModuleList then
+ # is set as an attribute of the parent module.
+ super(ScriptModule, self).__setattr__(attr, conv_module_to_const(value))
+ else:
+ super(ScriptModule, self).__setattr__(attr, _get_valid_constant(attr, value))
+
+ def __dir__(self):
+ return sorted(Module.__dir__(self) + self._method_names())
+
+ def define(self, lang):
+ # We use frames_up=1 to get to the proper surrounding scope. The stack
+ # will look like:
+ # 0. createResolutionCallback
+ # 1. define()
+ # 2. surrounding scope.
+ #
+ # createResolutionCallback internally adds 1 to get us to our frame, then
+ # we add 1 to get to the proper surrounding scope.
+ rcb = _jit_internal.createResolutionCallback(frames_up=1)
+ self._c._define(self, lang, rcb)
+
+ def copy(self):
+ m = ScriptModule()
+
+ def module_lookup(names):
+ curr = m
+ for name in names:
+ if not hasattr(curr, name):
+ setattr(curr, name, ScriptModule())
+ curr = getattr(curr, name)
+ return curr._c
+ self._c._copy_into(module_lookup, {}, [])
+ return m
+
+ def __getstate__(self):
+ raise pickle.PickleError(
+ "ScriptModules cannot be saved using torch.save. " +
+ "Mixed serialization of script and non-script modules is not supported. " +
+ "For purely script modules use my_script_module.save(<filename>) instead.")
+
+ def graph_for(self, *args, **kwargs):
+ return self.forward.graph_for(*args, **kwargs)
+
+ class WeakScriptModuleProxy(ScriptModule):
+ def __init__(self, original, stubs):
+ # Guards behavior of __setattr__ and __getattr__ so ScriptModule
+ # __init__ can run correctly
+ self.__dict__['_initialized'] = False
+ super(WeakScriptModuleProxy, self).__init__()
+
+ self.__dict__["_original"] = weakref.ref(original)
+
+ # Copy Parameters / Modules / Buffers
+ for name in dir(original):
+ item = getattr(original, name)
+ if item is None and name in original._parameters:
+ # XXX: treat None value simply as module attributes instead of adding them to the parameter list
+ # TODO: need to handle this more generally when non-tensor attributes added to module
+ object.__setattr__(self, name, item)
+ elif isinstance(item, Parameter) or (isinstance(item, Module) and item is not self):
+ ScriptModule.__setattr__(self, name, item)
+ for name in original._buffers:
+ if original._buffers[name] is None:
+ object.__setattr__(self, name, None)
+ else:
+ self.register_buffer(name, original._buffers[name])
+
+ # Copy constants
+ self.__dict__["_constants_set"] = set(getattr(original, "__constants__", []))
+
+ # Copy overloads
+ self.__dict__["_overloads"] = dict(getattr(original, "__overloads__", {}))
+
+ self.__dict__["_initialized"] = True
+ _create_methods_from_stubs(self, stubs)
+
+ def __getattr__(self, attr):
+ # Try to get the attribute directly, if that fails, fall back to the
+ # weak module itself
+ try:
+ return ScriptModule.__getattr__(self, attr)
+ except AttributeError:
+ if self.__dict__["_initialized"]:
+ return getattr(self.__dict__["_original"](), attr)
+ else:
+ # Only fall back to original once __init__() is done
+ raise AttributeError("Weak module has no attribute '{}'"
+ .format(attr))
+
+ def __setattr__(self, attr, value):
+ # Once constructed, no new properties can be set
+
+ if not self.__dict__["_initialized"]:
+ # If constructing, don't fall back to original module
+ return ScriptModule.__setattr__(self, attr, value)
+
+ if hasattr(self, attr):
+ return ScriptModule.__setattr__(self, attr, value)
+ else:
+ raise AttributeError("Cannot set new attribute '{}' on "
+ "weak script module once it has been "
+ "created".format(attr))
+
+else:
+[docs] class ScriptModule(torch.nn.Module):
+ def __init__(self, optimize=True):
+ super(ScriptModule, self).__init__()
+
+
+def _get_weak_stubs(cls):
+ """
+ Calls script_method for each method on the type of the object passed in and
+ returns the generated ScriptMethodStubs
+ """
+ stubs = []
+ for name in dir(cls):
+ func = get_function_from_type(cls, name)
+ if func in _jit_internal.weak_script_methods:
+ entry = _jit_internal.weak_script_methods[func]
+ stub = script_method(entry["original_method"], entry["rcb"])
+ stubs.append(stub)
+ return stubs
+
+
+def _make_strong(mod):
+ """
+ Converts a weak module into a subclass of ScriptModule
+ """
+ if mod in _jit_internal.weak_modules:
+ return _jit_internal.weak_modules[mod]
+
+ stubs = _jit_internal.weak_types.get(type(mod))["method_stubs"]
+
+ if stubs is None:
+ # Generate stubs and and store on weak_types in case this type is
+ # used again
+ stubs = _get_weak_stubs(type(mod))
+ _jit_internal.weak_types[type(mod)]["method_stubs"] = stubs
+
+ # Create proxy with stubs
+ proxy = WeakScriptModuleProxy(mod, stubs)
+
+ _jit_internal.weak_modules[mod] = proxy
+
+ return proxy
+
+
+def _get_methods(cls):
+ import inspect
+ # In Python 3 unbound methods are functions, but in Python 2 they are methods
+ return inspect.getmembers(cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x))
+
+
+_compiled_methods_whitelist = {
+ 'forward', 'register_buffer', 'register_parameter', 'add_module',
+ '_apply', 'apply', 'cuda', 'cpu', 'to', 'type', 'float', 'double', 'half',
+ 'state_dict', 'load_state_dict', '_load_from_state_dict',
+ '_named_members', 'parameters', 'named_parameters',
+ 'buffers', 'named_buffers', 'children', 'named_children', 'modules',
+ 'named_modules', 'zero_grad', 'share_memory', '_get_name', 'extra_repr',
+ '_slow_forward', '_tracing_name', 'eval', 'train',
+}
+
+
+def _make_fail(name):
+ def fail(self, *args, **kwargs):
+ raise RuntimeError(name + " is not supported on ScriptModules")
+ return fail
+
+
+for name, method in _get_methods(torch.nn.Module):
+ if name.startswith('__'):
+ continue
+ if name not in ScriptModule.__dict__ and name not in _compiled_methods_whitelist:
+ setattr(ScriptModule, method.__name__, _make_fail(name))
+
+
+class TracedModule(ScriptModule):
+ __frozen = False
+
+ def __init__(self, orig, id_set=None, optimize=True):
+ # XXX: orig can be a nn.Module or a function!
+ super(TracedModule, self).__init__(optimize=optimize)
+ if id_set is None:
+ id_set = set()
+
+ assert(isinstance(orig, torch.nn.Module))
+ self._name = 'TracedModule[' + type(orig).__name__ + ']'
+
+ def check_unique(param):
+ if param in id_set:
+ raise ValueError("TracedModules don't support parameter sharing between modules")
+ id_set.add(param)
+
+ self.training = orig.training
+
+ for name, param in orig._parameters.items():
+ if param is not None:
+ self._parameters[name] = param
+ check_unique(param)
+ for name, buf in orig._buffers.items():
+ if buf is not None:
+ self._buffers[name] = buf
+ check_unique(buf)
+
+ if orig._backward_hooks or orig._forward_hooks or orig._forward_pre_hooks:
+ raise ValueError("Modules that have hooks assigned can't be compiled")
+
+ for name, submodule in orig._modules.items():
+ if isinstance(submodule, ScriptModule) and not isinstance(submodule, TracedModule):
+ self._modules[name] = submodule.copy()
+ else:
+ self._modules[name] = TracedModule(submodule, id_set, optimize=optimize)
+
+ self._freeze()
+
+ def forward(self, *args, **kwargs):
+ raise RuntimeError('Trace submodules cannot be called.')
+
+ def _freeze(self):
+ self.__frozen = True
+
+ def _get_name(self):
+ return self._name
+
+ def __setattr__(self, attr, value):
+ if not self.__frozen or hasattr(self, attr):
+ return super(TracedModule, self).__setattr__(attr, value)
+ raise RuntimeError("Cannot set new properties on a traced module.")
+
+
+class TopLevelTracedModule(TracedModule):
+ forward = _CachedForward()
+
+
+class _ConstModuleList(ScriptModule):
+ def __init__(self, modules):
+ super(_ConstModuleList, self).__init__()
+ for i, module in enumerate(modules):
+ if _is_weak_type(type(module)):
+ module = _make_strong(module)
+ self.add_module(str(i), module)
+
+ def __getitem__(self, idx):
+ if isinstance(idx, slice):
+ return _ConstModuleList(list(self._modules.values())[idx])
+ else:
+ if not (-len(self) <= idx < len(self)):
+ raise IndexError('index {} is out of range'.format(idx))
+ if idx < 0:
+ idx += len(self)
+ return self._modules[str(idx)]
+
+ def __len__(self):
+ return len(self._modules)
+
+ def __iter__(self):
+ return iter(self._modules.values())
+
+ def __dir__(self):
+ keys = super(_ConstModuleList, self).__dir__()
+ keys = [key for key in keys if not key.isdigit()]
+ return keys
+
+
+class _ConstSequential(_ConstModuleList):
+ __constants__ = ['mods']
+
+ def __init__(self, mods):
+ super(_ConstSequential, self).__init__(mods._modules.values())
+
+ # we define the forward method via self.define rather than
+ # making it a direct class member (with a @script) annotation
+ # because, in optimized runtime environments where only .pyc files
+ # are shipped, we cant retrieve the source code.
+ # TODO: find a workaround for this and remove this hack
+ self.define("""
+ def forward(self, input):
+ for m in self:
+ input = m(input)
+ return input
+ """)
+
+
+_builtin_table = None
+
+_modules_containing_builtins = (torch, torch._C._nn)
+
+
+def _unwrap_optional(x):
+ assert x is not None, "Unwrapping null optional"
+ return x
+
+
+# lazily built to ensure the correct initialization order
+def _get_builtin_table():
+ global _builtin_table
+ if _builtin_table is not None:
+ return _builtin_table
+ _builtin_table = {}
+
+ def register_all(mod):
+ for name in dir(mod):
+ v = getattr(mod, name)
+ if callable(v):
+ _builtin_table[id(v)] = "aten::" + name
+ for mod in _modules_containing_builtins:
+ register_all(mod)
+
+ _builtin_table[id(warnings.warn)] = "aten::warn"
+ _builtin_table[id(_single)] = "aten::_single"
+ _builtin_table[id(_pair)] = "aten::_pair"
+ _builtin_table[id(_triple)] = "aten::_triple"
+ _builtin_table[id(_quadruple)] = "aten::_quadruple"
+ _builtin_table[id(_list_with_default)] = "aten::list_with_default"
+ _builtin_table[id(_unwrap_optional)] = "aten::_unwrap_optional"
+ _builtin_table[id(cudnn.is_acceptable)] = "aten::cudnn_is_acceptable"
+ _builtin_table[id(torch._C._infer_size)] = "aten::_infer_size"
+ _builtin_table[id(torch.nn.functional._no_grad_embedding_renorm_)] = "aten::_no_grad_embedding_renorm_"
+
+ _builtin_table[id(math.floor)] = "aten::floor"
+ _builtin_table[id(math.ceil)] = "aten::ceil"
+ _builtin_table[id(math.log)] = "aten::log"
+ _builtin_table[id(math.log1p)] = "aten::log1p"
+ _builtin_table[id(math.log10)] = "aten::log10"
+ _builtin_table[id(math.exp)] = "aten::exp"
+ _builtin_table[id(math.sqrt)] = "aten::sqrt"
+ _builtin_table[id(math.pow)] = "aten::pow"
+ _builtin_table[id(torch.nn.functional.interpolate)] = "aten::__interpolate"
+ _builtin_table[id(torch.nn.functional.upsample_nearest)] = "aten::__upsample_nearest"
+ _builtin_table[id(torch.nn.functional.upsample)] = "aten::__upsample"
+ _builtin_table[id(torch.nn.functional.upsample_bilinear)] = "aten::__upsample_bilinear"
+ _builtin_table[id(torch.nn.functional.assert_int_or_pair)] = "aten::_assert_int_or_pair"
+ _builtin_table[id(torch.nn.utils.rnn.get_packed_sequence)] = "aten::_pack_sequence"
+
+ _builtin_table[id(torch.nn.init._no_grad_fill_)] = "aten::_no_grad_fill_"
+ _builtin_table[id(torch.nn.init._no_grad_normal_)] = "aten::_no_grad_normal_"
+ _builtin_table[id(torch.nn.init._no_grad_uniform_)] = "aten::_no_grad_uniform_"
+ _builtin_table[id(torch.nn.init._no_grad_zero_)] = "aten::_no_grad_zero_"
+
+ return _builtin_table
+
+
+def _register_builtin(fn, op):
+ _get_builtin_table()[id(fn)] = op
+
+
+def _find_builtin(fn):
+ return _get_builtin_table().get(id(fn))
+
+
+_register_builtin(len, 'aten::len')
+_register_builtin(_wait, 'aten::wait')
+
+# qualified_name => ScriptClass mapping
+_script_classes = {}
+
+
+def _add_script_class(cls, name):
+ global _script_classes
+ _script_classes[name] = cls
+
+
+def _get_script_class(name):
+ global _script_classes
+ if name not in _script_classes:
+ raise RuntimeError("Unknown reference to ScriptClass '{}'. "
+ "Did you forget to import it?".format(name))
+ return _script_classes[name]
+
+# torch.jit.Error
+Error = torch._C.JITException
+
+
+class _disable_tracing(object):
+ def __enter__(self):
+ self.state = torch._C._get_tracing_state()
+ torch._C._set_tracing_state(None)
+
+ def __exit__(self, *args):
+ torch._C._set_tracing_state(self.state)
+ self.state = None
+
+
+# for use in python if using annotate
+def annotate(the_type, the_value):
+ # noop in python
+ return the_value
+
+
+Attribute = collections.namedtuple('Attribute', ['value', 'type'])
+
+last_executed_optimized_graph = torch._C._last_executed_optimized_graph
+
+
+def _graph_for(self, *args, **kwargs):
+ self(*args, **kwargs)
+ return last_executed_optimized_graph()
+
+torch._C.ScriptMethod.graph_for = _graph_for
+torch._C.Function.graph_for = _graph_for
+Function = torch._C.Function
+
+if not torch._C._jit_init():
+ raise RuntimeError("JIT initialization failed")
+
+"""
+torch.multiprocessing is a wrapper around the native :mod:`multiprocessing`
+module. It registers custom reducers, that use shared memory to provide shared
+views on the same data in different processes. Once the tensor/storage is moved
+to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible
+to send it to other processes without making any copies.
+
+The API is 100% compatible with the original module - it's enough to change
+``import multiprocessing`` to ``import torch.multiprocessing`` to have all the
+tensors sent through the queues or shared via other mechanisms, moved to shared
+memory.
+
+Because of the similarity of APIs we do not document most of this package
+contents, and we recommend referring to very good docs of the original module.
+"""
+import torch
+import sys
+from .reductions import init_reductions
+import multiprocessing
+
+__all__ = ['set_sharing_strategy', 'get_sharing_strategy',
+ 'get_all_sharing_strategies']
+
+
+from multiprocessing import * # noqa: F401
+
+
+__all__ += multiprocessing.__all__
+
+
+# This call adds a Linux specific prctl(2) wrapper function to this module.
+# See https://github.com/pytorch/pytorch/pull/14391 for more information.
+torch._C._multiprocessing_init()
+
+
+if sys.version_info < (3, 3):
+ """Override basic classes in Python 2.7 and Python 3.3 to use ForkingPickler
+ for serialization. Later versions of Python already use ForkingPickler."""
+ from .queue import Queue, SimpleQueue # noqa: F401
+ from .pool import Pool # noqa: F401
+
+
+"""Add helper function to spawn N processes and wait for completion of any of
+them. This depends `mp.get_context` which was added in Python 3.4."""
+from .spawn import spawn, SpawnContext # noqa: F401
+
+
+if sys.platform == 'darwin' or sys.platform == 'win32':
+ _sharing_strategy = 'file_system'
+ _all_sharing_strategies = {'file_system'}
+else:
+ _sharing_strategy = 'file_descriptor'
+ _all_sharing_strategies = {'file_descriptor', 'file_system'}
+
+
+[docs]def set_sharing_strategy(new_strategy):
+ """Sets the strategy for sharing CPU tensors.
+
+ Arguments:
+ new_strategy (str): Name of the selected strategy. Should be one of
+ the values returned by :func:`get_all_sharing_strategies()`.
+ """
+ global _sharing_strategy
+ assert new_strategy in _all_sharing_strategies
+ _sharing_strategy = new_strategy
+
+
+[docs]def get_sharing_strategy():
+ """Returns the current strategy for sharing CPU tensors."""
+ return _sharing_strategy
+
+
+[docs]def get_all_sharing_strategies():
+ """Returns a set of sharing strategies supported on a current system."""
+ return _all_sharing_strategies
+
+
+init_reductions()
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import multiprocessing
+import multiprocessing.connection
+import signal
+import sys
+
+from . import _prctl_pr_set_pdeathsig
+
+
+def _wrap(fn, i, args, error_queue):
+ # prctl(2) is a Linux specific system call.
+ # On other systems the following function call has no effect.
+ # This is set to ensure that non-daemonic child processes can
+ # terminate if their parent terminates before they do.
+ _prctl_pr_set_pdeathsig(signal.SIGINT)
+
+ try:
+ fn(i, *args)
+ except KeyboardInterrupt:
+ pass # SIGINT; Killed by parent, do nothing
+ except Exception:
+ # Propagate exception to parent process, keeping original traceback
+ import traceback
+ error_queue.put(traceback.format_exc())
+ sys.exit(1)
+
+
+def _python_version_check():
+ if sys.version_info < (3, 4):
+ raise RuntimeError("Requires python 3.4 or higher to use "
+ "torch.multiprocessing.spawn and "
+ "torch.multiprocessing.SpawnContext helper "
+ "to launch multiple processes. If you are using "
+ "this for distributed training and have a lower "
+ "version of python, please use "
+ "torch.distributed.launch instead.")
+
+
+[docs]class SpawnContext:
+ def __init__(self, processes, error_queues):
+ _python_version_check()
+ self.error_queues = error_queues
+ self.processes = processes
+ self.sentinels = {
+ process.sentinel: index
+ for index, process in enumerate(processes)
+ }
+
+ def pids(self):
+ return [int(process.pid) for process in self.processes]
+
+[docs] def join(self, timeout=None):
+ r"""
+ Tries to join one or more processes in this spawn context.
+ If one of them exited with a non-zero exit status, this function
+ kills the remaining processes and raises an exception with the cause
+ of the first process exiting.
+
+ Returns ``True`` if all processes have been joined successfully,
+ ``False`` if there are more processes that need to be joined.
+
+ Arguments:
+ timeout (float): Wait this long before giving up on waiting.
+ """
+ # Ensure this function can be called even when we're done.
+ if len(self.sentinels) == 0:
+ return True
+
+ # Wait for any process to fail or all of them to succeed.
+ ready = multiprocessing.connection.wait(
+ self.sentinels.keys(),
+ timeout=timeout,
+ )
+
+ error_index = None
+ for sentinel in ready:
+ index = self.sentinels.pop(sentinel)
+ process = self.processes[index]
+ process.join()
+ if process.exitcode != 0:
+ error_index = index
+ break
+
+ # Return if there was no error.
+ if error_index is None:
+ # Return whether or not all processes have been joined.
+ return len(self.sentinels) == 0
+
+ # Assume failure. Terminate processes that are still alive.
+ for process in self.processes:
+ if process.is_alive():
+ process.terminate()
+ process.join()
+
+ # There won't be an error on the queue if the process crashed.
+ if self.error_queues[error_index].empty():
+ exitcode = self.processes[error_index].exitcode
+ if exitcode < 0:
+ name = signal.Signals(-exitcode).name
+ raise Exception(
+ "process %d terminated with signal %s" %
+ (error_index, name)
+ )
+ else:
+ raise Exception(
+ "process %d terminated with exit code %d" %
+ (error_index, exitcode)
+ )
+
+ original_trace = self.error_queues[error_index].get()
+ msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
+ msg += original_trace
+ raise Exception(msg)
+
+
+[docs]def spawn(fn, args=(), nprocs=1, join=True, daemon=False):
+ r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
+
+ If one of the processes exits with a non-zero exit status, the
+ remaining processes are killed and an exception is raised with the
+ cause of termination. In the case an exception was caught in the
+ child process, it is forwarded and its traceback is included in
+ the exception raised in the parent process.
+
+ Arguments:
+ fn (function): Function is called as the entrypoint of the
+ spawned process. This function must be defined at the top
+ level of a module so it can be pickled and spawned. This
+ is a requirement imposed by multiprocessing.
+
+ The function is called as ``fn(i, *args)``, where ``i`` is
+ the process index and ``args`` is the passed through tuple
+ of arguments.
+
+ args (tuple): Arguments passed to ``fn``.
+ nprocs (int): Number of processes to spawn.
+ join (bool): Perform a blocking join on all processes.
+ daemon (bool): The spawned processes' daemon flag. If set to True,
+ daemonic processes will be created.
+
+ Returns:
+ None if ``join`` is ``True``,
+ :class:`~SpawnContext` if ``join`` is ``False``
+
+ """
+ _python_version_check()
+ mp = multiprocessing.get_context('spawn')
+ error_queues = []
+ processes = []
+ for i in range(nprocs):
+ error_queue = mp.SimpleQueue()
+ process = mp.Process(
+ target=_wrap,
+ args=(fn, i, args, error_queue),
+ daemon=daemon,
+ )
+ process.start()
+ error_queues.append(error_queue)
+ processes.append(process)
+
+ spawn_context = SpawnContext(processes, error_queues)
+ if not join:
+ return spawn_context
+
+ # Loop on join until it returns True or raises an exception.
+ while not spawn_context.join():
+ pass
+
+r"""Functional interface"""
+from __future__ import division
+
+import warnings
+import math
+
+import torch
+from torch._C import _infer_size, _add_docstr
+from . import _reduction as _Reduction
+from .modules import utils
+from ._functions import vision
+from .modules.utils import _single, _pair, _triple, _list_with_default
+from . import grad # noqa: F401
+from . import _VF
+from .._jit_internal import weak_script, List
+
+
+conv1d = _add_docstr(torch.conv1d, r"""
+conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros') -> Tensor
+
+Applies a 1D convolution over an input signal composed of several input
+planes.
+
+See :class:`~torch.nn.Conv1d` for details and output shape.
+
+.. include:: cudnn_deterministic.rst
+
+Args:
+ input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
+ weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kW)`
+ bias: optional bias of shape :math:`(\text{out\_channels})`. Default: ``None``
+ stride: the stride of the convolving kernel. Can be a single number or
+ a one-element tuple `(sW,)`. Default: 1
+ padding: implicit paddings on both sides of the input. Can be a
+ single number or a one-element tuple `(padW,)`. Default: 0
+ dilation: the spacing between kernel elements. Can be a single number or
+ a one-element tuple `(dW,)`. Default: 1
+ groups: split input into groups, :math:`\text{in\_channels}` should be divisible by
+ the number of groups. Default: 1
+ padding_mode: the type of paddings applied to both sided can be: `zeros` or `circular`. Default: `zeros`
+
+Examples::
+
+ >>> filters = torch.randn(33, 16, 3)
+ >>> inputs = torch.randn(20, 16, 50)
+ >>> F.conv1d(inputs, filters)
+""")
+
+conv2d = _add_docstr(torch.conv2d, r"""
+conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros') -> Tensor
+
+Applies a 2D convolution over an input image composed of several input
+planes.
+
+See :class:`~torch.nn.Conv2d` for details and output shape.
+
+.. include:: cudnn_deterministic.rst
+
+Args:
+ input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
+ weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
+ bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None``
+ stride: the stride of the convolving kernel. Can be a single number or a
+ tuple `(sH, sW)`. Default: 1
+ padding: implicit paddings on both sides of the input. Can be a
+ single number or a tuple `(padH, padW)`. Default: 0
+ dilation: the spacing between kernel elements. Can be a single number or
+ a tuple `(dH, dW)`. Default: 1
+ groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
+ number of groups. Default: 1
+ padding_mode: the type of paddings applied to both sided can be: `zeros` or `circular`. Default: `zeros`
+
+Examples::
+
+ >>> # With square kernels and equal stride
+ >>> filters = torch.randn(8,4,3,3)
+ >>> inputs = torch.randn(1,4,5,5)
+ >>> F.conv2d(inputs, filters, padding=1)
+""") # noqa: E501
+
+conv3d = _add_docstr(torch.conv3d, r"""
+conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros') -> Tensor
+
+Applies a 3D convolution over an input image composed of several input
+planes.
+
+See :class:`~torch.nn.Conv3d` for details and output shape.
+
+.. include:: cudnn_deterministic.rst
+
+Args:
+ input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)`
+ weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kT , kH , kW)`
+ bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: None
+ stride: the stride of the convolving kernel. Can be a single number or a
+ tuple `(sT, sH, sW)`. Default: 1
+ padding: implicit paddings on both sides of the input. Can be a
+ single number or a tuple `(padT, padH, padW)`. Default: 0
+ dilation: the spacing between kernel elements. Can be a single number or
+ a tuple `(dT, dH, dW)`. Default: 1
+ groups: split input into groups, :math:`\text{in\_channels}` should be divisible by
+ the number of groups. Default: 1
+ padding_mode: the type of paddings applied to both sided can be: `zeros` or `circular`. Default: `zeros`
+
+Examples::
+
+ >>> filters = torch.randn(33, 16, 3, 3, 3)
+ >>> inputs = torch.randn(20, 16, 50, 10, 20)
+ >>> F.conv3d(inputs, filters)
+""") # noqa: E501
+
+conv_transpose1d = _add_docstr(torch.conv_transpose1d, r"""
+conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
+
+Applies a 1D transposed convolution operator over an input signal
+composed of several input planes, sometimes also called "deconvolution".
+
+See :class:`~torch.nn.ConvTranspose1d` for details and output shape.
+
+.. include:: cudnn_deterministic.rst
+
+Args:
+ input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
+ weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kW)`
+ bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
+ stride: the stride of the convolving kernel. Can be a single number or a
+ tuple ``(sW,)``. Default: 1
+ padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
+ sides of each dimension in the input. Can be a single number or a tuple
+ ``(padW,)``. Default: 0
+ output_padding: additional size added to one side of each dimension in the
+ output shape. Can be a single number or a tuple ``(out_padW)``. Default: 0
+ groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
+ number of groups. Default: 1
+ dilation: the spacing between kernel elements. Can be a single number or
+ a tuple ``(dW,)``. Default: 1
+
+Examples::
+
+ >>> inputs = torch.randn(20, 16, 50)
+ >>> weights = torch.randn(16, 33, 5)
+ >>> F.conv_transpose1d(inputs, weights)
+""")
+
+conv_transpose2d = _add_docstr(torch.conv_transpose2d, r"""
+conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
+
+Applies a 2D transposed convolution operator over an input image
+composed of several input planes, sometimes also called "deconvolution".
+
+See :class:`~torch.nn.ConvTranspose2d` for details and output shape.
+
+.. include:: cudnn_deterministic.rst
+
+Args:
+ input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
+ weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kH , kW)`
+ bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
+ stride: the stride of the convolving kernel. Can be a single number or a
+ tuple ``(sH, sW)``. Default: 1
+ padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
+ sides of each dimension in the input. Can be a single number or a tuple
+ ``(padH, padW)``. Default: 0
+ output_padding: additional size added to one side of each dimension in the
+ output shape. Can be a single number or a tuple ``(out_padH, out_padW)``.
+ Default: 0
+ groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
+ number of groups. Default: 1
+ dilation: the spacing between kernel elements. Can be a single number or
+ a tuple ``(dH, dW)``. Default: 1
+
+Examples::
+
+ >>> # With square kernels and equal stride
+ >>> inputs = torch.randn(1, 4, 5, 5)
+ >>> weights = torch.randn(4, 8, 3, 3)
+ >>> F.conv_transpose2d(inputs, weights, padding=1)
+""") # noqa: E501
+
+conv_transpose3d = _add_docstr(torch.conv_transpose3d, r"""
+conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor
+
+Applies a 3D transposed convolution operator over an input image
+composed of several input planes, sometimes also called "deconvolution"
+
+See :class:`~torch.nn.ConvTranspose3d` for details and output shape.
+
+.. include:: cudnn_deterministic.rst
+
+Args:
+ input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)`
+ weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kT , kH , kW)`
+ bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None
+ stride: the stride of the convolving kernel. Can be a single number or a
+ tuple ``(sT, sH, sW)``. Default: 1
+ padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both
+ sides of each dimension in the input. Can be a single number or a tuple
+ ``(padT, padH, padW)``. Default: 0
+ output_padding: additional size added to one side of each dimension in the
+ output shape. Can be a single number or a tuple
+ ``(out_padT, out_padH, out_padW)``. Default: 0
+ groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
+ number of groups. Default: 1
+ dilation: the spacing between kernel elements. Can be a single number or
+ a tuple `(dT, dH, dW)`. Default: 1
+
+Examples::
+
+ >>> inputs = torch.randn(20, 16, 50, 10, 20)
+ >>> weights = torch.randn(16, 33, 3, 3, 3)
+ >>> F.conv_transpose3d(inputs, weights)
+""") # noqa: E501
+
+conv_tbc = _add_docstr(torch.conv_tbc, r"""
+Applies a 1-dimensional sequence convolution over an input sequence.
+Input and output dimensions are (Time, Batch, Channels) - hence TBC.
+
+Args:
+ input: input tensor of shape :math:`(\text{sequence length} \times batch \times \text{in\_channels})`
+ weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`)
+ bias: bias of shape (:math:`\text{out\_channels}`)
+ pad: number of timesteps to pad. Default: 0
+""")
+
+
+# Pooling
+avg_pool1d = _add_docstr(torch.avg_pool1d, r"""
+avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor
+
+Applies a 1D average pooling over an input signal composed of several
+input planes.
+
+See :class:`~torch.nn.AvgPool1d` for details and output shape.
+
+Args:
+ input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
+ kernel_size: the size of the window. Can be a single number or a
+ tuple `(kW,)`
+ stride: the stride of the window. Can be a single number or a tuple
+ `(sW,)`. Default: :attr:`kernel_size`
+ padding: implicit zero paddings on both sides of the input. Can be a
+ single number or a tuple `(padW,)`. Default: 0
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the
+ output shape. Default: ``False``
+ count_include_pad: when True, will include the zero-padding in the
+ averaging calculation. Default: ``True``
+
+Examples::
+
+ >>> # pool of square window of size=3, stride=2
+ >>> input = torch.tensor([[[1, 2, 3, 4, 5, 6, 7]]], dtype=torch.float32)
+ >>> F.avg_pool1d(input, kernel_size=3, stride=2)
+ tensor([[[ 2., 4., 6.]]])
+
+""")
+
+
+avg_pool2d = _add_docstr(torch._C._nn.avg_pool2d, r"""
+avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor
+
+Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size
+:math:`sH \times sW` steps. The number of output features is equal to the number of
+input planes.
+
+See :class:`~torch.nn.AvgPool2d` for details and output shape.
+
+Args:
+ input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
+ kernel_size: size of the pooling region. Can be a single number or a
+ tuple `(kH, kW)`
+ stride: stride of the pooling operation. Can be a single number or a
+ tuple `(sH, sW)`. Default: :attr:`kernel_size`
+ padding: implicit zero paddings on both sides of the input. Can be a
+ single number or a tuple `(padH, padW)`. Default: 0
+ ceil_mode: when True, will use `ceil` instead of `floor` in the formula
+ to compute the output shape. Default: ``False``
+ count_include_pad: when True, will include the zero-padding in the
+ averaging calculation. Default: ``True``
+""")
+
+avg_pool3d = _add_docstr(torch._C._nn.avg_pool3d, r"""
+avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor
+
+Applies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step
+size :math:`sT \times sH \times sW` steps. The number of output features is equal to
+:math:`\lfloor\frac{\text{input planes}}{sT}\rfloor`.
+
+See :class:`~torch.nn.AvgPool3d` for details and output shape.
+
+Args:
+ input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iT \times iH , iW)`
+ kernel_size: size of the pooling region. Can be a single number or a
+ tuple `(kT, kH, kW)`
+ stride: stride of the pooling operation. Can be a single number or a
+ tuple `(sT, sH, sW)`. Default: :attr:`kernel_size`
+ padding: implicit zero paddings on both sides of the input. Can be a
+ single number or a tuple `(padT, padH, padW)`, Default: 0
+ ceil_mode: when True, will use `ceil` instead of `floor` in the formula
+ to compute the output shape
+ count_include_pad: when True, will include the zero-padding in the
+ averaging calculation
+""")
+
+
+@weak_script
+def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
+ output_ratio=None, return_indices=False,
+ _random_samples=None):
+ # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa
+ r"""Applies 2D fractional max pooling over an input signal composed of several input planes.
+
+ Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
+
+ The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic
+ step size determined by the target output size.
+ The number of output features is equal to the number of input planes.
+
+ Args:
+ kernel_size: the size of the window to take a max over.
+ Can be a single number :math:`k` (for a square kernel of :math:`k \times k`)
+ or a tuple `(kH, kW)`
+ output_size: the target output size of the image of the form :math:`oH \times oW`.
+ Can be a tuple `(oH, oW)` or a single number :math:`oH` for a square image :math:`oH \times oH`
+ output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
+ This has to be a number or tuple in the range (0, 1)
+ return_indices: if ``True``, will return the indices along with the outputs.
+ Useful to pass to :func:`~torch.nn.functional.max_unpool2d`.
+
+ Examples::
+ >>> input = torch.randn(20, 16, 50, 32)
+ >>> # pool of square window of size=3, and target output size 13x12
+ >>> F.fractional_max_pool2d(input, 3, output_size=(13, 12))
+ >>> # pool of square window and target output size being half of input image size
+ >>> F.fractional_max_pool2d(input, 3, output_ratio=(0.5, 0.5))
+
+ .. _Fractional MaxPooling:
+ http://arxiv.org/abs/1412.6071
+ """
+ if output_size is None and output_ratio is None:
+ raise ValueError("fractional_max_pool2d requires specifying either "
+ "an output_size or an output_ratio")
+ if output_size is None:
+ _output_ratio = _pair(torch.jit._unwrap_optional(output_ratio))
+ output_size = [int(input.size(2) * _output_ratio[0]),
+ int(input.size(3) * _output_ratio[1])]
+
+ if _random_samples is None:
+ _random_samples = torch.rand(input.size(0), input.size(1), 2, dtype=input.dtype, device=input.device)
+ return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples)
+
+
+@weak_script
+def _fractional_max_pool2d(input, kernel_size, output_size=None,
+ output_ratio=None, return_indices=False,
+ _random_samples=None):
+ # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], Optional[BroadcastingList2[float]], bool, Optional[Tensor]) -> Tensor # noqa
+ return fractional_max_pool2d_with_indices(input, kernel_size, output_size,
+ output_ratio, return_indices,
+ _random_samples)[0]
+
+fractional_max_pool2d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=4,
+ default=False,
+ if_true=fractional_max_pool2d_with_indices,
+ if_false=_fractional_max_pool2d,
+ module_name=__name__,
+ func_name='fractional_max_pool2d')
+
+
+@weak_script
+def fractional_max_pool3d_with_indices(input, kernel_size, output_size=None,
+ output_ratio=None, return_indices=False,
+ _random_samples=None):
+ # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa
+ r"""Applies 3D fractional max pooling over an input signal composed of several input planes.
+
+ Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
+
+ The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic
+ step size determined by the target output size.
+ The number of output features is equal to the number of input planes.
+
+ Args:
+ kernel_size: the size of the window to take a max over.
+ Can be a single number :math:`k` (for a square kernel of :math:`k \times k \times k`)
+ or a tuple `(kT, kH, kW)`
+ output_size: the target output size of the form :math:`oT \times oH \times oW`.
+ Can be a tuple `(oT, oH, oW)` or a single number :math:`oH` for a cubic output
+ :math:`oH \times oH \times oH`
+ output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
+ This has to be a number or tuple in the range (0, 1)
+ return_indices: if ``True``, will return the indices along with the outputs.
+ Useful to pass to :func:`~torch.nn.functional.max_unpool3d`.
+
+ Examples::
+ >>> input = torch.randn(20, 16, 50, 32, 16)
+ >>> # pool of cubic window of size=3, and target output size 13x12x11
+ >>> F.fractional_max_pool3d(input, 3, output_size=(13, 12, 11))
+ >>> # pool of cubic window and target output size being half of input size
+ >>> F.fractional_max_pool3d(input, 3, output_ratio=(0.5, 0.5, 0.5))
+
+ .. _Fractional MaxPooling:
+ http://arxiv.org/abs/1412.6071
+ """
+ if output_size is None and output_ratio is None:
+ raise ValueError("fractional_max_pool3d requires specifying either "
+ "an output_size or an output_ratio")
+ if output_size is None:
+ _output_ratio = _triple(torch.jit._unwrap_optional(output_ratio))
+ output_size = [int(input.size(2) * _output_ratio[0]),
+ int(input.size(3) * _output_ratio[1]),
+ int(input.size(4) * _output_ratio[2])]
+
+ if _random_samples is None:
+ _random_samples = torch.rand(input.size(0), input.size(1), 3, dtype=input.dtype, device=input.device)
+ return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples)
+
+
+@weak_script
+def _fractional_max_pool3d(input, kernel_size, output_size=None,
+ output_ratio=None, return_indices=False,
+ _random_samples=None):
+ # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], Optional[BroadcastingList3[float]], bool, Optional[Tensor]) -> Tensor # noqa
+ return fractional_max_pool3d_with_indices(input, kernel_size, output_size,
+ output_ratio, return_indices,
+ _random_samples)[0]
+
+fractional_max_pool3d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=4,
+ default=False,
+ if_true=fractional_max_pool3d_with_indices,
+ if_false=_fractional_max_pool3d,
+ module_name=__name__,
+ func_name='fractional_max_pool3d')
+
+
+@weak_script
+def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
+ dilation=1, ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
+ r"""Applies a 1D max pooling over an input signal composed of several input
+ planes.
+
+ See :class:`~torch.nn.MaxPool1d` for details.
+ """
+ if stride is None:
+ stride = torch.jit.annotate(List[int], [])
+ return torch.max_pool1d_with_indices(
+ input, kernel_size, stride, padding, dilation, ceil_mode)
+
+
+@weak_script
+def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
+ ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor # noqa
+ if stride is None:
+ stride = torch.jit.annotate(List[int], [])
+ return torch.max_pool1d(
+ input, kernel_size, stride, padding, dilation, ceil_mode)
+
+max_pool1d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=6,
+ default=False,
+ if_true=max_pool1d_with_indices,
+ if_false=_max_pool1d,
+ module_name=__name__,
+ func_name='max_pool1d')
+
+
+@weak_script
+def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1,
+ ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
+ r"""Applies a 2D max pooling over an input signal composed of several input
+ planes.
+
+ See :class:`~torch.nn.MaxPool2d` for details.
+ """
+ if stride is None:
+ stride = torch.jit.annotate(List[int], [])
+ return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
+
+
+@weak_script
+def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
+ ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor # noqa
+ if stride is None:
+ stride = torch.jit.annotate(List[int], [])
+ return torch.max_pool2d(
+ input, kernel_size, stride, padding, dilation, ceil_mode)
+
+max_pool2d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=6,
+ default=False,
+ if_true=max_pool2d_with_indices,
+ if_false=_max_pool2d,
+ module_name=__name__,
+ func_name='max_pool2d')
+
+
+@weak_script
+def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
+ dilation=1, ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor] # noqa
+ r"""Applies a 3D max pooling over an input signal composed of several input
+ planes.
+
+ See :class:`~torch.nn.MaxPool3d` for details.
+ """
+ if stride is None:
+ stride = torch.jit.annotate(List[int], [])
+ return torch._C._nn.max_pool3d_with_indices(
+ input, kernel_size, stride, padding, dilation, ceil_mode)
+
+
+@weak_script
+def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
+ ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor # noqa
+ if stride is None:
+ stride = torch.jit.annotate(List[int], [])
+ return torch.max_pool3d(
+ input, kernel_size, stride, padding, dilation, ceil_mode)
+
+max_pool3d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=6,
+ default=False,
+ if_true=max_pool3d_with_indices,
+ if_false=_max_pool3d,
+ module_name=__name__,
+ func_name='max_pool3d')
+
+
+@weak_script
+def _unpool_output_size(input, kernel_size, stride, padding, output_size):
+ # type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int]
+ input_size = input.size()
+ default_size = torch.jit.annotate(List[int], [])
+ for d in range(len(kernel_size)):
+ default_size.append((input_size[d + 2] - 1) * stride[d] +
+ kernel_size[d] - 2 * padding[d])
+ if output_size is None:
+ ret = default_size
+ else:
+ if len(output_size) == len(kernel_size) + 2:
+ output_size = output_size[2:]
+ if len(output_size) != len(kernel_size):
+ raise ValueError("output_size should be a sequence containing "
+ "{} or {} elements, but it has a length of '{}'"
+ .format(len(kernel_size), len(kernel_size) + 2,
+ len(output_size)))
+ for d in range(len(kernel_size)):
+ min_size = default_size[d] - stride[d]
+ max_size = default_size[d] + stride[d]
+ if not (min_size < output_size[d] < max_size):
+ raise ValueError(
+ 'invalid output_size "{}" (dim {} must be between {} and {})'
+ .format(output_size, d, min_size, max_size))
+
+ ret = output_size
+ return ret
+
+
+[docs]@weak_script
+def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
+ output_size=None):
+ # type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa
+ r"""Computes a partial inverse of :class:`MaxPool1d`.
+
+ See :class:`~torch.nn.MaxUnpool1d` for details.
+ """
+ kernel_size = _single(kernel_size)
+ if stride is not None:
+ _stride = _single(stride)
+ else:
+ _stride = kernel_size
+ padding = _single(padding)
+ output_size = _unpool_output_size(input, kernel_size, _stride, padding,
+ output_size)
+ if isinstance(output_size, list):
+ output_size = output_size + [1]
+ else:
+ output_size = output_size + (1,)
+ return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3),
+ output_size).squeeze(3)
+
+
+[docs]@weak_script
+def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
+ output_size=None):
+ # type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa
+ r"""Computes a partial inverse of :class:`MaxPool2d`.
+
+ See :class:`~torch.nn.MaxUnpool2d` for details.
+ """
+ kernel_size = _pair(kernel_size)
+ if stride is not None:
+ _stride = _pair(stride)
+ else:
+ _stride = kernel_size
+ padding = _pair(padding)
+ output_size = _unpool_output_size(input, kernel_size, _stride, padding,
+ output_size)
+ return torch._C._nn.max_unpool2d(input, indices, output_size)
+
+
+[docs]@weak_script
+def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
+ output_size=None):
+ # type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa
+ r"""Computes a partial inverse of :class:`MaxPool3d`.
+
+ See :class:`~torch.nn.MaxUnpool3d` for details.
+ """
+ kernel_size = _triple(kernel_size)
+ if stride is not None:
+ _stride = _triple(stride)
+ else:
+ _stride = kernel_size
+ padding = _triple(padding)
+ output_size = _unpool_output_size(input, kernel_size, _stride, padding,
+ output_size)
+ return torch._C._nn.max_unpool3d(
+ input, indices, output_size, _stride, padding)
+
+
+[docs]@weak_script
+def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
+ # type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor
+ r"""Applies a 2D power-average pooling over an input signal composed of
+ several input planes. If the sum of all inputs to the power of `p` is
+ zero, the gradient is set to zero as well.
+
+ See :class:`~torch.nn.LPPool2d` for details.
+ """
+ kw, kh = utils._pair(kernel_size)
+ if stride is not None:
+ out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
+ else:
+ out = avg_pool2d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode)
+
+ return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type)
+
+
+[docs]@weak_script
+def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
+ # type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor
+ r"""Applies a 1D power-average pooling over an input signal composed of
+ several input planes. If the sum of all inputs to the power of `p` is
+ zero, the gradient is set to zero as well.
+
+ See :class:`~torch.nn.LPPool1d` for details.
+ """
+ if stride is not None:
+ out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode)
+ else:
+ out = avg_pool1d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode)
+
+ return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type)
+
+
+@weak_script
+def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
+ r"""Applies a 1D adaptive max pooling over an input signal composed of
+ several input planes.
+
+ See :class:`~torch.nn.AdaptiveMaxPool1d` for details and output shape.
+
+ Args:
+ output_size: the target output size (single integer)
+ return_indices: whether to return pooling indices. Default: ``False``
+ """
+ return torch.adaptive_max_pool1d(input, output_size)
+
+
+@weak_script
+def _adaptive_max_pool1d(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
+ return adaptive_max_pool1d_with_indices(input, output_size)[0]
+
+adaptive_max_pool1d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=2,
+ default=False,
+ if_true=adaptive_max_pool1d_with_indices,
+ if_false=_adaptive_max_pool1d,
+ module_name=__name__,
+ func_name='adaptive_max_pool1d')
+
+
+@weak_script
+def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
+ r"""Applies a 2D adaptive max pooling over an input signal composed of
+ several input planes.
+
+ See :class:`~torch.nn.AdaptiveMaxPool2d` for details and output shape.
+
+ Args:
+ output_size: the target output size (single integer or
+ double-integer tuple)
+ return_indices: whether to return pooling indices. Default: ``False``
+ """
+ output_size = _list_with_default(output_size, input.size())
+ return torch._C._nn.adaptive_max_pool2d(input, output_size)
+
+
+@weak_script
+def _adaptive_max_pool2d(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
+ return adaptive_max_pool2d_with_indices(input, output_size)[0]
+
+adaptive_max_pool2d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=2,
+ default=False,
+ if_true=adaptive_max_pool2d_with_indices,
+ if_false=_adaptive_max_pool2d,
+ module_name=__name__,
+ func_name='adaptive_max_pool2d')
+
+
+@weak_script
+def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
+ r"""Applies a 3D adaptive max pooling over an input signal composed of
+ several input planes.
+
+ See :class:`~torch.nn.AdaptiveMaxPool3d` for details and output shape.
+
+ Args:
+ output_size: the target output size (single integer or
+ triple-integer tuple)
+ return_indices: whether to return pooling indices. Default: ``False``
+ """
+ output_size = _list_with_default(output_size, input.size())
+ return torch._C._nn.adaptive_max_pool3d(input, output_size)
+
+
+@weak_script
+def _adaptive_max_pool3d(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
+ return adaptive_max_pool3d_with_indices(input, output_size)[0]
+
+adaptive_max_pool3d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=2,
+ default=False,
+ if_true=adaptive_max_pool3d_with_indices,
+ if_false=_adaptive_max_pool3d,
+ module_name=__name__,
+ func_name='adaptive_max_pool3d')
+
+
+adaptive_avg_pool1d = _add_docstr(torch.adaptive_avg_pool1d, r"""
+adaptive_avg_pool1d(input, output_size) -> Tensor
+
+Applies a 1D adaptive average pooling over an input signal composed of
+several input planes.
+
+See :class:`~torch.nn.AdaptiveAvgPool1d` for details and output shape.
+
+Args:
+ output_size: the target output size (single integer)
+""")
+
+
+[docs]@weak_script
+def adaptive_avg_pool2d(input, output_size):
+ # type: (Tensor, BroadcastingList2[int]) -> Tensor
+ r"""
+ Applies a 2D adaptive average pooling over an input signal composed of
+ several input planes.
+
+ See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape.
+
+ Args:
+ output_size: the target output size (single integer or
+ double-integer tuple)
+ """
+ _output_size = _list_with_default(output_size, input.size())
+ return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
+
+
+[docs]@weak_script
+def adaptive_avg_pool3d(input, output_size):
+ # type: (Tensor, BroadcastingList3[int]) -> Tensor
+ r"""
+ Applies a 3D adaptive average pooling over an input signal composed of
+ several input planes.
+
+ See :class:`~torch.nn.AdaptiveAvgPool3d` for details and output shape.
+
+ Args:
+ output_size: the target output size (single integer or
+ triple-integer tuple)
+ """
+ _output_size = _list_with_default(output_size, input.size())
+ return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
+
+
+# Activation functions
+[docs]@weak_script
+def dropout(input, p=0.5, training=True, inplace=False):
+ # type: (Tensor, float, bool, bool) -> Tensor
+ r"""
+ During training, randomly zeroes some of the elements of the input
+ tensor with probability :attr:`p` using samples from a Bernoulli
+ distribution.
+
+ See :class:`~torch.nn.Dropout` for details.
+
+ Args:
+ p: probability of an element to be zeroed. Default: 0.5
+ training: apply dropout if is ``True``. Default: ``True``
+ inplace: If set to ``True``, will do this operation in-place. Default: ``False``
+ """
+ if p < 0. or p > 1.:
+ raise ValueError("dropout probability has to be between 0 and 1, "
+ "but got {}".format(p))
+ return (_VF.dropout_(input, p, training)
+ if inplace
+ else _VF.dropout(input, p, training))
+
+
+[docs]@weak_script
+def alpha_dropout(input, p=0.5, training=False, inplace=False):
+ # type: (Tensor, float, bool, bool) -> Tensor
+ r"""Applies alpha dropout to the input.
+
+ See :class:`~torch.nn.AlphaDropout` for details.
+ """
+ if p < 0. or p > 1.:
+ raise ValueError("dropout probability has to be between 0 and 1, "
+ "but got {}".format(p))
+ return (_VF.alpha_dropout_(input, p, training)
+ if inplace
+ else _VF.alpha_dropout(input, p, training))
+
+
+[docs]@weak_script
+def dropout2d(input, p=0.5, training=True, inplace=False):
+ # type: (Tensor, float, bool, bool) -> Tensor
+ r"""
+ Randomly zero out entire channels (a channel is a 2D feature map,
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
+ batched input is a 2D tensor :math:`\text{input}[i, j]`) of the input tensor).
+ Each channel will be zeroed out independently on every forward call with
+ probability :attr:`p` using samples from a Bernoulli distribution.
+
+ See :class:`~torch.nn.Dropout2d` for details.
+
+ Args:
+ p: probability of a channel to be zeroed. Default: 0.5
+ training: apply dropout if is ``True``. Default: ``True``
+ inplace: If set to ``True``, will do this operation in-place. Default: ``False``
+ """
+ if p < 0. or p > 1.:
+ raise ValueError("dropout probability has to be between 0 and 1, "
+ "but got {}".format(p))
+ return (_VF.feature_dropout_(input, p, training)
+ if inplace
+ else _VF.feature_dropout(input, p, training))
+
+
+[docs]@weak_script
+def dropout3d(input, p=0.5, training=True, inplace=False):
+ # type: (Tensor, float, bool, bool) -> Tensor
+ r"""
+ Randomly zero out entire channels (a channel is a 3D feature map,
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
+ batched input is a 3D tensor :math:`\text{input}[i, j]`) of the input tensor).
+ Each channel will be zeroed out independently on every forward call with
+ probability :attr:`p` using samples from a Bernoulli distribution.
+
+ See :class:`~torch.nn.Dropout3d` for details.
+
+ Args:
+ p: probability of a channel to be zeroed. Default: 0.5
+ training: apply dropout if is ``True``. Default: ``True``
+ inplace: If set to ``True``, will do this operation in-place. Default: ``False``
+ """
+ # This is 100% the same code as dropout2d. We duplicate this code so that
+ # stack traces are not confusing.
+ if p < 0. or p > 1.:
+ raise ValueError("dropout probability has to be between 0 and 1, "
+ "but got {}".format(p))
+ return (_VF.feature_dropout_(input, p, training)
+ if inplace
+ else _VF.feature_dropout(input, p, training))
+
+
+@weak_script
+def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
+ # type: (Tensor, float, bool, bool) -> Tensor
+ if p < 0. or p > 1.:
+ raise ValueError("dropout probability has to be between 0 and 1, "
+ "but got {}".format(p))
+ return (_VF.feature_alpha_dropout_(input, p, training)
+ if inplace
+ else _VF.feature_alpha_dropout(input, p, training))
+
+
+[docs]@weak_script
+def threshold(input, threshold, value, inplace=False):
+ # type: (Tensor, float, float, bool) -> Tensor
+ r"""Thresholds each element of the input Tensor.
+
+ See :class:`~torch.nn.Threshold` for more details.
+ """
+ if inplace:
+ result = _VF.threshold_(input, threshold, value)
+ else:
+ result = _VF.threshold(input, threshold, value)
+ return result
+
+
+threshold_ = _add_docstr(_VF.threshold_, r"""
+threshold_(input, threshold, value) -> Tensor
+
+In-place version of :func:`~threshold`.
+""")
+
+
+[docs]@weak_script
+def relu(input, inplace=False):
+ # type: (Tensor, bool) -> Tensor
+ r"""relu(input, inplace=False) -> Tensor
+
+ Applies the rectified linear unit function element-wise. See
+ :class:`~torch.nn.ReLU` for more details.
+ """
+ if inplace:
+ result = torch.relu_(input)
+ else:
+ result = torch.relu(input)
+ return result
+
+
+relu_ = _add_docstr(torch.relu_, r"""
+relu_(input) -> Tensor
+
+In-place version of :func:`~relu`.
+""")
+
+
+[docs]@weak_script
+def glu(input, dim=-1):
+ # type: (Tensor, int) -> Tensor
+ r"""
+ glu(input, dim=-1) -> Tensor
+
+ The gated linear unit. Computes:
+
+ .. math ::
+ \text{GLU}(a, b) = a \otimes \sigma(b)
+
+ where `input` is split in half along `dim` to form `a` and `b`, :math:`\sigma`
+ is the sigmoid function and :math:`\otimes` is the element-wise product between matrices.
+
+ See `Language Modeling with Gated Convolutional Networks <https://arxiv.org/abs/1612.08083>`_.
+
+ Args:
+ input (Tensor): input tensor
+ dim (int): dimension on which to split the input. Default: -1
+ """
+ if input.dim() == 0:
+ raise RuntimeError("glu does not suppport scalars because halving size must be even")
+ return torch._C._nn.glu(input, dim)
+
+
+[docs]@weak_script
+def hardtanh(input, min_val=-1., max_val=1., inplace=False):
+ # type: (Tensor, float, float, bool) -> Tensor
+ r"""
+ hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor
+
+ Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more
+ details.
+ """
+ if inplace:
+ result = torch._C._nn.hardtanh_(input, min_val, max_val)
+ else:
+ result = torch._C._nn.hardtanh(input, min_val, max_val)
+ return result
+
+
+hardtanh_ = _add_docstr(torch._C._nn.hardtanh_, r"""
+hardtanh_(input, min_val=-1., max_val=1.) -> Tensor
+
+In-place version of :func:`~hardtanh`.
+""")
+
+
+[docs]@weak_script
+def relu6(input, inplace=False):
+ # type: (Tensor, bool) -> Tensor
+ r"""relu6(input, inplace=False) -> Tensor
+
+ Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`.
+
+ See :class:`~torch.nn.ReLU6` for more details.
+ """
+ return hardtanh(input, 0., 6., inplace)
+
+
+[docs]@weak_script
+def elu(input, alpha=1., inplace=False):
+ # type: (Tensor, float, bool) -> Tensor
+ r"""Applies element-wise,
+ :math:`\text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))`.
+
+ See :class:`~torch.nn.ELU` for more details.
+ """
+ if inplace:
+ result = torch._C._nn.elu_(input, alpha)
+ else:
+ result = torch._C._nn.elu(input, alpha)
+ return result
+
+
+elu_ = _add_docstr(torch._C._nn.elu_, r"""
+elu_(input, alpha=1.) -> Tensor
+
+In-place version of :func:`~elu`.
+""")
+
+
+[docs]@weak_script
+def selu(input, inplace=False):
+ # type: (Tensor, bool) -> Tensor
+ r"""selu(input, inplace=False) -> Tensor
+
+ Applies element-wise,
+ :math:`\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))`,
+ with :math:`\alpha=1.6732632423543772848170429916717` and
+ :math:`scale=1.0507009873554804934193349852946`.
+
+ See :class:`~torch.nn.SELU` for more details.
+ """
+ if inplace:
+ result = torch.selu_(input)
+ else:
+ result = torch.selu(input)
+ return result
+
+
+selu_ = _add_docstr(torch.selu_, r"""
+selu_(input) -> Tensor
+
+In-place version of :func:`~selu`.
+""")
+
+
+[docs]@weak_script
+def celu(input, alpha=1., inplace=False):
+ # type: (Tensor, float, bool) -> Tensor
+ r"""celu(input, alpha=1., inplace=False) -> Tensor
+
+ Applies element-wise,
+ :math:`\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))`.
+
+ See :class:`~torch.nn.CELU` for more details.
+ """
+ if inplace:
+ result = torch.celu_(input, alpha)
+ else:
+ result = torch.celu(input, alpha)
+ return result
+
+celu_ = _add_docstr(torch.celu_, r"""
+celu_(input, alpha=1.) -> Tensor
+
+In-place version of :func:`~celu`.
+""")
+
+
+[docs]@weak_script
+def leaky_relu(input, negative_slope=0.01, inplace=False):
+ # type: (Tensor, float, bool) -> Tensor
+ r"""
+ leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor
+
+ Applies element-wise,
+ :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)`
+
+ See :class:`~torch.nn.LeakyReLU` for more details.
+ """
+ if inplace:
+ result = torch._C._nn.leaky_relu_(input, negative_slope)
+ else:
+ result = torch._C._nn.leaky_relu(input, negative_slope)
+ return result
+
+
+leaky_relu_ = _add_docstr(torch._C._nn.leaky_relu_, r"""
+leaky_relu_(input, negative_slope=0.01) -> Tensor
+
+In-place version of :func:`~leaky_relu`.
+""")
+
+
+[docs]@weak_script
+def prelu(input, weight):
+ # type: (Tensor, Tensor) -> Tensor
+ r"""prelu(input, weight) -> Tensor
+
+ Applies element-wise the function
+ :math:`\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x)` where weight is a
+ learnable parameter.
+
+ See :class:`~torch.nn.PReLU` for more details.
+ """
+ return torch.prelu(input, weight)
+
+
+[docs]@weak_script
+def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False):
+ # type: (Tensor, float, float, bool, bool) -> Tensor
+ r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor
+
+ Randomized leaky ReLU.
+
+ See :class:`~torch.nn.RReLU` for more details.
+ """
+ if inplace:
+ result = torch.rrelu_(input, lower, upper, training)
+ else:
+ result = torch.rrelu(input, lower, upper, training)
+ return result
+
+
+rrelu_ = _add_docstr(torch.rrelu_, r"""
+rrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor
+
+In-place version of :func:`~rrelu`.
+""")
+
+logsigmoid = _add_docstr(torch._C._nn.log_sigmoid, r"""
+logsigmoid(input) -> Tensor
+
+Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)`
+
+See :class:`~torch.nn.LogSigmoid` for more details.
+""")
+
+
+[docs]@weak_script
+def hardshrink(input, lambd=0.5):
+ # type: (Tensor, float) -> Tensor
+ r"""
+ hardshrink(input, lambd=0.5) -> Tensor
+
+ Applies the hard shrinkage function element-wise
+
+ See :class:`~torch.nn.Hardshrink` for more details.
+ """
+ return torch.hardshrink(input, lambd)
+
+
+[docs]@weak_script
+def tanhshrink(input):
+ r"""tanhshrink(input) -> Tensor
+
+ Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)`
+
+ See :class:`~torch.nn.Tanhshrink` for more details.
+ """
+ return input - input.tanh()
+
+
+[docs]@weak_script
+def softsign(input):
+ r"""softsign(input) -> Tensor
+
+ Applies element-wise, the function :math:`\text{SoftSign}(x) = \frac{x}{1 + |x|}`
+
+ See :class:`~torch.nn.Softsign` for more details.
+ """
+ return input / (input.abs() + 1)
+
+
+softplus = _add_docstr(torch._C._nn.softplus, r"""
+softplus(input, beta=1, threshold=20) -> Tensor
+""")
+
+
+@weak_script
+def _get_softmax_dim(name, ndim, stacklevel):
+ # type: (str, int, int) -> int
+ warnings.warn("Implicit dimension choice for {} has been deprecated. "
+ "Change the call to include dim=X as an argument.".format(name), stacklevel=stacklevel)
+ if ndim == 0 or ndim == 1 or ndim == 3:
+ ret = 0
+ else:
+ ret = 1
+ return ret
+
+
+[docs]@weak_script
+def softmin(input, dim=None, _stacklevel=3, dtype=None):
+ # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
+ r"""Applies a softmin function.
+
+ Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula.
+
+ See :class:`~torch.nn.Softmin` for more details.
+
+ Arguments:
+ input (Tensor): input
+ dim (int): A dimension along which softmin will be computed (so every slice
+ along dim will sum to 1).
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+ If specified, the input tensor is casted to :attr:`dtype` before the operation
+ is performed. This is useful for preventing data type overflows. Default: None.
+ """
+ if dim is None:
+ dim = _get_softmax_dim('softmin', input.dim(), _stacklevel)
+ if dtype is None:
+ ret = (-input).softmax(dim)
+ else:
+ ret = (-input).softmax(dim, dtype=dtype)
+ return ret
+
+
+[docs]@weak_script
+def softmax(input, dim=None, _stacklevel=3, dtype=None):
+ # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
+ r"""Applies a softmax function.
+
+ Softmax is defined as:
+
+ :math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}`
+
+ It is applied to all slices along dim, and will re-scale them so that the elements
+ lie in the range `[0, 1]` and sum to 1.
+
+ See :class:`~torch.nn.Softmax` for more details.
+
+ Arguments:
+ input (Tensor): input
+ dim (int): A dimension along which softmax will be computed.
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+ If specified, the input tensor is casted to :attr:`dtype` before the operation
+ is performed. This is useful for preventing data type overflows. Default: None.
+
+ .. note::
+ This function doesn't work directly with NLLLoss,
+ which expects the Log to be computed between the Softmax and itself.
+ Use log_softmax instead (it's faster and has better numerical properties).
+
+ """
+ if dim is None:
+ dim = _get_softmax_dim('softmax', input.dim(), _stacklevel)
+ if dtype is None:
+ ret = input.softmax(dim)
+ else:
+ ret = input.softmax(dim, dtype=dtype)
+ return ret
+
+
+[docs]@weak_script
+def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
+ # type: (Tensor, float, bool, float, int) -> Tensor
+ r"""
+ Samples from the `Gumbel-Softmax distribution`_ and optionally discretizes.
+
+ Args:
+ logits: `[..., num_features]` unnormalized log probabilities
+ tau: non-negative scalar temperature
+ hard: if ``True``, the returned samples will be discretized as one-hot vectors,
+ but will be differentiated as if it is the soft sample in autograd
+ dim (int): A dimension along which softmax will be computed. Default: -1.
+
+ Returns:
+ Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
+ If ``hard=True``, the returned samples will be one-hot, otherwise they will
+ be probability distributions that sum to 1 across `dim`.
+
+ .. note::
+ This function is here for legacy reasons, may be removed from nn.Functional in the future.
+
+ .. note::
+ The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft`
+
+ It achieves two things:
+ - makes the output value exactly one-hot
+ (since we add then subtract y_soft value)
+ - makes the gradient equal to y_soft gradient
+ (since we strip all other gradients)
+
+ Examples::
+ >>> logits = torch.randn(20, 32)
+ >>> # Sample soft categorical using reparametrization trick:
+ >>> F.gumbel_softmax(logits, tau=1, hard=False)
+ >>> # Sample hard categorical using "Straight-through" trick:
+ >>> F.gumbel_softmax(logits, tau=1, hard=True)
+
+ .. _Gumbel-Softmax distribution:
+ https://arxiv.org/abs/1611.00712
+ https://arxiv.org/abs/1611.01144
+ """
+
+ if eps != 1e-10:
+ warnings.warn("`eps` parameter is deprecated and has no effect.")
+
+ gumbels = -torch.empty_like(logits).exponential_().log() # ~Gumbel(0,1)
+ gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
+ y_soft = gumbels.softmax(dim)
+
+ if hard:
+ # Straight through.
+ index = y_soft.max(dim, keepdim=True)[1]
+ y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
+ ret = y_hard - y_soft.detach() + y_soft
+ else:
+ # Reparametrization trick.
+ ret = y_soft
+ return ret
+
+
+[docs]@weak_script
+def log_softmax(input, dim=None, _stacklevel=3, dtype=None):
+ # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
+ r"""Applies a softmax followed by a logarithm.
+
+ While mathematically equivalent to log(softmax(x)), doing these two
+ operations separately is slower, and numerically unstable. This function
+ uses an alternative formulation to compute the output and gradient correctly.
+
+ See :class:`~torch.nn.LogSoftmax` for more details.
+
+ Arguments:
+ input (Tensor): input
+ dim (int): A dimension along which log_softmax will be computed.
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+ If specified, the input tensor is casted to :attr:`dtype` before the operation
+ is performed. This is useful for preventing data type overflows. Default: None.
+ """
+ if dim is None:
+ dim = _get_softmax_dim('log_softmax', input.dim(), _stacklevel)
+ if dtype is None:
+ ret = input.log_softmax(dim)
+ else:
+ ret = input.log_softmax(dim, dtype=dtype)
+ return ret
+
+
+softshrink = _add_docstr(torch._C._nn.softshrink, r"""
+softshrink(input, lambd=0.5) -> Tensor
+
+Applies the soft shrinkage function elementwise
+
+See :class:`~torch.nn.Softshrink` for more details.
+""")
+
+
+[docs]@weak_script
+def tanh(input):
+ r"""tanh(input) -> Tensor
+
+ Applies element-wise,
+ :math:`\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}`
+
+ See :class:`~torch.nn.Tanh` for more details.
+ """
+ warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
+ return input.tanh()
+
+
+[docs]@weak_script
+def sigmoid(input):
+ r"""sigmoid(input) -> Tensor
+
+ Applies the element-wise function :math:`\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}`
+
+ See :class:`~torch.nn.Sigmoid` for more details.
+ """
+ warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
+ return input.sigmoid()
+
+
+[docs]@weak_script
+def linear(input, weight, bias=None):
+ # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
+ r"""
+ Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
+
+ Shape:
+
+ - Input: :math:`(N, *, in\_features)` where `*` means any number of
+ additional dimensions
+ - Weight: :math:`(out\_features, in\_features)`
+ - Bias: :math:`(out\_features)`
+ - Output: :math:`(N, *, out\_features)`
+ """
+ if input.dim() == 2 and bias is not None:
+ # fused op is marginally faster
+ ret = torch.addmm(bias, input, weight.t())
+ else:
+ output = input.matmul(weight.t())
+ if bias is not None:
+ output += bias
+ ret = output
+ return ret
+
+
+[docs]@weak_script
+def bilinear(input1, input2, weight, bias=None):
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor
+ return torch.bilinear(input1, input2, weight, bias)
+
+
+def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type):
+ # type: (Tensor, Tensor, float, float) -> Tensor
+ with torch.no_grad():
+ torch.embedding_renorm_(weight, input, max_norm, norm_type)
+
+
+[docs]@weak_script
+def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
+ scale_grad_by_freq=False, sparse=False):
+ # type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor
+ r"""A simple lookup table that looks up embeddings in a fixed dictionary and size.
+
+ This module is often used to retrieve word embeddings using indices.
+ The input to the module is a list of indices, and the embedding matrix,
+ and the output is the corresponding word embeddings.
+
+ See :class:`torch.nn.Embedding` for more details.
+
+ Args:
+ input (LongTensor): Tensor containing indices into the embedding matrix
+ weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,
+ and number of columns equal to the embedding size
+ padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
+ (initialized to zeros) whenever it encounters the index.
+ max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
+ is renormalized to have norm :attr:`max_norm`.
+ Note: this will modify :attr:`weight` in-place.
+ norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
+ scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
+ the words in the mini-batch. Default ``False``.
+ sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
+ :class:`torch.nn.Embedding` for more details regarding sparse gradients.
+
+ Shape:
+ - Input: LongTensor of arbitrary shape containing the indices to extract
+ - Weight: Embedding matrix of floating point type with shape `(V, embedding_dim)`,
+ where V = maximum index + 1 and embedding_dim = the embedding size
+ - Output: `(*, embedding_dim)`, where `*` is the input shape
+
+ Examples::
+
+ >>> # a batch of 2 samples of 4 indices each
+ >>> input = torch.tensor([[1,2,4,5],[4,3,2,9]])
+ >>> # an embedding matrix containing 10 tensors of size 3
+ >>> embedding_matrix = torch.rand(10, 3)
+ >>> F.embedding(input, embedding_matrix)
+ tensor([[[ 0.8490, 0.9625, 0.6753],
+ [ 0.9666, 0.7761, 0.6108],
+ [ 0.6246, 0.9751, 0.3618],
+ [ 0.4161, 0.2419, 0.7383]],
+
+ [[ 0.6246, 0.9751, 0.3618],
+ [ 0.0237, 0.7794, 0.0528],
+ [ 0.9666, 0.7761, 0.6108],
+ [ 0.3385, 0.8612, 0.1867]]])
+
+ >>> # example with padding_idx
+ >>> weights = torch.rand(10, 3)
+ >>> weights[0, :].zero_()
+ >>> embedding_matrix = weights
+ >>> input = torch.tensor([[0,2,0,5]])
+ >>> F.embedding(input, embedding_matrix, padding_idx=0)
+ tensor([[[ 0.0000, 0.0000, 0.0000],
+ [ 0.5609, 0.5384, 0.8720],
+ [ 0.0000, 0.0000, 0.0000],
+ [ 0.6262, 0.2438, 0.7471]]])
+ """
+ if padding_idx is not None:
+ if padding_idx > 0:
+ assert padding_idx < weight.size(0), 'Padding_idx must be within num_embeddings'
+ elif padding_idx < 0:
+ assert padding_idx >= -weight.size(0), 'Padding_idx must be within num_embeddings'
+ padding_idx = weight.size(0) + padding_idx
+ else:
+ padding_idx = -1
+ if max_norm is not None:
+ # `embedding_renorm_` will call .contiguous() on input anyways, so we
+ # call it here and take advantage of the improved locality in the
+ # `embedding` call below too.
+ input = input.contiguous()
+ # XXX: equivalent to
+ # with torch.no_grad():
+ # torch.nembedding_renorm_
+ # remove once script supports set_grad_enabled
+ _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
+ return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
+
+
+[docs]@weak_script
+def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
+ scale_grad_by_freq=False, mode='mean', sparse=False,
+ per_sample_weights=None):
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool, Optional[Tensor]) -> Tensor
+ r"""Computes sums, means or maxes of `bags` of embeddings, without instantiating the
+ intermediate embeddings.
+
+ See :class:`torch.nn.EmbeddingBag` for more details.
+
+ .. include:: cuda_deterministic_backward.rst
+
+ Args:
+ input (LongTensor): Tensor containing bags of indices into the embedding matrix
+ weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1,
+ and number of columns equal to the embedding size
+ offsets (LongTensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines
+ the starting index position of each bag (sequence) in :attr:`input`.
+ max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
+ is renormalized to have norm :attr:`max_norm`.
+ Note: this will modify :attr:`weight` in-place.
+ norm_type (float, optional): The ``p`` in the ``p``-norm to compute for the :attr:`max_norm` option.
+ Default ``2``.
+ scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of
+ the words in the mini-batch. Default ``False``.
+ Note: this option is not supported when ``mode="max"``.
+ mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
+ Default: ``"mean"``
+ sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under
+ :class:`torch.nn.Embedding` for more details regarding sparse gradients.
+ Note: this option is not supported when ``mode="max"``.
+ per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
+ to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights`
+ must have exactly the same shape as input and is treated as having the same
+ :attr:`offsets`, if those are not None.
+
+
+ Shape:
+
+ - :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional)
+
+ - If :attr:`input` is 2D of shape `(B, N)`,
+
+ it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and
+ this will return ``B`` values aggregated in a way depending on the :attr:`mode`.
+ :attr:`offsets` is ignored and required to be ``None`` in this case.
+
+ - If :attr:`input` is 1D of shape `(N)`,
+
+ it will be treated as a concatenation of multiple bags (sequences).
+ :attr:`offsets` is required to be a 1D tensor containing the
+ starting index positions of each bag in :attr:`input`. Therefore,
+ for :attr:`offsets` of shape `(B)`, :attr:`input` will be viewed as
+ having ``B`` bags. Empty bags (i.e., having 0-length) will have
+ returned vectors filled by zeros.
+
+ - :attr:`weight` (Tensor): the learnable weights of the module of
+ shape `(num_embeddings, embedding_dim)`
+
+ - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as
+ :attr:`input`.
+
+ - :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)`
+
+ Examples::
+
+ >>> # an Embedding module containing 10 tensors of size 3
+ >>> embedding_matrix = torch.rand(10, 3)
+ >>> # a batch of 2 samples of 4 indices each
+ >>> input = torch.tensor([1,2,4,5,4,3,2,9])
+ >>> offsets = torch.tensor([0,4])
+ >>> F.embedding_bag(embedding_matrix, input, offsets)
+ tensor([[ 0.3397, 0.3552, 0.5545],
+ [ 0.5893, 0.4386, 0.5882]])
+ """
+ # Check for backward compatibility.
+ # Used to be embedding_bag(weight, input, ...)
+ # Now is embedding_bag(input, weight, ...)
+ if weight.dtype == torch.long and input.is_floating_point():
+ warnings.warn("Argument order of nn.functional.embedding_bag was changed. "
+ "Usage `embedding_bag(weight, input, ...)` is deprecated, "
+ "and should now be `embedding_bag(input, weight, ...)`.")
+ weight, input = input, weight
+
+ if per_sample_weights is not None and input.size() != per_sample_weights.size():
+ raise ValueError("embedding_bag: If per_sample_weights ({}) is not None, "
+ "then it must have the same shape as the input ({})"
+ .format(per_sample_weights.shape, input.shape))
+
+ if input.dim() == 2:
+ if offsets is not None:
+ raise ValueError("if input is 2D, then offsets has to be None"
+ ", as input is treated is a mini-batch of"
+ " fixed length sequences. However, found "
+ "offsets of type {}".format(type(offsets)))
+ offsets = torch.arange(0, input.numel(), input.size(1),
+ dtype=torch.long, device=input.device)
+
+ input = input.reshape(-1)
+ if per_sample_weights is not None:
+ per_sample_weights = per_sample_weights.reshape(-1)
+ elif input.dim() == 1:
+ if offsets is None:
+ raise ValueError("offsets has to be a 1D Tensor but got None")
+ offsets = torch.jit._unwrap_optional(offsets)
+ if offsets.dim() != 1:
+ raise ValueError("offsets has to be a 1D Tensor")
+ if int(offsets[0]) != 0:
+ raise ValueError("offsets[0] has to be 0, i.e., the first sequence "
+ "in the mini-batch has to start from position 0. "
+ "However, got {}".format(offsets[0].item()))
+ if int(offsets[-1]) > input.size(0):
+ raise ValueError("offsets[-1] can not be greater than input's length"
+ " ({}), but got offsets[-1] of {}"
+ .format(input.size(0), offsets[-1].item()))
+ else:
+ raise ValueError("input has to be 1D or 2D Tensor,"
+ " but got Tensor of dimension {}".format(input.dim()))
+ offsets = torch.jit._unwrap_optional(offsets) # TODO remove when exception control flow logic
+ if mode == 'sum':
+ mode_enum = 0
+ elif mode == 'mean':
+ mode_enum = 1
+ elif mode == 'max':
+ mode_enum = 2
+
+ if scale_grad_by_freq:
+ raise ValueError("max mode does not support scaling the gradient by the frequency")
+
+ if sparse:
+ raise ValueError("max mode does not support sparse weights")
+
+ else:
+ mode_enum = -1 # TODO when exception control flow logic
+ raise ValueError("mode has to be one of sum, mean or max")
+
+ if max_norm is not None:
+ # XXX: equivalent to
+ # with torch.no_grad():
+ # torch.nembedding_renorm_
+ # remove once script supports set_grad_enabled
+ _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
+
+ if per_sample_weights is not None and mode != 'sum':
+ raise NotImplementedError("embedding_bag: per_sample_weights was not None. "
+ "per_sample_weights is only supported for mode='sum' "
+ "(got mode='{}'). Please open a feature request on GitHub."
+ .format(mode))
+
+ ret, _, _, _ = torch.embedding_bag(
+ weight,
+ input,
+ offsets,
+ scale_grad_by_freq,
+ mode_enum,
+ sparse,
+ per_sample_weights)
+ return ret
+
+
+[docs]@weak_script
+def batch_norm(input, running_mean, running_var, weight=None, bias=None,
+ training=False, momentum=0.1, eps=1e-5):
+ # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
+ r"""Applies Batch Normalization for each channel across a batch of data.
+
+ See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
+ :class:`~torch.nn.BatchNorm3d` for details.
+ """
+ if training:
+ size = input.size()
+ # XXX: JIT script does not support the reduce from functools, and mul op is a
+ # builtin, which cannot be used as a value to a func yet, so rewrite this size
+ # check to a simple equivalent for loop
+ #
+ # TODO: make use of reduce like below when JIT is ready with the missing features:
+ # from operator import mul
+ # from functools import reduce
+ #
+ # if reduce(mul, size[2:], size[0]) == 1
+ size_prods = size[0]
+ for i in range(len(size) - 2):
+ size_prods *= size[i + 2]
+ if size_prods == 1:
+ raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
+
+ return torch.batch_norm(
+ input, weight, bias, running_mean, running_var,
+ training, momentum, eps, torch.backends.cudnn.enabled
+ )
+
+
+[docs]@weak_script
+def instance_norm(input, running_mean=None, running_var=None, weight=None,
+ bias=None, use_input_stats=True, momentum=0.1, eps=1e-5):
+ # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor # noqa
+ r"""Applies Instance Normalization for each channel in each data sample in a
+ batch.
+
+ See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`,
+ :class:`~torch.nn.InstanceNorm3d` for details.
+ """
+ return torch.instance_norm(
+ input, weight, bias, running_mean, running_var,
+ use_input_stats, momentum, eps, torch.backends.cudnn.enabled
+ )
+
+
+[docs]@weak_script
+def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
+ # type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor
+ r"""Applies Layer Normalization for last certain number of dimensions.
+
+ See :class:`~torch.nn.LayerNorm` for details.
+ """
+ return torch.layer_norm(input, normalized_shape, weight, bias, eps,
+ torch.backends.cudnn.enabled)
+
+
+@weak_script
+def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
+ # type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor
+ r"""Applies Group Normalization for last certain number of dimensions.
+
+ See :class:`~torch.nn.GroupNorm` for details.
+ """
+ return torch.group_norm(input, num_groups, weight, bias, eps,
+ torch.backends.cudnn.enabled)
+
+
+[docs]@weak_script
+def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
+ # type: (Tensor, int, float, float, float) -> Tensor
+ r"""Applies local response normalization over an input signal composed of
+ several input planes, where channels occupy the second dimension.
+ Applies normalization across channels.
+
+ See :class:`~torch.nn.LocalResponseNorm` for details.
+ """
+ dim = input.dim()
+ if dim < 3:
+ raise ValueError('Expected 3D or higher dimensionality \
+ input (got {} dimensions)'.format(dim))
+ div = input.mul(input).unsqueeze(1)
+ if dim == 3:
+ div = pad(div, (0, 0, size // 2, (size - 1) // 2))
+ div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
+ else:
+ sizes = input.size()
+ div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
+ div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
+ div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
+ div = div.view(sizes)
+ div = div.mul(alpha).add(k).pow(beta)
+ return input / div
+
+
+# loss
+
+[docs]@weak_script
+def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
+ reduction='mean', zero_infinity=False):
+ # type: (Tensor, Tensor, Tensor, Tensor, int, str, bool) -> Tensor
+ r"""The Connectionist Temporal Classification loss.
+
+ See :class:`~torch.nn.CTCLoss` for details.
+
+ .. include:: cudnn_deterministic.rst
+ .. include:: cuda_deterministic_backward.rst
+
+ Args:
+ log_probs: :math:`(T, N, C)` where `C = number of characters in alphabet including blank`,
+ `T = input length`, and `N = batch size`.
+ The logarithmized probabilities of the outputs
+ (e.g. obtained with :func:`torch.nn.functional.log_softmax`).
+ targets: :math:`(N, S)` or `(sum(target_lengths))`.
+ Targets cannot be blank. In the second form, the targets are assumed to be concatenated.
+ input_lengths: :math:`(N)`.
+ Lengths of the inputs (must each be :math:`\leq T`)
+ target_lengths: :math:`(N)`.
+ Lengths of the targets
+ blank (int, optional):
+ Blank label. Default :math:`0`.
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the output losses will be divided by the target lengths and
+ then the mean over the batch is taken, ``'sum'``: the output will be
+ summed. Default: ``'mean'``
+ zero_infinity (bool, optional):
+ Whether to zero infinite losses and the associated gradients.
+ Default: ``False``
+ Infinite losses mainly occur when the inputs are too short
+ to be aligned to the targets.
+
+ Example::
+
+ >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
+ >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
+ >>> input_lengths = torch.full((16,), 50, dtype=torch.long)
+ >>> target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
+ >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
+ >>> loss.backward()
+ """
+ return torch.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction),
+ zero_infinity)
+
+
+[docs]@weak_script
+def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
+ reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
+ r"""The negative log likelihood loss.
+
+ See :class:`~torch.nn.NLLLoss` for details.
+
+ Args:
+ input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
+ in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \geq 1`
+ in the case of K-dimensional loss.
+ target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`,
+ or :math:`(N, d_1, d_2, ..., d_K)` where :math:`K \geq 1` for
+ K-dimensional loss.
+ weight (Tensor, optional): a manual rescaling weight given to each
+ class. If given, has to be a Tensor of size `C`
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ ignore_index (int, optional): Specifies a target value that is ignored
+ and does not contribute to the input gradient. When :attr:`size_average` is
+ ``True``, the loss is averaged over non-ignored targets. Default: -100
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Example::
+
+ >>> # input is of size N x C = 3 x 5
+ >>> input = torch.randn(3, 5, requires_grad=True)
+ >>> # each element in target has to have 0 <= value < C
+ >>> target = torch.tensor([1, 0, 4])
+ >>> output = F.nll_loss(F.log_softmax(input), target)
+ >>> output.backward()
+ """
+ if size_average is not None or reduce is not None:
+ reduction = _Reduction.legacy_get_string(size_average, reduce)
+ dim = input.dim()
+ if dim < 2:
+ raise ValueError('Expected 2 or more dimensions (got {})'.format(dim))
+
+ if input.size(0) != target.size(0):
+ raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
+ .format(input.size(0), target.size(0)))
+ if dim == 2:
+ ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
+ elif dim == 4:
+ ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
+ else:
+ # dim == 3 or dim > 4
+ n = input.size(0)
+ c = input.size(1)
+ out_size = (n,) + input.size()[2:]
+ if target.size()[1:] != input.size()[2:]:
+ raise ValueError('Expected target size {}, got {}'.format(
+ out_size, target.size()))
+ input = input.contiguous().view(n, c, 1, -1)
+ target = target.contiguous().view(n, 1, -1)
+ reduction_enum = _Reduction.get_enum(reduction)
+ if reduction != 'none':
+ ret = torch._C._nn.nll_loss2d(
+ input, target, weight, reduction_enum, ignore_index)
+ else:
+ out = torch._C._nn.nll_loss2d(
+ input, target, weight, reduction_enum, ignore_index)
+ ret = out.view(out_size)
+ return ret
+
+
+[docs]@weak_script
+def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8,
+ reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor
+ r"""Poisson negative log likelihood loss.
+
+ See :class:`~torch.nn.PoissonNLLLoss` for details.
+
+ Args:
+ input: expectation of underlying Poisson distribution.
+ target: random sample :math:`target \sim \text{Poisson}(input)`.
+ log_input: if ``True`` the loss is computed as
+ :math:`\exp(\text{input}) - \text{target} * \text{input}`, if ``False`` then loss is
+ :math:`\text{input} - \text{target} * \log(\text{input}+\text{eps})`. Default: ``True``
+ full: whether to compute full loss, i. e. to add the Stirling
+ approximation term. Default: ``False``
+ :math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`.
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when
+ :attr:`log_input`=``False``. Default: 1e-8
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ """
+ if size_average is not None or reduce is not None:
+ reduction = _Reduction.legacy_get_string(size_average, reduce)
+ if log_input:
+ loss = torch.exp(input) - target * input
+ else:
+ loss = input - target * torch.log(input + eps)
+ if full:
+ mask = target > 1
+ loss[mask] += (target * torch.log(target) - target + 0.5 * torch.log(2 * math.pi * target))[mask]
+ if reduction == 'none':
+ ret = loss
+ elif reduction == 'mean':
+ ret = torch.mean(loss)
+ elif reduction == 'sum':
+ ret = torch.sum(loss)
+ else:
+ ret = input
+ raise ValueError(reduction + " is not valid")
+ return ret
+
+
+[docs]@weak_script
+def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
+ r"""The `Kullback-Leibler divergence`_ Loss.
+
+ See :class:`~torch.nn.KLDivLoss` for details.
+
+ Args:
+ input: Tensor of arbitrary shape
+ target: Tensor of the same shape as input
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
+ ``'none'``: no reduction will be applied
+ ``'batchmean'``: the sum of the output will be divided by the batchsize
+ ``'sum'``: the output will be summed
+ ``'mean'``: the output will be divided by the number of elements in the output
+ Default: ``'mean'``
+
+ .. note::
+ :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
+ and in the meantime, specifying either of those two args will override :attr:`reduction`.
+
+ .. note::
+ :attr:``reduction`` = ``'mean'`` doesn't return the true kl divergence value, please use
+ :attr:``reduction`` = ``'batchmean'`` which aligns with KL math definition.
+ In the next major release, ``'mean'`` will be changed to be the same as 'batchmean'.
+ """
+ if size_average is not None or reduce is not None:
+ reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
+ else:
+ if reduction == 'mean':
+ warnings.warn("reduction: 'mean' divides the total loss by both the batch size and the support size."
+ "'batchmean' divides only by the batch size, and aligns with the KL div math definition."
+ "'mean' will be changed to behave the same as 'batchmean' in the next major release.")
+
+ # special case for batchmean
+ if reduction == 'batchmean':
+ reduction_enum = _Reduction.get_enum('sum')
+ else:
+ reduction_enum = _Reduction.get_enum(reduction)
+
+ reduced = torch.kl_div(input, target, reduction_enum)
+
+ if reduction == 'batchmean' and input.dim() != 0:
+ reduced = reduced / input.size()[0]
+
+ return reduced
+
+
+[docs]@weak_script
+def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100,
+ reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
+ r"""This criterion combines `log_softmax` and `nll_loss` in a single
+ function.
+
+ See :class:`~torch.nn.CrossEntropyLoss` for details.
+
+ Args:
+ input (Tensor) : :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
+ in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \geq 1`
+ in the case of K-dimensional loss.
+ target (Tensor) : :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`,
+ or :math:`(N, d_1, d_2, ..., d_K)` where :math:`K \geq 1` for
+ K-dimensional loss.
+ weight (Tensor, optional): a manual rescaling weight given to each
+ class. If given, has to be a Tensor of size `C`
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ ignore_index (int, optional): Specifies a target value that is ignored
+ and does not contribute to the input gradient. When :attr:`size_average` is
+ ``True``, the loss is averaged over non-ignored targets. Default: -100
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Examples::
+
+ >>> input = torch.randn(3, 5, requires_grad=True)
+ >>> target = torch.randint(5, (3,), dtype=torch.int64)
+ >>> loss = F.cross_entropy(input, target)
+ >>> loss.backward()
+ """
+ if size_average is not None or reduce is not None:
+ reduction = _Reduction.legacy_get_string(size_average, reduce)
+ return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
+
+
+[docs]@weak_script
+def binary_cross_entropy(input, target, weight=None, size_average=None,
+ reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
+ r"""Function that measures the Binary Cross Entropy
+ between the target and the output.
+
+ See :class:`~torch.nn.BCELoss` for details.
+
+ Args:
+ input: Tensor of arbitrary shape
+ target: Tensor of the same shape as input
+ weight (Tensor, optional): a manual rescaling weight
+ if provided it's repeated to match input tensor shape
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Examples::
+
+ >>> input = torch.randn((3, 2), requires_grad=True)
+ >>> target = torch.rand((3, 2), requires_grad=False)
+ >>> loss = F.binary_cross_entropy(F.sigmoid(input), target)
+ >>> loss.backward()
+ """
+ if size_average is not None or reduce is not None:
+ reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
+ else:
+ reduction_enum = _Reduction.get_enum(reduction)
+ if target.size() != input.size():
+ warnings.warn("Using a target size ({}) that is different to the input size ({}) is deprecated. "
+ "Please ensure they have the same size.".format(target.size(), input.size()),
+ stacklevel=2)
+ if input.numel() != target.numel():
+ raise ValueError("Target and input must have the same number of elements. target nelement ({}) "
+ "!= input nelement ({})".format(target.numel(), input.numel()))
+
+ if weight is not None:
+ new_size = _infer_size(target.size(), weight.size())
+ weight = weight.expand(new_size)
+
+ return torch._C._nn.binary_cross_entropy(
+ input, target, weight, reduction_enum)
+
+
+[docs]@weak_script
+def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
+ reduce=None, reduction='mean', pos_weight=None):
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor
+ r"""Function that measures Binary Cross Entropy between target and output
+ logits.
+
+ See :class:`~torch.nn.BCEWithLogitsLoss` for details.
+
+ Args:
+ input: Tensor of arbitrary shape
+ target: Tensor of the same shape as input
+ weight (Tensor, optional): a manual rescaling weight
+ if provided it's repeated to match input tensor shape
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+ pos_weight (Tensor, optional): a weight of positive examples.
+ Must be a vector with length equal to the number of classes.
+
+ Examples::
+
+ >>> input = torch.randn(3, requires_grad=True)
+ >>> target = torch.empty(3).random_(2)
+ >>> loss = F.binary_cross_entropy_with_logits(input, target)
+ >>> loss.backward()
+ """
+ if size_average is not None or reduce is not None:
+ reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
+ else:
+ reduction_enum = _Reduction.get_enum(reduction)
+
+ if not (target.size() == input.size()):
+ raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
+
+ return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
+
+
+def _pointwise_loss(lambd, lambd_optimized, input, target, reduction='mean'):
+ if target.requires_grad:
+ d = lambd(input, target)
+ if reduction == 'none':
+ return d
+ return torch.mean(d) if reduction == 'mean' else torch.sum(d)
+ else:
+ expanded_input, expanded_target = torch.broadcast_tensors(input, target)
+ return lambd_optimized(expanded_input, expanded_target, _Reduction.get_enum(reduction))
+
+
+@weak_script
+def _smooth_l1_loss(input, target):
+ # type: (Tensor, Tensor) -> Tensor
+ t = torch.abs(input - target)
+ return torch.where(t < 1, 0.5 * t ** 2, t - 0.5)
+
+
+[docs]@weak_script
+def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
+ r"""Function that uses a squared term if the absolute
+ element-wise error falls below 1 and an L1 term otherwise.
+
+ See :class:`~torch.nn.SmoothL1Loss` for details.
+ """
+ if not (target.size() == input.size()):
+ warnings.warn("Using a target size ({}) that is different to the input size ({}). "
+ "This will likely lead to incorrect results due to broadcasting. "
+ "Please ensure they have the same size.".format(target.size(), input.size()),
+ stacklevel=2)
+ if size_average is not None or reduce is not None:
+ reduction = _Reduction.legacy_get_string(size_average, reduce)
+ if target.requires_grad:
+ ret = _smooth_l1_loss(input, target)
+ if reduction != 'none':
+ ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
+ else:
+ expanded_input, expanded_target = torch.broadcast_tensors(input, target)
+ ret = torch._C._nn.smooth_l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
+ return ret
+
+
+[docs]@weak_script
+def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
+ r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
+
+ Function that takes the mean element-wise absolute value difference.
+
+ See :class:`~torch.nn.L1Loss` for details.
+ """
+ if not (target.size() == input.size()):
+ warnings.warn("Using a target size ({}) that is different to the input size ({}). "
+ "This will likely lead to incorrect results due to broadcasting. "
+ "Please ensure they have the same size.".format(target.size(), input.size()),
+ stacklevel=2)
+ if size_average is not None or reduce is not None:
+ reduction = _Reduction.legacy_get_string(size_average, reduce)
+ if target.requires_grad:
+ ret = torch.abs(input - target)
+ if reduction != 'none':
+ ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
+ else:
+ expanded_input, expanded_target = torch.broadcast_tensors(input, target)
+ ret = torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
+ return ret
+
+
+[docs]@weak_script
+def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
+ r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
+
+ Measures the element-wise mean squared error.
+
+ See :class:`~torch.nn.MSELoss` for details.
+ """
+ if not (target.size() == input.size()):
+ warnings.warn("Using a target size ({}) that is different to the input size ({}). "
+ "This will likely lead to incorrect results due to broadcasting. "
+ "Please ensure they have the same size.".format(target.size(), input.size()),
+ stacklevel=2)
+ if size_average is not None or reduce is not None:
+ reduction = _Reduction.legacy_get_string(size_average, reduce)
+ if target.requires_grad:
+ ret = (input - target) ** 2
+ if reduction != 'none':
+ ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
+ else:
+ expanded_input, expanded_target = torch.broadcast_tensors(input, target)
+ ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
+ return ret
+
+
+[docs]@weak_script
+def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
+ reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
+ r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor
+
+ See :class:`~torch.nn.MarginRankingLoss` for details.
+ """ # noqa
+ if size_average is not None or reduce is not None:
+ reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
+ else:
+ reduction_enum = _Reduction.get_enum(reduction)
+ if input1.dim() == 0 or input2.dim() == 0 or target.dim() == 0:
+ raise RuntimeError(("margin_ranking_loss does not support scalars, got sizes: "
+ "input1: {}, input2: {}, target: {} ".format(input1.size(), input2.size(), target.size())))
+ return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum)
+
+
+[docs]@weak_script
+def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
+ reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
+ r"""hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean') -> Tensor
+
+ See :class:`~torch.nn.HingeEmbeddingLoss` for details.
+ """ # noqa
+ if size_average is not None or reduce is not None:
+ reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
+ else:
+ reduction_enum = _Reduction.get_enum(reduction)
+ return torch.hinge_embedding_loss(input, target, margin, reduction_enum)
+
+
+[docs]@weak_script
+def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
+ r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
+
+ See :class:`~torch.nn.MultiLabelMarginLoss` for details.
+ """
+ if size_average is not None or reduce is not None:
+ reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
+ else:
+ reduction_enum = _Reduction.get_enum(reduction)
+ return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
+
+
+[docs]@weak_script
+def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
+ r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
+
+ See :class:`~torch.nn.SoftMarginLoss` for details.
+ """
+ if size_average is not None or reduce is not None:
+ reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
+ else:
+ reduction_enum = _Reduction.get_enum(reduction)
+ return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
+
+
+[docs]@weak_script
+def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
+ reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
+ r"""multilabel_soft_margin_loss(input, target, weight=None, size_average=None) -> Tensor
+
+ See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details.
+ """
+ if size_average is not None or reduce is not None:
+ reduction = _Reduction.legacy_get_string(size_average, reduce)
+
+ loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input))
+
+ if weight is not None:
+ loss = loss * weight
+
+ loss = loss.sum(dim=1) / input.size(1) # only return N loss values
+
+ if reduction == 'none':
+ ret = loss
+ elif reduction == 'mean':
+ ret = loss.mean()
+ elif reduction == 'sum':
+ ret = loss.sum()
+ else:
+ ret = input
+ raise ValueError(reduction + " is not valid")
+ return ret
+
+
+[docs]@weak_script
+def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
+ reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
+ r"""cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor
+
+ See :class:`~torch.nn.CosineEmbeddingLoss` for details.
+ """ # noqa
+ if size_average is not None or reduce is not None:
+ reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
+ else:
+ reduction_enum = _Reduction.get_enum(reduction)
+ return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum)
+
+
+[docs]@weak_script
+def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None,
+ reduce=None, reduction='mean'):
+ # type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
+ r"""multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None,
+ reduce=None, reduction='mean') -> Tensor
+
+ See :class:`~torch.nn.MultiMarginLoss` for details.
+ """
+ if size_average is not None or reduce is not None:
+ reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
+ else:
+ reduction_enum = _Reduction.get_enum(reduction)
+ if p != 1 and p != 2:
+ raise ValueError('only p == 1 and p == 2 supported')
+ if weight is not None:
+ if weight.dim() != 1:
+ raise ValueError('weight must be one-dimensional')
+
+ return torch._C._nn.multi_margin_loss(input, target, p, margin, weight, reduction_enum)
+
+
+pixel_shuffle = _add_docstr(torch.pixel_shuffle, r"""
+Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a
+tensor of shape :math:`(*, C, H \times r, W \times r)`.
+
+See :class:`~torch.nn.PixelShuffle` for details.
+
+Args:
+ input (Tensor): the input tensor
+ upscale_factor (int): factor to increase spatial resolution by
+
+Examples::
+
+ >>> input = torch.randn(1, 9, 4, 4)
+ >>> output = torch.nn.functional.pixel_shuffle(input, 3)
+ >>> print(output.size())
+ torch.Size([1, 1, 12, 12])
+""")
+
+
+[docs]def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
+ r"""Upsamples the input to either the given :attr:`size` or the given
+ :attr:`scale_factor`
+
+ .. warning::
+ This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
+ This is equivalent with ``nn.functional.interpolate(...)``.
+
+ .. include:: cuda_deterministic_backward.rst
+
+ The algorithm used for upsampling is determined by :attr:`mode`.
+
+ Currently temporal, spatial and volumetric upsampling are supported, i.e.
+ expected inputs are 3-D, 4-D or 5-D in shape.
+
+ The input dimensions are interpreted in the form:
+ `mini-batch x channels x [optional depth] x [optional height] x width`.
+
+ The modes available for upsampling are: `nearest`, `linear` (3D-only),
+ `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only)
+
+ Args:
+ input (Tensor): the input tensor
+ size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
+ output spatial size.
+ scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer.
+ mode (string): algorithm used for upsampling:
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
+ ``'trilinear'``. Default: ``'nearest'``
+ align_corners (bool, optional): Geometrically, we consider the pixels of the
+ input and output as squares rather than points.
+ If set to ``True``, the input and output tensors are aligned by the
+ center points of their corner pixels. If set to ``False``, the input and
+ output tensors are aligned by the corner points of their corner
+ pixels, and the interpolation uses edge value padding for out-of-boundary values.
+ This only has effect when :attr:`mode` is ``'linear'``,
+ ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.
+ Default: ``False``
+
+ .. warning::
+ With ``align_corners = True``, the linearly interpolating modes
+ (`linear`, `bilinear`, and `trilinear`) don't proportionally align the
+ output and input pixels, and thus the output values can depend on the
+ input size. This was the default behavior for these modes up to version
+ 0.3.1. Since then, the default behavior is ``align_corners = False``.
+ See :class:`~torch.nn.Upsample` for concrete examples on how this
+ affects the outputs.
+
+ """
+ warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
+ return interpolate(input, size, scale_factor, mode, align_corners)
+
+
+[docs]def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
+ r"""Down/up samples the input to either the given :attr:`size` or the given
+ :attr:`scale_factor`
+
+ The algorithm used for interpolation is determined by :attr:`mode`.
+
+ Currently temporal, spatial and volumetric sampling are supported, i.e.
+ expected inputs are 3-D, 4-D or 5-D in shape.
+
+ The input dimensions are interpreted in the form:
+ `mini-batch x channels x [optional depth] x [optional height] x width`.
+
+ The modes available for resizing are: `nearest`, `linear` (3D-only),
+ `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`
+
+ Args:
+ input (Tensor): the input tensor
+ size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
+ output spatial size.
+ scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
+ mode (str): algorithm used for upsampling:
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
+ ``'trilinear'`` | ``'area'``. Default: ``'nearest'``
+ align_corners (bool, optional): Geometrically, we consider the pixels of the
+ input and output as squares rather than points.
+ If set to ``True``, the input and output tensors are aligned by the
+ center points of their corner pixels. If set to ``False``, the input and
+ output tensors are aligned by the corner points of their corner
+ pixels, and the interpolation uses edge value padding for out-of-boundary values.
+ This only has effect when :attr:`mode` is ``'linear'``,
+ ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``.
+ Default: ``False``
+
+ .. warning::
+ With ``align_corners = True``, the linearly interpolating modes
+ (`linear`, `bilinear`, and `trilinear`) don't proportionally align the
+ output and input pixels, and thus the output values can depend on the
+ input size. This was the default behavior for these modes up to version
+ 0.3.1. Since then, the default behavior is ``align_corners = False``.
+ See :class:`~torch.nn.Upsample` for concrete examples on how this
+ affects the outputs.
+
+ .. include:: cuda_deterministic_backward.rst
+ """
+ from .modules.utils import _ntuple
+
+ def _check_size_scale_factor(dim):
+ if size is None and scale_factor is None:
+ raise ValueError('either size or scale_factor should be defined')
+ if size is not None and scale_factor is not None:
+ raise ValueError('only one of size or scale_factor should be defined')
+ if scale_factor is not None and isinstance(scale_factor, tuple)\
+ and len(scale_factor) != dim:
+ raise ValueError('scale_factor shape must match input shape. '
+ 'Input is {}D, scale_factor size is {}'.format(dim, len(scale_factor)))
+
+ def _output_size(dim):
+ _check_size_scale_factor(dim)
+ if size is not None:
+ return size
+ scale_factors = _ntuple(dim)(scale_factor)
+ # math.floor might return float in py2.7
+
+ # make scale_factor a tensor in tracing so constant doesn't get baked in
+ if torch._C._get_tracing_state():
+ return [(torch.floor(input.size(i + 2) * torch.tensor(float(scale_factors[i])))) for i in range(dim)]
+ else:
+ return [int(math.floor(int(input.size(i + 2)) * scale_factors[i])) for i in range(dim)]
+
+ if mode in ('nearest', 'area'):
+ if align_corners is not None:
+ raise ValueError("align_corners option can only be set with the "
+ "interpolating modes: linear | bilinear | bicubic | trilinear")
+ else:
+ if align_corners is None:
+ warnings.warn("Default upsampling behavior when mode={} is changed "
+ "to align_corners=False since 0.4.0. Please specify "
+ "align_corners=True if the old behavior is desired. "
+ "See the documentation of nn.Upsample for details.".format(mode))
+ align_corners = False
+
+ if input.dim() == 3 and mode == 'nearest':
+ return torch._C._nn.upsample_nearest1d(input, _output_size(1))
+ elif input.dim() == 4 and mode == 'nearest':
+ return torch._C._nn.upsample_nearest2d(input, _output_size(2))
+ elif input.dim() == 5 and mode == 'nearest':
+ return torch._C._nn.upsample_nearest3d(input, _output_size(3))
+ elif input.dim() == 3 and mode == 'area':
+ return adaptive_avg_pool1d(input, _output_size(1))
+ elif input.dim() == 4 and mode == 'area':
+ return adaptive_avg_pool2d(input, _output_size(2))
+ elif input.dim() == 5 and mode == 'area':
+ return adaptive_avg_pool3d(input, _output_size(3))
+ elif input.dim() == 3 and mode == 'linear':
+ return torch._C._nn.upsample_linear1d(input, _output_size(1), align_corners)
+ elif input.dim() == 3 and mode == 'bilinear':
+ raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input")
+ elif input.dim() == 3 and mode == 'trilinear':
+ raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input")
+ elif input.dim() == 4 and mode == 'linear':
+ raise NotImplementedError("Got 4D input, but linear mode needs 3D input")
+ elif input.dim() == 4 and mode == 'bilinear':
+ return torch._C._nn.upsample_bilinear2d(input, _output_size(2), align_corners)
+ elif input.dim() == 4 and mode == 'trilinear':
+ raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
+ elif input.dim() == 5 and mode == 'linear':
+ raise NotImplementedError("Got 5D input, but linear mode needs 3D input")
+ elif input.dim() == 5 and mode == 'bilinear':
+ raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
+ elif input.dim() == 5 and mode == 'trilinear':
+ return torch._C._nn.upsample_trilinear3d(input, _output_size(3), align_corners)
+ elif input.dim() == 4 and mode == 'bicubic':
+ return torch._C._nn.upsample_bicubic2d(input, _output_size(2), align_corners)
+ else:
+ raise NotImplementedError("Input Error: Only 3D, 4D and 5D input Tensors supported"
+ " (got {}D) for the modes: nearest | linear | bilinear | bicubic | trilinear"
+ " (got {})".format(input.dim(), mode))
+
+
+[docs]def upsample_nearest(input, size=None, scale_factor=None):
+ r"""Upsamples the input, using nearest neighbours' pixel values.
+
+ .. warning::
+ This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
+ This is equivalent with ``nn.functional.interpolate(..., mode='nearest')``.
+
+ Currently spatial and volumetric upsampling are supported (i.e. expected
+ inputs are 4 or 5 dimensional).
+
+ Args:
+ input (Tensor): input
+ size (int or Tuple[int, int] or Tuple[int, int, int]): output spatia
+ size.
+ scale_factor (int): multiplier for spatial size. Has to be an integer.
+
+ .. include:: cuda_deterministic_backward.rst
+ """
+ # DeprecationWarning is ignored by default
+ warnings.warn("nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.")
+ return interpolate(input, size, scale_factor, mode='nearest')
+
+
+[docs]def upsample_bilinear(input, size=None, scale_factor=None):
+ r"""Upsamples the input, using bilinear upsampling.
+
+ .. warning::
+ This function is deprecated in favor of :func:`torch.nn.functional.interpolate`.
+ This is equivalent with
+ ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``.
+
+ Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` fo
+ volumetric (5 dimensional) inputs.
+
+ Args:
+ input (Tensor): input
+ size (int or Tuple[int, int]): output spatial size.
+ scale_factor (int or Tuple[int, int]): multiplier for spatial size
+
+ .. include:: cuda_deterministic_backward.rst
+ """
+ # DeprecationWarning is ignored by default
+ warnings.warn("nn.functional.upsample_bilinear is deprecated. Use nn.functional.interpolate instead.")
+ return interpolate(input, size, scale_factor, mode='bilinear', align_corners=True)
+
+
+GRID_SAMPLE_INTERPOLATION_MODES = {
+ 'bilinear': 0,
+ 'nearest': 1,
+}
+
+GRID_SAMPLE_PADDING_MODES = {
+ 'zeros': 0,
+ 'border': 1,
+ 'reflection': 2,
+}
+
+
+[docs]@weak_script
+def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
+ # type: (Tensor, Tensor, str, str) -> Tensor
+ r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the
+ ``output`` using :attr:`input` values and pixel locations from :attr:`grid`.
+
+ Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are
+ supported.
+
+ In the spatial (4-D) case, for :attr:`input` with shape
+ :math:`(N, C, H_\text{in}, W_\text{in})` and :attr:`grid` with shape
+ :math:`(N, H_\text{out}, W_\text{out}, 2)`, the output will have shape
+ :math:`(N, C, H_\text{out}, W_\text{out})`.
+
+ For each output location ``output[n, :, h, w]``, the size-2 vector
+ ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``,
+ which are used to interpolate the output value ``output[n, :, h, w]``.
+ In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the
+ ``x``, ``y``, ``z`` pixel locations for interpolating
+ ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or
+ ``bilinear`` interpolation method to sample the input pixels.
+
+ :attr:`grid` specifies the sampling pixel locations normalized by the
+ :attr:`input` spatial dimensions. Therefore, it should have most values in
+ the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the
+ left-top pixel of :attr:`input`, and values ``x = 1, y = 1`` is the
+ right-bottom pixel of :attr:`input`.
+
+ If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding
+ outputs are handled as defined by :attr:`padding_mode`. Options are
+
+ * ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations,
+ * ``padding_mode="border"``: use border values for out-of-bound grid locations,
+ * ``padding_mode="reflection"``: use values at locations reflected by
+ the border for out-of-bound grid locations. For location far away
+ from the border, it will keep being reflected until becoming in bound,
+ e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1``
+ and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes
+ ``x'' = -0.5``.
+
+ .. Note:: This function is often used in building `Spatial Transformer Networks`_ .
+ .. include:: cuda_deterministic_backward.rst
+
+ Args:
+ input (Tensor): input of shape :math:`(N, C, H_\text{in}, W_\text{in})` (4-D case)
+ or :math:`(N, C, D_\text{in}, H_\text{in}, W_\text{in})` (5-D case)
+ grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case)
+ or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case)
+ mode (str): interpolation mode to calculate output values
+ ``'bilinear'`` | ``'nearest'``. Default: ``'bilinear'``
+ padding_mode (str): padding mode for outside grid values
+ ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'``
+
+ Returns:
+ output (Tensor): output Tensor
+
+ .. _`Spatial Transformer Networks`:
+ https://arxiv.org/abs/1506.02025
+ """
+ if mode != 'bilinear' and mode != 'nearest':
+ raise ValueError("nn.functional.grid_sample(): expected mode to be "
+ "'bilinear' or 'nearest', but got: '{}'".format(mode))
+ if padding_mode != 'zeros' and padding_mode != 'border' and padding_mode != 'reflection':
+ raise ValueError("nn.functional.grid_sample(): expected padding_mode "
+ "to be 'zeros', 'border', or 'reflection', "
+ "but got: '{}'".format(padding_mode))
+
+ if mode == 'bilinear':
+ mode_enum = 0
+ else:
+ mode_enum = 1
+
+ if padding_mode == 'zeros':
+ padding_mode_enum = 0
+ elif padding_mode == 'border':
+ padding_mode_enum = 1
+ else:
+ padding_mode_enum = 2
+
+ return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum)
+
+
+[docs]@weak_script
+def affine_grid(theta, size):
+ # type: (Tensor, List[int]) -> Tensor
+ r"""Generates a 2d flow field, given a batch of affine matrices :attr:`theta`.
+ Generally used in conjunction with :func:`grid_sample` to
+ implement Spatial Transformer Networks.
+
+ Args:
+ theta (Tensor): input batch of affine matrices (:math:`N \times 2 \times 3`)
+ size (torch.Size): the target output image size (:math:`N \times C \times H \times W`).
+ Example: torch.Size((32, 3, 24, 24))
+
+ Returns:
+ output (Tensor): output Tensor of size (:math:`N \times H \times W \times 2`)
+ """
+ return vision.affine_grid_generator(theta, size)
+
+
+[docs]@weak_script
+def pad(input, pad, mode='constant', value=0):
+ # type: (Tensor, List[int], str, float) -> Tensor
+ r"""Pads tensor.
+
+ Padding size:
+ The padding size by which to pad some dimensions of :attr:`input`
+ are described starting from the last dimension and moving forward.
+ :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions
+ of ``input`` will be padded.
+ For example, to pad only the last dimension of the input tensor, then
+ :attr:`pad` has the form
+ :math:`(\text{padding\_left}, \text{padding\_right})`;
+ to pad the last 2 dimensions of the input tensor, then use
+ :math:`(\text{padding\_left}, \text{padding\_right},`
+ :math:`\text{padding\_top}, \text{padding\_bottom})`;
+ to pad the last 3 dimensions, use
+ :math:`(\text{padding\_left}, \text{padding\_right},`
+ :math:`\text{padding\_top}, \text{padding\_bottom}`
+ :math:`\text{padding\_front}, \text{padding\_back})`.
+
+ Padding mode:
+ See :class:`torch.nn.ConstantPad2d`, :class:`torch.nn.ReflectionPad2d`, and
+ :class:`torch.nn.ReplicationPad2d` for concrete examples on how each of the
+ padding modes works. Constant padding is implemented for arbitrary dimensions.
+ Replicate padding is implemented for padding the last 3 dimensions of 5D input
+ tensor, or the last 2 dimensions of 4D input tensor, or the last dimension of
+ 3D input tensor. Reflect padding is only implemented for padding the last 2
+ dimensions of 4D input tensor, or the last dimension of 3D input tensor.
+
+ .. include:: cuda_deterministic_backward.rst
+
+ Args:
+ input (Tensor): N-dimensional tensor
+ pad (tuple): m-elements tuple, where
+ :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even.
+ mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ Default: ``'constant'``
+ value: fill value for ``'constant'`` padding. Default: ``0``
+
+ Examples::
+
+ >>> t4d = torch.empty(3, 3, 4, 2)
+ >>> p1d = (1, 1) # pad last dim by 1 on each side
+ >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding
+ >>> print(out.data.size())
+ torch.Size([3, 3, 4, 4])
+ >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
+ >>> out = F.pad(t4d, p2d, "constant", 0)
+ >>> print(out.data.size())
+ torch.Size([3, 3, 8, 4])
+ >>> t4d = torch.empty(3, 3, 4, 2)
+ >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3)
+ >>> out = F.pad(t4d, p3d, "constant", 0)
+ >>> print(out.data.size())
+ torch.Size([3, 9, 7, 3])
+
+ """
+ assert len(pad) % 2 == 0, 'Padding length must be divisible by 2'
+ assert len(pad) // 2 <= input.dim(), 'Padding length too large'
+ if mode == 'constant':
+ ret = _VF.constant_pad_nd(input, pad, value)
+ else:
+ assert value == 0, 'Padding mode "{}"" doesn\'t take in value argument'.format(mode)
+ if input.dim() == 3:
+ assert len(pad) == 2, '3D tensors expect 2 values for padding'
+ if mode == 'reflect':
+ ret = torch._C._nn.reflection_pad1d(input, pad)
+ elif mode == 'replicate':
+ ret = torch._C._nn.replication_pad1d(input, pad)
+ elif mode == 'circular':
+ ret = _pad_circular(input, pad)
+ else:
+ ret = input # TODO: remove this when jit raise supports control flow
+ raise NotImplementedError
+
+ elif input.dim() == 4:
+ assert len(pad) == 4, '4D tensors expect 4 values for padding'
+ if mode == 'reflect':
+ ret = torch._C._nn.reflection_pad2d(input, pad)
+ elif mode == 'replicate':
+ ret = torch._C._nn.replication_pad2d(input, pad)
+ elif mode == 'circular':
+ ret = _pad_circular(input, pad)
+ else:
+ ret = input # TODO: remove this when jit raise supports control flow
+ raise NotImplementedError
+
+ elif input.dim() == 5:
+ assert len(pad) == 6, '5D tensors expect 6 values for padding'
+ if mode == 'reflect':
+ ret = input # TODO: remove this when jit raise supports control flow
+ raise NotImplementedError
+ elif mode == 'replicate':
+ ret = torch._C._nn.replication_pad3d(input, pad)
+ elif mode == 'circular':
+ ret = _pad_circular(input, pad)
+ else:
+ ret = input # TODO: remove this when jit raise supports control flow
+ raise NotImplementedError
+ else:
+ ret = input # TODO: remove this when jit raise supports control flow
+ raise NotImplementedError("Only 3D, 4D, 5D padding with non-constant padding are supported for now")
+
+ return ret
+
+# distance
+
+
+[docs]@weak_script
+def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False):
+ # type: (Tensor, Tensor, float, float, bool) -> Tensor
+ r"""
+ See :class:`torch.nn.PairwiseDistance` for details
+ """
+ return torch.pairwise_distance(x1, x2, p, eps, keepdim)
+
+
+pdist = _add_docstr(torch.pdist, r"""
+pdist(input, p=2) -> Tensor
+
+Computes the p-norm distance between every pair of row vectors in the input.
+This is identical to the upper triangular portion, excluding the diagonal, of
+`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster
+if the rows are contiguous.
+
+If input has shape :math:`N \times M` then the output will have shape
+:math:`\frac{1}{2} N (N - 1)`.
+
+This function is equivalent to `scipy.spatial.distance.pdist(input,
+'minkowski', p=p)` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is
+equivalent to `scipy.spatial.distance.pdist(input, 'hamming') * M`.
+When :math:`p = \infty`, the closest scipy function is
+`scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())`.
+
+Args:
+ input: input tensor of shape :math:`N \times M`.
+ p: p value for the p-norm distance to calculate between each vector pair
+ :math:`\in [0, \infty]`.
+""")
+
+
+cosine_similarity = _add_docstr(torch.cosine_similarity, r"""
+cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor
+
+Returns cosine similarity between x1 and x2, computed along dim.
+
+.. math ::
+ \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}
+
+Args:
+ x1 (Tensor): First input.
+ x2 (Tensor): Second input (of size matching x1).
+ dim (int, optional): Dimension of vectors. Default: 1
+ eps (float, optional): Small value to avoid division by zero.
+ Default: 1e-8
+
+Shape:
+ - Input: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`.
+ - Output: :math:`(\ast_1, \ast_2)` where 1 is at position `dim`.
+
+Example::
+
+ >>> input1 = torch.randn(100, 128)
+ >>> input2 = torch.randn(100, 128)
+ >>> output = F.cosine_similarity(input1, input2)
+ >>> print(output)
+""")
+
+
+one_hot = _add_docstr(torch._C._nn.one_hot, r"""
+one_hot(tensor, num_classes=0) -> LongTensor
+
+Takes LongTensor with index values of shape ``(*)`` and returns a tensor
+of shape ``(*, num_classes)`` that have zeros everywhere except where the
+index of last dimension matches the corresponding value of the input tensor,
+in which case it will be 1.
+
+See also `One-hot on Wikipedia`_ .
+
+.. _One-hot on Wikipedia:
+ https://en.wikipedia.org/wiki/One-hot
+
+Arguments:
+ tensor (LongTensor): class values of any shape.
+ num_classes (int): Total number of classes. If set to -1, the number
+ of classes will be inferred as one greater than the largest class
+ value in the input tensor.
+
+Returns:
+ LongTensor that has one more dimension with 1 values at the
+ index of last dimension indicated by the input, and 0 everywhere
+ else.
+
+Examples:
+ >>> F.one_hot(torch.arange(0, 5) % 3)
+ tensor([[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
+ tensor([[1, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 0],
+ [1, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0]])
+ >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)
+ tensor([[[1, 0, 0],
+ [0, 1, 0]],
+ [[0, 0, 1],
+ [1, 0, 0]],
+ [[0, 1, 0],
+ [0, 0, 1]]])
+""")
+
+
+[docs]@weak_script
+def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
+ reduce=None, reduction="mean"):
+ # type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor
+ r"""
+ See :class:`~torch.nn.TripletMarginLoss` for details
+ """
+ if size_average is not None or reduce is not None:
+ reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
+ else:
+ reduction_enum = _Reduction.get_enum(reduction)
+ return torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps,
+ swap, reduction_enum)
+
+
+[docs]@weak_script
+def normalize(input, p=2, dim=1, eps=1e-12, out=None):
+ # type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
+ r"""Performs :math:`L_p` normalization of inputs over specified dimension.
+
+ For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
+ :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as
+
+ .. math::
+ v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
+
+ With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization.
+
+ Args:
+ input: input tensor of any shape
+ p (float): the exponent value in the norm formulation. Default: 2
+ dim (int): the dimension to reduce. Default: 1
+ eps (float): small value to avoid division by zero. Default: 1e-12
+ out (Tensor, optional): the output tensor. If :attr:`out` is used, this
+ operation won't be differentiable.
+ """
+ if out is None:
+ denom = input.norm(p, dim, True).clamp_min(eps).expand_as(input)
+ ret = input / denom
+ else:
+ denom = input.norm(p, dim, True).clamp_min(eps).expand_as(input)
+ ret = torch.div(input, denom, out=out)
+ return ret
+
+
+def assert_int_or_pair(arg, arg_name, message):
+ assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
+
+
+[docs]@weak_script
+def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
+ # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
+ r"""Extracts sliding local blocks from an batched input tensor.
+
+ .. warning::
+ Currently, only 4-D input tensors (batched image-like tensors) are
+ supported.
+
+ .. warning::
+
+ More than one element of the unfolded tensor may refer to a single
+ memory location. As a result, in-place operations (especially ones that
+ are vectorized) may result in incorrect behavior. If you need to write
+ to the tensor, please clone it first.
+
+
+ See :class:`torch.nn.Unfold` for details
+ """
+
+ if input.dim() == 4:
+ msg = '{} must be int or 2-tuple for 4D input'
+ assert_int_or_pair(kernel_size, 'kernel_size', msg)
+ assert_int_or_pair(dilation, 'dilation', msg)
+ assert_int_or_pair(padding, 'padding', msg)
+ assert_int_or_pair(stride, 'stride', msg)
+
+ ret = torch._C._nn.thnn_im2col(input, _pair(kernel_size),
+ _pair(dilation), _pair(padding), _pair(stride))
+ else:
+ raise NotImplementedError("Input Error: Only 4D input Tensors are supported (got {}D)".format(input.dim()))
+ ret = input # TODO: remove when jit supports exception control flow
+ return ret
+
+
+[docs]@weak_script
+def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
+ # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor # noqa
+ r"""Combines an array of sliding local blocks into a large containing
+ tensor.
+
+ .. warning::
+ Currently, only 4-D output tensors (batched image-like tensors) are
+ supported.
+
+ See :class:`torch.nn.Fold` for details
+ """
+ if input.dim() == 3:
+ msg = '{} must be int or 2-tuple for 3D input'
+ assert_int_or_pair(output_size, 'output_size', msg)
+ assert_int_or_pair(kernel_size, 'kernel_size', msg)
+ assert_int_or_pair(dilation, 'dilation', msg)
+ assert_int_or_pair(padding, 'padding', msg)
+ assert_int_or_pair(stride, 'stride', msg)
+
+ ret = torch._C._nn.thnn_col2im(input, _pair(output_size), _pair(kernel_size),
+ _pair(dilation), _pair(padding), _pair(stride))
+ else:
+ raise NotImplementedError("Input Error: Only 3D input Tensors are supported (got {}D)".format(input.dim()))
+ ret = input # TODO: remove when jit supports exception control flow
+ return ret
+
+
+@weak_script
+def _pad_circular(input, padding):
+ # type: (Tensor, List[int]) -> Tensor
+ """
+ Arguments
+ :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))`
+ :param padding: (tuple): m-elem tuple where m is the degree of convolution
+ Returns
+ :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0],
+ H + 2 * padding[1]], W + 2 * padding[2]))`
+ """
+
+ input = torch.cat([input, input[:, :, 0:padding[-1]]], dim=2)
+ input = torch.cat([input[:, :, -(padding[-1] + padding[-2]):-padding[-1]], input], dim=2)
+
+ if len(padding) > 2:
+ input = torch.cat([input, input[:, :, :, 0:padding[-3]]], dim=3)
+ input = torch.cat([input[:, :, :, -(padding[-3] + padding[-4]):-padding[-3]], input], dim=3)
+
+ if len(padding) > 4:
+ input = torch.cat([input, input[:, :, :, :, 0:padding[-5]]], dim=4)
+ input = torch.cat([input[:, :, :, :, -(padding[-5] + padding[-6]):-padding[-5]], input], dim=4)
+
+ return input
+
+from __future__ import division
+
+import math
+import warnings
+
+import torch
+from .._jit_internal import weak_script
+
+# These no_grad_* functions are necessary as wrappers around the parts of these
+# functions that use `with torch.no_grad()`. The JIT doesn't support context
+# managers, so these need to be implemented as builtins. Using these wrappers
+# lets us keep those builtins small and re-usable.
+def _no_grad_uniform_(tensor, a, b):
+ with torch.no_grad():
+ return tensor.uniform_(a, b)
+
+
+def _no_grad_normal_(tensor, mean, std):
+ with torch.no_grad():
+ return tensor.normal_(mean, std)
+
+
+def _no_grad_fill_(tensor, val):
+ with torch.no_grad():
+ return tensor.fill_(val)
+
+
+def _no_grad_zero_(tensor):
+ with torch.no_grad():
+ return tensor.zero_()
+
+
+[docs]def calculate_gain(nonlinearity, param=None):
+ r"""Return the recommended gain value for the given nonlinearity function.
+ The values are as follows:
+
+ ================= ====================================================
+ nonlinearity gain
+ ================= ====================================================
+ Linear / Identity :math:`1`
+ Conv{1,2,3}D :math:`1`
+ Sigmoid :math:`1`
+ Tanh :math:`\frac{5}{3}`
+ ReLU :math:`\sqrt{2}`
+ Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
+ ================= ====================================================
+
+ Args:
+ nonlinearity: the non-linear function (`nn.functional` name)
+ param: optional parameter for the non-linear function
+
+ Examples:
+ >>> gain = nn.init.calculate_gain('leaky_relu')
+ """
+ linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
+ if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
+ return 1
+ elif nonlinearity == 'tanh':
+ return 5.0 / 3
+ elif nonlinearity == 'relu':
+ return math.sqrt(2.0)
+ elif nonlinearity == 'leaky_relu':
+ if param is None:
+ negative_slope = 0.01
+ elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
+ # True/False are instances of int, hence check above
+ negative_slope = param
+ else:
+ raise ValueError("negative_slope {} not a valid number".format(param))
+ return math.sqrt(2.0 / (1 + negative_slope ** 2))
+ else:
+ raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
+
+
+[docs]@weak_script
+def uniform_(tensor, a=0., b=1.):
+ # type: (Tensor, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from the uniform
+ distribution :math:`\mathcal{U}(a, b)`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ a: the lower bound of the uniform distribution
+ b: the upper bound of the uniform distribution
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.uniform_(w)
+ """
+ return _no_grad_uniform_(tensor, a, b)
+
+
+[docs]@weak_script
+def normal_(tensor, mean=0., std=1.):
+ # type: (Tensor, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std})`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.normal_(w)
+ """
+ return _no_grad_normal_(tensor, mean, std)
+
+
+[docs]@weak_script
+def constant_(tensor, val):
+ # type: (Tensor, float) -> Tensor
+ r"""Fills the input Tensor with the value :math:`\text{val}`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ val: the value to fill the tensor with
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.constant_(w, 0.3)
+ """
+ return _no_grad_fill_(tensor, val)
+
+
+@weak_script
+def ones_(tensor):
+ # type: (Tensor) -> Tensor
+ r"""Fills the input Tensor with ones`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.ones_(w)
+ """
+ return _no_grad_fill_(tensor, 1.)
+
+
+@weak_script
+def zeros_(tensor):
+ # type: (Tensor) -> Tensor
+ r"""Fills the input Tensor with zeros`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.zeros_(w)
+ """
+ return _no_grad_zero_(tensor)
+
+
+[docs]def eye_(tensor):
+ r"""Fills the 2-dimensional input `Tensor` with the identity
+ matrix. Preserves the identity of the inputs in `Linear` layers, where as
+ many inputs are preserved as possible.
+
+ Args:
+ tensor: a 2-dimensional `torch.Tensor`
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.eye_(w)
+ """
+ if tensor.ndimension() != 2:
+ raise ValueError("Only tensors with 2 dimensions are supported")
+
+ with torch.no_grad():
+ torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
+ return tensor
+
+
+[docs]def dirac_(tensor):
+ r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac
+ delta function. Preserves the identity of the inputs in `Convolutional`
+ layers, where as many input channels are preserved as possible.
+
+ Args:
+ tensor: a {3, 4, 5}-dimensional `torch.Tensor`
+
+ Examples:
+ >>> w = torch.empty(3, 16, 5, 5)
+ >>> nn.init.dirac_(w)
+ """
+ dimensions = tensor.ndimension()
+ if dimensions not in [3, 4, 5]:
+ raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
+
+ sizes = tensor.size()
+ min_dim = min(sizes[0], sizes[1])
+ with torch.no_grad():
+ tensor.zero_()
+
+ for d in range(min_dim):
+ if dimensions == 3: # Temporal convolution
+ tensor[d, d, tensor.size(2) // 2] = 1
+ elif dimensions == 4: # Spatial convolution
+ tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2] = 1
+ else: # Volumetric convolution
+ tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2, tensor.size(4) // 2] = 1
+ return tensor
+
+
+@weak_script
+def _calculate_fan_in_and_fan_out(tensor):
+ dimensions = tensor.dim()
+ if dimensions < 2:
+ raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
+
+ if dimensions == 2: # Linear
+ fan_in = tensor.size(1)
+ fan_out = tensor.size(0)
+ else:
+ num_input_fmaps = tensor.size(1)
+ num_output_fmaps = tensor.size(0)
+ receptive_field_size = 1
+ if tensor.dim() > 2:
+ receptive_field_size = tensor[0][0].numel()
+ fan_in = num_input_fmaps * receptive_field_size
+ fan_out = num_output_fmaps * receptive_field_size
+
+ return fan_in, fan_out
+
+
+[docs]@weak_script
+def xavier_uniform_(tensor, gain=1.):
+ # type: (Tensor, float) -> Tensor
+ r"""Fills the input `Tensor` with values according to the method
+ described in `Understanding the difficulty of training deep feedforward
+ neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
+ distribution. The resulting tensor will have values sampled from
+ :math:`\mathcal{U}(-a, a)` where
+
+ .. math::
+ a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
+
+ Also known as Glorot initialization.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ gain: an optional scaling factor
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
+ """
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
+ a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
+
+ return _no_grad_uniform_(tensor, -a, a)
+
+
+[docs]@weak_script
+def xavier_normal_(tensor, gain=1.):
+ # type: (Tensor, float) -> Tensor
+ r"""Fills the input `Tensor` with values according to the method
+ described in `Understanding the difficulty of training deep feedforward
+ neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal
+ distribution. The resulting tensor will have values sampled from
+ :math:`\mathcal{N}(0, \text{std})` where
+
+ .. math::
+ \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
+
+ Also known as Glorot initialization.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ gain: an optional scaling factor
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.xavier_normal_(w)
+ """
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
+
+ return _no_grad_normal_(tensor, 0., std)
+
+
+def _calculate_correct_fan(tensor, mode):
+ mode = mode.lower()
+ valid_modes = ['fan_in', 'fan_out']
+ if mode not in valid_modes:
+ raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
+
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ return fan_in if mode == 'fan_in' else fan_out
+
+
+[docs]def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
+ r"""Fills the input `Tensor` with values according to the method
+ described in `Delving deep into rectifiers: Surpassing human-level
+ performance on ImageNet classification` - He, K. et al. (2015), using a
+ uniform distribution. The resulting tensor will have values sampled from
+ :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
+
+ .. math::
+ \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan\_in}}}
+
+ Also known as He initialization.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ a: the negative slope of the rectifier used after this layer (0 for ReLU
+ by default)
+ mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
+ preserves the magnitude of the variance of the weights in the
+ forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
+ backwards pass.
+ nonlinearity: the non-linear function (`nn.functional` name),
+ recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
+ """
+ fan = _calculate_correct_fan(tensor, mode)
+ gain = calculate_gain(nonlinearity, a)
+ std = gain / math.sqrt(fan)
+ bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
+ with torch.no_grad():
+ return tensor.uniform_(-bound, bound)
+
+
+[docs]def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
+ r"""Fills the input `Tensor` with values according to the method
+ described in `Delving deep into rectifiers: Surpassing human-level
+ performance on ImageNet classification` - He, K. et al. (2015), using a
+ normal distribution. The resulting tensor will have values sampled from
+ :math:`\mathcal{N}(0, \text{std})` where
+
+ .. math::
+ \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan\_in}}}
+
+ Also known as He initialization.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ a: the negative slope of the rectifier used after this layer (0 for ReLU
+ by default)
+ mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
+ preserves the magnitude of the variance of the weights in the
+ forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
+ backwards pass.
+ nonlinearity: the non-linear function (`nn.functional` name),
+ recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
+ """
+ fan = _calculate_correct_fan(tensor, mode)
+ gain = calculate_gain(nonlinearity, a)
+ std = gain / math.sqrt(fan)
+ with torch.no_grad():
+ return tensor.normal_(0, std)
+
+
+[docs]def orthogonal_(tensor, gain=1):
+ r"""Fills the input `Tensor` with a (semi) orthogonal matrix, as
+ described in `Exact solutions to the nonlinear dynamics of learning in deep
+ linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
+ at least 2 dimensions, and for tensors with more than 2 dimensions the
+ trailing dimensions are flattened.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
+ gain: optional scaling factor
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.orthogonal_(w)
+ """
+ if tensor.ndimension() < 2:
+ raise ValueError("Only tensors with 2 or more dimensions are supported")
+
+ rows = tensor.size(0)
+ cols = tensor.numel() // rows
+ flattened = tensor.new(rows, cols).normal_(0, 1)
+
+ if rows < cols:
+ flattened.t_()
+
+ # Compute the qr factorization
+ q, r = torch.qr(flattened)
+ # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
+ d = torch.diag(r, 0)
+ ph = d.sign()
+ q *= ph
+
+ if rows < cols:
+ q.t_()
+
+ with torch.no_grad():
+ tensor.view_as(q).copy_(q)
+ tensor.mul_(gain)
+ return tensor
+
+
+[docs]def sparse_(tensor, sparsity, std=0.01):
+ r"""Fills the 2D input `Tensor` as a sparse matrix, where the
+ non-zero elements will be drawn from the normal distribution
+ :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
+ Hessian-free optimization` - Martens, J. (2010).
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ sparsity: The fraction of elements in each column to be set to zero
+ std: the standard deviation of the normal distribution used to generate
+ the non-zero values
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.sparse_(w, sparsity=0.1)
+ """
+ if tensor.ndimension() != 2:
+ raise ValueError("Only tensors with 2 dimensions are supported")
+
+ rows, cols = tensor.shape
+ num_zeros = int(math.ceil(sparsity * rows))
+
+ with torch.no_grad():
+ tensor.normal_(0, std)
+ for col_idx in range(cols):
+ row_indices = torch.randperm(rows)
+ zero_indices = row_indices[:num_zeros]
+ tensor[zero_indices, col_idx] = 0
+ return tensor
+
+
+# for backward compatibility
+def _make_deprecate(meth):
+ new_name = meth.__name__
+ old_name = new_name[:-1]
+
+ def deprecated_init(*args, **kwargs):
+ warnings.warn("nn.init.{} is now deprecated in favor of nn.init.{}."
+ .format(old_name, new_name), stacklevel=2)
+ return meth(*args, **kwargs)
+
+ deprecated_init.__doc__ = r"""
+ {old_name}(...)
+
+ .. warning::
+ This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
+
+ See :func:`~torch.nn.init.{new_name}` for details.""".format(
+ old_name=old_name, new_name=new_name)
+ deprecated_init.__name__ = old_name
+ return deprecated_init
+
+
+uniform = _make_deprecate(uniform_)
+normal = _make_deprecate(normal_)
+constant = _make_deprecate(constant_)
+eye = _make_deprecate(eye_)
+dirac = _make_deprecate(dirac_)
+xavier_uniform = _make_deprecate(xavier_uniform_)
+xavier_normal = _make_deprecate(xavier_normal_)
+kaiming_uniform = _make_deprecate(kaiming_uniform_)
+kaiming_normal = _make_deprecate(kaiming_normal_)
+orthogonal = _make_deprecate(orthogonal_)
+sparse = _make_deprecate(sparse_)
+
+import warnings
+import torch
+from . import Linear
+from torch.nn.init import xavier_uniform_
+from torch.nn.init import constant_
+from torch.nn.init import xavier_normal_
+from torch.nn.parameter import Parameter
+from .module import Module
+from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
+
+
+[docs]@weak_module
+class Threshold(Module):
+ r"""Thresholds each element of the input Tensor.
+
+ Threshold is defined as:
+
+ .. math::
+ y =
+ \begin{cases}
+ x, &\text{ if } x > \text{threshold} \\
+ \text{value}, &\text{ otherwise }
+ \end{cases}
+
+ Args:
+ threshold: The value to threshold at
+ value: The value to replace with
+ inplace: can optionally do the operation in-place. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ Examples::
+
+ >>> m = nn.Threshold(0.1, 20)
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+ __constants__ = ['threshold', 'value', 'inplace']
+
+ def __init__(self, threshold, value, inplace=False):
+ super(Threshold, self).__init__()
+ self.threshold = threshold
+ self.value = value
+ self.inplace = inplace
+ # TODO: check in THNN (if inplace == True, then assert value <= threshold)
+
+ @weak_script_method
+ def forward(self, input):
+ return F.threshold(input, self.threshold, self.value, self.inplace)
+
+ def extra_repr(self):
+ inplace_str = ', inplace' if self.inplace else ''
+ return 'threshold={}, value={}{}'.format(
+ self.threshold, self.value, inplace_str
+ )
+
+
+[docs]@weak_module
+class ReLU(Module):
+ r"""Applies the rectified linear unit function element-wise:
+
+ :math:`\text{ReLU}(x)= \max(0, x)`
+
+ Args:
+ inplace: can optionally do the operation in-place. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/ReLU.png
+
+ Examples::
+
+ >>> m = nn.ReLU()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+
+
+ An implementation of CReLU - https://arxiv.org/abs/1603.05201
+
+ >>> m = nn.ReLU()
+ >>> input = torch.randn(2).unsqueeze(0)
+ >>> output = torch.cat((m(input),m(-input)))
+ """
+ __constants__ = ['inplace']
+
+ def __init__(self, inplace=False):
+ super(ReLU, self).__init__()
+ self.inplace = inplace
+
+ @weak_script_method
+ def forward(self, input):
+ return F.relu(input, inplace=self.inplace)
+
+ def extra_repr(self):
+ inplace_str = 'inplace' if self.inplace else ''
+ return inplace_str
+
+
+[docs]@weak_module
+class RReLU(Module):
+ r"""Applies the randomized leaky rectified liner unit function, element-wise,
+ as described in the paper:
+
+ `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
+
+ The function is defined as:
+
+ .. math::
+ \text{RReLU}(x) =
+ \begin{cases}
+ x & \text{if } x \geq 0 \\
+ ax & \text{ otherwise }
+ \end{cases}
+
+ where :math:`a` is randomly sampled from uniform distribution
+ :math:`\mathcal{U}(\text{lower}, \text{upper})`.
+
+ See: https://arxiv.org/pdf/1505.00853.pdf
+
+ Args:
+ lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
+ upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
+ inplace: can optionally do the operation in-place. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ Examples::
+
+ >>> m = nn.RReLU(0.1, 0.3)
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+
+ .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
+ https://arxiv.org/abs/1505.00853
+ """
+ __constants__ = ['lower', 'upper', 'inplace']
+
+ def __init__(self, lower=1. / 8, upper=1. / 3, inplace=False):
+ super(RReLU, self).__init__()
+ self.lower = lower
+ self.upper = upper
+ self.inplace = inplace
+
+ @weak_script_method
+ def forward(self, input):
+ return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
+
+ def extra_repr(self):
+ inplace_str = ', inplace' if self.inplace else ''
+ return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
+
+
+[docs]@weak_module
+class Hardtanh(Module):
+ r"""Applies the HardTanh function element-wise
+
+ HardTanh is defined as:
+
+ .. math::
+ \text{HardTanh}(x) = \begin{cases}
+ 1 & \text{ if } x > 1 \\
+ -1 & \text{ if } x < -1 \\
+ x & \text{ otherwise } \\
+ \end{cases}
+
+ The range of the linear region :math:`[-1, 1]` can be adjusted using
+ :attr:`min_val` and :attr:`max_val`.
+
+ Args:
+ min_val: minimum value of the linear region range. Default: -1
+ max_val: maximum value of the linear region range. Default: 1
+ inplace: can optionally do the operation in-place. Default: ``False``
+
+ Keyword arguments :attr:`min_value` and :attr:`max_value`
+ have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/Hardtanh.png
+
+ Examples::
+
+ >>> m = nn.Hardtanh(-2, 2)
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+ __constants__ = ['min_val', 'max_val', 'inplace']
+
+ def __init__(self, min_val=-1., max_val=1., inplace=False, min_value=None, max_value=None):
+ super(Hardtanh, self).__init__()
+ if min_value is not None:
+ warnings.warn("keyword argument min_value is deprecated and renamed to min_val")
+ min_val = min_value
+ if max_value is not None:
+ warnings.warn("keyword argument max_value is deprecated and renamed to max_val")
+ max_val = max_value
+
+ self.min_val = min_val
+ self.max_val = max_val
+ self.inplace = inplace
+ assert self.max_val > self.min_val
+
+ @weak_script_method
+ def forward(self, input):
+ return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
+
+ def extra_repr(self):
+ inplace_str = ', inplace' if self.inplace else ''
+ return 'min_val={}, max_val={}{}'.format(
+ self.min_val, self.max_val, inplace_str
+ )
+
+
+[docs]@weak_module
+class ReLU6(Hardtanh):
+ r"""Applies the element-wise function:
+
+ .. math::
+ \text{ReLU6}(x) = \min(\max(0,x), 6)
+
+ Args:
+ inplace: can optionally do the operation in-place. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/ReLU6.png
+
+ Examples::
+
+ >>> m = nn.ReLU6()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ def __init__(self, inplace=False):
+ super(ReLU6, self).__init__(0., 6., inplace)
+
+ def extra_repr(self):
+ inplace_str = 'inplace' if self.inplace else ''
+ return inplace_str
+
+
+[docs]@weak_module
+class Sigmoid(Module):
+ r"""Applies the element-wise function:
+
+ .. math::
+ \text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}
+
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/Sigmoid.png
+
+ Examples::
+
+ >>> m = nn.Sigmoid()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return torch.sigmoid(input)
+
+
+[docs]@weak_module
+class Tanh(Module):
+ r"""Applies the element-wise function:
+
+ .. math::
+ \text{Tanh}(x) = \tanh(x) = \frac{e^x - e^{-x}} {e^x + e^{-x}}
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/Tanh.png
+
+ Examples::
+
+ >>> m = nn.Tanh()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return torch.tanh(input)
+
+
+[docs]@weak_module
+class ELU(Module):
+ r"""Applies the element-wise function:
+
+ .. math::
+ \text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))
+
+ Args:
+ alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
+ inplace: can optionally do the operation in-place. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/ELU.png
+
+ Examples::
+
+ >>> m = nn.ELU()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+ __constants__ = ['alpha', 'inplace']
+
+ def __init__(self, alpha=1., inplace=False):
+ super(ELU, self).__init__()
+ self.alpha = alpha
+ self.inplace = inplace
+
+ @weak_script_method
+ def forward(self, input):
+ return F.elu(input, self.alpha, self.inplace)
+
+ def extra_repr(self):
+ inplace_str = ', inplace' if self.inplace else ''
+ return 'alpha={}{}'.format(self.alpha, inplace_str)
+
+
+[docs]@weak_module
+class CELU(Module):
+ r"""Applies the element-wise function:
+
+ .. math::
+ \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
+
+ More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
+
+ Args:
+ alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
+ inplace: can optionally do the operation in-place. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/CELU.png
+
+ Examples::
+
+ >>> m = nn.CELU()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+
+ .. _`Continuously Differentiable Exponential Linear Units`:
+ https://arxiv.org/abs/1704.07483
+ """
+ __constants__ = ['alpha', 'inplace']
+
+ def __init__(self, alpha=1., inplace=False):
+ super(CELU, self).__init__()
+ self.alpha = alpha
+ self.inplace = inplace
+
+ @weak_script_method
+ def forward(self, input):
+ return F.celu(input, self.alpha, self.inplace)
+
+ def extra_repr(self):
+ inplace_str = ', inplace' if self.inplace else ''
+ return 'alpha={}{}'.format(self.alpha, inplace_str)
+
+
+[docs]@weak_module
+class SELU(Module):
+ r"""Applied element-wise, as:
+
+ .. math::
+ \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
+
+ with :math:`\alpha = 1.6732632423543772848170429916717` and
+ :math:`\text{scale} = 1.0507009873554804934193349852946`.
+
+ More details can be found in the paper `Self-Normalizing Neural Networks`_ .
+
+ Args:
+ inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/SELU.png
+
+ Examples::
+
+ >>> m = nn.SELU()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
+ """
+ __constants__ = ['inplace']
+
+ def __init__(self, inplace=False):
+ super(SELU, self).__init__()
+ self.inplace = inplace
+
+ @weak_script_method
+ def forward(self, input):
+ return F.selu(input, self.inplace)
+
+ def extra_repr(self):
+ inplace_str = 'inplace' if self.inplace else ''
+ return inplace_str
+
+
+@weak_module
+class GLU(Module):
+ r"""Applies the gated linear unit function
+ :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
+ of the input matrices and :math:`b` is the second half.
+
+ Args:
+ dim (int): the dimension on which to split the input. Default: -1
+
+ Shape:
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
+
+ Examples::
+
+ >>> m = nn.GLU()
+ >>> input = torch.randn(4, 2)
+ >>> output = m(input)
+ """
+ __constants__ = ['dim']
+
+ def __init__(self, dim=-1):
+ super(GLU, self).__init__()
+ self.dim = dim
+
+ @weak_script_method
+ def forward(self, input):
+ return F.glu(input, self.dim)
+
+ def extra_repr(self):
+ return 'dim={}'.format(self.dim)
+
+
+[docs]@weak_module
+class Hardshrink(Module):
+ r"""Applies the hard shrinkage function element-wise:
+
+ .. math::
+ \text{HardShrink}(x) =
+ \begin{cases}
+ x, & \text{ if } x > \lambda \\
+ x, & \text{ if } x < -\lambda \\
+ 0, & \text{ otherwise }
+ \end{cases}
+
+ Args:
+ lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/Hardshrink.png
+
+ Examples::
+
+ >>> m = nn.Hardshrink()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+ __constants__ = ['lambd']
+
+ def __init__(self, lambd=0.5):
+ super(Hardshrink, self).__init__()
+ self.lambd = lambd
+
+ @weak_script_method
+ def forward(self, input):
+ return F.hardshrink(input, self.lambd)
+
+ def extra_repr(self):
+ return '{}'.format(self.lambd)
+
+
+[docs]@weak_module
+class LeakyReLU(Module):
+ r"""Applies the element-wise function:
+
+ .. math::
+ \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
+
+
+ or
+
+ .. math::
+ \text{LeakyRELU}(x) =
+ \begin{cases}
+ x, & \text{ if } x \geq 0 \\
+ \text{negative\_slope} \times x, & \text{ otherwise }
+ \end{cases}
+
+ Args:
+ negative_slope: Controls the angle of the negative slope. Default: 1e-2
+ inplace: can optionally do the operation in-place. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/LeakyReLU.png
+
+ Examples::
+
+ >>> m = nn.LeakyReLU(0.1)
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+ __constants__ = ['inplace', 'negative_slope']
+
+ def __init__(self, negative_slope=1e-2, inplace=False):
+ super(LeakyReLU, self).__init__()
+ self.negative_slope = negative_slope
+ self.inplace = inplace
+
+ @weak_script_method
+ def forward(self, input):
+ return F.leaky_relu(input, self.negative_slope, self.inplace)
+
+ def extra_repr(self):
+ inplace_str = ', inplace' if self.inplace else ''
+ return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
+
+
+[docs]@weak_module
+class LogSigmoid(Module):
+ r"""Applies the element-wise function:
+
+ .. math::
+ \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/LogSigmoid.png
+
+ Examples::
+
+ >>> m = nn.LogSigmoid()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.logsigmoid(input)
+
+
+[docs]@weak_module
+class Softplus(Module):
+ r"""Applies the element-wise function:
+
+ .. math::
+ \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
+
+ SoftPlus is a smooth approximation to the ReLU function and can be used
+ to constrain the output of a machine to always be positive.
+
+ For numerical stability the implementation reverts to the linear function
+ for inputs above a certain value.
+
+ Args:
+ beta: the :math:`\beta` value for the Softplus formulation. Default: 1
+ threshold: values above this revert to a linear function. Default: 20
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/Softplus.png
+
+ Examples::
+
+ >>> m = nn.Softplus()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+ __constants__ = ['beta', 'threshold']
+
+ def __init__(self, beta=1, threshold=20):
+ super(Softplus, self).__init__()
+ self.beta = beta
+ self.threshold = threshold
+
+ @weak_script_method
+ def forward(self, input):
+ return F.softplus(input, self.beta, self.threshold)
+
+ def extra_repr(self):
+ return 'beta={}, threshold={}'.format(self.beta, self.threshold)
+
+
+[docs]@weak_module
+class Softshrink(Module):
+ r"""Applies the soft shrinkage function elementwise:
+
+ .. math::
+ \text{SoftShrinkage}(x) =
+ \begin{cases}
+ x - \lambda, & \text{ if } x > \lambda \\
+ x + \lambda, & \text{ if } x < -\lambda \\
+ 0, & \text{ otherwise }
+ \end{cases}
+
+ Args:
+ lambd: the :math:`\lambda` value for the Softshrink formulation. Default: 0.5
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/Softshrink.png
+
+ Examples::
+
+ >>> m = nn.Softshrink()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+ __constants__ = ['lambd']
+
+ def __init__(self, lambd=0.5):
+ super(Softshrink, self).__init__()
+ self.lambd = lambd
+
+ @weak_script_method
+ def forward(self, input):
+ return F.softshrink(input, self.lambd)
+
+ def extra_repr(self):
+ return str(self.lambd)
+
+
+[docs]@weak_module
+class MultiheadAttention(Module):
+ r"""Allows the model to jointly attend to information
+ from different representation subspaces.
+ See reference: Attention Is All You Need
+
+ .. math::
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
+
+ Args:
+ embed_dim: total dimension of the model
+ num_heads: parallel attention layers, or heads
+
+ Examples::
+
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
+ """
+
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim ** -0.5
+
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
+ if bias:
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+ self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ xavier_uniform_(self.in_proj_weight[:self.embed_dim, :])
+ xavier_uniform_(self.in_proj_weight[self.embed_dim:(self.embed_dim * 2), :])
+ xavier_uniform_(self.in_proj_weight[(self.embed_dim * 2):, :])
+
+ xavier_uniform_(self.out_proj.weight)
+ if self.in_proj_bias is not None:
+ constant_(self.in_proj_bias, 0.)
+ constant_(self.out_proj.bias, 0.)
+ if self.bias_k is not None:
+ xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ xavier_normal_(self.bias_v)
+
+[docs] @weak_script_method
+ def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
+ need_weights=True, static_kv=False, attn_mask=None):
+ """
+ Inputs of forward function
+ query: [target length, batch size, embed dim]
+ key: [sequence length, batch size, embed dim]
+ value: [sequence length, batch size, embed dim]
+ key_padding_mask: if True, mask padding based on batch size
+ incremental_state: if provided, previous time steps are cashed
+ need_weights: output attn_output_weights
+ static_kv: key and value are static
+
+ Outputs of forward function
+ attn_output: [target length, batch size, embed dim]
+ attn_output_weights: [batch size, target length, sequence length]
+ """
+ qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
+ kv_same = key.data_ptr() == value.data_ptr()
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ assert key.size() == value.size()
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert kv_same and not qkv_same
+ key = value = None
+ else:
+ saved_state = None
+
+ if qkv_same:
+ # self-attention
+ q, k, v = self._in_proj_qkv(query)
+ elif kv_same:
+ # encoder-decoder attention
+ q = self._in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k, v = self._in_proj_kv(key)
+ else:
+ q = self._in_proj_q(query)
+ k = self._in_proj_k(key)
+ v = self._in_proj_v(value)
+ q *= self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
+
+ self._set_input_buffer(incremental_state, saved_state)
+
+ src_len = k.size(1)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
+
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
+ assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.unsqueeze(0)
+ attn_output_weights += attn_mask
+
+ if key_padding_mask is not None:
+ attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ float('-inf'),
+ )
+ attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_output_weights = F.softmax(
+ attn_output_weights.float(), dim=-1,
+ dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype)
+ attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_output_weights, v)
+ assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn_output = self.out_proj(attn_output)
+
+ if need_weights:
+ # average attention weights over heads
+ attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads
+ else:
+ attn_output_weights = None
+
+ return attn_output, attn_output_weights
+
+ def _in_proj_qkv(self, query):
+ return self._in_proj(query).chunk(3, dim=-1)
+
+ def _in_proj_kv(self, key):
+ return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
+
+ def _in_proj_q(self, query):
+ return self._in_proj(query, end=self.embed_dim)
+
+ def _in_proj_k(self, key):
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
+
+ def _in_proj_v(self, value):
+ return self._in_proj(value, start=2 * self.embed_dim)
+
+ def _in_proj(self, input, start=0, end=None):
+ weight = self.in_proj_weight
+ bias = self.in_proj_bias
+ weight = weight[start:end, :]
+ if bias is not None:
+ bias = bias[start:end]
+ return F.linear(input, weight, bias)
+
+
+[docs]@weak_module
+class PReLU(Module):
+ r"""Applies the element-wise function:
+
+ .. math::
+ \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
+
+ or
+
+ .. math::
+ \text{PReLU}(x) =
+ \begin{cases}
+ x, & \text{ if } x \geq 0 \\
+ ax, & \text{ otherwise }
+ \end{cases}
+
+ Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
+ parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
+ a separate :math:`a` is used for each input channel.
+
+
+ .. note::
+ weight decay should not be used when learning :math:`a` for good performance.
+
+ .. note::
+ Channel dim is the 2nd dim of input. When input has dims < 2, then there is
+ no channel dim and the number of channels = 1.
+
+ Args:
+ num_parameters (int): number of :math:`a` to learn.
+ Although it takes an int as input, there is only two values are legitimate:
+ 1, or the number of channels at input. Default: 1
+ init (float): the initial value of :math:`a`. Default: 0.25
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ Attributes:
+ weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
+
+ .. image:: scripts/activation_images/PReLU.png
+
+ Examples::
+
+ >>> m = nn.PReLU()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ def __init__(self, num_parameters=1, init=0.25):
+ self.num_parameters = num_parameters
+ super(PReLU, self).__init__()
+ self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
+
+ @weak_script_method
+ def forward(self, input):
+ return F.prelu(input, self.weight)
+
+ def extra_repr(self):
+ return 'num_parameters={}'.format(self.num_parameters)
+
+
+[docs]@weak_module
+class Softsign(Module):
+ r"""Applies the element-wise function:
+
+ .. math::
+ \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/Softsign.png
+
+ Examples::
+
+ >>> m = nn.Softsign()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.softsign(input)
+
+
+[docs]@weak_module
+class Tanhshrink(Module):
+ r"""Applies the element-wise function:
+
+ .. math::
+ \text{Tanhshrink}(x) = x - \text{Tanh}(x)
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/Tanhshrink.png
+
+ Examples::
+
+ >>> m = nn.Tanhshrink()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.tanhshrink(input)
+
+
+[docs]@weak_module
+class Softmin(Module):
+ r"""Applies the Softmin function to an n-dimensional input Tensor
+ rescaling them so that the elements of the n-dimensional output Tensor
+ lie in the range `[0, 1]` and sum to 1.
+
+ Softmin is defined as:
+
+ .. math::
+ \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
+
+ Shape:
+ - Input: :math:`(*)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(*)`, same shape as the input
+
+ Arguments:
+ dim (int): A dimension along which Softmin will be computed (so every slice
+ along dim will sum to 1).
+
+ Returns:
+ a Tensor of the same dimension and shape as the input, with
+ values in the range [0, 1]
+
+ Examples::
+
+ >>> m = nn.Softmin()
+ >>> input = torch.randn(2, 3)
+ >>> output = m(input)
+ """
+ __constants__ = ['dim']
+
+ def __init__(self, dim=None):
+ super(Softmin, self).__init__()
+ self.dim = dim
+
+ @weak_script_method
+ def forward(self, input):
+ return F.softmin(input, self.dim, _stacklevel=5)
+
+
+[docs]@weak_module
+class Softmax(Module):
+ r"""Applies the Softmax function to an n-dimensional input Tensor
+ rescaling them so that the elements of the n-dimensional output Tensor
+ lie in the range [0,1] and sum to 1.
+
+ Softmax is defined as:
+
+ .. math::
+ \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
+
+ Shape:
+ - Input: :math:`(*)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(*)`, same shape as the input
+
+ Returns:
+ a Tensor of the same dimension and shape as the input with
+ values in the range [0, 1]
+
+ Arguments:
+ dim (int): A dimension along which Softmax will be computed (so every slice
+ along dim will sum to 1).
+
+ .. note::
+ This module doesn't work directly with NLLLoss,
+ which expects the Log to be computed between the Softmax and itself.
+ Use `LogSoftmax` instead (it's faster and has better numerical properties).
+
+ Examples::
+
+ >>> m = nn.Softmax()
+ >>> input = torch.randn(2, 3)
+ >>> output = m(input)
+ """
+ __constants__ = ['dim']
+
+ def __init__(self, dim=None):
+ super(Softmax, self).__init__()
+ self.dim = dim
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ if not hasattr(self, 'dim'):
+ self.dim = None
+
+ @weak_script_method
+ def forward(self, input):
+ return F.softmax(input, self.dim, _stacklevel=5)
+
+
+[docs]@weak_module
+class Softmax2d(Module):
+ r"""Applies SoftMax over features to each spatial location.
+
+ When given an image of ``Channels x Height x Width``, it will
+ apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
+
+ Shape:
+ - Input: :math:`(N, C, H, W)`
+ - Output: :math:`(N, C, H, W)` (same shape as input)
+
+ Returns:
+ a Tensor of the same dimension and shape as the input with
+ values in the range [0, 1]
+
+ Examples::
+
+ >>> m = nn.Softmax2d()
+ >>> # you softmax over the 2nd dimension
+ >>> input = torch.randn(2, 3, 12, 13)
+ >>> output = m(input)
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
+ return F.softmax(input, 1, _stacklevel=5)
+
+
+[docs]@weak_module
+class LogSoftmax(Module):
+ r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
+ input Tensor. The LogSoftmax formulation can be simplified as:
+
+ .. math::
+ \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
+
+ Shape:
+ - Input: :math:`(*)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(*)`, same shape as the input
+
+ Arguments:
+ dim (int): A dimension along which LogSoftmax will be computed.
+
+ Returns:
+ a Tensor of the same dimension and shape as the input with
+ values in the range [-inf, 0)
+
+ Examples::
+
+ >>> m = nn.LogSoftmax()
+ >>> input = torch.randn(2, 3)
+ >>> output = m(input)
+ """
+ __constants__ = ['dim']
+
+ def __init__(self, dim=None):
+ super(LogSoftmax, self).__init__()
+ self.dim = dim
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ if not hasattr(self, 'dim'):
+ self.dim = None
+
+ @weak_script_method
+ def forward(self, input):
+ return F.log_softmax(input, self.dim, _stacklevel=5)
+
+# -*- coding: utf-8 -*-
+
+from collections import namedtuple
+
+import torch
+
+from . import Sequential, ModuleList, Linear
+from .module import Module
+from ..functional import log_softmax
+
+
+_ASMoutput = namedtuple('ASMoutput', ['output', 'loss'])
+
+
+[docs]class AdaptiveLogSoftmaxWithLoss(Module):
+ r"""Efficient softmax approximation as described in
+ `Efficient softmax approximation for GPUs`_ by Edouard Grave, Armand Joulin,
+ Moustapha Cissé, David Grangier, and Hervé Jégou.
+
+ Adaptive softmax is an approximate strategy for training models with large
+ output spaces. It is most effective when the label distribution is highly
+ imbalanced, for example in natural language modelling, where the word
+ frequency distribution approximately follows the `Zipf's law`_.
+
+ Adaptive softmax partitions the labels into several clusters, according to
+ their frequency. These clusters may contain different number of targets
+ each.
+ Additionally, clusters containing less frequent labels assign lower
+ dimensional embeddings to those labels, which speeds up the computation.
+ For each minibatch, only clusters for which at least one target is
+ present are evaluated.
+
+ The idea is that the clusters which are accessed frequently
+ (like the first one, containing most frequent labels), should also be cheap
+ to compute -- that is, contain a small number of assigned labels.
+
+ We highly recommend taking a look at the original paper for more details.
+
+ * :attr:`cutoffs` should be an ordered Sequence of integers sorted
+ in the increasing order.
+ It controls number of clusters and the partitioning of targets into
+ clusters. For example setting ``cutoffs = [10, 100, 1000]``
+ means that first `10` targets will be assigned
+ to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be
+ assigned to the first cluster, and targets `101, 102, ..., 1000` will be
+ assigned to the second cluster, while targets
+ `1001, 1002, ..., n_classes - 1` will be assigned
+ to the last, third cluster.
+
+ * :attr:`div_value` is used to compute the size of each additional cluster,
+ which is given as
+ :math:`\left\lfloor\frac{in\_features}{div\_value^{idx}}\right\rfloor`,
+ where :math:`idx` is the cluster index (with clusters
+ for less frequent words having larger indices,
+ and indices starting from :math:`1`).
+
+ * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the
+ adaptive softmax. See paper for details. Set to False in the official
+ implementation.
+
+ .. warning::
+ Labels passed as inputs to this module should be sorted accoridng to
+ their frequency. This means that the most frequent label should be
+ represented by the index `0`, and the least frequent
+ label should be represented by the index `n_classes - 1`.
+
+ .. note::
+ This module returns a ``NamedTuple`` with ``output``
+ and ``loss`` fields. See further documentation for details.
+
+ .. note::
+ To compute log-probabilities for all classes, the ``log_prob``
+ method can be used.
+
+ Args:
+ in_features (int): Number of features in the input tensor
+ n_classes (int): Number of classes in the dataset
+ cutoffs (Sequence): Cutoffs used to assign targets to their buckets
+ div_value (float, optional): value used as an exponent to compute sizes
+ of the clusters. Default: 4.0
+ head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the
+ adaptive softmax. Default: ``False``
+
+ Returns:
+ ``NamedTuple`` with ``output`` and ``loss`` fields:
+ * **output** is a Tensor of size ``N`` containing computed target
+ log probabilities for each example
+ * **loss** is a Scalar representing the computed negative
+ log likelihood loss
+
+ Shape:
+ - input: :math:`(N, in\_features)`
+ - target: :math:`(N)` where each value satisfies :math:`0 <= target[i] <= n\_classes`
+ - output1: :math:`(N)`
+ - output2: ``Scalar``
+
+
+ .. _Efficient softmax approximation for GPUs:
+ https://arxiv.org/abs/1609.04309
+
+ .. _Zipf's law:
+ https://en.wikipedia.org/wiki/Zipf%27s_law
+ """
+
+ def __init__(self, in_features, n_classes, cutoffs, div_value=4., head_bias=False):
+ super(AdaptiveLogSoftmaxWithLoss, self).__init__()
+
+ cutoffs = list(cutoffs)
+
+ if (cutoffs != sorted(cutoffs)) \
+ or (min(cutoffs) <= 0) \
+ or (max(cutoffs) > (n_classes - 1)) \
+ or (len(set(cutoffs)) != len(cutoffs)) \
+ or any([int(c) != c for c in cutoffs]):
+
+ raise ValueError("cutoffs should be a sequence of unique, positive "
+ "integers sorted in an increasing order, where "
+ "each value is between 1 and n_classes-1")
+
+ self.in_features = in_features
+ self.n_classes = n_classes
+ self.cutoffs = cutoffs + [n_classes]
+ self.div_value = div_value
+ self.head_bias = head_bias
+
+ self.shortlist_size = self.cutoffs[0]
+ self.n_clusters = len(self.cutoffs) - 1
+ self.head_size = self.shortlist_size + self.n_clusters
+
+ self.head = Linear(self.in_features, self.head_size, bias=self.head_bias)
+ self.tail = ModuleList()
+
+ for i in range(self.n_clusters):
+
+ hsz = int(self.in_features // (self.div_value ** (i + 1)))
+ osz = self.cutoffs[i + 1] - self.cutoffs[i]
+
+ projection = Sequential(
+ Linear(self.in_features, hsz, bias=False),
+ Linear(hsz, osz, bias=False)
+ )
+
+ self.tail.append(projection)
+
+ def reset_parameters(self):
+ self.head.reset_parameters()
+ for i2h, h2o in self.tail:
+ i2h.reset_parameters()
+ h2o.reset_parameters()
+
+ def forward(self, input, target):
+ if input.size(0) != target.size(0):
+ raise RuntimeError('Input and target should have the same size '
+ 'in the batch dimension.')
+
+ used_rows = 0
+ batch_size = target.size(0)
+
+ output = input.new_zeros(batch_size)
+ gather_inds = target.new_empty(batch_size)
+
+ cutoff_values = [0] + self.cutoffs
+ for i in range(len(cutoff_values) - 1):
+
+ low_idx = cutoff_values[i]
+ high_idx = cutoff_values[i + 1]
+
+ target_mask = (target >= low_idx) & (target < high_idx)
+ row_indices = target_mask.nonzero().squeeze()
+
+ if row_indices.numel() == 0:
+ continue
+
+ if i == 0:
+ gather_inds.index_copy_(0, row_indices, target[target_mask])
+
+ else:
+ relative_target = target[target_mask] - low_idx
+ input_subset = input.index_select(0, row_indices)
+
+ cluster_output = self.tail[i - 1](input_subset)
+ cluster_index = self.shortlist_size + i - 1
+
+ gather_inds.index_fill_(0, row_indices, cluster_index)
+
+ cluster_logprob = log_softmax(cluster_output, dim=1)
+ local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
+ output.index_copy_(0, row_indices, local_logprob.squeeze(1))
+
+ used_rows += row_indices.numel()
+
+ if used_rows != batch_size:
+ raise RuntimeError("Target values should be in [0, {}], "
+ "but values in range [{}, {}] "
+ "were found. ".format(self.n_classes - 1,
+ target.min().item(),
+ target.max().item()))
+
+ head_output = self.head(input)
+ head_logprob = log_softmax(head_output, dim=1)
+ output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
+ loss = (-output).mean()
+
+ return _ASMoutput(output, loss)
+
+ def _get_full_log_prob(self, input, head_output):
+ """ Given input tensor, and output of `self.head`,
+ compute the log of the full distribution """
+
+ out = input.new_empty((head_output.size(0), self.n_classes))
+ head_logprob = log_softmax(head_output, dim=1)
+
+ out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size]
+
+ for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])):
+ cluster_output = self.tail[i](input)
+ cluster_logprob = log_softmax(cluster_output, dim=1)
+ output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1)
+
+ out[:, start_idx:stop_idx] = output_logprob
+
+ return out
+
+[docs] def log_prob(self, input):
+ r""" Computes log probabilities for all :math:`n\_classes`
+
+ Args:
+ input (Tensor): a minibatch of examples
+
+ Returns:
+ log-probabilities of for each class :math:`c`
+ in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a
+ parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
+
+ Shape:
+ - Input: :math:`(N, in\_features)`
+ - Output: :math:`(N, n\_classes)`
+
+ """
+
+ head_output = self.head(input)
+ return self._get_full_log_prob(input, head_output)
+
+[docs] def predict(self, input):
+ r""" This is equivalent to `self.log_pob(input).argmax(dim=1)`,
+ but is more efficient in some cases.
+
+ Args:
+ input (Tensor): a minibatch of examples
+
+ Returns:
+ output (Tensor): a class with the highest probability for each example
+
+ Shape:
+ - Input: :math:`(N, in\_features)`
+ - Output: :math:`(N)`
+ """
+
+ head_output = self.head(input)
+ output = torch.argmax(head_output, dim=1)
+ not_in_shortlist = (output >= self.shortlist_size)
+ all_in_shortlist = not (not_in_shortlist.any())
+
+ if all_in_shortlist:
+ return output
+
+ elif not_in_shortlist.all():
+ log_prob = self._get_full_log_prob(input, head_output)
+ return torch.argmax(log_prob, dim=1)
+
+ else:
+ log_prob = self._get_full_log_prob(input[not_in_shortlist],
+ head_output[not_in_shortlist])
+ output[not_in_shortlist] = torch.argmax(log_prob, dim=1)
+ return output
+
+from __future__ import division
+
+import torch
+from ._functions import SyncBatchNorm as sync_batch_norm
+from .module import Module
+from torch.nn.parameter import Parameter
+from .. import functional as F
+from .. import init
+from ..._jit_internal import weak_module, weak_script_method
+
+
+# TODO: check contiguous in THNN
+# TODO: use separate backend functions?
+@weak_module
+class _BatchNorm(Module):
+ _version = 2
+ __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
+ 'running_mean', 'running_var', 'num_batches_tracked']
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
+ track_running_stats=True):
+ super(_BatchNorm, self).__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.momentum = momentum
+ self.affine = affine
+ self.track_running_stats = track_running_stats
+ if self.affine:
+ self.weight = Parameter(torch.Tensor(num_features))
+ self.bias = Parameter(torch.Tensor(num_features))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ if self.track_running_stats:
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
+ else:
+ self.register_parameter('running_mean', None)
+ self.register_parameter('running_var', None)
+ self.register_parameter('num_batches_tracked', None)
+ self.reset_parameters()
+
+ def reset_running_stats(self):
+ if self.track_running_stats:
+ self.running_mean.zero_()
+ self.running_var.fill_(1)
+ self.num_batches_tracked.zero_()
+
+ def reset_parameters(self):
+ self.reset_running_stats()
+ if self.affine:
+ init.uniform_(self.weight)
+ init.zeros_(self.bias)
+
+ def _check_input_dim(self, input):
+ raise NotImplementedError
+
+ @weak_script_method
+ def forward(self, input):
+ self._check_input_dim(input)
+
+ # exponential_average_factor is self.momentum set to
+ # (when it is available) only so that if gets updated
+ # in ONNX graph when this node is exported to ONNX.
+ if self.momentum is None:
+ exponential_average_factor = 0.0
+ else:
+ exponential_average_factor = self.momentum
+
+ if self.training and self.track_running_stats:
+ # TODO: if statement only here to tell the jit to skip emitting this when it is None
+ if self.num_batches_tracked is not None:
+ self.num_batches_tracked += 1
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / float(self.num_batches_tracked)
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
+
+ return F.batch_norm(
+ input, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training or not self.track_running_stats,
+ exponential_average_factor, self.eps)
+
+ def extra_repr(self):
+ return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
+ 'track_running_stats={track_running_stats}'.format(**self.__dict__)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+
+ if (version is None or version < 2) and self.track_running_stats:
+ # at version 2: added num_batches_tracked buffer
+ # this should have a default value of 0
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
+ if num_batches_tracked_key not in state_dict:
+ state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)
+
+ super(_BatchNorm, self)._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs)
+
+
+[docs]@weak_module
+class BatchNorm1d(_BatchNorm):
+ r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
+ inputs with optional additional channel dimension) as described in the paper
+ `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
+
+ .. math::
+
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
+ of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
+ from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
+
+ Also by default, during training this layer keeps running estimates of its
+ computed mean and variance, which are then used for normalization during
+ evaluation. The running estimates are kept with a default :attr:`momentum`
+ of 0.1.
+
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
+ keep running estimates, and batch statistics are instead used during
+ evaluation time as well.
+
+ .. note::
+ This :attr:`momentum` argument is different from one used in optimizer
+ classes and the conventional notion of momentum. Mathematically, the
+ update rule for running statistics here is
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+ new observed value.
+
+ Because the Batch Normalization is done over the `C` dimension, computing statistics
+ on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
+
+ Args:
+ num_features: :math:`C` from an expected input of size
+ :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Can be set to ``None`` for cumulative moving average
+ (i.e. simple average). Default: 0.1
+ affine: a boolean value that when set to ``True``, this module has
+ learnable affine parameters. Default: ``True``
+ track_running_stats: a boolean value that when set to ``True``, this
+ module tracks the running mean and variance, and when set to ``False``,
+ this module does not track such statistics and always uses batch
+ statistics in both training and eval modes. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
+
+ Examples::
+
+ >>> # With Learnable Parameters
+ >>> m = nn.BatchNorm1d(100)
+ >>> # Without Learnable Parameters
+ >>> m = nn.BatchNorm1d(100, affine=False)
+ >>> input = torch.randn(20, 100)
+ >>> output = m(input)
+
+ .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
+ https://arxiv.org/abs/1502.03167
+ """
+
+ @weak_script_method
+ def _check_input_dim(self, input):
+ if input.dim() != 2 and input.dim() != 3:
+ raise ValueError('expected 2D or 3D input (got {}D input)'
+ .format(input.dim()))
+
+
+[docs]@weak_module
+class BatchNorm2d(_BatchNorm):
+ r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
+ with additional channel dimension) as described in the paper
+ `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
+
+ .. math::
+
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
+ of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
+ from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
+
+ Also by default, during training this layer keeps running estimates of its
+ computed mean and variance, which are then used for normalization during
+ evaluation. The running estimates are kept with a default :attr:`momentum`
+ of 0.1.
+
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
+ keep running estimates, and batch statistics are instead used during
+ evaluation time as well.
+
+ .. note::
+ This :attr:`momentum` argument is different from one used in optimizer
+ classes and the conventional notion of momentum. Mathematically, the
+ update rule for running statistics here is
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+ new observed value.
+
+ Because the Batch Normalization is done over the `C` dimension, computing statistics
+ on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
+
+ Args:
+ num_features: :math:`C` from an expected input of size
+ :math:`(N, C, H, W)`
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Can be set to ``None`` for cumulative moving average
+ (i.e. simple average). Default: 0.1
+ affine: a boolean value that when set to ``True``, this module has
+ learnable affine parameters. Default: ``True``
+ track_running_stats: a boolean value that when set to ``True``, this
+ module tracks the running mean and variance, and when set to ``False``,
+ this module does not track such statistics and always uses batch
+ statistics in both training and eval modes. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, H, W)`
+ - Output: :math:`(N, C, H, W)` (same shape as input)
+
+ Examples::
+
+ >>> # With Learnable Parameters
+ >>> m = nn.BatchNorm2d(100)
+ >>> # Without Learnable Parameters
+ >>> m = nn.BatchNorm2d(100, affine=False)
+ >>> input = torch.randn(20, 100, 35, 45)
+ >>> output = m(input)
+
+ .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
+ https://arxiv.org/abs/1502.03167
+ """
+
+ @weak_script_method
+ def _check_input_dim(self, input):
+ if input.dim() != 4:
+ raise ValueError('expected 4D input (got {}D input)'
+ .format(input.dim()))
+
+
+[docs]@weak_module
+class BatchNorm3d(_BatchNorm):
+ r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
+ with additional channel dimension) as described in the paper
+ `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
+
+ .. math::
+
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
+ of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled
+ from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
+
+ Also by default, during training this layer keeps running estimates of its
+ computed mean and variance, which are then used for normalization during
+ evaluation. The running estimates are kept with a default :attr:`momentum`
+ of 0.1.
+
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
+ keep running estimates, and batch statistics are instead used during
+ evaluation time as well.
+
+ .. note::
+ This :attr:`momentum` argument is different from one used in optimizer
+ classes and the conventional notion of momentum. Mathematically, the
+ update rule for running statistics here is
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+ new observed value.
+
+ Because the Batch Normalization is done over the `C` dimension, computing statistics
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
+ or Spatio-temporal Batch Normalization.
+
+ Args:
+ num_features: :math:`C` from an expected input of size
+ :math:`(N, C, D, H, W)`
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Can be set to ``None`` for cumulative moving average
+ (i.e. simple average). Default: 0.1
+ affine: a boolean value that when set to ``True``, this module has
+ learnable affine parameters. Default: ``True``
+ track_running_stats: a boolean value that when set to ``True``, this
+ module tracks the running mean and variance, and when set to ``False``,
+ this module does not track such statistics and always uses batch
+ statistics in both training and eval modes. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, D, H, W)`
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
+
+ Examples::
+
+ >>> # With Learnable Parameters
+ >>> m = nn.BatchNorm3d(100)
+ >>> # Without Learnable Parameters
+ >>> m = nn.BatchNorm3d(100, affine=False)
+ >>> input = torch.randn(20, 100, 35, 45, 10)
+ >>> output = m(input)
+
+ .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
+ https://arxiv.org/abs/1502.03167
+ """
+
+ @weak_script_method
+ def _check_input_dim(self, input):
+ if input.dim() != 5:
+ raise ValueError('expected 5D input (got {}D input)'
+ .format(input.dim()))
+
+
+[docs]class SyncBatchNorm(_BatchNorm):
+ r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
+ with additional channel dimension) as described in the paper
+ `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
+
+ .. math::
+
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ The mean and standard-deviation are calculated per-dimension over all
+ mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
+ are learnable parameter vectors of size `C` (where `C` is the input size).
+ By default, the elements of :math:`\gamma` are sampled from
+ :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
+
+ Also by default, during training this layer keeps running estimates of its
+ computed mean and variance, which are then used for normalization during
+ evaluation. The running estimates are kept with a default :attr:`momentum`
+ of 0.1.
+
+ If :attr:`track_running_stats` is set to ``False``, this layer then does not
+ keep running estimates, and batch statistics are instead used during
+ evaluation time as well.
+
+ .. note::
+ This :attr:`momentum` argument is different from one used in optimizer
+ classes and the conventional notion of momentum. Mathematically, the
+ update rule for running statistics here is
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+ new observed value.
+
+ Because the Batch Normalization is done over the `C` dimension, computing statistics
+ on `(N, +)` slices, it's common terminology to call this Volumetric Batch Normalization
+ or Spatio-temporal Batch Normalization.
+
+ Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use
+ torch.nn.SyncBatchNorm.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping
+ Network with DDP.
+
+ Args:
+ num_features: :math:`C` from an expected input of size
+ :math:`(N, C, +)`
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Can be set to ``None`` for cumulative moving average
+ (i.e. simple average). Default: 0.1
+ affine: a boolean value that when set to ``True``, this module has
+ learnable affine parameters. Default: ``True``
+ track_running_stats: a boolean value that when set to ``True``, this
+ module tracks the running mean and variance, and when set to ``False``,
+ this module does not track such statistics and always uses batch
+ statistics in both training and eval modes. Default: ``True``
+ process_group: synchronization of stats happen within each process group
+ individually. Default behavior is synchronization across the whole
+ world
+
+ Shape:
+ - Input: :math:`(N, C, +)`
+ - Output: :math:`(N, C, +)` (same shape as input)
+
+ Examples::
+
+ >>> # With Learnable Parameters
+ >>> m = nn.SyncBatchNorm(100)
+ >>> # creating process group (optional)
+ >>> # process_ids is a list of int identifying rank ids.
+ >>> process_group = torch.distributed.new_group(process_ids)
+ >>> # Without Learnable Parameters
+ >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
+ >>> input = torch.randn(20, 100, 35, 45, 10)
+ >>> output = m(input)
+
+ >>> # network is nn.BatchNorm layer
+ >>> sync_bn_network = torch.nn.utils.convert_sync_batchnorm(network, process_group)
+ >>> # only single gpu per process is currently supported
+ >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
+ >>> sync_bn_network,
+ >>> device_ids=[args.local_rank],
+ >>> output_device=args.local_rank)
+
+ .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
+ https://arxiv.org/abs/1502.03167
+ """
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
+ track_running_stats=True, process_group=None):
+ super(SyncBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats)
+ self.process_group = process_group
+ # gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used
+ # under supported condition (single GPU per process)
+ self.ddp_gpu_size = None
+
+ def _check_input_dim(self, input):
+ if input.dim() <= 2:
+ raise ValueError('expected at least 3D input (got {}D input)'
+ .format(input.dim()))
+
+ def _specify_ddp_gpu_num(self, gpu_size):
+ if gpu_size > 1:
+ raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process')
+ self.ddp_gpu_size = gpu_size
+
+ def forward(self, input):
+ # currently only GPU input is supported
+ if not input.is_cuda:
+ raise ValueError('expected input tensor to be on GPU')
+
+ if not self.ddp_gpu_size:
+ raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel')
+
+ self._check_input_dim(input)
+
+ exponential_average_factor = 0.0
+
+ if self.training and self.track_running_stats:
+ self.num_batches_tracked += 1
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / self.num_batches_tracked.item()
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
+
+ world_size = 1
+ process_group = torch.distributed.group.WORLD
+ if self.process_group:
+ process_group = self.process_group
+ world_size = torch.distributed.get_world_size(process_group)
+
+ # fallback to framework BN when synchronization is not necessary
+ if world_size == 1 or (not self.training and self.track_running_stats):
+ return F.batch_norm(
+ input, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training or not self.track_running_stats,
+ exponential_average_factor, self.eps)
+ else:
+ return sync_batch_norm.apply(
+ input, self.weight, self.bias, self.running_mean, self.running_var,
+ self.eps, exponential_average_factor, process_group, world_size)
+
+[docs] @classmethod
+ def convert_sync_batchnorm(cls, module, process_group=None):
+ r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
+ `torch.nn.SyncBatchNorm` layer.
+
+ Args:
+ module (nn.Module): containing module
+ process_group (optional): process group to scope synchronization,
+ default is the whole world
+
+ Returns:
+ The original module with the converted `torch.nn.SyncBatchNorm` layer
+
+ Example::
+
+ >>> # Network with nn.BatchNorm layer
+ >>> module = torch.nn.Sequential(
+ >>> torch.nn.Linear(20, 100),
+ >>> torch.nn.BatchNorm1d(100)
+ >>> ).cuda()
+ >>> # creating process group (optional)
+ >>> # process_ids is a list of int identifying rank ids.
+ >>> process_group = torch.distributed.new_group(process_ids)
+ >>> sync_bn_module = convert_sync_batchnorm(module, process_group)
+
+ """
+ module_output = module
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
+ module_output = torch.nn.SyncBatchNorm(module.num_features,
+ module.eps, module.momentum,
+ module.affine,
+ module.track_running_stats,
+ process_group)
+ if module.affine:
+ module_output.weight.data = module.weight.data.clone().detach()
+ module_output.bias.data = module.bias.data.clone().detach()
+ module_output.running_mean = module.running_mean
+ module_output.running_var = module.running_var
+ module_output.num_batches_tracked = module.num_batches_tracked
+ for name, child in module.named_children():
+ module_output.add_module(name, cls.convert_sync_batchnorm(child))
+ del module
+ return module_output
+
+import warnings
+from collections import OrderedDict
+from torch._six import container_abcs
+from itertools import islice
+import operator
+
+import torch
+from .module import Module
+
+
+class Container(Module):
+
+ def __init__(self, **kwargs):
+ super(Container, self).__init__()
+ # DeprecationWarning is ignored by default <sigh>
+ warnings.warn("nn.Container is deprecated. All of it's functionality "
+ "is now implemented in nn.Module. Subclass that instead.")
+ for key, value in kwargs.items():
+ self.add_module(key, value)
+
+
+[docs]class Sequential(Module):
+ r"""A sequential container.
+ Modules will be added to it in the order they are passed in the constructor.
+ Alternatively, an ordered dict of modules can also be passed in.
+
+ To make it easier to understand, here is a small example::
+
+ # Example of using Sequential
+ model = nn.Sequential(
+ nn.Conv2d(1,20,5),
+ nn.ReLU(),
+ nn.Conv2d(20,64,5),
+ nn.ReLU()
+ )
+
+ # Example of using Sequential with OrderedDict
+ model = nn.Sequential(OrderedDict([
+ ('conv1', nn.Conv2d(1,20,5)),
+ ('relu1', nn.ReLU()),
+ ('conv2', nn.Conv2d(20,64,5)),
+ ('relu2', nn.ReLU())
+ ]))
+ """
+
+ def __init__(self, *args):
+ super(Sequential, self).__init__()
+ if len(args) == 1 and isinstance(args[0], OrderedDict):
+ for key, module in args[0].items():
+ self.add_module(key, module)
+ else:
+ for idx, module in enumerate(args):
+ self.add_module(str(idx), module)
+
+ def _get_item_by_idx(self, iterator, idx):
+ """Get the idx-th item of the iterator"""
+ size = len(self)
+ idx = operator.index(idx)
+ if not -size <= idx < size:
+ raise IndexError('index {} is out of range'.format(idx))
+ idx %= size
+ return next(islice(iterator, idx, None))
+
+ def __getitem__(self, idx):
+ if isinstance(idx, slice):
+ return self.__class__(OrderedDict(list(self._modules.items())[idx]))
+ else:
+ return self._get_item_by_idx(self._modules.values(), idx)
+
+ def __setitem__(self, idx, module):
+ key = self._get_item_by_idx(self._modules.keys(), idx)
+ return setattr(self, key, module)
+
+ def __delitem__(self, idx):
+ if isinstance(idx, slice):
+ for key in list(self._modules.keys())[idx]:
+ delattr(self, key)
+ else:
+ key = self._get_item_by_idx(self._modules.keys(), idx)
+ delattr(self, key)
+
+ def __len__(self):
+ return len(self._modules)
+
+ def __dir__(self):
+ keys = super(Sequential, self).__dir__()
+ keys = [key for key in keys if not key.isdigit()]
+ return keys
+
+ def forward(self, input):
+ for module in self._modules.values():
+ input = module(input)
+ return input
+
+
+[docs]class ModuleList(Module):
+ r"""Holds submodules in a list.
+
+ :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
+ modules it contains are properly registered, and will be visible by all
+ :class:`~torch.nn.Module` methods.
+
+ Arguments:
+ modules (iterable, optional): an iterable of modules to add
+
+ Example::
+
+ class MyModule(nn.Module):
+ def __init__(self):
+ super(MyModule, self).__init__()
+ self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
+
+ def forward(self, x):
+ # ModuleList can act as an iterable, or be indexed using ints
+ for i, l in enumerate(self.linears):
+ x = self.linears[i // 2](x) + l(x)
+ return x
+ """
+
+ def __init__(self, modules=None):
+ super(ModuleList, self).__init__()
+ if modules is not None:
+ self += modules
+
+ def _get_abs_string_index(self, idx):
+ """Get the absolute index for the list of modules"""
+ idx = operator.index(idx)
+ if not (-len(self) <= idx < len(self)):
+ raise IndexError('index {} is out of range'.format(idx))
+ if idx < 0:
+ idx += len(self)
+ return str(idx)
+
+ def __getitem__(self, idx):
+ if isinstance(idx, slice):
+ return self.__class__(list(self._modules.values())[idx])
+ else:
+ return self._modules[self._get_abs_string_index(idx)]
+
+ def __setitem__(self, idx, module):
+ idx = self._get_abs_string_index(idx)
+ return setattr(self, str(idx), module)
+
+ def __delitem__(self, idx):
+ if isinstance(idx, slice):
+ for k in range(len(self._modules))[idx]:
+ delattr(self, str(k))
+ else:
+ delattr(self, self._get_abs_string_index(idx))
+ # To preserve numbering, self._modules is being reconstructed with modules after deletion
+ str_indices = [str(i) for i in range(len(self._modules))]
+ self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
+
+ def __len__(self):
+ return len(self._modules)
+
+ def __iter__(self):
+ return iter(self._modules.values())
+
+ def __iadd__(self, modules):
+ return self.extend(modules)
+
+ def __dir__(self):
+ keys = super(ModuleList, self).__dir__()
+ keys = [key for key in keys if not key.isdigit()]
+ return keys
+
+[docs] def insert(self, index, module):
+ r"""Insert a given module before a given index in the list.
+
+ Arguments:
+ index (int): index to insert.
+ module (nn.Module): module to insert
+ """
+ for i in range(len(self._modules), index, -1):
+ self._modules[str(i)] = self._modules[str(i - 1)]
+ self._modules[str(index)] = module
+
+[docs] def append(self, module):
+ r"""Appends a given module to the end of the list.
+
+ Arguments:
+ module (nn.Module): module to append
+ """
+ self.add_module(str(len(self)), module)
+ return self
+
+[docs] def extend(self, modules):
+ r"""Appends modules from a Python iterable to the end of the list.
+
+ Arguments:
+ modules (iterable): iterable of modules to append
+ """
+ if not isinstance(modules, container_abcs.Iterable):
+ raise TypeError("ModuleList.extend should be called with an "
+ "iterable, but got " + type(modules).__name__)
+ offset = len(self)
+ for i, module in enumerate(modules):
+ self.add_module(str(offset + i), module)
+ return self
+
+
+[docs]class ModuleDict(Module):
+ r"""Holds submodules in a dictionary.
+
+ :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
+ but modules it contains are properly registered, and will be visible by all
+ :class:`~torch.nn.Module` methods.
+
+ :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
+
+ * the order of insertion, and
+
+ * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged ``OrderedDict``
+ or another :class:`~torch.nn.ModuleDict` (the argument to :meth:`~torch.nn.ModuleDict.update`).
+
+ Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
+ types (e.g., Python's plain ``dict``) does not preserve the order of the
+ merged mapping.
+
+ Arguments:
+ modules (iterable, optional): a mapping (dictionary) of (string: module)
+ or an iterable of key-value pairs of type (string, module)
+
+ Example::
+
+ class MyModule(nn.Module):
+ def __init__(self):
+ super(MyModule, self).__init__()
+ self.choices = nn.ModuleDict({
+ 'conv': nn.Conv2d(10, 10, 3),
+ 'pool': nn.MaxPool2d(3)
+ })
+ self.activations = nn.ModuleDict([
+ ['lrelu', nn.LeakyReLU()],
+ ['prelu', nn.PReLU()]
+ ])
+
+ def forward(self, x, choice, act):
+ x = self.choices[choice](x)
+ x = self.activations[act](x)
+ return x
+ """
+
+ def __init__(self, modules=None):
+ super(ModuleDict, self).__init__()
+ if modules is not None:
+ self.update(modules)
+
+ def __getitem__(self, key):
+ return self._modules[key]
+
+ def __setitem__(self, key, module):
+ self.add_module(key, module)
+
+ def __delitem__(self, key):
+ del self._modules[key]
+
+ def __len__(self):
+ return len(self._modules)
+
+ def __iter__(self):
+ return iter(self._modules)
+
+ def __contains__(self, key):
+ return key in self._modules
+
+
+
+[docs] def pop(self, key):
+ r"""Remove key from the ModuleDict and return its module.
+
+ Arguments:
+ key (string): key to pop from the ModuleDict
+ """
+ v = self[key]
+ del self[key]
+ return v
+
+[docs] def keys(self):
+ r"""Return an iterable of the ModuleDict keys.
+ """
+ return self._modules.keys()
+
+[docs] def items(self):
+ r"""Return an iterable of the ModuleDict key/value pairs.
+ """
+ return self._modules.items()
+
+[docs] def values(self):
+ r"""Return an iterable of the ModuleDict values.
+ """
+ return self._modules.values()
+
+[docs] def update(self, modules):
+ r"""Update the :class:`~torch.nn.ModuleDict` with the key-value pairs from a
+ mapping or an iterable, overwriting existing keys.
+
+ .. note::
+ If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
+ an iterable of key-value pairs, the order of new elements in it is preserved.
+
+ Arguments:
+ modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
+ or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
+ """
+ if not isinstance(modules, container_abcs.Iterable):
+ raise TypeError("ModuleDict.update should be called with an "
+ "iterable of key/value pairs, but got " +
+ type(modules).__name__)
+
+ if isinstance(modules, container_abcs.Mapping):
+ if isinstance(modules, (OrderedDict, ModuleDict)):
+ for key, module in modules.items():
+ self[key] = module
+ else:
+ for key, module in sorted(modules.items()):
+ self[key] = module
+ else:
+ for j, m in enumerate(modules):
+ if not isinstance(m, container_abcs.Iterable):
+ raise TypeError("ModuleDict update sequence element "
+ "#" + str(j) + " should be Iterable; is" +
+ type(m).__name__)
+ if not len(m) == 2:
+ raise ValueError("ModuleDict update sequence element "
+ "#" + str(j) + " has length " + str(len(m)) +
+ "; 2 is required")
+ self[m[0]] = m[1]
+
+
+[docs]class ParameterList(Module):
+ r"""Holds parameters in a list.
+
+ :class:`~torch.nn.ParameterList` can be indexed like a regular Python
+ list, but parameters it contains are properly registered, and will be
+ visible by all :class:`~torch.nn.Module` methods.
+
+ Arguments:
+ parameters (iterable, optional): an iterable of :class:`~torch.nn.Parameter` to add
+
+ Example::
+
+ class MyModule(nn.Module):
+ def __init__(self):
+ super(MyModule, self).__init__()
+ self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
+
+ def forward(self, x):
+ # ParameterList can act as an iterable, or be indexed using ints
+ for i, p in enumerate(self.params):
+ x = self.params[i // 2].mm(x) + p.mm(x)
+ return x
+ """
+
+ def __init__(self, parameters=None):
+ super(ParameterList, self).__init__()
+ if parameters is not None:
+ self += parameters
+
+ def _get_abs_string_index(self, idx):
+ """Get the absolute index for the list of modules"""
+ idx = operator.index(idx)
+ if not (-len(self) <= idx < len(self)):
+ raise IndexError('index {} is out of range'.format(idx))
+ if idx < 0:
+ idx += len(self)
+ return str(idx)
+
+ def __getitem__(self, idx):
+ if isinstance(idx, slice):
+ return self.__class__(list(self._parameters.values())[idx])
+ else:
+ idx = self._get_abs_string_index(idx)
+ return self._parameters[str(idx)]
+
+ def __setitem__(self, idx, param):
+ idx = self._get_abs_string_index(idx)
+ return self.register_parameter(str(idx), param)
+
+ def __len__(self):
+ return len(self._parameters)
+
+ def __iter__(self):
+ return iter(self._parameters.values())
+
+ def __iadd__(self, parameters):
+ return self.extend(parameters)
+
+ def __dir__(self):
+ keys = super(ParameterList, self).__dir__()
+ keys = [key for key in keys if not key.isdigit()]
+ return keys
+
+[docs] def append(self, parameter):
+ """Appends a given parameter at the end of the list.
+
+ Arguments:
+ parameter (nn.Parameter): parameter to append
+ """
+ self.register_parameter(str(len(self)), parameter)
+ return self
+
+[docs] def extend(self, parameters):
+ """Appends parameters from a Python iterable to the end of the list.
+
+ Arguments:
+ parameters (iterable): iterable of parameters to append
+ """
+ if not isinstance(parameters, container_abcs.Iterable):
+ raise TypeError("ParameterList.extend should be called with an "
+ "iterable, but got " + type(parameters).__name__)
+ offset = len(self)
+ for i, param in enumerate(parameters):
+ self.register_parameter(str(offset + i), param)
+ return self
+
+ def extra_repr(self):
+ child_lines = []
+ for k, p in self._parameters.items():
+ size_str = 'x'.join(str(size) for size in p.size())
+ device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
+ parastr = 'Parameter containing: [{} of size {}{}]'.format(
+ torch.typename(p.data), size_str, device_str)
+ child_lines.append(' (' + str(k) + '): ' + parastr)
+ tmpstr = '\n'.join(child_lines)
+ return tmpstr
+
+
+[docs]class ParameterDict(Module):
+ r"""Holds parameters in a dictionary.
+
+ ParameterDict can be indexed like a regular Python dictionary, but parameters it
+ contains are properly registered, and will be visible by all Module methods.
+
+ :class:`~torch.nn.ParameterDict` is an **ordered** dictionary that respects
+
+ * the order of insertion, and
+
+ * in :meth:`~torch.nn.ParameterDict.update`, the order of the merged ``OrderedDict``
+ or another :class:`~torch.nn.ParameterDict` (the argument to
+ :meth:`~torch.nn.ParameterDict.update`).
+
+ Note that :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
+ types (e.g., Python's plain ``dict``) does not preserve the order of the
+ merged mapping.
+
+ Arguments:
+ parameters (iterable, optional): a mapping (dictionary) of
+ (string : :class:`~torch.nn.Parameter`) or an iterable of key-value pairs
+ of type (string, :class:`~torch.nn.Parameter`)
+
+ Example::
+
+ class MyModule(nn.Module):
+ def __init__(self):
+ super(MyModule, self).__init__()
+ self.params = nn.ParameterDict({
+ 'left': nn.Parameter(torch.randn(5, 10)),
+ 'right': nn.Parameter(torch.randn(5, 10))
+ })
+
+ def forward(self, x, choice):
+ x = self.params[choice].mm(x)
+ return x
+ """
+
+ def __init__(self, parameters=None):
+ super(ParameterDict, self).__init__()
+ if parameters is not None:
+ self.update(parameters)
+
+ def __getitem__(self, key):
+ return self._parameters[key]
+
+ def __setitem__(self, key, parameter):
+ self.register_parameter(key, parameter)
+
+ def __delitem__(self, key):
+ del self._parameters[key]
+
+ def __len__(self):
+ return len(self._parameters)
+
+ def __iter__(self):
+ return iter(self._parameters.keys())
+
+ def __contains__(self, key):
+ return key in self._parameters
+
+[docs] def clear(self):
+ """Remove all items from the ParameterDict.
+ """
+ self._parameters.clear()
+
+[docs] def pop(self, key):
+ r"""Remove key from the ParameterDict and return its parameter.
+
+ Arguments:
+ key (string): key to pop from the ParameterDict
+ """
+ v = self[key]
+ del self[key]
+ return v
+
+[docs] def keys(self):
+ r"""Return an iterable of the ParameterDict keys.
+ """
+ return self._parameters.keys()
+
+[docs] def items(self):
+ r"""Return an iterable of the ParameterDict key/value pairs.
+ """
+ return self._parameters.items()
+
+[docs] def values(self):
+ r"""Return an iterable of the ParameterDict values.
+ """
+ return self._parameters.values()
+
+[docs] def update(self, parameters):
+ r"""Update the :class:`~torch.nn.ParameterDict` with the key-value pairs from a
+ mapping or an iterable, overwriting existing keys.
+
+ .. note::
+ If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
+ an iterable of key-value pairs, the order of new elements in it is preserved.
+
+ Arguments:
+ parameters (iterable): a mapping (dictionary) from string to
+ :class:`~torch.nn.Parameter`, or an iterable of
+ key-value pairs of type (string, :class:`~torch.nn.Parameter`)
+ """
+ if not isinstance(parameters, container_abcs.Iterable):
+ raise TypeError("ParametersDict.update should be called with an "
+ "iterable of key/value pairs, but got " +
+ type(parameters).__name__)
+
+ if isinstance(parameters, container_abcs.Mapping):
+ if isinstance(parameters, (OrderedDict, ParameterDict)):
+ for key, parameter in parameters.items():
+ self[key] = parameter
+ else:
+ for key, parameter in sorted(parameters.items()):
+ self[key] = parameter
+ else:
+ for j, p in enumerate(parameters):
+ if not isinstance(p, container_abcs.Iterable):
+ raise TypeError("ParameterDict update sequence element "
+ "#" + str(j) + " should be Iterable; is" +
+ type(p).__name__)
+ if not len(p) == 2:
+ raise ValueError("ParameterDict update sequence element "
+ "#" + str(j) + " has length " + str(len(p)) +
+ "; 2 is required")
+ self[p[0]] = p[1]
+
+ def extra_repr(self):
+ child_lines = []
+ for k, p in self._parameters.items():
+ size_str = 'x'.join(str(size) for size in p.size())
+ device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
+ parastr = 'Parameter containing: [{} of size {}{}]'.format(
+ torch.typename(p.data), size_str, device_str)
+ child_lines.append(' (' + k + '): ' + parastr)
+ tmpstr = '\n'.join(child_lines)
+ return tmpstr
+
+# coding=utf-8
+import math
+import torch
+from torch.nn.parameter import Parameter
+from .. import functional as F
+from .. import init
+from .module import Module
+from .utils import _single, _pair, _triple
+from ..._jit_internal import weak_module, weak_script_method, List
+
+
+@weak_module
+class _ConvNd(Module):
+
+ __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias', 'padding_mode']
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, transposed, output_padding,
+ groups, bias, padding_mode):
+ super(_ConvNd, self).__init__()
+ if in_channels % groups != 0:
+ raise ValueError('in_channels must be divisible by groups')
+ if out_channels % groups != 0:
+ raise ValueError('out_channels must be divisible by groups')
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.transposed = transposed
+ self.output_padding = output_padding
+ self.groups = groups
+ self.padding_mode = padding_mode
+ if transposed:
+ self.weight = Parameter(torch.Tensor(
+ in_channels, out_channels // groups, *kernel_size))
+ else:
+ self.weight = Parameter(torch.Tensor(
+ out_channels, in_channels // groups, *kernel_size))
+ if bias:
+ self.bias = Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in)
+ init.uniform_(self.bias, -bound, bound)
+
+ def extra_repr(self):
+ s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
+ ', stride={stride}')
+ if self.padding != (0,) * len(self.padding):
+ s += ', padding={padding}'
+ if self.dilation != (1,) * len(self.dilation):
+ s += ', dilation={dilation}'
+ if self.output_padding != (0,) * len(self.output_padding):
+ s += ', output_padding={output_padding}'
+ if self.groups != 1:
+ s += ', groups={groups}'
+ if self.bias is None:
+ s += ', bias=False'
+ return s.format(**self.__dict__)
+
+
+[docs]@weak_module
+class Conv1d(_ConvNd):
+ r"""Applies a 1D convolution over an input signal composed of several input
+ planes.
+
+ In the simplest case, the output value of the layer with input size
+ :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be
+ precisely described as:
+
+ .. math::
+ \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
+ \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k)
+ \star \text{input}(N_i, k)
+
+ where :math:`\star` is the valid `cross-correlation`_ operator,
+ :math:`N` is a batch size, :math:`C` denotes a number of channels,
+ :math:`L` is a length of signal sequence.
+
+ * :attr:`stride` controls the stride for the cross-correlation, a single
+ number or a one-element tuple.
+
+ * :attr:`padding` controls the amount of implicit zero-paddings on both sides
+ for :attr:`padding` number of points.
+
+ * :attr:`dilation` controls the spacing between the kernel points; also
+ known as the à trous algorithm. It is harder to describe, but this `link`_
+ has a nice visualization of what :attr:`dilation` does.
+
+ * :attr:`groups` controls the connections between inputs and outputs.
+ :attr:`in_channels` and :attr:`out_channels` must both be divisible by
+ :attr:`groups`. For example,
+
+ * At groups=1, all inputs are convolved to all outputs.
+ * At groups=2, the operation becomes equivalent to having two conv
+ layers side by side, each seeing half the input channels,
+ and producing half the output channels, and both subsequently
+ concatenated.
+ * At groups= :attr:`in_channels`, each input channel is convolved with
+ its own set of filters,
+ of size
+ :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`.
+
+ .. note::
+
+ Depending of the size of your kernel, several (of the last)
+ columns of the input might be lost, because it is a valid
+ `cross-correlation`_, and not a full `cross-correlation`_.
+ It is up to the user to add proper padding.
+
+ .. note::
+
+ When `groups == in_channels` and `out_channels == K * in_channels`,
+ where `K` is a positive integer, this operation is also termed in
+ literature as depthwise convolution.
+
+ In other words, for an input of size :math:`(N, C_{in}, L_{in})`,
+ a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments
+ :math:`(C_\text{in}=C_{in}, C_\text{out}=C_{in} \times K, ..., \text{groups}=C_{in})`.
+
+ .. include:: cudnn_deterministic.rst
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ padding_mode (string, optional). Accepted values `zeros` and `circular` Default: `zeros`
+ dilation (int or tuple, optional): Spacing between kernel
+ elements. Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C_{in}, L_{in})`
+ - Output: :math:`(N, C_{out}, L_{out})` where
+
+ .. math::
+ L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation}
+ \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
+
+ Attributes:
+ weight (Tensor): the learnable weights of the module of shape
+ :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`.
+ The values of these weights are sampled from
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \text{kernel\_size}}`
+ bias (Tensor): the learnable bias of the module of shape
+ (out_channels). If :attr:`bias` is ``True``, then the values of these weights are
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \text{kernel\_size}}`
+
+ Examples::
+
+ >>> m = nn.Conv1d(16, 33, 3, stride=2)
+ >>> input = torch.randn(20, 16, 50)
+ >>> output = m(input)
+
+ .. _cross-correlation:
+ https://en.wikipedia.org/wiki/Cross-correlation
+
+ .. _link:
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1,
+ bias=True, padding_mode='zeros'):
+ kernel_size = _single(kernel_size)
+ stride = _single(stride)
+ padding = _single(padding)
+ dilation = _single(dilation)
+ super(Conv1d, self).__init__(
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
+ False, _single(0), groups, bias, padding_mode)
+
+ @weak_script_method
+ def forward(self, input):
+ if self.padding_mode == 'circular':
+ expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
+ return F.conv1d(F.pad(input, expanded_padding, mode='circular'),
+ self.weight, self.bias, self.stride,
+ _single(0), self.dilation, self.groups)
+ return F.conv1d(input, self.weight, self.bias, self.stride,
+ self.padding, self.dilation, self.groups)
+
+
+[docs]@weak_module
+class Conv2d(_ConvNd):
+ r"""Applies a 2D convolution over an input signal composed of several input
+ planes.
+
+ In the simplest case, the output value of the layer with input size
+ :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
+ can be precisely described as:
+
+ .. math::
+ \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
+ \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
+
+
+ where :math:`\star` is the valid 2D `cross-correlation`_ operator,
+ :math:`N` is a batch size, :math:`C` denotes a number of channels,
+ :math:`H` is a height of input planes in pixels, and :math:`W` is
+ width in pixels.
+
+ * :attr:`stride` controls the stride for the cross-correlation, a single
+ number or a tuple.
+
+ * :attr:`padding` controls the amount of implicit zero-paddings on both
+ sides for :attr:`padding` number of points for each dimension.
+
+ * :attr:`dilation` controls the spacing between the kernel points; also
+ known as the à trous algorithm. It is harder to describe, but this `link`_
+ has a nice visualization of what :attr:`dilation` does.
+
+ * :attr:`groups` controls the connections between inputs and outputs.
+ :attr:`in_channels` and :attr:`out_channels` must both be divisible by
+ :attr:`groups`. For example,
+
+ * At groups=1, all inputs are convolved to all outputs.
+ * At groups=2, the operation becomes equivalent to having two conv
+ layers side by side, each seeing half the input channels,
+ and producing half the output channels, and both subsequently
+ concatenated.
+ * At groups= :attr:`in_channels`, each input channel is convolved with
+ its own set of filters, of size:
+ :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`.
+
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
+
+ - a single ``int`` -- in which case the same value is used for the height and width dimension
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
+ and the second `int` for the width dimension
+
+ .. note::
+
+ Depending of the size of your kernel, several (of the last)
+ columns of the input might be lost, because it is a valid `cross-correlation`_,
+ and not a full `cross-correlation`_.
+ It is up to the user to add proper padding.
+
+ .. note::
+
+ When `groups == in_channels` and `out_channels == K * in_channels`,
+ where `K` is a positive integer, this operation is also termed in
+ literature as depthwise convolution.
+
+ In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
+ a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments
+ :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
+
+ .. include:: cudnn_deterministic.rst
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
+ padding_mode (string, optional). Accepted values `zeros` and `circular` Default: `zeros`
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
+
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
+ \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
+ \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
+
+ Attributes:
+ weight (Tensor): the learnable weights of the module of shape
+ :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
+ The values of these weights are sampled from
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
+ bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
+ then the values of these weights are
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
+
+ Examples::
+
+ >>> # With square kernels and equal stride
+ >>> m = nn.Conv2d(16, 33, 3, stride=2)
+ >>> # non-square kernels and unequal stride and with padding
+ >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
+ >>> # non-square kernels and unequal stride and with padding and dilation
+ >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
+ >>> input = torch.randn(20, 16, 50, 100)
+ >>> output = m(input)
+
+ .. _cross-correlation:
+ https://en.wikipedia.org/wiki/Cross-correlation
+
+ .. _link:
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+ """
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1,
+ bias=True, padding_mode='zeros'):
+ kernel_size = _pair(kernel_size)
+ stride = _pair(stride)
+ padding = _pair(padding)
+ dilation = _pair(dilation)
+ super(Conv2d, self).__init__(
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
+ False, _pair(0), groups, bias, padding_mode)
+
+ @weak_script_method
+ def forward(self, input):
+ if self.padding_mode == 'circular':
+ expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
+ (self.padding[0] + 1) // 2, self.padding[0] // 2)
+ return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
+ self.weight, self.bias, self.stride,
+ _pair(0), self.dilation, self.groups)
+ return F.conv2d(input, self.weight, self.bias, self.stride,
+ self.padding, self.dilation, self.groups)
+
+
+[docs]@weak_module
+class Conv3d(_ConvNd):
+ r"""Applies a 3D convolution over an input signal composed of several input
+ planes.
+
+ In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)`
+ and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as:
+
+ .. math::
+ out(N_i, C_{out_j}) = bias(C_{out_j}) +
+ \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k)
+
+ where :math:`\star` is the valid 3D `cross-correlation`_ operator
+
+ * :attr:`stride` controls the stride for the cross-correlation.
+
+ * :attr:`padding` controls the amount of implicit zero-paddings on both
+ sides for :attr:`padding` number of points for each dimension.
+
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
+
+ * :attr:`groups` controls the connections between inputs and outputs.
+ :attr:`in_channels` and :attr:`out_channels` must both be divisible by
+ :attr:`groups`. For example,
+
+ * At groups=1, all inputs are convolved to all outputs.
+ * At groups=2, the operation becomes equivalent to having two conv
+ layers side by side, each seeing half the input channels,
+ and producing half the output channels, and both subsequently
+ concatenated.
+ * At groups= :attr:`in_channels`, each input channel is convolved with
+ its own set of filters, of size
+ :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`.
+
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
+
+ - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
+ the second `int` for the height dimension and the third `int` for the width dimension
+
+ .. note::
+
+ Depending of the size of your kernel, several (of the last)
+ columns of the input might be lost, because it is a valid `cross-correlation`_,
+ and not a full `cross-correlation`_.
+ It is up to the user to add proper padding.
+
+ .. note::
+
+ When `groups == in_channels` and `out_channels == K * in_channels`,
+ where `K` is a positive integer, this operation is also termed in
+ literature as depthwise convolution.
+
+ In other words, for an input of size :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`,
+ a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments
+ :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
+
+ .. include:: cudnn_deterministic.rst
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
+ padding_mode (string, optional). Accepted values `zeros` and `circular` Default: `zeros`
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
+
+ .. math::
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
+ \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
+
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
+ \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2]
+ \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
+
+ Attributes:
+ weight (Tensor): the learnable weights of the module of shape
+ :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
+ The values of these weights are sampled from
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
+ bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
+ then the values of these weights are
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
+
+ Examples::
+
+ >>> # With square kernels and equal stride
+ >>> m = nn.Conv3d(16, 33, 3, stride=2)
+ >>> # non-square kernels and unequal stride and with padding
+ >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))
+ >>> input = torch.randn(20, 16, 10, 50, 100)
+ >>> output = m(input)
+
+ .. _cross-correlation:
+ https://en.wikipedia.org/wiki/Cross-correlation
+
+ .. _link:
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+ """
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1,
+ bias=True, padding_mode='zeros'):
+ kernel_size = _triple(kernel_size)
+ stride = _triple(stride)
+ padding = _triple(padding)
+ dilation = _triple(dilation)
+ super(Conv3d, self).__init__(
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
+ False, _triple(0), groups, bias, padding_mode)
+
+ @weak_script_method
+ def forward(self, input):
+ if self.padding_mode == 'circular':
+ expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2,
+ (self.padding[1] + 1) // 2, self.padding[1] // 2,
+ (self.padding[0] + 1) // 2, self.padding[0] // 2)
+ return F.conv3d(F.pad(input, expanded_padding, mode='circular'),
+ self.weight, self.bias, self.stride, _triple(0),
+ self.dilation, self.groups)
+ return F.conv3d(input, self.weight, self.bias, self.stride,
+ self.padding, self.dilation, self.groups)
+
+
+@weak_module
+class _ConvTransposeMixin(object):
+ __constants__ = ['stride', 'padding', 'kernel_size', 'dim_size',
+ 'output_padding', 'groups', 'dilation', 'transposed',
+ 'bias', 'padding_mode']
+
+ @weak_script_method
+ def forward(self, input, output_size=None):
+ # type(Tensor, Optional[List[int]]) -> Tensor
+ output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
+ func = self._backend.ConvNd(
+ self.stride, self.padding, self.dilation, self.transposed,
+ output_padding, self.groups)
+ if self.bias is None:
+ return func(input, self.weight)
+ else:
+ return func(input, self.weight, self.bias)
+
+ @weak_script_method
+ def _output_padding(self, input, output_size, stride, padding, kernel_size):
+ # type: (Tensor, Optional[List[int]], List[int], List[int], List[int]) -> List[int]
+ if output_size is None:
+ ret = _single(self.output_padding) # converting to list if was not already
+ else:
+ k = input.dim() - 2
+ if len(output_size) == k + 2:
+ output_size = output_size[2:]
+ if len(output_size) != k:
+ raise ValueError(
+ "output_size must have {} or {} elements (got {})"
+ .format(k, k + 2, len(output_size)))
+
+ min_sizes = torch.jit.annotate(List[int], [])
+ max_sizes = torch.jit.annotate(List[int], [])
+ for d in range(k):
+ dim_size = ((input.size(d + 2) - 1) * stride[d] -
+ 2 * padding[d] + kernel_size[d])
+ min_sizes.append(dim_size)
+ max_sizes.append(min_sizes[d] + stride[d] - 1)
+
+ for i in range(len(output_size)):
+ size = output_size[i]
+ min_size = min_sizes[i]
+ max_size = max_sizes[i]
+ if size < min_size or size > max_size:
+ raise ValueError((
+ "requested an output size of {}, but valid sizes range "
+ "from {} to {} (for an input of {})").format(
+ output_size, min_sizes, max_sizes, input.size()[2:]))
+
+ res = torch.jit.annotate(List[int], [])
+ for d in range(k):
+ res.append(output_size[d] - min_sizes[d])
+
+ ret = res
+ return ret
+
+
+[docs]@weak_module
+class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
+ r"""Applies a 1D transposed convolution operator over an input image
+ composed of several input planes.
+
+ This module can be seen as the gradient of Conv1d with respect to its input.
+ It is also known as a fractionally-strided convolution or
+ a deconvolution (although it is not an actual deconvolution operation).
+
+ * :attr:`stride` controls the stride for the cross-correlation.
+
+ * :attr:`padding` controls the amount of implicit zero-paddings on both
+ sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
+ below for details.
+
+ * :attr:`output_padding` controls the additional size added to one side
+ of the output shape. See note below for details.
+
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
+
+ * :attr:`groups` controls the connections between inputs and outputs.
+ :attr:`in_channels` and :attr:`out_channels` must both be divisible by
+ :attr:`groups`. For example,
+
+ * At groups=1, all inputs are convolved to all outputs.
+ * At groups=2, the operation becomes equivalent to having two conv
+ layers side by side, each seeing half the input channels,
+ and producing half the output channels, and both subsequently
+ concatenated.
+ * At groups= :attr:`in_channels`, each input channel is convolved with
+ its own set of filters (of size
+ :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`).
+
+ .. note::
+
+ Depending of the size of your kernel, several (of the last)
+ columns of the input might be lost, because it is a valid `cross-correlation`_,
+ and not a full `cross-correlation`_.
+ It is up to the user to add proper padding.
+
+ .. note::
+ The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
+ amount of zero padding to both sizes of the input. This is set so that
+ when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d`
+ are initialized with same parameters, they are inverses of each other in
+ regard to the input and output shapes. However, when ``stride > 1``,
+ :class:`~torch.nn.Conv1d` maps multiple input shapes to the same output
+ shape. :attr:`output_padding` is provided to resolve this ambiguity by
+ effectively increasing the calculated output shape on one side. Note
+ that :attr:`output_padding` is only used to find output shape, but does
+ not actually add zero-padding to output.
+
+ .. include:: cudnn_deterministic.rst
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
+ will be added to both sides of the input. Default: 0
+ output_padding (int or tuple, optional): Additional size added to one side
+ of the output shape. Default: 0
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
+
+ Shape:
+ - Input: :math:`(N, C_{in}, L_{in})`
+ - Output: :math:`(N, C_{out}, L_{out})` where
+
+ .. math::
+ L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation}
+ \times (\text{kernel\_size} - 1) + \text{output\_padding} + 1
+
+ Attributes:
+ weight (Tensor): the learnable weights of the module of shape
+ :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
+ :math:`\text{kernel\_size})`.
+ The values of these weights are sampled from
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \text{kernel\_size}}`
+ bias (Tensor): the learnable bias of the module of shape (out_channels).
+ If :attr:`bias` is ``True``, then the values of these weights are
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \text{kernel\_size}}`
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, output_padding=0, groups=1, bias=True,
+ dilation=1, padding_mode='zeros'):
+ kernel_size = _single(kernel_size)
+ stride = _single(stride)
+ padding = _single(padding)
+ dilation = _single(dilation)
+ output_padding = _single(output_padding)
+ super(ConvTranspose1d, self).__init__(
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
+ True, output_padding, groups, bias, padding_mode)
+
+ @weak_script_method
+ def forward(self, input, output_size=None):
+ # type: (Tensor, Optional[List[int]]) -> Tensor
+ if self.padding_mode != 'zeros':
+ raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
+
+ output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
+ return F.conv_transpose1d(
+ input, self.weight, self.bias, self.stride, self.padding,
+ output_padding, self.groups, self.dilation)
+
+
+[docs]@weak_module
+class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
+ r"""Applies a 2D transposed convolution operator over an input image
+ composed of several input planes.
+
+ This module can be seen as the gradient of Conv2d with respect to its input.
+ It is also known as a fractionally-strided convolution or
+ a deconvolution (although it is not an actual deconvolution operation).
+
+ * :attr:`stride` controls the stride for the cross-correlation.
+
+ * :attr:`padding` controls the amount of implicit zero-paddings on both
+ sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
+ below for details.
+
+ * :attr:`output_padding` controls the additional size added to one side
+ of the output shape. See note below for details.
+
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
+
+ * :attr:`groups` controls the connections between inputs and outputs.
+ :attr:`in_channels` and :attr:`out_channels` must both be divisible by
+ :attr:`groups`. For example,
+
+ * At groups=1, all inputs are convolved to all outputs.
+ * At groups=2, the operation becomes equivalent to having two conv
+ layers side by side, each seeing half the input channels,
+ and producing half the output channels, and both subsequently
+ concatenated.
+ * At groups= :attr:`in_channels`, each input channel is convolved with
+ its own set of filters (of size
+ :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`).
+
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
+ can either be:
+
+ - a single ``int`` -- in which case the same value is used for the height and width dimensions
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
+ and the second `int` for the width dimension
+
+ .. note::
+
+ Depending of the size of your kernel, several (of the last)
+ columns of the input might be lost, because it is a valid `cross-correlation`_,
+ and not a full `cross-correlation`_.
+ It is up to the user to add proper padding.
+
+ .. note::
+ The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
+ amount of zero padding to both sizes of the input. This is set so that
+ when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d`
+ are initialized with same parameters, they are inverses of each other in
+ regard to the input and output shapes. However, when ``stride > 1``,
+ :class:`~torch.nn.Conv2d` maps multiple input shapes to the same output
+ shape. :attr:`output_padding` is provided to resolve this ambiguity by
+ effectively increasing the calculated output shape on one side. Note
+ that :attr:`output_padding` is only used to find output shape, but does
+ not actually add zero-padding to output.
+
+ .. include:: cudnn_deterministic.rst
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
+ will be added to both sides of each dimension in the input. Default: 0
+ output_padding (int or tuple, optional): Additional size added to one side
+ of each dimension in the output shape. Default: 0
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
+
+ Shape:
+ - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
+
+ .. math::
+ H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
+ \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
+ .. math::
+ W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
+ \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
+
+ Attributes:
+ weight (Tensor): the learnable weights of the module of shape
+ :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
+ The values of these weights are sampled from
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
+ bias (Tensor): the learnable bias of the module of shape (out_channels)
+ If :attr:`bias` is ``True``, then the values of these weights are
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
+
+ Examples::
+
+ >>> # With square kernels and equal stride
+ >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2)
+ >>> # non-square kernels and unequal stride and with padding
+ >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
+ >>> input = torch.randn(20, 16, 50, 100)
+ >>> output = m(input)
+ >>> # exact output size can be also specified as an argument
+ >>> input = torch.randn(1, 16, 12, 12)
+ >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
+ >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
+ >>> h = downsample(input)
+ >>> h.size()
+ torch.Size([1, 16, 6, 6])
+ >>> output = upsample(h, output_size=input.size())
+ >>> output.size()
+ torch.Size([1, 16, 12, 12])
+
+ .. _cross-correlation:
+ https://en.wikipedia.org/wiki/Cross-correlation
+
+ .. _link:
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, output_padding=0, groups=1, bias=True,
+ dilation=1, padding_mode='zeros'):
+ kernel_size = _pair(kernel_size)
+ stride = _pair(stride)
+ padding = _pair(padding)
+ dilation = _pair(dilation)
+ output_padding = _pair(output_padding)
+ super(ConvTranspose2d, self).__init__(
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
+ True, output_padding, groups, bias, padding_mode)
+
+ @weak_script_method
+ def forward(self, input, output_size=None):
+ # type: (Tensor, Optional[List[int]]) -> Tensor
+ if self.padding_mode != 'zeros':
+ raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
+
+ output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
+
+ return F.conv_transpose2d(
+ input, self.weight, self.bias, self.stride, self.padding,
+ output_padding, self.groups, self.dilation)
+
+
+[docs]@weak_module
+class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
+ r"""Applies a 3D transposed convolution operator over an input image composed of several input
+ planes.
+ The transposed convolution operator multiplies each input value element-wise by a learnable kernel,
+ and sums over the outputs from all input feature planes.
+
+ This module can be seen as the gradient of Conv3d with respect to its input.
+ It is also known as a fractionally-strided convolution or
+ a deconvolution (although it is not an actual deconvolution operation).
+
+ * :attr:`stride` controls the stride for the cross-correlation.
+
+ * :attr:`padding` controls the amount of implicit zero-paddings on both
+ sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note
+ below for details.
+
+ * :attr:`output_padding` controls the additional size added to one side
+ of the output shape. See note below for details.
+
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
+
+ * :attr:`groups` controls the connections between inputs and outputs.
+ :attr:`in_channels` and :attr:`out_channels` must both be divisible by
+ :attr:`groups`. For example,
+
+ * At groups=1, all inputs are convolved to all outputs.
+ * At groups=2, the operation becomes equivalent to having two conv
+ layers side by side, each seeing half the input channels,
+ and producing half the output channels, and both subsequently
+ concatenated.
+ * At groups= :attr:`in_channels`, each input channel is convolved with
+ its own set of filters (of size
+ :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`).
+
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding`
+ can either be:
+
+ - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
+ the second `int` for the height dimension and the third `int` for the width dimension
+
+ .. note::
+
+ Depending of the size of your kernel, several (of the last)
+ columns of the input might be lost, because it is a valid `cross-correlation`_,
+ and not a full `cross-correlation`_.
+ It is up to the user to add proper padding.
+
+ .. note::
+ The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding``
+ amount of zero padding to both sizes of the input. This is set so that
+ when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d`
+ are initialized with same parameters, they are inverses of each other in
+ regard to the input and output shapes. However, when ``stride > 1``,
+ :class:`~torch.nn.Conv3d` maps multiple input shapes to the same output
+ shape. :attr:`output_padding` is provided to resolve this ambiguity by
+ effectively increasing the calculated output shape on one side. Note
+ that :attr:`output_padding` is only used to find output shape, but does
+ not actually add zero-padding to output.
+
+ .. include:: cudnn_deterministic.rst
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
+ will be added to both sides of each dimension in the input. Default: 0
+ output_padding (int or tuple, optional): Additional size added to one side
+ of each dimension in the output shape. Default: 0
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
+ dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
+
+ Shape:
+ - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
+
+ .. math::
+ D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
+ \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
+ .. math::
+ H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
+ \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
+ .. math::
+ W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2]
+ \times (\text{kernel\_size}[2] - 1) + \text{output\_padding}[2] + 1
+
+
+ Attributes:
+ weight (Tensor): the learnable weights of the module of shape
+ :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},`
+ :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
+ The values of these weights are sampled from
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
+ bias (Tensor): the learnable bias of the module of shape (out_channels)
+ If :attr:`bias` is ``True``, then the values of these weights are
+ sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
+
+ Examples::
+
+ >>> # With square kernels and equal stride
+ >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2)
+ >>> # non-square kernels and unequal stride and with padding
+ >>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2))
+ >>> input = torch.randn(20, 16, 10, 50, 100)
+ >>> output = m(input)
+
+ .. _cross-correlation:
+ https://en.wikipedia.org/wiki/Cross-correlation
+
+ .. _link:
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, output_padding=0, groups=1, bias=True,
+ dilation=1, padding_mode='zeros'):
+ kernel_size = _triple(kernel_size)
+ stride = _triple(stride)
+ padding = _triple(padding)
+ dilation = _triple(dilation)
+ output_padding = _triple(output_padding)
+ super(ConvTranspose3d, self).__init__(
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
+ True, output_padding, groups, bias, padding_mode)
+
+ @weak_script_method
+ def forward(self, input, output_size=None):
+ # type: (Tensor, Optional[List[int]]) -> Tensor
+ if self.padding_mode != 'zeros':
+ raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d')
+
+ output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
+
+ return F.conv_transpose3d(
+ input, self.weight, self.bias, self.stride, self.padding,
+ output_padding, self.groups, self.dilation)
+
+
+# TODO: Conv2dLocal
+# TODO: Conv2dMap
+# TODO: ConvTranspose2dMap
+
+from .module import Module
+from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
+
+
+[docs]@weak_module
+class PairwiseDistance(Module):
+ r"""
+ Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
+
+ .. math ::
+ \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}.
+
+ Args:
+ p (real): the norm degree. Default: 2
+ eps (float, optional): Small value to avoid division by zero.
+ Default: 1e-6
+ keepdim (bool, optional): Determines whether or not to keep the vector dimension.
+ Default: False
+ Shape:
+ - Input1: :math:`(N, D)` where `D = vector dimension`
+ - Input2: :math:`(N, D)`, same shape as the Input1
+ - Output: :math:`(N)`. If :attr:`keepdim` is ``True``, then :math:`(N, 1)`.
+ Examples::
+ >>> pdist = nn.PairwiseDistance(p=2)
+ >>> input1 = torch.randn(100, 128)
+ >>> input2 = torch.randn(100, 128)
+ >>> output = pdist(input1, input2)
+ """
+ __constants__ = ['norm', 'eps', 'keepdim']
+
+ def __init__(self, p=2., eps=1e-6, keepdim=False):
+ super(PairwiseDistance, self).__init__()
+ self.norm = p
+ self.eps = eps
+ self.keepdim = keepdim
+
+ @weak_script_method
+ def forward(self, x1, x2):
+ return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim)
+
+
+[docs]@weak_module
+class CosineSimilarity(Module):
+ r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim.
+
+ .. math ::
+ \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}.
+
+ Args:
+ dim (int, optional): Dimension where cosine similarity is computed. Default: 1
+ eps (float, optional): Small value to avoid division by zero.
+ Default: 1e-8
+ Shape:
+ - Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`
+ - Input2: :math:`(\ast_1, D, \ast_2)`, same shape as the Input1
+ - Output: :math:`(\ast_1, \ast_2)`
+ Examples::
+ >>> input1 = torch.randn(100, 128)
+ >>> input2 = torch.randn(100, 128)
+ >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6)
+ >>> output = cos(input1, input2)
+ """
+ __constants__ = ['dim', 'eps']
+
+ def __init__(self, dim=1, eps=1e-8):
+ super(CosineSimilarity, self).__init__()
+ self.dim = dim
+ self.eps = eps
+
+ @weak_script_method
+ def forward(self, x1, x2):
+ return F.cosine_similarity(x1, x2, self.dim, self.eps)
+
+from .module import Module
+from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
+
+
+class _DropoutNd(Module):
+ __constants__ = ['p', 'inplace']
+
+ def __init__(self, p=0.5, inplace=False):
+ super(_DropoutNd, self).__init__()
+ if p < 0 or p > 1:
+ raise ValueError("dropout probability has to be between 0 and 1, "
+ "but got {}".format(p))
+ self.p = p
+ self.inplace = inplace
+
+ def extra_repr(self):
+ inplace_str = ', inplace' if self.inplace else ''
+ return 'p={}{}'.format(self.p, inplace_str)
+
+
+[docs]@weak_module
+class Dropout(_DropoutNd):
+ r"""During training, randomly zeroes some of the elements of the input
+ tensor with probability :attr:`p` using samples from a Bernoulli
+ distribution. Each channel will be zeroed out independently on every forward
+ call.
+
+ This has proven to be an effective technique for regularization and
+ preventing the co-adaptation of neurons as described in the paper
+ `Improving neural networks by preventing co-adaptation of feature
+ detectors`_ .
+
+ Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
+ training. This means that during evaluation the module simply computes an
+ identity function.
+
+ Args:
+ p: probability of an element to be zeroed. Default: 0.5
+ inplace: If set to ``True``, will do this operation in-place. Default: ``False``
+
+ Shape:
+ - Input: :math:`(*)`. Input can be of any shape
+ - Output: :math:`(*)`. Output is of the same shape as input
+
+ Examples::
+
+ >>> m = nn.Dropout(p=0.2)
+ >>> input = torch.randn(20, 16)
+ >>> output = m(input)
+
+ .. _Improving neural networks by preventing co-adaptation of feature
+ detectors: https://arxiv.org/abs/1207.0580
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.dropout(input, self.p, self.training, self.inplace)
+
+
+[docs]@weak_module
+class Dropout2d(_DropoutNd):
+ r"""Randomly zero out entire channels (a channel is a 2D feature map,
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
+ batched input is a 2D tensor :math:`\text{input}[i, j]`).
+ Each channel will be zeroed out independently on every forward call with
+ probability :attr:`p` using samples from a Bernoulli distribution.
+
+ Usually the input comes from :class:`nn.Conv2d` modules.
+
+ As described in the paper
+ `Efficient Object Localization Using Convolutional Networks`_ ,
+ if adjacent pixels within feature maps are strongly correlated
+ (as is normally the case in early convolution layers) then i.i.d. dropout
+ will not regularize the activations and will otherwise just result
+ in an effective learning rate decrease.
+
+ In this case, :func:`nn.Dropout2d` will help promote independence between
+ feature maps and should be used instead.
+
+ Args:
+ p (float, optional): probability of an element to be zero-ed.
+ inplace (bool, optional): If set to ``True``, will do this operation
+ in-place
+
+ Shape:
+ - Input: :math:`(N, C, H, W)`
+ - Output: :math:`(N, C, H, W)` (same shape as input)
+
+ Examples::
+
+ >>> m = nn.Dropout2d(p=0.2)
+ >>> input = torch.randn(20, 16, 32, 32)
+ >>> output = m(input)
+
+ .. _Efficient Object Localization Using Convolutional Networks:
+ http://arxiv.org/abs/1411.4280
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.dropout2d(input, self.p, self.training, self.inplace)
+
+
+[docs]@weak_module
+class Dropout3d(_DropoutNd):
+ r"""Randomly zero out entire channels (a channel is a 3D feature map,
+ e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
+ batched input is a 3D tensor :math:`\text{input}[i, j]`).
+ Each channel will be zeroed out independently on every forward call with
+ probability :attr:`p` using samples from a Bernoulli distribution.
+
+ Usually the input comes from :class:`nn.Conv3d` modules.
+
+ As described in the paper
+ `Efficient Object Localization Using Convolutional Networks`_ ,
+ if adjacent pixels within feature maps are strongly correlated
+ (as is normally the case in early convolution layers) then i.i.d. dropout
+ will not regularize the activations and will otherwise just result
+ in an effective learning rate decrease.
+
+ In this case, :func:`nn.Dropout3d` will help promote independence between
+ feature maps and should be used instead.
+
+ Args:
+ p (float, optional): probability of an element to be zeroed.
+ inplace (bool, optional): If set to ``True``, will do this operation
+ in-place
+
+ Shape:
+ - Input: :math:`(N, C, D, H, W)`
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
+
+ Examples::
+
+ >>> m = nn.Dropout3d(p=0.2)
+ >>> input = torch.randn(20, 16, 4, 32, 32)
+ >>> output = m(input)
+
+ .. _Efficient Object Localization Using Convolutional Networks:
+ http://arxiv.org/abs/1411.4280
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.dropout3d(input, self.p, self.training, self.inplace)
+
+
+[docs]@weak_module
+class AlphaDropout(_DropoutNd):
+ r"""Applies Alpha Dropout over the input.
+
+ Alpha Dropout is a type of Dropout that maintains the self-normalizing
+ property.
+ For an input with zero mean and unit standard deviation, the output of
+ Alpha Dropout maintains the original mean and standard deviation of the
+ input.
+ Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
+ that the outputs have zero mean and unit standard deviation.
+
+ During training, it randomly masks some of the elements of the input
+ tensor with probability *p* using samples from a bernoulli distribution.
+ The elements to masked are randomized on every forward call, and scaled
+ and shifted to maintain zero mean and unit standard deviation.
+
+ During evaluation the module simply computes an identity function.
+
+ More details can be found in the paper `Self-Normalizing Neural Networks`_ .
+
+ Args:
+ p (float): probability of an element to be dropped. Default: 0.5
+ inplace (bool, optional): If set to ``True``, will do this operation
+ in-place
+
+ Shape:
+ - Input: :math:`(*)`. Input can be of any shape
+ - Output: :math:`(*)`. Output is of the same shape as input
+
+ Examples::
+
+ >>> m = nn.AlphaDropout(p=0.2)
+ >>> input = torch.randn(20, 16)
+ >>> output = m(input)
+
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.alpha_dropout(input, self.p, self.training)
+
+
+@weak_module
+class FeatureAlphaDropout(_DropoutNd):
+
+ @weak_script_method
+ def forward(self, input):
+ return F.feature_alpha_dropout(input, self.p, self.training)
+
+# coding=utf-8
+from .module import Module
+from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
+
+
+[docs]@weak_module
+class Fold(Module):
+ r"""Combines an array of sliding local blocks into a large containing
+ tensor.
+
+ Consider a batched :attr:`input` tensor containing sliding local blocks,
+ e.g., patches of images, of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`,
+ where :math:`N` is batch dimension, :math:`C \times \prod(\text{kernel\_size})`
+ is the number of values within a block (a block has :math:`\prod(\text{kernel\_size})`
+ spatial locations each containing a :math:`C`-channeled vector), and
+ :math:`L` is the total number of blocks. (This is exactly the
+ same specification as the output shape of :class:`~torch.nn.Unfold`.) This
+ operation combines these local blocks into the large :attr:`output` tensor
+ of shape :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)`
+ by summing the overlapping values. Similar to :class:`~torch.nn.Unfold`, the
+ arguments must satisfy
+
+ .. math::
+ L = \prod_d \left\lfloor\frac{\text{output\_size}[d] + 2 \times \text{padding}[d] %
+ - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,
+
+ where :math:`d` is over all spatial dimensions.
+
+ * :attr:`output_size` describes the spatial shape of the large containing
+ tensor of the sliding local blocks. It is useful to resolve the ambiguity
+ when multiple input shapes map to same number of sliding blocks, e.g.,
+ with ``stride > 0``.
+
+ The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify
+ how the sliding blocks are retrieved.
+
+ * :attr:`stride` controls the stride for the sliding blocks.
+
+ * :attr:`padding` controls the amount of implicit zero-paddings on both
+ sides for :attr:`padding` number of points for each dimension before
+ reshaping.
+
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
+
+ Args:
+ output_size (int or tuple): the shape of the spatial dimensions of the
+ output (i.e., ``output.sizes()[2:]``)
+ kernel_size (int or tuple): the size of the sliding blocks
+ stride (int or tuple): the stride of the sliding blocks in the input
+ spatial dimensions. Default: 1
+ padding (int or tuple, optional): implicit zero padding to be added on
+ both sides of input. Default: 0
+ dilation (int or tuple, optional): a parameter that controls the
+ stride of elements within the
+ neighborhood. Default: 1
+
+ * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`,
+ :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then
+ their values will be replicated across all spatial dimensions.
+
+ * For the case of two output spatial dimensions this operation is sometimes
+ called ``col2im``.
+
+ .. note::
+ :class:`~torch.nn.Fold` calculates each combined value in the resulting
+ large tensor by summing all values from all containing blocks.
+ :class:`~torch.nn.Unfold` extracts the values in the local blocks by
+ copying from the large tensor. So, if the blocks overlap, they are not
+ inverses of each other.
+
+ .. warning::
+ Currently, only 4-D output tensors (batched image-like tensors) are
+ supported.
+
+ Shape:
+ - Input: :math:`(N, C \times \prod(\text{kernel\_size}), L)`
+ - Output: :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)` as described above
+
+ Examples::
+
+ >>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2))
+ >>> input = torch.randn(1, 3 * 2 * 2, 12)
+ >>> output = fold(input)
+ >>> output.size()
+ torch.Size([1, 3, 4, 5])
+
+ .. _link:
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+
+ """
+ __constants__ = ['output_size', 'kernel_size', 'dilation', 'padding',
+ 'stride']
+
+ def __init__(self, output_size, kernel_size, dilation=1, padding=0, stride=1):
+ super(Fold, self).__init__()
+ self.output_size = output_size
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+ self.padding = padding
+ self.stride = stride
+
+ @weak_script_method
+ def forward(self, input):
+ return F.fold(input, self.output_size, self.kernel_size, self.dilation,
+ self.padding, self.stride)
+
+ def extra_repr(self):
+ return 'output_size={output_size}, kernel_size={kernel_size}, ' \
+ 'dilation={dilation}, padding={padding}, stride={stride}'.format(
+ **self.__dict__
+ )
+
+
+[docs]@weak_module
+class Unfold(Module):
+ r"""Extracts sliding local blocks from a batched input tensor.
+
+ Consider an batched :attr:`input` tensor of shape :math:`(N, C, *)`,
+ where :math:`N` is the batch dimension, :math:`C` is the channel dimension,
+ and :math:`*` represent arbitrary spatial dimensions. This operation flattens
+ each sliding :attr:`kernel_size`-sized block within the spatial dimensions
+ of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output`
+ tensor of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, where
+ :math:`C \times \prod(\text{kernel\_size})` is the total number of values
+ within each block (a block has :math:`\prod(\text{kernel\_size})` spatial
+ locations each containing a :math:`C`-channeled vector), and :math:`L` is
+ the total number of such blocks:
+
+ .. math::
+ L = \prod_d \left\lfloor\frac{\text{spatial\_size}[d] + 2 \times \text{padding}[d] %
+ - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,
+
+ where :math:`\text{spatial\_size}` is formed by the spatial dimensions
+ of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial
+ dimensions.
+
+ Therefore, indexing :attr:`output` at the last dimension (column dimension)
+ gives all values within a certain block.
+
+ The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify
+ how the sliding blocks are retrieved.
+
+ * :attr:`stride` controls the stride for the sliding blocks.
+
+ * :attr:`padding` controls the amount of implicit zero-paddings on both
+ sides for :attr:`padding` number of points for each dimension before
+ reshaping.
+
+ * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
+
+ Args:
+ kernel_size (int or tuple): the size of the sliding blocks
+ stride (int or tuple, optional): the stride of the sliding blocks in the input
+ spatial dimensions. Default: 1
+ padding (int or tuple, optional): implicit zero padding to be added on
+ both sides of input. Default: 0
+ dilation (int or tuple, optional): a parameter that controls the
+ stride of elements within the
+ neighborhood. Default: 1
+
+ * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or
+ :attr:`stride` is an int or a tuple of length 1, their values will be
+ replicated across all spatial dimensions.
+
+ * For the case of two input spatial dimensions this operation is sometimes
+ called ``im2col``.
+
+ .. note::
+ :class:`~torch.nn.Fold` calculates each combined value in the resulting
+ large tensor by summing all values from all containing blocks.
+ :class:`~torch.nn.Unfold` extracts the values in the local blocks by
+ copying from the large tensor. So, if the blocks overlap, they are not
+ inverses of each other.
+
+ .. warning::
+ Currently, only 4-D input tensors (batched image-like tensors) are
+ supported.
+
+ Shape:
+ - Input: :math:`(N, C, *)`
+ - Output: :math:`(N, C \times \prod(\text{kernel\_size}), L)` as described above
+
+ Examples::
+
+ >>> unfold = nn.Unfold(kernel_size=(2, 3))
+ >>> input = torch.randn(2, 5, 3, 4)
+ >>> output = unfold(input)
+ >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels)
+ >>> # 4 blocks (2x3 kernels) in total in the 3x4 input
+ >>> output.size()
+ torch.Size([2, 30, 4])
+
+ >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
+ >>> inp = torch.randn(1, 3, 10, 12)
+ >>> w = torch.randn(2, 3, 4, 5)
+ >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5))
+ >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
+ >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
+ >>> # or equivalently (and avoiding a copy),
+ >>> # out = out_unf.view(1, 2, 7, 8)
+ >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max()
+ tensor(1.9073e-06)
+
+ .. _link:
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+
+ """
+ __constants__ = ['kernel_size', 'dilation', 'padding', 'stride']
+
+ def __init__(self, kernel_size, dilation=1, padding=0, stride=1):
+ super(Unfold, self).__init__()
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+ self.padding = padding
+ self.stride = stride
+
+ @weak_script_method
+ def forward(self, input):
+ return F.unfold(input, self.kernel_size, self.dilation,
+ self.padding, self.stride)
+
+ def extra_repr(self):
+ return 'kernel_size={kernel_size}, dilation={dilation}, padding={padding},' \
+ ' stride={stride}'.format(**self.__dict__)
+
+from .batchnorm import _BatchNorm
+from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
+
+
+class _InstanceNorm(_BatchNorm):
+ __constants__ = ['running_mean', 'running_var', 'weight', 'bias',
+ 'track_running_stats', 'momentum', 'eps']
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False,
+ track_running_stats=False):
+ super(_InstanceNorm, self).__init__(
+ num_features, eps, momentum, affine, track_running_stats)
+
+ @weak_script_method
+ def _check_input_dim(self, input):
+ raise NotImplementedError
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+ # at version 1: removed running_mean and running_var when
+ # track_running_stats=False (default)
+ if version is None and not self.track_running_stats:
+ running_stats_keys = []
+ for name in ('running_mean', 'running_var'):
+ key = prefix + name
+ if key in state_dict:
+ running_stats_keys.append(key)
+ if len(running_stats_keys) > 0:
+ error_msgs.append(
+ 'Unexpected running stats buffer(s) {names} for {klass} '
+ 'with track_running_stats=False. If state_dict is a '
+ 'checkpoint saved before 0.4.0, this may be expected '
+ 'because {klass} does not track running stats by default '
+ 'since 0.4.0. Please remove these keys from state_dict. If '
+ 'the running stats are actually needed, instead set '
+ 'track_running_stats=True in {klass} to enable them. See '
+ 'the documentation of {klass} for details.'
+ .format(names=" and ".join('"{}"'.format(k) for k in running_stats_keys),
+ klass=self.__class__.__name__))
+ for key in running_stats_keys:
+ state_dict.pop(key)
+
+ super(_InstanceNorm, self)._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs)
+
+ @weak_script_method
+ def forward(self, input):
+ self._check_input_dim(input)
+
+ return F.instance_norm(
+ input, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training or not self.track_running_stats, self.momentum, self.eps)
+
+
+[docs]@weak_module
+class InstanceNorm1d(_InstanceNorm):
+ r"""Applies Instance Normalization over a 3D input (a mini-batch of 1D
+ inputs with optional additional channel dimension) as described in the paper
+ `Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
+
+ .. math::
+
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ The mean and standard-deviation are calculated per-dimension separately
+ for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
+ of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
+
+ By default, this layer uses instance statistics computed from input data in
+ both training and evaluation modes.
+
+ If :attr:`track_running_stats` is set to ``True``, during training this
+ layer keeps running estimates of its computed mean and variance, which are
+ then used for normalization during evaluation. The running estimates are
+ kept with a default :attr:`momentum` of 0.1.
+
+ .. note::
+ This :attr:`momentum` argument is different from one used in optimizer
+ classes and the conventional notion of momentum. Mathematically, the
+ update rule for running statistics here is
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+ new observed value.
+
+ .. note::
+ :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but
+ have some subtle differences. :class:`InstanceNorm1d` is applied
+ on each channel of channeled data like multidimensional time series, but
+ :class:`LayerNorm` is usually applied on entire sample and often in NLP
+ tasks. Additionaly, :class:`LayerNorm` applies elementwise affine
+ transform, while :class:`InstanceNorm1d` usually don't apply affine
+ transform.
+
+ Args:
+ num_features: :math:`C` from an expected input of size
+ :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
+ momentum: the value used for the running_mean and running_var computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, this module has
+ learnable affine parameters, initialized the same way as done for batch normalization.
+ Default: ``False``.
+ track_running_stats: a boolean value that when set to ``True``, this
+ module tracks the running mean and variance, and when set to ``False``,
+ this module does not track such statistics and always uses batch
+ statistics in both training and eval modes. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, C, L)`
+ - Output: :math:`(N, C, L)` (same shape as input)
+
+ Examples::
+
+ >>> # Without Learnable Parameters
+ >>> m = nn.InstanceNorm1d(100)
+ >>> # With Learnable Parameters
+ >>> m = nn.InstanceNorm1d(100, affine=True)
+ >>> input = torch.randn(20, 100, 40)
+ >>> output = m(input)
+
+ .. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
+ https://arxiv.org/abs/1607.08022
+ """
+
+ @weak_script_method
+ def _check_input_dim(self, input):
+ if input.dim() == 2:
+ raise ValueError(
+ 'InstanceNorm1d returns 0-filled tensor to 2D tensor.'
+ 'This is because InstanceNorm1d reshapes inputs to'
+ '(1, N * C, ...) from (N, C,...) and this makes'
+ 'variances 0.'
+ )
+ if input.dim() != 3:
+ raise ValueError('expected 3D input (got {}D input)'
+ .format(input.dim()))
+
+
+[docs]@weak_module
+class InstanceNorm2d(_InstanceNorm):
+ r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs
+ with additional channel dimension) as described in the paper
+ `Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
+
+ .. math::
+
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ The mean and standard-deviation are calculated per-dimension separately
+ for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
+ of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
+
+ By default, this layer uses instance statistics computed from input data in
+ both training and evaluation modes.
+
+ If :attr:`track_running_stats` is set to ``True``, during training this
+ layer keeps running estimates of its computed mean and variance, which are
+ then used for normalization during evaluation. The running estimates are
+ kept with a default :attr:`momentum` of 0.1.
+
+ .. note::
+ This :attr:`momentum` argument is different from one used in optimizer
+ classes and the conventional notion of momentum. Mathematically, the
+ update rule for running statistics here is
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+ new observed value.
+
+ .. note::
+ :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but
+ have some subtle differences. :class:`InstanceNorm2d` is applied
+ on each channel of channeled data like RGB images, but
+ :class:`LayerNorm` is usually applied on entire sample and often in NLP
+ tasks. Additionaly, :class:`LayerNorm` applies elementwise affine
+ transform, while :class:`InstanceNorm2d` usually don't apply affine
+ transform.
+
+ Args:
+ num_features: :math:`C` from an expected input of size
+ :math:`(N, C, H, W)`
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
+ momentum: the value used for the running_mean and running_var computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, this module has
+ learnable affine parameters, initialized the same way as done for batch normalization.
+ Default: ``False``.
+ track_running_stats: a boolean value that when set to ``True``, this
+ module tracks the running mean and variance, and when set to ``False``,
+ this module does not track such statistics and always uses batch
+ statistics in both training and eval modes. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, C, H, W)`
+ - Output: :math:`(N, C, H, W)` (same shape as input)
+
+ Examples::
+
+ >>> # Without Learnable Parameters
+ >>> m = nn.InstanceNorm2d(100)
+ >>> # With Learnable Parameters
+ >>> m = nn.InstanceNorm2d(100, affine=True)
+ >>> input = torch.randn(20, 100, 35, 45)
+ >>> output = m(input)
+
+ .. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
+ https://arxiv.org/abs/1607.08022
+ """
+
+ @weak_script_method
+ def _check_input_dim(self, input):
+ if input.dim() != 4:
+ raise ValueError('expected 4D input (got {}D input)'
+ .format(input.dim()))
+
+
+[docs]@weak_module
+class InstanceNorm3d(_InstanceNorm):
+ r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs
+ with additional channel dimension) as described in the paper
+ `Instance Normalization: The Missing Ingredient for Fast Stylization`_ .
+
+ .. math::
+
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ The mean and standard-deviation are calculated per-dimension separately
+ for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
+ of size C (where C is the input size) if :attr:`affine` is ``True``.
+
+ By default, this layer uses instance statistics computed from input data in
+ both training and evaluation modes.
+
+ If :attr:`track_running_stats` is set to ``True``, during training this
+ layer keeps running estimates of its computed mean and variance, which are
+ then used for normalization during evaluation. The running estimates are
+ kept with a default :attr:`momentum` of 0.1.
+
+ .. note::
+ This :attr:`momentum` argument is different from one used in optimizer
+ classes and the conventional notion of momentum. Mathematically, the
+ update rule for running statistics here is
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
+ new observed value.
+
+ .. note::
+ :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but
+ have some subtle differences. :class:`InstanceNorm3d` is applied
+ on each channel of channeled data like 3D models with RGB color, but
+ :class:`LayerNorm` is usually applied on entire sample and often in NLP
+ tasks. Additionaly, :class:`LayerNorm` applies elementwise affine
+ transform, while :class:`InstanceNorm3d` usually don't apply affine
+ transform.
+
+ Args:
+ num_features: :math:`C` from an expected input of size
+ :math:`(N, C, D, H, W)`
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
+ momentum: the value used for the running_mean and running_var computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, this module has
+ learnable affine parameters, initialized the same way as done for batch normalization.
+ Default: ``False``.
+ track_running_stats: a boolean value that when set to ``True``, this
+ module tracks the running mean and variance, and when set to ``False``,
+ this module does not track such statistics and always uses batch
+ statistics in both training and eval modes. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, C, D, H, W)`
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
+
+ Examples::
+
+ >>> # Without Learnable Parameters
+ >>> m = nn.InstanceNorm3d(100)
+ >>> # With Learnable Parameters
+ >>> m = nn.InstanceNorm3d(100, affine=True)
+ >>> input = torch.randn(20, 100, 35, 45, 10)
+ >>> output = m(input)
+
+ .. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
+ https://arxiv.org/abs/1607.08022
+ """
+
+ @weak_script_method
+ def _check_input_dim(self, input):
+ if input.dim() != 5:
+ raise ValueError('expected 5D input (got {}D input)'
+ .format(input.dim()))
+
+import math
+
+import torch
+from torch.nn.parameter import Parameter
+from .. import functional as F
+from .. import init
+from .module import Module
+from ..._jit_internal import weak_module, weak_script_method
+
+
+[docs]@weak_module
+class Identity(Module):
+ r"""A placeholder identity operator that is argument-insensitive.
+
+ Args:
+ args: any argument (unused)
+ kwargs: any keyword argument (unused)
+
+ Examples::
+
+ >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
+ >>> input = torch.randn(128, 20)
+ >>> output = m(input)
+ >>> print(output.size())
+ torch.Size([128, 20])
+
+ """
+ def __init__(self, *args, **kwargs):
+ super(Identity, self).__init__()
+
+ @weak_script_method
+ def forward(self, input):
+ return input
+
+
+[docs]@weak_module
+class Linear(Module):
+ r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
+
+ Args:
+ in_features: size of each input sample
+ out_features: size of each output sample
+ bias: If set to ``False``, the layer will not learn an additive bias.
+ Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
+ additional dimensions and :math:`H_{in} = \text{in\_features}`
+ - Output: :math:`(N, *, H_{out})` where all but the last dimension
+ are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
+
+ Attributes:
+ weight: the learnable weights of the module of shape
+ :math:`(\text{out\_features}, \text{in\_features})`. The values are
+ initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
+ :math:`k = \frac{1}{\text{in\_features}}`
+ bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
+ If :attr:`bias` is ``True``, the values are initialized from
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{\text{in\_features}}`
+
+ Examples::
+
+ >>> m = nn.Linear(20, 30)
+ >>> input = torch.randn(128, 20)
+ >>> output = m(input)
+ >>> print(output.size())
+ torch.Size([128, 30])
+ """
+ __constants__ = ['bias']
+
+ def __init__(self, in_features, out_features, bias=True):
+ super(Linear, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = Parameter(torch.Tensor(out_features, in_features))
+ if bias:
+ self.bias = Parameter(torch.Tensor(out_features))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in)
+ init.uniform_(self.bias, -bound, bound)
+
+ @weak_script_method
+ def forward(self, input):
+ return F.linear(input, self.weight, self.bias)
+
+ def extra_repr(self):
+ return 'in_features={}, out_features={}, bias={}'.format(
+ self.in_features, self.out_features, self.bias is not None
+ )
+
+
+[docs]@weak_module
+class Bilinear(Module):
+ r"""Applies a bilinear transformation to the incoming data:
+ :math:`y = x_1 A x_2 + b`
+
+ Args:
+ in1_features: size of each first input sample
+ in2_features: size of each second input sample
+ out_features: size of each output sample
+ bias: If set to False, the layer will not learn an additive bias.
+ Default: ``True``
+
+ Shape:
+ - Input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` and
+ :math:`*` means any number of additional dimensions. All but the last dimension
+ of the inputs should be the same.
+ - Input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}`.
+ - Output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}`
+ and all but the last dimension are the same shape as the input.
+
+ Attributes:
+ weight: the learnable weights of the module of shape
+ :math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`.
+ The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
+ :math:`k = \frac{1}{\text{in1\_features}}`
+ bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
+ If :attr:`bias` is ``True``, the values are initialized from
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
+ :math:`k = \frac{1}{\text{in1\_features}}`
+
+ Examples::
+
+ >>> m = nn.Bilinear(20, 30, 40)
+ >>> input1 = torch.randn(128, 20)
+ >>> input2 = torch.randn(128, 30)
+ >>> output = m(input1, input2)
+ >>> print(output.size())
+ torch.Size([128, 40])
+ """
+ __constants__ = ['in1_features', 'in2_features', 'out_features', 'bias']
+
+ def __init__(self, in1_features, in2_features, out_features, bias=True):
+ super(Bilinear, self).__init__()
+ self.in1_features = in1_features
+ self.in2_features = in2_features
+ self.out_features = out_features
+ self.weight = Parameter(torch.Tensor(out_features, in1_features, in2_features))
+
+ if bias:
+ self.bias = Parameter(torch.Tensor(out_features))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ bound = 1 / math.sqrt(self.weight.size(1))
+ init.uniform_(self.weight, -bound, bound)
+ if self.bias is not None:
+ init.uniform_(self.bias, -bound, bound)
+
+ @weak_script_method
+ def forward(self, input1, input2):
+ return F.bilinear(input1, input2, self.weight, self.bias)
+
+ def extra_repr(self):
+ return 'in1_features={}, in2_features={}, out_features={}, bias={}'.format(
+ self.in1_features, self.in2_features, self.out_features, self.bias is not None
+ )
+
+# TODO: PartialLinear - maybe in sparse?
+
+import warnings
+
+from .module import Module
+from .. import functional as F
+from .. import _reduction as _Reduction
+from ..._jit_internal import weak_module, weak_script_method
+
+
+class _Loss(Module):
+ def __init__(self, size_average=None, reduce=None, reduction='mean'):
+ super(_Loss, self).__init__()
+ if size_average is not None or reduce is not None:
+ self.reduction = _Reduction.legacy_get_string(size_average, reduce)
+ else:
+ self.reduction = reduction
+
+
+class _WeightedLoss(_Loss):
+ def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
+ super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
+ self.register_buffer('weight', weight)
+
+
+[docs]@weak_module
+class L1Loss(_Loss):
+ r"""Creates a criterion that measures the mean absolute error (MAE) between each element in
+ the input :math:`x` and target :math:`y`.
+
+ The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
+
+ .. math::
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
+ l_n = \left| x_n - y_n \right|,
+
+ where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
+ (default ``'mean'``), then:
+
+ .. math::
+ \ell(x, y) =
+ \begin{cases}
+ \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
+ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
+ \end{cases}
+
+ :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
+ of :math:`n` elements each.
+
+ The sum operation still operates over all the elements, and divides by :math:`n`.
+
+ The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.
+
+ Args:
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(N, *)` where :math:`*` means, any number of additional
+ dimensions
+ - Target: :math:`(N, *)`, same shape as the input
+ - Output: scalar. If :attr:`reduction` is ``'none'``, then
+ :math:`(N, *)`, same shape as the input
+
+ Examples::
+
+ >>> loss = nn.L1Loss()
+ >>> input = torch.randn(3, 5, requires_grad=True)
+ >>> target = torch.randn(3, 5)
+ >>> output = loss(input, target)
+ >>> output.backward()
+ """
+ __constants__ = ['reduction']
+
+ def __init__(self, size_average=None, reduce=None, reduction='mean'):
+ super(L1Loss, self).__init__(size_average, reduce, reduction)
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.l1_loss(input, target, reduction=self.reduction)
+
+
+[docs]@weak_module
+class NLLLoss(_WeightedLoss):
+ r"""The negative log likelihood loss. It is useful to train a classification
+ problem with `C` classes.
+
+ If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning
+ weight to each of the classes. This is particularly useful when you have an
+ unbalanced training set.
+
+ The `input` given through a forward call is expected to contain
+ log-probabilities of each class. `input` has to be a Tensor of size either
+ :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)`
+ with :math:`K \geq 1` for the `K`-dimensional case (described later).
+
+ Obtaining log-probabilities in a neural network is easily achieved by
+ adding a `LogSoftmax` layer in the last layer of your network.
+ You may use `CrossEntropyLoss` instead, if you prefer not to add an extra
+ layer.
+
+ The `target` that this loss expects should be a class index in the range :math:`[0, C-1]`
+ where `C = number of classes`; if `ignore_index` is specified, this loss also accepts
+ this class index (this index may not necessarily be in the class range).
+
+ The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
+
+ .. math::
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
+ l_n = - w_{y_n} x_{n,y_n}, \quad
+ w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\},
+
+ where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
+ (default ``'mean'``), then
+
+ .. math::
+ \ell(x, y) = \begin{cases}
+ \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, &
+ \text{if reduction} = \text{'mean';}\\
+ \sum_{n=1}^N l_n, &
+ \text{if reduction} = \text{'sum'.}
+ \end{cases}
+
+ Can also be used for higher dimension inputs, such as 2D images, by providing
+ an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`,
+ where :math:`K` is the number of dimensions, and a target of appropriate shape
+ (see below). In the case of images, it computes NLL loss per-pixel.
+
+ Args:
+ weight (Tensor, optional): a manual rescaling weight given to each
+ class. If given, it has to be a Tensor of size `C`. Otherwise, it is
+ treated as if having all ones.
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ ignore_index (int, optional): Specifies a target value that is ignored
+ and does not contribute to the input gradient. When
+ :attr:`size_average` is ``True``, the loss is averaged over
+ non-ignored targets.
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(N, C)` where `C = number of classes`, or
+ :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
+ in the case of `K`-dimensional loss.
+ - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
+ :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
+ K-dimensional loss.
+ - Output: scalar.
+ If :attr:`reduction` is ``'none'``, then the same size as the target: :math:`(N)`, or
+ :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case
+ of K-dimensional loss.
+
+ Examples::
+
+ >>> m = nn.LogSoftmax(dim=1)
+ >>> loss = nn.NLLLoss()
+ >>> # input is of size N x C = 3 x 5
+ >>> input = torch.randn(3, 5, requires_grad=True)
+ >>> # each element in target has to have 0 <= value < C
+ >>> target = torch.tensor([1, 0, 4])
+ >>> output = loss(m(input), target)
+ >>> output.backward()
+ >>>
+ >>>
+ >>> # 2D loss example (used, for example, with image inputs)
+ >>> N, C = 5, 4
+ >>> loss = nn.NLLLoss()
+ >>> # input is of size N x C x height x width
+ >>> data = torch.randn(N, 16, 10, 10)
+ >>> conv = nn.Conv2d(16, C, (3, 3))
+ >>> m = nn.LogSoftmax(dim=1)
+ >>> # each element in target has to have 0 <= value < C
+ >>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
+ >>> output = loss(m(conv(data)), target)
+ >>> output.backward()
+ """
+ __constants__ = ['ignore_index', 'weight', 'reduction']
+
+ def __init__(self, weight=None, size_average=None, ignore_index=-100,
+ reduce=None, reduction='mean'):
+ super(NLLLoss, self).__init__(weight, size_average, reduce, reduction)
+ self.ignore_index = ignore_index
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
+
+
+@weak_module
+class NLLLoss2d(NLLLoss):
+ def __init__(self, weight=None, size_average=None, ignore_index=-100,
+ reduce=None, reduction='mean'):
+ warnings.warn("NLLLoss2d has been deprecated. "
+ "Please use NLLLoss instead as a drop-in replacement and see "
+ "https://pytorch.org/docs/master/nn.html#torch.nn.NLLLoss for more details.")
+ super(NLLLoss2d, self).__init__(weight, size_average, ignore_index, reduce, reduction)
+
+
+[docs]@weak_module
+class PoissonNLLLoss(_Loss):
+ r"""Negative log likelihood loss with Poisson distribution of target.
+
+ The loss can be described as:
+
+ .. math::
+ \text{target} \sim \mathrm{Poisson}(\text{input})
+
+ \text{loss}(\text{input}, \text{target}) = \text{input} - \text{target} * \log(\text{input})
+ + \log(\text{target!})
+
+ The last term can be omitted or approximated with Stirling formula. The
+ approximation is used for target values more than 1. For targets less or
+ equal to 1 zeros are added to the loss.
+
+ Args:
+ log_input (bool, optional): if ``True`` the loss is computed as
+ :math:`\exp(\text{input}) - \text{target}*\text{input}`, if ``False`` the loss is
+ :math:`\text{input} - \text{target}*\log(\text{input}+\text{eps})`.
+ full (bool, optional): whether to compute full loss, i. e. to add the
+ Stirling approximation term
+
+ .. math::
+ \text{target}*\log(\text{target}) - \text{target} + 0.5 * \log(2\pi\text{target}).
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when
+ :attr:`log_input = False`. Default: 1e-8
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Examples::
+
+ >>> loss = nn.PoissonNLLLoss()
+ >>> log_input = torch.randn(5, 2, requires_grad=True)
+ >>> target = torch.randn(5, 2)
+ >>> output = loss(log_input, target)
+ >>> output.backward()
+
+ Shape:
+ - Input: :math:`(N, *)` where :math:`*` means, any number of additional
+ dimensions
+ - Target: :math:`(N, *)`, same shape as the input
+ - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`,
+ the same shape as the input
+ """
+ __constants__ = ['log_input', 'full', 'eps', 'reduction']
+
+ def __init__(self, log_input=True, full=False, size_average=None,
+ eps=1e-8, reduce=None, reduction='mean'):
+ super(PoissonNLLLoss, self).__init__(size_average, reduce, reduction)
+ self.log_input = log_input
+ self.full = full
+ self.eps = eps
+
+ @weak_script_method
+ def forward(self, log_input, target):
+ return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full,
+ eps=self.eps, reduction=self.reduction)
+
+
+[docs]@weak_module
+class KLDivLoss(_Loss):
+ r"""The `Kullback-Leibler divergence`_ Loss
+
+ KL divergence is a useful distance measure for continuous distributions
+ and is often useful when performing direct regression over the space of
+ (discretely sampled) continuous output distributions.
+
+ As with :class:`~torch.nn.NLLLoss`, the `input` given is expected to contain
+ *log-probabilities* and is not restricted to a 2D Tensor.
+ The targets are given as *probabilities* (i.e. without taking the logarithm).
+
+ This criterion expects a `target` `Tensor` of the same size as the
+ `input` `Tensor`.
+
+ The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
+
+ .. math::
+ l(x,y) = L = \{ l_1,\dots,l_N \}, \quad
+ l_n = y_n \cdot \left( \log y_n - x_n \right)
+
+ where the index :math:`N` spans all dimensions of ``input`` and :math:`L` has the same
+ shape as ``input``. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then:
+
+ .. math::
+ \ell(x, y) = \begin{cases}
+ \operatorname{mean}(L), & \text{if reduction} = \text{'mean';} \\
+ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
+ \end{cases}
+
+ In default :attr:`reduction` mode ``'mean'``, the losses are averaged for each minibatch over observations
+ **as well as** over dimensions. ``'batchmean'`` mode gives the correct KL divergence where losses
+ are averaged over batch dimension only. ``'mean'`` mode's behavior will be changed to the same as
+ ``'batchmean'`` in the next major release.
+
+ .. _Kullback-Leibler divergence:
+ https://en.wikipedia.org/wiki/Kullback-Leibler_divergence
+
+ Args:
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
+ ``'none'``: no reduction will be applied.
+ ``'batchmean'``: the sum of the output will be divided by batchsize.
+ ``'sum'``: the output will be summed.
+ ``'mean'``: the output will be divided by the number of elements in the output.
+ Default: ``'mean'``
+
+ .. note::
+ :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
+ and in the meantime, specifying either of those two args will override :attr:`reduction`.
+
+ .. note::
+ :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use
+ :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition.
+ In the next major release, ``'mean'`` will be changed to be the same as ``'batchmean'``.
+
+ Shape:
+ - Input: :math:`(N, *)` where :math:`*` means, any number of additional
+ dimensions
+ - Target: :math:`(N, *)`, same shape as the input
+ - Output: scalar by default. If :attr:``reduction`` is ``'none'``, then :math:`(N, *)`,
+ the same shape as the input
+
+ """
+ __constants__ = ['reduction']
+
+ def __init__(self, size_average=None, reduce=None, reduction='mean'):
+ super(KLDivLoss, self).__init__(size_average, reduce, reduction)
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.kl_div(input, target, reduction=self.reduction)
+
+
+[docs]@weak_module
+class MSELoss(_Loss):
+ r"""Creates a criterion that measures the mean squared error (squared L2 norm) between
+ each element in the input :math:`x` and target :math:`y`.
+
+ The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
+
+ .. math::
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
+ l_n = \left( x_n - y_n \right)^2,
+
+ where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
+ (default ``'mean'``), then:
+
+ .. math::
+ \ell(x, y) =
+ \begin{cases}
+ \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
+ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
+ \end{cases}
+
+ :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
+ of :math:`n` elements each.
+
+ The sum operation still operates over all the elements, and divides by :math:`n`.
+
+ The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.
+
+ Args:
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(N, *)` where :math:`*` means, any number of additional
+ dimensions
+ - Target: :math:`(N, *)`, same shape as the input
+
+ Examples::
+
+ >>> loss = nn.MSELoss()
+ >>> input = torch.randn(3, 5, requires_grad=True)
+ >>> target = torch.randn(3, 5)
+ >>> output = loss(input, target)
+ >>> output.backward()
+ """
+ __constants__ = ['reduction']
+
+ def __init__(self, size_average=None, reduce=None, reduction='mean'):
+ super(MSELoss, self).__init__(size_average, reduce, reduction)
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.mse_loss(input, target, reduction=self.reduction)
+
+
+[docs]@weak_module
+class BCELoss(_WeightedLoss):
+ r"""Creates a criterion that measures the Binary Cross Entropy
+ between the target and the output:
+
+ The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
+
+ .. math::
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
+ l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right],
+
+ where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
+ (default ``'mean'``), then
+
+ .. math::
+ \ell(x, y) = \begin{cases}
+ \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
+ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
+ \end{cases}
+
+ This is used for measuring the error of a reconstruction in for example
+ an auto-encoder. Note that the targets :math:`y` should be numbers
+ between 0 and 1.
+
+ Args:
+ weight (Tensor, optional): a manual rescaling weight given to the loss
+ of each batch element. If given, has to be a Tensor of size `nbatch`.
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(N, *)` where :math:`*` means, any number of additional
+ dimensions
+ - Target: :math:`(N, *)`, same shape as the input
+ - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same
+ shape as input.
+
+ Examples::
+
+ >>> m = nn.Sigmoid()
+ >>> loss = nn.BCELoss()
+ >>> input = torch.randn(3, requires_grad=True)
+ >>> target = torch.empty(3).random_(2)
+ >>> output = loss(m(input), target)
+ >>> output.backward()
+ """
+ __constants__ = ['reduction', 'weight']
+
+ def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
+ super(BCELoss, self).__init__(weight, size_average, reduce, reduction)
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
+
+
+[docs]@weak_module
+class BCEWithLogitsLoss(_Loss):
+ r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single
+ class. This version is more numerically stable than using a plain `Sigmoid`
+ followed by a `BCELoss` as, by combining the operations into one layer,
+ we take advantage of the log-sum-exp trick for numerical stability.
+
+ The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
+
+ .. math::
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
+ l_n = - w_n \left[ y_n \cdot \log \sigma(x_n)
+ + (1 - y_n) \cdot \log (1 - \sigma(x_n)) \right],
+
+ where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
+ (default ``'mean'``), then
+
+ .. math::
+ \ell(x, y) = \begin{cases}
+ \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
+ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
+ \end{cases}
+
+ This is used for measuring the error of a reconstruction in for example
+ an auto-encoder. Note that the targets `t[i]` should be numbers
+ between 0 and 1.
+
+ It's possible to trade off recall and precision by adding weights to positive examples.
+ In the case of multi-label classification the loss can be described as:
+
+ .. math::
+ \ell_c(x, y) = L_c = \{l_{1,c},\dots,l_{N,c}\}^\top, \quad
+ l_{n,c} = - w_{n,c} \left[ p_c y_{n,c} \cdot \log \sigma(x_{n,c})
+ + (1 - y_{n,c}) \cdot \log (1 - \sigma(x_{n,c})) \right],
+
+ where :math:`c` is the class number (:math:`c > 1` for multi-label binary classification,
+ :math:`c = 1` for single-label binary classification),
+ :math:`n` is the number of the sample in the batch and
+ :math:`p_c` is the weight of the positive answer for the class :math:`c`.
+
+ :math:`p_c > 1` increases the recall, :math:`p_c < 1` increases the precision.
+
+ For example, if a dataset contains 100 positive and 300 negative examples of a single class,
+ then `pos_weight` for the class should be equal to :math:`\frac{300}{100}=3`.
+ The loss would act as if the dataset contains :math:`3\times 100=300` positive examples.
+
+ Examples::
+
+ >>> target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10
+ >>> output = torch.full([10, 64], 0.999) # A prediction (logit)
+ >>> pos_weight = torch.ones([64]) # All weights are equal to 1
+ >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
+ >>> criterion(output, target) # -log(sigmoid(0.999))
+ tensor(0.3135)
+
+ Args:
+ weight (Tensor, optional): a manual rescaling weight given to the loss
+ of each batch element. If given, has to be a Tensor of size `nbatch`.
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+ pos_weight (Tensor, optional): a weight of positive examples.
+ Must be a vector with length equal to the number of classes.
+
+ Shape:
+ - Input: :math:`(N, *)` where :math:`*` means, any number of additional dimensions
+ - Target: :math:`(N, *)`, same shape as the input
+ - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same
+ shape as input.
+
+ Examples::
+
+ >>> loss = nn.BCEWithLogitsLoss()
+ >>> input = torch.randn(3, requires_grad=True)
+ >>> target = torch.empty(3).random_(2)
+ >>> output = loss(input, target)
+ >>> output.backward()
+ """
+ __constants__ = ['weight', 'pos_weight', 'reduction']
+
+ def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None):
+ super(BCEWithLogitsLoss, self).__init__(size_average, reduce, reduction)
+ self.register_buffer('weight', weight)
+ self.register_buffer('pos_weight', pos_weight)
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.binary_cross_entropy_with_logits(input, target,
+ self.weight,
+ pos_weight=self.pos_weight,
+ reduction=self.reduction)
+
+
+[docs]@weak_module
+class HingeEmbeddingLoss(_Loss):
+ r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`
+ (containing 1 or -1).
+ This is usually used for measuring whether two inputs are similar or
+ dissimilar, e.g. using the L1 pairwise distance as :math:`x`, and is typically
+ used for learning nonlinear embeddings or semi-supervised learning.
+
+ The loss function for :math:`n`-th sample in the mini-batch is
+
+ .. math::
+ l_n = \begin{cases}
+ x_n, & \text{if}\; y_n = 1,\\
+ \max \{0, \Delta - x_n\}, & \text{if}\; y_n = -1,
+ \end{cases}
+
+ and the total loss functions is
+
+ .. math::
+ \ell(x, y) = \begin{cases}
+ \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
+ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
+ \end{cases}
+
+ where :math:`L = \{l_1,\dots,l_N\}^\top`.
+
+ Args:
+ margin (float, optional): Has a default value of `1`.
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(*)` where :math:`*` means, any number of dimensions. The sum operation
+ operates over all the elements.
+ - Target: :math:`(*)`, same shape as the input
+ - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input
+ """
+ __constants__ = ['margin', 'reduction']
+
+ def __init__(self, margin=1.0, size_average=None, reduce=None, reduction='mean'):
+ super(HingeEmbeddingLoss, self).__init__(size_average, reduce, reduction)
+ self.margin = margin
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction)
+
+
+[docs]@weak_module
+class MultiLabelMarginLoss(_Loss):
+ r"""Creates a criterion that optimizes a multi-class multi-classification
+ hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
+ and output :math:`y` (which is a 2D `Tensor` of target class indices).
+ For each sample in the mini-batch:
+
+ .. math::
+ \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)}
+
+ where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \
+ :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \
+ :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \
+ and :math:`i \neq y[j]` for all :math:`i` and :math:`j`.
+
+ :math:`y` and :math:`x` must have the same size.
+
+ The criterion only considers a contiguous block of non-negative targets that
+ starts at the front.
+
+ This allows for different samples to have variable amounts of target classes.
+
+ Args:
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(C)` or :math:`(N, C)` where `N` is the batch size and `C`
+ is the number of classes.
+ - Target: :math:`(C)` or :math:`(N, C)`, label targets padded by -1 ensuring same shape as the input.
+ - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
+
+ Examples::
+
+ >>> loss = nn.MultiLabelMarginLoss()
+ >>> x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]])
+ >>> # for target y, only consider labels 3 and 0, not after label -1
+ >>> y = torch.LongTensor([[3, 0, -1, 1]])
+ >>> loss(x, y)
+ >>> # 0.25 * ((1-(0.1-0.2)) + (1-(0.1-0.4)) + (1-(0.8-0.2)) + (1-(0.8-0.4)))
+ tensor(0.8500)
+
+ """
+ __constants__ = ['reduction']
+
+ def __init__(self, size_average=None, reduce=None, reduction='mean'):
+ super(MultiLabelMarginLoss, self).__init__(size_average, reduce, reduction)
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.multilabel_margin_loss(input, target, reduction=self.reduction)
+
+
+[docs]@weak_module
+class SmoothL1Loss(_Loss):
+ r"""Creates a criterion that uses a squared term if the absolute
+ element-wise error falls below 1 and an L1 term otherwise.
+ It is less sensitive to outliers than the `MSELoss` and in some cases
+ prevents exploding gradients (e.g. see `Fast R-CNN` paper by Ross Girshick).
+ Also known as the Huber loss:
+
+ .. math::
+ \text{loss}(x, y) = \frac{1}{n} \sum_{i} z_{i}
+
+ where :math:`z_{i}` is given by:
+
+ .. math::
+ z_{i} =
+ \begin{cases}
+ 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\
+ |x_i - y_i| - 0.5, & \text{otherwise }
+ \end{cases}
+
+ :math:`x` and :math:`y` arbitrary shapes with a total of :math:`n` elements each
+ the sum operation still operates over all the elements, and divides by :math:`n`.
+
+ The division by :math:`n` can be avoided if sets ``reduction = 'sum'``.
+
+ Args:
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(N, *)` where :math:`*` means, any number of additional
+ dimensions
+ - Target: :math:`(N, *)`, same shape as the input
+ - Output: scalar. If :attr:`reduction` is ``'none'``, then
+ :math:`(N, *)`, same shape as the input
+
+ """
+ __constants__ = ['reduction']
+
+ def __init__(self, size_average=None, reduce=None, reduction='mean'):
+ super(SmoothL1Loss, self).__init__(size_average, reduce, reduction)
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.smooth_l1_loss(input, target, reduction=self.reduction)
+
+
+[docs]@weak_module
+class SoftMarginLoss(_Loss):
+ r"""Creates a criterion that optimizes a two-class classification
+ logistic loss between input tensor :math:`x` and target tensor :math:`y`
+ (containing 1 or -1).
+
+ .. math::
+ \text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()}
+
+ Args:
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(*)` where :math:`*` means, any number of additional
+ dimensions
+ - Target: :math:`(*)`, same shape as the input
+ - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input
+
+ """
+ __constants__ = ['reduction']
+
+ def __init__(self, size_average=None, reduce=None, reduction='mean'):
+ super(SoftMarginLoss, self).__init__(size_average, reduce, reduction)
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.soft_margin_loss(input, target, reduction=self.reduction)
+
+
+[docs]@weak_module
+class CrossEntropyLoss(_WeightedLoss):
+ r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class.
+
+ It is useful when training a classification problem with `C` classes.
+ If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
+ assigning weight to each of the classes.
+ This is particularly useful when you have an unbalanced training set.
+
+ The `input` is expected to contain raw, unnormalized scores for each class.
+
+ `input` has to be a Tensor of size either :math:`(minibatch, C)` or
+ :math:`(minibatch, C, d_1, d_2, ..., d_K)`
+ with :math:`K \geq 1` for the `K`-dimensional case (described later).
+
+ This criterion expects a class index in the range :math:`[0, C-1]` as the
+ `target` for each value of a 1D tensor of size `minibatch`; if `ignore_index`
+ is specified, this criterion also accepts this class index (this index may not
+ necessarily be in the class range).
+
+ The loss can be described as:
+
+ .. math::
+ \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)
+ = -x[class] + \log\left(\sum_j \exp(x[j])\right)
+
+ or in the case of the :attr:`weight` argument being specified:
+
+ .. math::
+ \text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right)
+
+ The losses are averaged across observations for each minibatch.
+
+ Can also be used for higher dimension inputs, such as 2D images, by providing
+ an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`,
+ where :math:`K` is the number of dimensions, and a target of appropriate shape
+ (see below).
+
+
+ Args:
+ weight (Tensor, optional): a manual rescaling weight given to each class.
+ If given, has to be a Tensor of size `C`
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ ignore_index (int, optional): Specifies a target value that is ignored
+ and does not contribute to the input gradient. When :attr:`size_average` is
+ ``True``, the loss is averaged over non-ignored targets.
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(N, C)` where `C = number of classes`, or
+ :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
+ in the case of `K`-dimensional loss.
+ - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
+ :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
+ K-dimensional loss.
+ - Output: scalar.
+ If :attr:`reduction` is ``'none'``, then the same size as the target:
+ :math:`(N)`, or
+ :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case
+ of K-dimensional loss.
+
+ Examples::
+
+ >>> loss = nn.CrossEntropyLoss()
+ >>> input = torch.randn(3, 5, requires_grad=True)
+ >>> target = torch.empty(3, dtype=torch.long).random_(5)
+ >>> output = loss(input, target)
+ >>> output.backward()
+ """
+ __constants__ = ['weight', 'ignore_index', 'reduction']
+
+ def __init__(self, weight=None, size_average=None, ignore_index=-100,
+ reduce=None, reduction='mean'):
+ super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
+ self.ignore_index = ignore_index
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.cross_entropy(input, target, weight=self.weight,
+ ignore_index=self.ignore_index, reduction=self.reduction)
+
+
+[docs]@weak_module
+class MultiLabelSoftMarginLoss(_WeightedLoss):
+ r"""Creates a criterion that optimizes a multi-label one-versus-all
+ loss based on max-entropy, between input :math:`x` and target :math:`y` of size
+ :math:`(N, C)`.
+ For each sample in the minibatch:
+
+ .. math::
+ loss(x, y) = - \frac{1}{C} * \sum_i y[i] * \log((1 + \exp(-x[i]))^{-1})
+ + (1-y[i]) * \log\left(\frac{\exp(-x[i])}{(1 + \exp(-x[i]))}\right)
+
+ where :math:`i \in \left\{0, \; \cdots , \; \text{x.nElement}() - 1\right\}`,
+ :math:`y[i] \in \left\{0, \; 1\right\}`.
+
+ Args:
+ weight (Tensor, optional): a manual rescaling weight given to each
+ class. If given, it has to be a Tensor of size `C`. Otherwise, it is
+ treated as if having all ones.
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(N, C)` where `N` is the batch size and `C` is the number of classes.
+ - Target: :math:`(N, C)`, label targets padded by -1 ensuring same shape as the input.
+ - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
+ """
+ __constants__ = ['weight', 'reduction']
+
+ def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
+ super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction)
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction)
+
+
+[docs]@weak_module
+class CosineEmbeddingLoss(_Loss):
+ r"""Creates a criterion that measures the loss given input tensors
+ :math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1.
+ This is used for measuring whether two inputs are similar or dissimilar,
+ using the cosine distance, and is typically used for learning nonlinear
+ embeddings or semi-supervised learning.
+
+ The loss function for each sample is:
+
+ .. math::
+ \text{loss}(x, y) =
+ \begin{cases}
+ 1 - \cos(x_1, x_2), & \text{if } y = 1 \\
+ \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1
+ \end{cases}
+
+ Args:
+ margin (float, optional): Should be a number from :math:`-1` to :math:`1`,
+ :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the
+ default value is :math:`0`.
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+ """
+ __constants__ = ['margin', 'reduction']
+
+ def __init__(self, margin=0., size_average=None, reduce=None, reduction='mean'):
+ super(CosineEmbeddingLoss, self).__init__(size_average, reduce, reduction)
+ self.margin = margin
+
+ @weak_script_method
+ def forward(self, input1, input2, target):
+ return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
+
+
+[docs]@weak_module
+class MarginRankingLoss(_Loss):
+ r"""Creates a criterion that measures the loss given
+ inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`,
+ and a label 1D mini-batch tensor :math:`y` (containing 1 or -1).
+
+ If :math:`y = 1` then it assumed the first input should be ranked higher
+ (have a larger value) than the second input, and vice-versa for :math:`y = -1`.
+
+ The loss function for each sample in the mini-batch is:
+
+ .. math::
+ \text{loss}(x, y) = \max(0, -y * (x1 - x2) + \text{margin})
+
+ Args:
+ margin (float, optional): Has a default value of :math:`0`.
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(N, D)` where `N` is the batch size and `D` is the size of a sample.
+ - Target: :math:`(N)`
+ - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
+ """
+ __constants__ = ['margin', 'reduction']
+
+ def __init__(self, margin=0., size_average=None, reduce=None, reduction='mean'):
+ super(MarginRankingLoss, self).__init__(size_average, reduce, reduction)
+ self.margin = margin
+
+ @weak_script_method
+ def forward(self, input1, input2, target):
+ return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
+
+
+[docs]@weak_module
+class MultiMarginLoss(_WeightedLoss):
+ r"""Creates a criterion that optimizes a multi-class classification hinge
+ loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and
+ output :math:`y` (which is a 1D tensor of target class indices,
+ :math:`0 \leq y \leq \text{x.size}(1)-1`):
+
+ For each mini-batch sample, the loss in terms of the 1D input :math:`x` and scalar
+ output :math:`y` is:
+
+ .. math::
+ \text{loss}(x, y) = \frac{\sum_i \max(0, \text{margin} - x[y] + x[i]))^p}{\text{x.size}(0)}
+
+ where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`
+ and :math:`i \neq y`.
+
+ Optionally, you can give non-equal weighting on the classes by passing
+ a 1D :attr:`weight` tensor into the constructor.
+
+ The loss function then becomes:
+
+ .. math::
+ \text{loss}(x, y) = \frac{\sum_i \max(0, w[y] * (\text{margin} - x[y] + x[i]))^p)}{\text{x.size}(0)}
+
+ Args:
+ p (int, optional): Has a default value of :math:`1`. :math:`1` and :math:`2`
+ are the only supported values.
+ margin (float, optional): Has a default value of :math:`1`.
+ weight (Tensor, optional): a manual rescaling weight given to each
+ class. If given, it has to be a Tensor of size `C`. Otherwise, it is
+ treated as if having all ones.
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+ """
+ __constants__ = ['p', 'margin', 'weight', 'reduction']
+
+ def __init__(self, p=1, margin=1., weight=None, size_average=None,
+ reduce=None, reduction='mean'):
+ super(MultiMarginLoss, self).__init__(weight, size_average, reduce, reduction)
+ if p != 1 and p != 2:
+ raise ValueError("only p == 1 and p == 2 supported")
+ assert weight is None or weight.dim() == 1
+ self.p = p
+ self.margin = margin
+
+ @weak_script_method
+ def forward(self, input, target):
+ return F.multi_margin_loss(input, target, p=self.p, margin=self.margin,
+ weight=self.weight, reduction=self.reduction)
+
+
+[docs]@weak_module
+class TripletMarginLoss(_Loss):
+ r"""Creates a criterion that measures the triplet loss given an input
+ tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
+ This is used for measuring a relative similarity between samples. A triplet
+ is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative
+ examples` respectively). The shapes of all input tensors should be
+ :math:`(N, D)`.
+
+ The distance swap is described in detail in the paper `Learning shallow
+ convolutional feature descriptors with triplet losses`_ by
+ V. Balntas, E. Riba et al.
+
+ The loss function for each sample in the mini-batch is:
+
+ .. math::
+ L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
+
+
+ where
+
+ .. math::
+ d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
+
+ Args:
+ margin (float, optional): Default: :math:`1`.
+ p (int, optional): The norm degree for pairwise distance. Default: :math:`2`.
+ swap (bool, optional): The distance swap is described in detail in the paper
+ `Learning shallow convolutional feature descriptors with triplet losses` by
+ V. Balntas, E. Riba et al. Default: ``False``.
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when reduce is ``False``. Default: ``True``
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the sum of the output will be divided by the number of
+ elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in the meantime,
+ specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
+
+ Shape:
+ - Input: :math:`(N, D)` where :math:`D` is the vector dimension.
+ - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
+
+ >>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
+ >>> input1 = torch.randn(100, 128, requires_grad=True)
+ >>> input2 = torch.randn(100, 128, requires_grad=True)
+ >>> input3 = torch.randn(100, 128, requires_grad=True)
+ >>> output = triplet_loss(input1, input2, input3)
+ >>> output.backward()
+
+ .. _Learning shallow convolutional feature descriptors with triplet losses:
+ http://www.bmva.org/bmvc/2016/papers/paper119/index.html
+ """
+ __constants__ = ['margin', 'p', 'eps', 'swap', 'reduction']
+
+ def __init__(self, margin=1.0, p=2., eps=1e-6, swap=False, size_average=None,
+ reduce=None, reduction='mean'):
+ super(TripletMarginLoss, self).__init__(size_average, reduce, reduction)
+ self.margin = margin
+ self.p = p
+ self.eps = eps
+ self.swap = swap
+
+ @weak_script_method
+ def forward(self, anchor, positive, negative):
+ return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p,
+ eps=self.eps, swap=self.swap, reduction=self.reduction)
+
+
+[docs]@weak_module
+class CTCLoss(_Loss):
+ r"""The Connectionist Temporal Classification loss.
+
+ Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the
+ probability of possible alignments of input to target, producing a loss value which is differentiable
+ with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which
+ limits the length of the target sequence such that it must be :math:`\leq` the input length.
+
+ **Args:**
+ **blank** (int, optional): blank label. Default :math:`0`.
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
+ ``'mean'``: the output losses will be divided by the target lengths and
+ then the mean over the batch is taken. Default: ``'mean'``
+
+ **zero_infinity** (bool, optional):
+ Whether to zero infinite losses and the associated gradients.
+ Default: ``False``
+ Infinite losses mainly occur when the inputs are too short
+ to be aligned to the targets.
+
+ **Inputs:**
+ **log_probs**: Tensor of size :math:`(T, N, C)`
+ | :math:`T = \text{input length}`
+ | :math:`N = \text{batch size}`
+ | :math:`C = \text{number of classes (including blank)}`
+
+ The logarithmized probabilities of the outputs
+ (e.g. obtained with :func:`torch.nn.functional.log_softmax`).
+ **targets**: Tensor of size :math:`(N, S)` or :math:`(\text{sum(target_lengths)})`
+ | :math:`N = \text{batch size}`
+ | :math:`S = \text{max target length, if shape is } (N, S)`.
+
+ | Target sequences. Each element in the target sequence is a class index. Target index
+ cannot be blank (default=0).
+
+ | In the :math:`(N, S)` form, targets are padded to the length of the longest sequence, and stacked.
+ | In the :math:`(\text{sum(target_lengths)})` form, the targets are assumed to be un-padded and concatenated
+ within 1 dimension.
+ **input_lengths**: Tuple or tensor of size :math:`(N)`.
+ Lengths of the inputs (must each be :math:`\leq T`).
+ Lengths are specified for each sequence to achieve masking under the
+ assumption that sequences are padded to equal lengths.
+ **target_lengths**: Tuple or tensor of size :math:`(N)`.
+ | Lengths of the targets. Lengths are specified for each sequence to achieve masking under the
+ assumption that sequences are padded to equal lengths.
+
+ | If target shape is :math:`(N,S)`, target_lengths are effectively the stop index
+ :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for
+ each target in a batch. Lengths must each be :math:`\leq S`
+
+ | If the targets are given as a 1d tensor that is the concatenation of individual targets,
+ the target_lengths must add up to the total length of the tensor.
+
+ Example::
+
+ >>> T = 50 # Input sequence length
+ >>> C = 20 # Number of classes (excluding blank)
+ >>> N = 16 # Batch size
+ >>> S = 30 # Target sequence length of longest target in batch
+ >>> S_min = 10 # Minimum target length, for demonstration purposes
+ >>>
+ >>> # Initialize random batch of input vectors, for *size = (T,N,C)
+ >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
+ >>>
+ >>> # Initialize random batch of targets (0 = blank, 1:C+1 = classes)
+ >>> target = torch.randint(low=1, high=C+1, size=(N, S), dtype=torch.long)
+ >>>
+ >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
+ >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
+ >>> ctc_loss = nn.CTCLoss()
+ >>> loss = ctc_loss(input, target, input_lengths, target_lengths)
+ >>> loss.backward()
+
+ Reference:
+ A. Graves et al.: Connectionist Temporal Classification:
+ Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
+ https://www.cs.toronto.edu/~graves/icml_2006.pdf
+
+ .. Note::
+ In order to use CuDNN, the following must be satisfied: :attr:`targets` must be
+ in concatenated format, all :attr:`input_lengths` must be `T`. :math:`blank=0`,
+ :attr:`target_lengths` :math:`\leq 256`, the integer arguments must be of
+ dtype :attr:`torch.int32`.
+
+ The regular implementation uses the (more common in PyTorch) `torch.long` dtype.
+
+
+ .. include:: cudnn_deterministic.rst
+
+ """
+ __constants__ = ['blank', 'reduction']
+
+ def __init__(self, blank=0, reduction='mean', zero_infinity=False):
+ super(CTCLoss, self).__init__(reduction=reduction)
+ self.blank = blank
+ self.zero_infinity = zero_infinity
+
+ @weak_script_method
+ def forward(self, log_probs, targets, input_lengths, target_lengths):
+ return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction,
+ self.zero_infinity)
+
+# TODO: L1HingeEmbeddingCriterion
+# TODO: MSECriterion weight
+# TODO: ClassSimplexCriterion
+
+from collections import OrderedDict, namedtuple
+import functools
+import itertools
+
+import torch
+from ..backends.thnn import backend as thnn_backend
+from ..parameter import Parameter
+import torch.utils.hooks as hooks
+
+
+_IncompatibleKeys = namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])
+
+
+def _addindent(s_, numSpaces):
+ s = s_.split('\n')
+ # don't do anything for single-line stuff
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(numSpaces * ' ') + line for line in s]
+ s = '\n'.join(s)
+ s = first + '\n' + s
+ return s
+
+
+[docs]class Module(object):
+ r"""Base class for all neural network modules.
+
+ Your models should also subclass this class.
+
+ Modules can also contain other Modules, allowing to nest them in
+ a tree structure. You can assign the submodules as regular attributes::
+
+ import torch.nn as nn
+ import torch.nn.functional as F
+
+ class Model(nn.Module):
+ def __init__(self):
+ super(Model, self).__init__()
+ self.conv1 = nn.Conv2d(1, 20, 5)
+ self.conv2 = nn.Conv2d(20, 20, 5)
+
+ def forward(self, x):
+ x = F.relu(self.conv1(x))
+ return F.relu(self.conv2(x))
+
+ Submodules assigned in this way will be registered, and will have their
+ parameters converted too when you call :meth:`to`, etc.
+ """
+
+ dump_patches = False
+
+ r"""This allows better BC support for :meth:`load_state_dict`. In
+ :meth:`state_dict`, the version number will be saved as in the attribute
+ `_metadata` of the returned state dict, and thus pickled. `_metadata` is a
+ dictionary with keys that follow the naming convention of state dict. See
+ ``_load_from_state_dict`` on how to use this information in loading.
+
+ If new parameters/buffers are added/removed from a module, this number shall
+ be bumped, and the module's `_load_from_state_dict` method can compare the
+ version number and do appropriate changes if the state dict is from before
+ the change."""
+ _version = 1
+
+ def __init__(self):
+ self._backend = thnn_backend
+ self._parameters = OrderedDict()
+ self._buffers = OrderedDict()
+ self._backward_hooks = OrderedDict()
+ self._forward_hooks = OrderedDict()
+ self._forward_pre_hooks = OrderedDict()
+ self._state_dict_hooks = OrderedDict()
+ self._load_state_dict_pre_hooks = OrderedDict()
+ self._modules = OrderedDict()
+ self.training = True
+
+[docs] def forward(self, *input):
+ r"""Defines the computation performed at every call.
+
+ Should be overridden by all subclasses.
+
+ .. note::
+ Although the recipe for forward pass needs to be defined within
+ this function, one should call the :class:`Module` instance afterwards
+ instead of this since the former takes care of running the
+ registered hooks while the latter silently ignores them.
+ """
+ raise NotImplementedError
+
+[docs] def register_buffer(self, name, tensor):
+ r"""Adds a persistent buffer to the module.
+
+ This is typically used to register a buffer that should not to be
+ considered a model parameter. For example, BatchNorm's ``running_mean``
+ is not a parameter, but is part of the persistent state.
+
+ Buffers can be accessed as attributes using given names.
+
+ Args:
+ name (string): name of the buffer. The buffer can be accessed
+ from this module using the given name
+ tensor (Tensor): buffer to be registered.
+
+ Example::
+
+ >>> self.register_buffer('running_mean', torch.zeros(num_features))
+
+ """
+ if '_buffers' not in self.__dict__:
+ raise AttributeError(
+ "cannot assign buffer before Module.__init__() call")
+ elif not isinstance(name, torch._six.string_classes):
+ raise TypeError("buffer name should be a string. "
+ "Got {}".format(torch.typename(name)))
+ elif '.' in name:
+ raise KeyError("buffer name can't contain \".\"")
+ elif name == '':
+ raise KeyError("buffer name can't be empty string \"\"")
+ elif hasattr(self, name) and name not in self._buffers:
+ raise KeyError("attribute '{}' already exists".format(name))
+ elif tensor is not None and not isinstance(tensor, torch.Tensor):
+ raise TypeError("cannot assign '{}' object to buffer '{}' "
+ "(torch Tensor or None required)"
+ .format(torch.typename(tensor), name))
+ else:
+ self._buffers[name] = tensor
+
+[docs] def register_parameter(self, name, param):
+ r"""Adds a parameter to the module.
+
+ The parameter can be accessed as an attribute using given name.
+
+ Args:
+ name (string): name of the parameter. The parameter can be accessed
+ from this module using the given name
+ param (Parameter): parameter to be added to the module.
+ """
+ if '_parameters' not in self.__dict__:
+ raise AttributeError(
+ "cannot assign parameter before Module.__init__() call")
+
+ elif not isinstance(name, torch._six.string_classes):
+ raise TypeError("parameter name should be a string. "
+ "Got {}".format(torch.typename(name)))
+ elif '.' in name:
+ raise KeyError("parameter name can't contain \".\"")
+ elif name == '':
+ raise KeyError("parameter name can't be empty string \"\"")
+ elif hasattr(self, name) and name not in self._parameters:
+ raise KeyError("attribute '{}' already exists".format(name))
+
+ if param is None:
+ self._parameters[name] = None
+ elif not isinstance(param, Parameter):
+ raise TypeError("cannot assign '{}' object to parameter '{}' "
+ "(torch.nn.Parameter or None required)"
+ .format(torch.typename(param), name))
+ elif param.grad_fn:
+ raise ValueError(
+ "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
+ "parameters must be created explicitly. To express '{0}' "
+ "as a function of another Tensor, compute the value in "
+ "the forward() method.".format(name))
+ else:
+ self._parameters[name] = param
+
+[docs] def add_module(self, name, module):
+ r"""Adds a child module to the current module.
+
+ The module can be accessed as an attribute using the given name.
+
+ Args:
+ name (string): name of the child module. The child module can be
+ accessed from this module using the given name
+ module (Module): child module to be added to the module.
+ """
+ if not isinstance(module, Module) and module is not None:
+ raise TypeError("{} is not a Module subclass".format(
+ torch.typename(module)))
+ elif not isinstance(name, torch._six.string_classes):
+ raise TypeError("module name should be a string. Got {}".format(
+ torch.typename(name)))
+ elif hasattr(self, name) and name not in self._modules:
+ raise KeyError("attribute '{}' already exists".format(name))
+ elif '.' in name:
+ raise KeyError("module name can't contain \".\"")
+ elif name == '':
+ raise KeyError("module name can't be empty string \"\"")
+ self._modules[name] = module
+
+ def _apply(self, fn):
+ for module in self.children():
+ module._apply(fn)
+
+ for param in self._parameters.values():
+ if param is not None:
+ # Tensors stored in modules are graph leaves, and we don't
+ # want to create copy nodes, so we have to unpack the data.
+ param.data = fn(param.data)
+ if param._grad is not None:
+ param._grad.data = fn(param._grad.data)
+
+ for key, buf in self._buffers.items():
+ if buf is not None:
+ self._buffers[key] = fn(buf)
+
+ return self
+
+[docs] def apply(self, fn):
+ r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
+ as well as self. Typical use includes initializing the parameters of a model
+ (see also :ref:`torch-nn-init`).
+
+ Args:
+ fn (:class:`Module` -> None): function to be applied to each submodule
+
+ Returns:
+ Module: self
+
+ Example::
+
+ >>> def init_weights(m):
+ >>> print(m)
+ >>> if type(m) == nn.Linear:
+ >>> m.weight.data.fill_(1.0)
+ >>> print(m.weight)
+ >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
+ >>> net.apply(init_weights)
+ Linear(in_features=2, out_features=2, bias=True)
+ Parameter containing:
+ tensor([[ 1., 1.],
+ [ 1., 1.]])
+ Linear(in_features=2, out_features=2, bias=True)
+ Parameter containing:
+ tensor([[ 1., 1.],
+ [ 1., 1.]])
+ Sequential(
+ (0): Linear(in_features=2, out_features=2, bias=True)
+ (1): Linear(in_features=2, out_features=2, bias=True)
+ )
+ Sequential(
+ (0): Linear(in_features=2, out_features=2, bias=True)
+ (1): Linear(in_features=2, out_features=2, bias=True)
+ )
+ """
+ for module in self.children():
+ module.apply(fn)
+ fn(self)
+ return self
+
+[docs] def cuda(self, device=None):
+ r"""Moves all model parameters and buffers to the GPU.
+
+ This also makes associated parameters and buffers different objects. So
+ it should be called before constructing optimizer if the module will
+ live on GPU while being optimized.
+
+ Arguments:
+ device (int, optional): if specified, all parameters will be
+ copied to that device
+
+ Returns:
+ Module: self
+ """
+ return self._apply(lambda t: t.cuda(device))
+
+[docs] def cpu(self):
+ r"""Moves all model parameters and buffers to the CPU.
+
+ Returns:
+ Module: self
+ """
+ return self._apply(lambda t: t.cpu())
+
+[docs] def type(self, dst_type):
+ r"""Casts all parameters and buffers to :attr:`dst_type`.
+
+ Arguments:
+ dst_type (type or string): the desired type
+
+ Returns:
+ Module: self
+ """
+ return self._apply(lambda t: t.type(dst_type))
+
+[docs] def float(self):
+ r"""Casts all floating point parameters and buffers to float datatype.
+
+ Returns:
+ Module: self
+ """
+ return self._apply(lambda t: t.float() if t.is_floating_point() else t)
+
+[docs] def double(self):
+ r"""Casts all floating point parameters and buffers to ``double`` datatype.
+
+ Returns:
+ Module: self
+ """
+ return self._apply(lambda t: t.double() if t.is_floating_point() else t)
+
+[docs] def half(self):
+ r"""Casts all floating point parameters and buffers to ``half`` datatype.
+
+ Returns:
+ Module: self
+ """
+ return self._apply(lambda t: t.half() if t.is_floating_point() else t)
+
+[docs] def to(self, *args, **kwargs):
+ r"""Moves and/or casts the parameters and buffers.
+
+ This can be called as
+
+ .. function:: to(device=None, dtype=None, non_blocking=False)
+
+ .. function:: to(dtype, non_blocking=False)
+
+ .. function:: to(tensor, non_blocking=False)
+
+ Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
+ floating point desired :attr:`dtype` s. In addition, this method will
+ only cast the floating point parameters and buffers to :attr:`dtype`
+ (if given). The integral parameters and buffers will be moved
+ :attr:`device`, if that is given, but with dtypes unchanged. When
+ :attr:`non_blocking` is set, it tries to convert/move asynchronously
+ with respect to the host if possible, e.g., moving CPU Tensors with
+ pinned memory to CUDA devices.
+
+ See below for examples.
+
+ .. note::
+ This method modifies the module in-place.
+
+ Args:
+ device (:class:`torch.device`): the desired device of the parameters
+ and buffers in this module
+ dtype (:class:`torch.dtype`): the desired floating point type of
+ the floating point parameters and buffers in this module
+ tensor (torch.Tensor): Tensor whose dtype and device are the desired
+ dtype and device for all parameters and buffers in this module
+
+ Returns:
+ Module: self
+
+ Example::
+
+ >>> linear = nn.Linear(2, 2)
+ >>> linear.weight
+ Parameter containing:
+ tensor([[ 0.1913, -0.3420],
+ [-0.5113, -0.2325]])
+ >>> linear.to(torch.double)
+ Linear(in_features=2, out_features=2, bias=True)
+ >>> linear.weight
+ Parameter containing:
+ tensor([[ 0.1913, -0.3420],
+ [-0.5113, -0.2325]], dtype=torch.float64)
+ >>> gpu1 = torch.device("cuda:1")
+ >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
+ Linear(in_features=2, out_features=2, bias=True)
+ >>> linear.weight
+ Parameter containing:
+ tensor([[ 0.1914, -0.3420],
+ [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
+ >>> cpu = torch.device("cpu")
+ >>> linear.to(cpu)
+ Linear(in_features=2, out_features=2, bias=True)
+ >>> linear.weight
+ Parameter containing:
+ tensor([[ 0.1914, -0.3420],
+ [-0.5112, -0.2324]], dtype=torch.float16)
+
+ """
+
+ device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)
+
+ if dtype is not None:
+ if not dtype.is_floating_point:
+ raise TypeError('nn.Module.to only accepts floating point '
+ 'dtypes, but got desired dtype={}'.format(dtype))
+
+ def convert(t):
+ return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
+
+ return self._apply(convert)
+
+[docs] def register_backward_hook(self, hook):
+ r"""Registers a backward hook on the module.
+
+ The hook will be called every time the gradients with respect to module
+ inputs are computed. The hook should have the following signature::
+
+ hook(module, grad_input, grad_output) -> Tensor or None
+
+ The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
+ module has multiple inputs or outputs. The hook should not modify its
+ arguments, but it can optionally return a new gradient with respect to
+ input that will be used in place of :attr:`grad_input` in subsequent
+ computations.
+
+ Returns:
+ :class:`torch.utils.hooks.RemovableHandle`:
+ a handle that can be used to remove the added hook by calling
+ ``handle.remove()``
+
+ .. warning ::
+
+ The current implementation will not have the presented behavior
+ for complex :class:`Module` that perform many operations.
+ In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
+ contain the gradients for a subset of the inputs and outputs.
+ For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
+ directly on a specific input or output to get the required gradients.
+
+ """
+ handle = hooks.RemovableHandle(self._backward_hooks)
+ self._backward_hooks[handle.id] = hook
+ return handle
+
+[docs] def register_forward_pre_hook(self, hook):
+ r"""Registers a forward pre-hook on the module.
+
+ The hook will be called every time before :func:`forward` is invoked.
+ It should have the following signature::
+
+ hook(module, input) -> None
+
+ The hook should not modify the input.
+
+ Returns:
+ :class:`torch.utils.hooks.RemovableHandle`:
+ a handle that can be used to remove the added hook by calling
+ ``handle.remove()``
+ """
+ handle = hooks.RemovableHandle(self._forward_pre_hooks)
+ self._forward_pre_hooks[handle.id] = hook
+ return handle
+
+[docs] def register_forward_hook(self, hook):
+ r"""Registers a forward hook on the module.
+
+ The hook will be called every time after :func:`forward` has computed an output.
+ It should have the following signature::
+
+ hook(module, input, output) -> None
+
+ The hook should not modify the input or output.
+
+ Returns:
+ :class:`torch.utils.hooks.RemovableHandle`:
+ a handle that can be used to remove the added hook by calling
+ ``handle.remove()``
+ """
+ handle = hooks.RemovableHandle(self._forward_hooks)
+ self._forward_hooks[handle.id] = hook
+ return handle
+
+ def _tracing_name(self, tracing_state):
+ if not tracing_state._traced_module_stack:
+ return None
+ module = tracing_state._traced_module_stack[-1]
+ for name, child in module.named_children():
+ if child is self:
+ return name
+ return None
+
+ def _slow_forward(self, *input, **kwargs):
+ tracing_state = torch._C._get_tracing_state()
+ if not tracing_state:
+ return self.forward(*input, **kwargs)
+ if not hasattr(tracing_state, '_traced_module_stack'):
+ tracing_state._traced_module_stack = []
+ name = self._tracing_name(tracing_state)
+ if name:
+ tracing_state.push_scope('%s[%s]' % (self._get_name(), name))
+ else:
+ tracing_state.push_scope(self._get_name())
+ tracing_state._traced_module_stack.append(self)
+ try:
+ result = self.forward(*input, **kwargs)
+ finally:
+ tracing_state.pop_scope()
+ tracing_state._traced_module_stack.pop()
+ return result
+
+ def __call__(self, *input, **kwargs):
+ for hook in self._forward_pre_hooks.values():
+ hook(self, input)
+ if torch._C._get_tracing_state():
+ result = self._slow_forward(*input, **kwargs)
+ else:
+ result = self.forward(*input, **kwargs)
+ for hook in self._forward_hooks.values():
+ hook_result = hook(self, input, result)
+ if hook_result is not None:
+ raise RuntimeError(
+ "forward hooks should never return any values, but '{}'"
+ "didn't return None".format(hook))
+ if len(self._backward_hooks) > 0:
+ var = result
+ while not isinstance(var, torch.Tensor):
+ if isinstance(var, dict):
+ var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
+ else:
+ var = var[0]
+ grad_fn = var.grad_fn
+ if grad_fn is not None:
+ for hook in self._backward_hooks.values():
+ wrapper = functools.partial(hook, self)
+ functools.update_wrapper(wrapper, hook)
+ grad_fn.register_hook(wrapper)
+ return result
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ # Support loading old checkpoints that don't have the following attrs:
+ if '_forward_pre_hooks' not in self.__dict__:
+ self._forward_pre_hooks = OrderedDict()
+ if '_state_dict_hooks' not in self.__dict__:
+ self._state_dict_hooks = OrderedDict()
+ if '_load_state_dict_pre_hooks' not in self.__dict__:
+ self._load_state_dict_pre_hooks = OrderedDict()
+
+ def __getattr__(self, name):
+ if '_parameters' in self.__dict__:
+ _parameters = self.__dict__['_parameters']
+ if name in _parameters:
+ return _parameters[name]
+ if '_buffers' in self.__dict__:
+ _buffers = self.__dict__['_buffers']
+ if name in _buffers:
+ return _buffers[name]
+ if '_modules' in self.__dict__:
+ modules = self.__dict__['_modules']
+ if name in modules:
+ return modules[name]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, name))
+
+ def __setattr__(self, name, value):
+ def remove_from(*dicts):
+ for d in dicts:
+ if name in d:
+ del d[name]
+
+ params = self.__dict__.get('_parameters')
+ if isinstance(value, Parameter):
+ if params is None:
+ raise AttributeError(
+ "cannot assign parameters before Module.__init__() call")
+ remove_from(self.__dict__, self._buffers, self._modules)
+ self.register_parameter(name, value)
+ elif params is not None and name in params:
+ if value is not None:
+ raise TypeError("cannot assign '{}' as parameter '{}' "
+ "(torch.nn.Parameter or None expected)"
+ .format(torch.typename(value), name))
+ self.register_parameter(name, value)
+ else:
+ modules = self.__dict__.get('_modules')
+ if isinstance(value, Module):
+ if modules is None:
+ raise AttributeError(
+ "cannot assign module before Module.__init__() call")
+ remove_from(self.__dict__, self._parameters, self._buffers)
+ modules[name] = value
+ elif modules is not None and name in modules:
+ if value is not None:
+ raise TypeError("cannot assign '{}' as child module '{}' "
+ "(torch.nn.Module or None expected)"
+ .format(torch.typename(value), name))
+ modules[name] = value
+ else:
+ buffers = self.__dict__.get('_buffers')
+ if buffers is not None and name in buffers:
+ if value is not None and not isinstance(value, torch.Tensor):
+ raise TypeError("cannot assign '{}' as buffer '{}' "
+ "(torch.Tensor or None expected)"
+ .format(torch.typename(value), name))
+ buffers[name] = value
+ else:
+ object.__setattr__(self, name, value)
+
+ def __delattr__(self, name):
+ if name in self._parameters:
+ del self._parameters[name]
+ elif name in self._buffers:
+ del self._buffers[name]
+ elif name in self._modules:
+ del self._modules[name]
+ else:
+ object.__delattr__(self, name)
+
+ def _register_state_dict_hook(self, hook):
+ r"""These hooks will be called with arguments: `self`, `state_dict`,
+ `prefix`, `local_metadata`, after the `state_dict` of `self` is set.
+ Note that only parameters and buffers of `self` or its children are
+ guaranteed to exist in `state_dict`. The hooks may modify `state_dict`
+ inplace or return a new one.
+ """
+ handle = hooks.RemovableHandle(self._state_dict_hooks)
+ self._state_dict_hooks[handle.id] = hook
+ return handle
+
+[docs] def state_dict(self, destination=None, prefix='', keep_vars=False):
+ r"""Returns a dictionary containing a whole state of the module.
+
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+
+ Returns:
+ dict:
+ a dictionary containing a whole state of the module
+
+ Example::
+
+ >>> module.state_dict().keys()
+ ['bias', 'weight']
+
+ """
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
+ for name, param in self._parameters.items():
+ if param is not None:
+ destination[prefix + name] = param if keep_vars else param.data
+ for name, buf in self._buffers.items():
+ if buf is not None:
+ destination[prefix + name] = buf if keep_vars else buf.data
+ for name, module in self._modules.items():
+ if module is not None:
+ module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
+ for hook in self._state_dict_hooks.values():
+ hook_result = hook(self, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+
+ def _register_load_state_dict_pre_hook(self, hook):
+ r"""These hooks will be called with arguments: `state_dict`, `prefix`,
+ `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`,
+ `error_msgs`, before loading `state_dict` into `self`. These arguments
+ are exactly the same as those of `_load_from_state_dict`.
+ """
+ handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
+ self._load_state_dict_pre_hooks[handle.id] = hook
+ return handle
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ r"""Copies parameters and buffers from :attr:`state_dict` into only
+ this module, but not its descendants. This is called on every submodule
+ in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
+ module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
+ For state dicts without metadata, :attr:`local_metadata` is empty.
+ Subclasses can achieve class-specific backward compatible loading using
+ the version number at `local_metadata.get("version", None)`.
+
+ .. note::
+ :attr:`state_dict` is not the same object as the input
+ :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
+ it can be modified.
+
+ Arguments:
+ state_dict (dict): a dict containing parameters and
+ persistent buffers.
+ prefix (str): the prefix for parameters and buffers used in this
+ module
+ local_metadata (dict): a dict containing the metadata for this module.
+ See
+ strict (bool): whether to strictly enforce that the keys in
+ :attr:`state_dict` with :attr:`prefix` match the names of
+ parameters and buffers in this module
+ missing_keys (list of str): if ``strict=True``, add missing keys to
+ this list
+ unexpected_keys (list of str): if ``strict=True``, add unexpected
+ keys to this list
+ error_msgs (list of str): error messages should be added to this
+ list, and will be reported together in
+ :meth:`~torch.nn.Module.load_state_dict`
+ """
+ for hook in self._load_state_dict_pre_hooks.values():
+ hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+ local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
+ local_state = {k: v.data for k, v in local_name_params if v is not None}
+
+ for name, param in local_state.items():
+ key = prefix + name
+ if key in state_dict:
+ input_param = state_dict[key]
+
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
+ if len(param.shape) == 0 and len(input_param.shape) == 1:
+ input_param = input_param[0]
+
+ if input_param.shape != param.shape:
+ # local shape should match the one in checkpoint
+ error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
+ 'the shape in current model is {}.'
+ .format(key, input_param.shape, param.shape))
+ continue
+
+ if isinstance(input_param, Parameter):
+ # backwards compatibility for serialized parameters
+ input_param = input_param.data
+ try:
+ param.copy_(input_param)
+ except Exception:
+ error_msgs.append('While copying the parameter named "{}", '
+ 'whose dimensions in the model are {} and '
+ 'whose dimensions in the checkpoint are {}.'
+ .format(key, param.size(), input_param.size()))
+ elif strict:
+ missing_keys.append(key)
+
+ if strict:
+ for key in state_dict.keys():
+ if key.startswith(prefix):
+ input_name = key[len(prefix):]
+ input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
+ if input_name not in self._modules and input_name not in local_state:
+ unexpected_keys.append(key)
+
+[docs] def load_state_dict(self, state_dict, strict=True):
+ r"""Copies parameters and buffers from :attr:`state_dict` into
+ this module and its descendants. If :attr:`strict` is ``True``, then
+ the keys of :attr:`state_dict` must exactly match the keys returned
+ by this module's :meth:`~torch.nn.Module.state_dict` function.
+
+ Arguments:
+ state_dict (dict): a dict containing parameters and
+ persistent buffers.
+ strict (bool, optional): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
+
+ Returns:
+ ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
+ * **missing_keys** is a list of str containing the missing keys
+ * **unexpected_keys** is a list of str containing the unexpected keys
+ """
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ def load(module, prefix=''):
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ module._load_from_state_dict(
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+
+ load(self)
+
+ if strict:
+ if len(unexpected_keys) > 0:
+ error_msgs.insert(
+ 0, 'Unexpected key(s) in state_dict: {}. '.format(
+ ', '.join('"{}"'.format(k) for k in unexpected_keys)))
+ if len(missing_keys) > 0:
+ error_msgs.insert(
+ 0, 'Missing key(s) in state_dict: {}. '.format(
+ ', '.join('"{}"'.format(k) for k in missing_keys)))
+
+ if len(error_msgs) > 0:
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
+ self.__class__.__name__, "\n\t".join(error_msgs)))
+ return _IncompatibleKeys(missing_keys, unexpected_keys)
+
+ def _named_members(self, get_members_fn, prefix='', recurse=True):
+ r"""Helper method for yielding various names + members of modules."""
+ memo = set()
+ modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
+ for module_prefix, module in modules:
+ members = get_members_fn(module)
+ for k, v in members:
+ if v is None or v in memo:
+ continue
+ memo.add(v)
+ name = module_prefix + ('.' if module_prefix else '') + k
+ yield name, v
+
+[docs] def parameters(self, recurse=True):
+ r"""Returns an iterator over module parameters.
+
+ This is typically passed to an optimizer.
+
+ Args:
+ recurse (bool): if True, then yields parameters of this module
+ and all submodules. Otherwise, yields only parameters that
+ are direct members of this module.
+
+ Yields:
+ Parameter: module parameter
+
+ Example::
+
+ >>> for param in model.parameters():
+ >>> print(type(param.data), param.size())
+ <class 'torch.FloatTensor'> (20L,)
+ <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
+
+ """
+ for name, param in self.named_parameters(recurse=recurse):
+ yield param
+
+[docs] def named_parameters(self, prefix='', recurse=True):
+ r"""Returns an iterator over module parameters, yielding both the
+ name of the parameter as well as the parameter itself.
+
+ Args:
+ prefix (str): prefix to prepend to all parameter names.
+ recurse (bool): if True, then yields parameters of this module
+ and all submodules. Otherwise, yields only parameters that
+ are direct members of this module.
+
+ Yields:
+ (string, Parameter): Tuple containing the name and parameter
+
+ Example::
+
+ >>> for name, param in self.named_parameters():
+ >>> if name in ['bias']:
+ >>> print(param.size())
+
+ """
+ gen = self._named_members(
+ lambda module: module._parameters.items(),
+ prefix=prefix, recurse=recurse)
+ for elem in gen:
+ yield elem
+
+[docs] def buffers(self, recurse=True):
+ r"""Returns an iterator over module buffers.
+
+ Args:
+ recurse (bool): if True, then yields buffers of this module
+ and all submodules. Otherwise, yields only buffers that
+ are direct members of this module.
+
+ Yields:
+ torch.Tensor: module buffer
+
+ Example::
+
+ >>> for buf in model.buffers():
+ >>> print(type(buf.data), buf.size())
+ <class 'torch.FloatTensor'> (20L,)
+ <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
+
+ """
+ for name, buf in self.named_buffers(recurse=recurse):
+ yield buf
+
+[docs] def named_buffers(self, prefix='', recurse=True):
+ r"""Returns an iterator over module buffers, yielding both the
+ name of the buffer as well as the buffer itself.
+
+ Args:
+ prefix (str): prefix to prepend to all buffer names.
+ recurse (bool): if True, then yields buffers of this module
+ and all submodules. Otherwise, yields only buffers that
+ are direct members of this module.
+
+ Yields:
+ (string, torch.Tensor): Tuple containing the name and buffer
+
+ Example::
+
+ >>> for name, buf in self.named_buffers():
+ >>> if name in ['running_var']:
+ >>> print(buf.size())
+
+ """
+ gen = self._named_members(
+ lambda module: module._buffers.items(),
+ prefix=prefix, recurse=recurse)
+ for elem in gen:
+ yield elem
+
+[docs] def children(self):
+ r"""Returns an iterator over immediate children modules.
+
+ Yields:
+ Module: a child module
+ """
+ for name, module in self.named_children():
+ yield module
+
+[docs] def named_children(self):
+ r"""Returns an iterator over immediate children modules, yielding both
+ the name of the module as well as the module itself.
+
+ Yields:
+ (string, Module): Tuple containing a name and child module
+
+ Example::
+
+ >>> for name, module in model.named_children():
+ >>> if name in ['conv4', 'conv5']:
+ >>> print(module)
+
+ """
+ memo = set()
+ for name, module in self._modules.items():
+ if module is not None and module not in memo:
+ memo.add(module)
+ yield name, module
+
+[docs] def modules(self):
+ r"""Returns an iterator over all modules in the network.
+
+ Yields:
+ Module: a module in the network
+
+ Note:
+ Duplicate modules are returned only once. In the following
+ example, ``l`` will be returned only once.
+
+ Example::
+
+ >>> l = nn.Linear(2, 2)
+ >>> net = nn.Sequential(l, l)
+ >>> for idx, m in enumerate(net.modules()):
+ print(idx, '->', m)
+
+ 0 -> Sequential(
+ (0): Linear(in_features=2, out_features=2, bias=True)
+ (1): Linear(in_features=2, out_features=2, bias=True)
+ )
+ 1 -> Linear(in_features=2, out_features=2, bias=True)
+
+ """
+ for name, module in self.named_modules():
+ yield module
+
+[docs] def named_modules(self, memo=None, prefix=''):
+ r"""Returns an iterator over all modules in the network, yielding
+ both the name of the module as well as the module itself.
+
+ Yields:
+ (string, Module): Tuple of name and module
+
+ Note:
+ Duplicate modules are returned only once. In the following
+ example, ``l`` will be returned only once.
+
+ Example::
+
+ >>> l = nn.Linear(2, 2)
+ >>> net = nn.Sequential(l, l)
+ >>> for idx, m in enumerate(net.named_modules()):
+ print(idx, '->', m)
+
+ 0 -> ('', Sequential(
+ (0): Linear(in_features=2, out_features=2, bias=True)
+ (1): Linear(in_features=2, out_features=2, bias=True)
+ ))
+ 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
+
+ """
+
+ if memo is None:
+ memo = set()
+ if self not in memo:
+ memo.add(self)
+ yield prefix, self
+ for name, module in self._modules.items():
+ if module is None:
+ continue
+ submodule_prefix = prefix + ('.' if prefix else '') + name
+ for m in module.named_modules(memo, submodule_prefix):
+ yield m
+
+[docs] def train(self, mode=True):
+ r"""Sets the module in training mode.
+
+ This has any effect only on certain modules. See documentations of
+ particular modules for details of their behaviors in training/evaluation
+ mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
+ etc.
+
+ Returns:
+ Module: self
+ """
+ self.training = mode
+ for module in self.children():
+ module.train(mode)
+ return self
+
+[docs] def eval(self):
+ r"""Sets the module in evaluation mode.
+
+ This has any effect only on certain modules. See documentations of
+ particular modules for details of their behaviors in training/evaluation
+ mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
+ etc.
+ """
+ return self.train(False)
+
+[docs] def zero_grad(self):
+ r"""Sets gradients of all model parameters to zero."""
+ for p in self.parameters():
+ if p.grad is not None:
+ p.grad.detach_()
+ p.grad.zero_()
+
+ def share_memory(self):
+ return self._apply(lambda t: t.share_memory_())
+
+ def _get_name(self):
+ return self.__class__.__name__
+
+[docs] def extra_repr(self):
+ r"""Set the extra representation of the module
+
+ To print customized extra information, you should reimplement
+ this method in your own modules. Both single-line and multi-line
+ strings are acceptable.
+ """
+ return ''
+
+ def __repr__(self):
+ # We treat the extra repr like the sub-module, one item per line
+ extra_lines = []
+ extra_repr = self.extra_repr()
+ # empty string will be split into list ['']
+ if extra_repr:
+ extra_lines = extra_repr.split('\n')
+ child_lines = []
+ for key, module in self._modules.items():
+ mod_str = repr(module)
+ mod_str = _addindent(mod_str, 2)
+ child_lines.append('(' + key + '): ' + mod_str)
+ lines = extra_lines + child_lines
+
+ main_str = self._get_name() + '('
+ if lines:
+ # simple one-liner info, which most builtin Modules will use
+ if len(extra_lines) == 1 and not child_lines:
+ main_str += extra_lines[0]
+ else:
+ main_str += '\n ' + '\n '.join(lines) + '\n'
+
+ main_str += ')'
+ return main_str
+
+ def __dir__(self):
+ module_attrs = dir(self.__class__)
+ attrs = list(self.__dict__.keys())
+ parameters = list(self._parameters.keys())
+ modules = list(self._modules.keys())
+ buffers = list(self._buffers.keys())
+ keys = module_attrs + attrs + parameters + modules + buffers
+
+ # Eliminate attrs that are not legal Python variable names
+ keys = [key for key in keys if not key[0].isdigit()]
+
+ return sorted(keys)
+
+import torch
+import numbers
+from torch.nn.parameter import Parameter
+from .module import Module
+from .. import functional as F
+from .. import init
+from ..._jit_internal import weak_module, weak_script_method
+
+
+[docs]@weak_module
+class LocalResponseNorm(Module):
+ r"""Applies local response normalization over an input signal composed
+ of several input planes, where channels occupy the second dimension.
+ Applies normalization across channels.
+
+ .. math::
+ b_{c} = a_{c}\left(k + \frac{\alpha}{n}
+ \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
+
+ Args:
+ size: amount of neighbouring channels used for normalization
+ alpha: multiplicative factor. Default: 0.0001
+ beta: exponent. Default: 0.75
+ k: additive factor. Default: 1
+
+ Shape:
+ - Input: :math:`(N, C, *)`
+ - Output: :math:`(N, C, *)` (same shape as input)
+
+ Examples::
+
+ >>> lrn = nn.LocalResponseNorm(2)
+ >>> signal_2d = torch.randn(32, 5, 24, 24)
+ >>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7)
+ >>> output_2d = lrn(signal_2d)
+ >>> output_4d = lrn(signal_4d)
+
+ """
+ __constants__ = ['size', 'alpha', 'beta', 'k']
+
+ def __init__(self, size, alpha=1e-4, beta=0.75, k=1.):
+ super(LocalResponseNorm, self).__init__()
+ self.size = size
+ self.alpha = alpha
+ self.beta = beta
+ self.k = k
+
+ @weak_script_method
+ def forward(self, input):
+ return F.local_response_norm(input, self.size, self.alpha, self.beta,
+ self.k)
+
+ def extra_repr(self):
+ return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
+
+
+class CrossMapLRN2d(Module):
+
+ def __init__(self, size, alpha=1e-4, beta=0.75, k=1):
+ super(CrossMapLRN2d, self).__init__()
+ self.size = size
+ self.alpha = alpha
+ self.beta = beta
+ self.k = k
+
+ def forward(self, input):
+ return self._backend.CrossMapLRN2d(self.size, self.alpha, self.beta,
+ self.k)(input)
+
+ def extra_repr(self):
+ return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
+
+
+[docs]@weak_module
+class LayerNorm(Module):
+ r"""Applies Layer Normalization over a mini-batch of inputs as described in
+ the paper `Layer Normalization`_ .
+
+ .. math::
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ The mean and standard-deviation are calculated separately over the last
+ certain number dimensions which have to be of the shape specified by
+ :attr:`normalized_shape`.
+ :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
+ :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
+
+ .. note::
+ Unlike Batch Normalization and Instance Normalization, which applies
+ scalar scale and bias for each entire channel/plane with the
+ :attr:`affine` option, Layer Normalization applies per-element scale and
+ bias with :attr:`elementwise_affine`.
+
+ This layer uses statistics computed from input data in both training and
+ evaluation modes.
+
+ Args:
+ normalized_shape (int or list or torch.Size): input shape from an expected input
+ of size
+
+ .. math::
+ [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
+ \times \ldots \times \text{normalized\_shape}[-1]]
+
+ If a single integer is used, it is treated as a singleton list, and this module will
+ normalize over the last dimension which is expected to be of that specific size.
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
+ elementwise_affine: a boolean value that when set to ``True``, this module
+ has learnable per-element affine parameters initialized to ones (for weights)
+ and zeros (for biases). Default: ``True``.
+
+ Shape:
+ - Input: :math:`(N, *)`
+ - Output: :math:`(N, *)` (same shape as input)
+
+ Examples::
+
+ >>> input = torch.randn(20, 5, 10, 10)
+ >>> # With Learnable Parameters
+ >>> m = nn.LayerNorm(input.size()[1:])
+ >>> # Without Learnable Parameters
+ >>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
+ >>> # Normalize over last two dimensions
+ >>> m = nn.LayerNorm([10, 10])
+ >>> # Normalize over last dimension of size 10
+ >>> m = nn.LayerNorm(10)
+ >>> # Activating the module
+ >>> output = m(input)
+
+ .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
+ """
+ __constants__ = ['normalized_shape', 'weight', 'bias', 'eps']
+
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
+ super(LayerNorm, self).__init__()
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ self.normalized_shape = torch.Size(normalized_shape)
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+ if self.elementwise_affine:
+ self.weight = Parameter(torch.Tensor(*normalized_shape))
+ self.bias = Parameter(torch.Tensor(*normalized_shape))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.elementwise_affine:
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+ @weak_script_method
+ def forward(self, input):
+ return F.layer_norm(
+ input, self.normalized_shape, self.weight, self.bias, self.eps)
+
+ def extra_repr(self):
+ return '{normalized_shape}, eps={eps}, ' \
+ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
+
+
+[docs]@weak_module
+class GroupNorm(Module):
+ r"""Applies Group Normalization over a mini-batch of inputs as described in
+ the paper `Group Normalization`_ .
+
+ .. math::
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ The input channels are separated into :attr:`num_groups` groups, each containing
+ ``num_channels / num_groups`` channels. The mean and standard-deviation are calculated
+ separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
+ per-channel affine transform parameter vectors of size :attr:`num_channels` if
+ :attr:`affine` is ``True``.
+
+ This layer uses statistics computed from input data in both training and
+ evaluation modes.
+
+ Args:
+ num_groups (int): number of groups to separate the channels into
+ num_channels (int): number of channels expected in input
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
+ affine: a boolean value that when set to ``True``, this module
+ has learnable per-channel affine parameters initialized to ones (for weights)
+ and zeros (for biases). Default: ``True``.
+
+ Shape:
+ - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
+ - Output: :math:`(N, C, *)` (same shape as input)
+
+ Examples::
+
+ >>> input = torch.randn(20, 6, 10, 10)
+ >>> # Separate 6 channels into 3 groups
+ >>> m = nn.GroupNorm(3, 6)
+ >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
+ >>> m = nn.GroupNorm(6, 6)
+ >>> # Put all 6 channels into a single group (equivalent with LayerNorm)
+ >>> m = nn.GroupNorm(1, 6)
+ >>> # Activating the module
+ >>> output = m(input)
+
+ .. _`Group Normalization`: https://arxiv.org/abs/1803.08494
+ """
+ __constants__ = ['num_groups', 'num_channels', 'eps', 'affine', 'weight',
+ 'bias']
+
+ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
+ super(GroupNorm, self).__init__()
+ self.num_groups = num_groups
+ self.num_channels = num_channels
+ self.eps = eps
+ self.affine = affine
+ if self.affine:
+ self.weight = Parameter(torch.Tensor(num_channels))
+ self.bias = Parameter(torch.Tensor(num_channels))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.affine:
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+ @weak_script_method
+ def forward(self, input):
+ return F.group_norm(
+ input, self.num_groups, self.weight, self.bias, self.eps)
+
+ def extra_repr(self):
+ return '{num_groups}, {num_channels}, eps={eps}, ' \
+ 'affine={affine}'.format(**self.__dict__)
+
+
+# TODO: ContrastiveNorm2d
+# TODO: DivisiveNorm2d
+# TODO: SubtractiveNorm2d
+
+from .module import Module
+from .utils import _pair, _quadruple, _ntuple
+from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
+
+
+# TODO: grad_output size asserts in THNN
+
+
+@weak_module
+class _ConstantPadNd(Module):
+ __constants__ = ['padding', 'value']
+
+ def __init__(self, value):
+ super(_ConstantPadNd, self).__init__()
+ self.value = value
+
+ @weak_script_method
+ def forward(self, input):
+ return F.pad(input, self.padding, 'constant', self.value)
+
+ def extra_repr(self):
+ return 'padding={}, value={}'.format(self.padding, self.value)
+
+
+[docs]@weak_module
+class ConstantPad1d(_ConstantPadNd):
+ r"""Pads the input tensor boundaries with a constant value.
+
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
+
+ Args:
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
+ padding in both boundaries. If a 2-`tuple`, uses
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
+
+ Shape:
+ - Input: :math:`(N, C, W_{in})`
+ - Output: :math:`(N, C, W_{out})` where
+
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
+
+ Examples::
+
+ >>> m = nn.ConstantPad1d(2, 3.5)
+ >>> input = torch.randn(1, 2, 4)
+ >>> input
+ tensor([[[-1.0491, -0.7152, -0.0749, 0.8530],
+ [-1.3287, 1.8966, 0.1466, -0.2771]]])
+ >>> m(input)
+ tensor([[[ 3.5000, 3.5000, -1.0491, -0.7152, -0.0749, 0.8530, 3.5000,
+ 3.5000],
+ [ 3.5000, 3.5000, -1.3287, 1.8966, 0.1466, -0.2771, 3.5000,
+ 3.5000]]])
+ >>> m = nn.ConstantPad1d(2, 3.5)
+ >>> input = torch.randn(1, 2, 3)
+ >>> input
+ tensor([[[ 1.6616, 1.4523, -1.1255],
+ [-3.6372, 0.1182, -1.8652]]])
+ >>> m(input)
+ tensor([[[ 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000, 3.5000],
+ [ 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000, 3.5000]]])
+ >>> # using different paddings for different sides
+ >>> m = nn.ConstantPad1d((3, 1), 3.5)
+ >>> m(input)
+ tensor([[[ 3.5000, 3.5000, 3.5000, 1.6616, 1.4523, -1.1255, 3.5000],
+ [ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
+
+ """
+
+ def __init__(self, padding, value):
+ super(ConstantPad1d, self).__init__(value)
+ self.padding = _pair(padding)
+
+
+[docs]@weak_module
+class ConstantPad2d(_ConstantPadNd):
+ r"""Pads the input tensor boundaries with a constant value.
+
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
+
+ Args:
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
+ padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
+ :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
+
+ Shape:
+ - Input: :math:`(N, C, H_{in}, W_{in})`
+ - Output: :math:`(N, C, H_{out}, W_{out})` where
+
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
+
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
+
+ Examples::
+
+ >>> m = nn.ConstantPad2d(2, 3.5)
+ >>> input = torch.randn(1, 2, 2)
+ >>> input
+ tensor([[[ 1.6585, 0.4320],
+ [-0.8701, -0.4649]]])
+ >>> m(input)
+ tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
+ [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
+ [ 3.5000, 3.5000, 1.6585, 0.4320, 3.5000, 3.5000],
+ [ 3.5000, 3.5000, -0.8701, -0.4649, 3.5000, 3.5000],
+ [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
+ [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]])
+ >>> # using different paddings for different sides
+ >>> m = nn.ConstantPad2d((3, 0, 2, 1), 3.5)
+ >>> m(input)
+ tensor([[[ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
+ [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000],
+ [ 3.5000, 3.5000, 3.5000, 1.6585, 0.4320],
+ [ 3.5000, 3.5000, 3.5000, -0.8701, -0.4649],
+ [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]])
+
+ """
+ __constants__ = ['padding', 'value']
+
+ def __init__(self, padding, value):
+ super(ConstantPad2d, self).__init__(value)
+ self.padding = _quadruple(padding)
+
+
+[docs]@weak_module
+class ConstantPad3d(_ConstantPadNd):
+ r"""Pads the input tensor boundaries with a constant value.
+
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
+
+ Args:
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
+ padding in all boundaries. If a 6-`tuple`, uses
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
+ :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
+ :math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
+
+ Shape:
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where
+
+ :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
+
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
+
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
+
+ Examples::
+
+ >>> m = nn.ConstantPad3d(3, 3.5)
+ >>> input = torch.randn(16, 3, 10, 20, 30)
+ >>> output = m(input)
+ >>> # using different paddings for different sides
+ >>> m = nn.ConstantPad3d((3, 3, 6, 6, 0, 1), 3.5)
+ >>> output = m(input)
+
+ """
+
+ def __init__(self, padding, value):
+ super(ConstantPad3d, self).__init__(value)
+ self.padding = _ntuple(6)(padding)
+
+
+@weak_module
+class _ReflectionPadNd(Module):
+ __constants__ = ['padding']
+
+ @weak_script_method
+ def forward(self, input):
+ return F.pad(input, self.padding, 'reflect')
+
+ def extra_repr(self):
+ return '{}'.format(self.padding)
+
+
+[docs]@weak_module
+class ReflectionPad1d(_ReflectionPadNd):
+ r"""Pads the input tensor using the reflection of the input boundary.
+
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
+
+ Args:
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
+ padding in all boundaries. If a 2-`tuple`, uses
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
+
+ Shape:
+ - Input: :math:`(N, C, W_{in})`
+ - Output: :math:`(N, C, W_{out})` where
+
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
+
+ Examples::
+
+ >>> m = nn.ReflectionPad1d(2)
+ >>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
+ >>> input
+ tensor([[[0., 1., 2., 3.],
+ [4., 5., 6., 7.]]])
+ >>> m(input)
+ tensor([[[2., 1., 0., 1., 2., 3., 2., 1.],
+ [6., 5., 4., 5., 6., 7., 6., 5.]]])
+ >>> # using different paddings for different sides
+ >>> m = nn.ReflectionPad1d((3, 1))
+ >>> m(input)
+ tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
+ [7., 6., 5., 4., 5., 6., 7., 6.]]])
+
+ """
+
+ def __init__(self, padding):
+ super(ReflectionPad1d, self).__init__()
+ self.padding = _pair(padding)
+
+
+[docs]@weak_module
+class ReflectionPad2d(_ReflectionPadNd):
+ r"""Pads the input tensor using the reflection of the input boundary.
+
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
+
+ Args:
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
+ padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
+ :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
+
+ Shape:
+ - Input: :math:`(N, C, H_{in}, W_{in})`
+ - Output: :math:`(N, C, H_{out}, W_{out})` where
+
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
+
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
+
+ Examples::
+
+ >>> m = nn.ReflectionPad2d(2)
+ >>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
+ >>> input
+ tensor([[[[0., 1., 2.],
+ [3., 4., 5.],
+ [6., 7., 8.]]]])
+ >>> m(input)
+ tensor([[[[8., 7., 6., 7., 8., 7., 6.],
+ [5., 4., 3., 4., 5., 4., 3.],
+ [2., 1., 0., 1., 2., 1., 0.],
+ [5., 4., 3., 4., 5., 4., 3.],
+ [8., 7., 6., 7., 8., 7., 6.],
+ [5., 4., 3., 4., 5., 4., 3.],
+ [2., 1., 0., 1., 2., 1., 0.]]]])
+ >>> # using different paddings for different sides
+ >>> m = nn.ReflectionPad2d((1, 1, 2, 0))
+ >>> m(input)
+ tensor([[[[7., 6., 7., 8., 7.],
+ [4., 3., 4., 5., 4.],
+ [1., 0., 1., 2., 1.],
+ [4., 3., 4., 5., 4.],
+ [7., 6., 7., 8., 7.]]]])
+
+ """
+
+ def __init__(self, padding):
+ super(ReflectionPad2d, self).__init__()
+ self.padding = _quadruple(padding)
+
+
+@weak_module
+class _ReplicationPadNd(Module):
+ __constants__ = ['padding']
+
+ @weak_script_method
+ def forward(self, input):
+ return F.pad(input, self.padding, 'replicate')
+
+ def extra_repr(self):
+ return '{}'.format(self.padding)
+
+
+[docs]@weak_module
+class ReplicationPad1d(_ReplicationPadNd):
+ r"""Pads the input tensor using replication of the input boundary.
+
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
+
+ Args:
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
+ padding in all boundaries. If a 2-`tuple`, uses
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
+
+ Shape:
+ - Input: :math:`(N, C, W_{in})`
+ - Output: :math:`(N, C, W_{out})` where
+
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
+
+ Examples::
+
+ >>> m = nn.ReplicationPad1d(2)
+ >>> input = torch.arange(8, dtype=torch.float).reshape(1, 2, 4)
+ >>> input
+ tensor([[[0., 1., 2., 3.],
+ [4., 5., 6., 7.]]])
+ >>> m(input)
+ tensor([[[0., 0., 0., 1., 2., 3., 3., 3.],
+ [4., 4., 4., 5., 6., 7., 7., 7.]]])
+ >>> # using different paddings for different sides
+ >>> m = nn.ReplicationPad1d((3, 1))
+ >>> m(input)
+ tensor([[[0., 0., 0., 0., 1., 2., 3., 3.],
+ [4., 4., 4., 4., 5., 6., 7., 7.]]])
+
+ """
+
+ def __init__(self, padding):
+ super(ReplicationPad1d, self).__init__()
+ self.padding = _pair(padding)
+
+
+[docs]@weak_module
+class ReplicationPad2d(_ReplicationPadNd):
+ r"""Pads the input tensor using replication of the input boundary.
+
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
+
+ Args:
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
+ padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
+ :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
+
+ Shape:
+ - Input: :math:`(N, C, H_{in}, W_{in})`
+ - Output: :math:`(N, C, H_{out}, W_{out})` where
+
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
+
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
+
+ Examples::
+
+ >>> m = nn.ReplicationPad2d(2)
+ >>> input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)
+ >>> input
+ tensor([[[[0., 1., 2.],
+ [3., 4., 5.],
+ [6., 7., 8.]]]])
+ >>> m(input)
+ tensor([[[[0., 0., 0., 1., 2., 2., 2.],
+ [0., 0., 0., 1., 2., 2., 2.],
+ [0., 0., 0., 1., 2., 2., 2.],
+ [3., 3., 3., 4., 5., 5., 5.],
+ [6., 6., 6., 7., 8., 8., 8.],
+ [6., 6., 6., 7., 8., 8., 8.],
+ [6., 6., 6., 7., 8., 8., 8.]]]])
+ >>> # using different paddings for different sides
+ >>> m = nn.ReplicationPad2d((1, 1, 2, 0))
+ >>> m(input)
+ tensor([[[[0., 0., 1., 2., 2.],
+ [0., 0., 1., 2., 2.],
+ [0., 0., 1., 2., 2.],
+ [3., 3., 4., 5., 5.],
+ [6., 6., 7., 8., 8.]]]])
+
+ """
+
+ def __init__(self, padding):
+ super(ReplicationPad2d, self).__init__()
+ self.padding = _quadruple(padding)
+
+
+[docs]@weak_module
+class ReplicationPad3d(_ReplicationPadNd):
+ r"""Pads the input tensor using replication of the input boundary.
+
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
+
+ Args:
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
+ padding in all boundaries. If a 6-`tuple`, uses
+ (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
+ :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
+ :math:`\text{padding\_front}`, :math:`\text{padding\_back}`)
+
+ Shape:
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where
+
+ :math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`
+
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
+
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
+
+ Examples::
+
+ >>> m = nn.ReplicationPad3d(3)
+ >>> input = torch.randn(16, 3, 8, 320, 480)
+ >>> output = m(input)
+ >>> # using different paddings for different sides
+ >>> m = nn.ReplicationPad3d((3, 3, 6, 6, 1, 1))
+ >>> output = m(input)
+
+ """
+
+ def __init__(self, padding):
+ super(ReplicationPad3d, self).__init__()
+ self.padding = _ntuple(6)(padding)
+
+
+[docs]@weak_module
+class ZeroPad2d(ConstantPad2d):
+ r"""Pads the input tensor boundaries with zero.
+
+ For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.
+
+ Args:
+ padding (int, tuple): the size of the padding. If is `int`, uses the same
+ padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
+ :math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)
+
+ Shape:
+ - Input: :math:`(N, C, H_{in}, W_{in})`
+ - Output: :math:`(N, C, H_{out}, W_{out})` where
+
+ :math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`
+
+ :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
+
+ Examples::
+
+ >>> m = nn.ZeroPad2d(2)
+ >>> input = torch.randn(1, 1, 3, 3)
+ >>> input
+ tensor([[[[-0.1678, -0.4418, 1.9466],
+ [ 0.9604, -0.4219, -0.5241],
+ [-0.9162, -0.5436, -0.6446]]]])
+ >>> m(input)
+ tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
+ [ 0.0000, 0.0000, -0.1678, -0.4418, 1.9466, 0.0000, 0.0000],
+ [ 0.0000, 0.0000, 0.9604, -0.4219, -0.5241, 0.0000, 0.0000],
+ [ 0.0000, 0.0000, -0.9162, -0.5436, -0.6446, 0.0000, 0.0000],
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
+ >>> # using different paddings for different sides
+ >>> m = nn.ZeroPad2d((1, 1, 2, 0))
+ >>> m(input)
+ tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
+ [ 0.0000, -0.1678, -0.4418, 1.9466, 0.0000],
+ [ 0.0000, 0.9604, -0.4219, -0.5241, 0.0000],
+ [ 0.0000, -0.9162, -0.5436, -0.6446, 0.0000]]]])
+
+ """
+
+ def __init__(self, padding):
+ super(ZeroPad2d, self).__init__(padding, 0.)
+
+from .module import Module
+from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
+
+
+[docs]@weak_module
+class PixelShuffle(Module):
+ r"""Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
+ to a tensor of shape :math:`(*, C, H \times r, W \times r)`.
+
+ This is useful for implementing efficient sub-pixel convolution
+ with a stride of :math:`1/r`.
+
+ Look at the paper:
+ `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
+ by Shi et. al (2016) for more details.
+
+ Args:
+ upscale_factor (int): factor to increase spatial resolution by
+
+ Shape:
+ - Input: :math:`(N, L, H_{in}, W_{in})` where :math:`L=C \times \text{upscale\_factor}^2`
+ - Output: :math:`(N, C, H_{out}, W_{out})` where
+ :math:`H_{out} = H_{in} \times \text{upscale\_factor}`
+ and :math:`W_{out} = W_{in} \times \text{upscale\_factor}`
+
+ Examples::
+
+ >>> pixel_shuffle = nn.PixelShuffle(3)
+ >>> input = torch.randn(1, 9, 4, 4)
+ >>> output = pixel_shuffle(input)
+ >>> print(output.size())
+ torch.Size([1, 1, 12, 12])
+
+ .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
+ https://arxiv.org/abs/1609.05158
+ """
+ __constants__ = ['upscale_factor']
+
+ def __init__(self, upscale_factor):
+ super(PixelShuffle, self).__init__()
+ self.upscale_factor = upscale_factor
+
+ @weak_script_method
+ def forward(self, input):
+ return F.pixel_shuffle(input, self.upscale_factor)
+
+ def extra_repr(self):
+ return 'upscale_factor={}'.format(self.upscale_factor)
+
+from .module import Module
+from .utils import _single, _pair, _triple
+from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
+
+
+@weak_module
+class _MaxPoolNd(Module):
+ __constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
+ 'return_indices', 'ceil_mode']
+
+ def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
+ return_indices=False, ceil_mode=False):
+ super(_MaxPoolNd, self).__init__()
+ self.kernel_size = kernel_size
+ self.stride = stride or kernel_size
+ self.padding = padding
+ self.dilation = dilation
+ self.return_indices = return_indices
+ self.ceil_mode = ceil_mode
+
+ def extra_repr(self):
+ return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \
+ ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
+
+
+[docs]@weak_module
+class MaxPool1d(_MaxPoolNd):
+ r"""Applies a 1D max pooling over an input signal composed of several input
+ planes.
+
+ In the simplest case, the output value of the layer with input size :math:`(N, C, L)`
+ and output :math:`(N, C, L_{out})` can be precisely described as:
+
+ .. math::
+ out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
+ input(N_i, C_j, stride \times k + m)
+
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
+
+ Args:
+ kernel_size: the size of the window to take a max over
+ stride: the stride of the window. Default value is :attr:`kernel_size`
+ padding: implicit zero padding to be added on both sides
+ dilation: a parameter that controls the stride of elements in the window
+ return_indices: if ``True``, will return the max indices along with the outputs.
+ Useful for :class:`torch.nn.MaxUnpool1d` later
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
+
+ Shape:
+ - Input: :math:`(N, C, L_{in})`
+ - Output: :math:`(N, C, L_{out})`, where
+
+ .. math::
+ L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
+ \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor
+
+ Examples::
+
+ >>> # pool of size=3, stride=2
+ >>> m = nn.MaxPool1d(3, stride=2)
+ >>> input = torch.randn(20, 16, 50)
+ >>> output = m(input)
+
+ .. _link:
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.max_pool1d(input, self.kernel_size, self.stride,
+ self.padding, self.dilation, self.ceil_mode,
+ self.return_indices)
+
+ def extra_repr(self):
+ return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \
+ ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
+
+
+[docs]@weak_module
+class MaxPool2d(_MaxPoolNd):
+ r"""Applies a 2D max pooling over an input signal composed of several input
+ planes.
+
+ In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
+ output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`
+ can be precisely described as:
+
+ .. math::
+ \begin{aligned}
+ out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
+ & \text{input}(N_i, C_j, \text{stride[0]} \times h + m,
+ \text{stride[1]} \times w + n)
+ \end{aligned}
+
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
+
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
+
+ - a single ``int`` -- in which case the same value is used for the height and width dimension
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
+ and the second `int` for the width dimension
+
+ Args:
+ kernel_size: the size of the window to take a max over
+ stride: the stride of the window. Default value is :attr:`kernel_size`
+ padding: implicit zero padding to be added on both sides
+ dilation: a parameter that controls the stride of elements in the window
+ return_indices: if ``True``, will return the max indices along with the outputs.
+ Useful for :class:`torch.nn.MaxUnpool2d` later
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
+
+ Shape:
+ - Input: :math:`(N, C, H_{in}, W_{in})`
+ - Output: :math:`(N, C, H_{out}, W_{out})`, where
+
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
+ \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
+ \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
+
+ Examples::
+
+ >>> # pool of square window of size=3, stride=2
+ >>> m = nn.MaxPool2d(3, stride=2)
+ >>> # pool of non-square window
+ >>> m = nn.MaxPool2d((3, 2), stride=(2, 1))
+ >>> input = torch.randn(20, 16, 50, 32)
+ >>> output = m(input)
+
+ .. _link:
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.max_pool2d(input, self.kernel_size, self.stride,
+ self.padding, self.dilation, self.ceil_mode,
+ self.return_indices)
+
+
+[docs]@weak_module
+class MaxPool3d(_MaxPoolNd):
+ r"""Applies a 3D max pooling over an input signal composed of several input
+ planes.
+
+ In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
+ output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
+ can be precisely described as:
+
+ .. math::
+ \begin{aligned}
+ \text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
+ & \text{input}(N_i, C_j, \text{stride[0]} \times d + k,
+ \text{stride[1]} \times h + m, \text{stride[2]} \times w + n)
+ \end{aligned}
+
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
+ for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
+ It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
+
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
+
+ - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
+ the second `int` for the height dimension and the third `int` for the width dimension
+
+ Args:
+ kernel_size: the size of the window to take a max over
+ stride: the stride of the window. Default value is :attr:`kernel_size`
+ padding: implicit zero padding to be added on all three sides
+ dilation: a parameter that controls the stride of elements in the window
+ return_indices: if ``True``, will return the max indices along with the outputs.
+ Useful for :class:`torch.nn.MaxUnpool3d` later
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
+
+ Shape:
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where
+
+ .. math::
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
+ (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
+
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
+ (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
+ (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
+
+ Examples::
+
+ >>> # pool of square window of size=3, stride=2
+ >>> m = nn.MaxPool3d(3, stride=2)
+ >>> # pool of non-square window
+ >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2))
+ >>> input = torch.randn(20, 16, 50,44, 31)
+ >>> output = m(input)
+
+ .. _link:
+ https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+ """ # noqa: E501
+
+ @weak_script_method
+ def forward(self, input):
+ return F.max_pool3d(input, self.kernel_size, self.stride,
+ self.padding, self.dilation, self.ceil_mode,
+ self.return_indices)
+
+
+@weak_module
+class _MaxUnpoolNd(Module):
+
+ def extra_repr(self):
+ return 'kernel_size={}, stride={}, padding={}'.format(
+ self.kernel_size, self.stride, self.padding
+ )
+
+
+[docs]@weak_module
+class MaxUnpool1d(_MaxUnpoolNd):
+ r"""Computes a partial inverse of :class:`MaxPool1d`.
+
+ :class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost.
+
+ :class:`MaxUnpool1d` takes in as input the output of :class:`MaxPool1d`
+ including the indices of the maximal values and computes a partial inverse
+ in which all non-maximal values are set to zero.
+
+ .. note:: :class:`MaxPool1d` can map several input sizes to the same output
+ sizes. Hence, the inversion process can get ambiguous.
+ To accommodate this, you can provide the needed output size
+ as an additional argument :attr:`output_size` in the forward call.
+ See the Inputs and Example below.
+
+ Args:
+ kernel_size (int or tuple): Size of the max pooling window.
+ stride (int or tuple): Stride of the max pooling window.
+ It is set to :attr:`kernel_size` by default.
+ padding (int or tuple): Padding that was added to the input
+
+ Inputs:
+ - `input`: the input Tensor to invert
+ - `indices`: the indices given out by :class:`~torch.nn.MaxPool1d`
+ - `output_size` (optional): the targeted output size
+
+ Shape:
+ - Input: :math:`(N, C, H_{in})`
+ - Output: :math:`(N, C, H_{out})`, where
+
+ .. math::
+ H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
+
+ or as given by :attr:`output_size` in the call operator
+
+ Example::
+
+ >>> pool = nn.MaxPool1d(2, stride=2, return_indices=True)
+ >>> unpool = nn.MaxUnpool1d(2, stride=2)
+ >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8]]])
+ >>> output, indices = pool(input)
+ >>> unpool(output, indices)
+ tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
+
+ >>> # Example showcasing the use of output_size
+ >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8, 9]]])
+ >>> output, indices = pool(input)
+ >>> unpool(output, indices, output_size=input.size())
+ tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8., 0.]]])
+
+ >>> unpool(output, indices)
+ tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
+ """
+
+ def __init__(self, kernel_size, stride=None, padding=0):
+ super(MaxUnpool1d, self).__init__()
+ self.kernel_size = _single(kernel_size)
+ self.stride = _single(stride or kernel_size)
+ self.padding = _single(padding)
+
+ def forward(self, input, indices, output_size=None):
+ return F.max_unpool1d(input, indices, self.kernel_size, self.stride,
+ self.padding, output_size)
+
+
+[docs]@weak_module
+class MaxUnpool2d(_MaxUnpoolNd):
+ r"""Computes a partial inverse of :class:`MaxPool2d`.
+
+ :class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost.
+
+ :class:`MaxUnpool2d` takes in as input the output of :class:`MaxPool2d`
+ including the indices of the maximal values and computes a partial inverse
+ in which all non-maximal values are set to zero.
+
+ .. note:: :class:`MaxPool2d` can map several input sizes to the same output
+ sizes. Hence, the inversion process can get ambiguous.
+ To accommodate this, you can provide the needed output size
+ as an additional argument :attr:`output_size` in the forward call.
+ See the Inputs and Example below.
+
+ Args:
+ kernel_size (int or tuple): Size of the max pooling window.
+ stride (int or tuple): Stride of the max pooling window.
+ It is set to :attr:`kernel_size` by default.
+ padding (int or tuple): Padding that was added to the input
+
+ Inputs:
+ - `input`: the input Tensor to invert
+ - `indices`: the indices given out by :class:`~torch.nn.MaxPool2d`
+ - `output_size` (optional): the targeted output size
+
+ Shape:
+ - Input: :math:`(N, C, H_{in}, W_{in})`
+ - Output: :math:`(N, C, H_{out}, W_{out})`, where
+
+ .. math::
+ H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
+
+ .. math::
+ W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
+
+ or as given by :attr:`output_size` in the call operator
+
+ Example::
+
+ >>> pool = nn.MaxPool2d(2, stride=2, return_indices=True)
+ >>> unpool = nn.MaxUnpool2d(2, stride=2)
+ >>> input = torch.tensor([[[[ 1., 2, 3, 4],
+ [ 5, 6, 7, 8],
+ [ 9, 10, 11, 12],
+ [13, 14, 15, 16]]]])
+ >>> output, indices = pool(input)
+ >>> unpool(output, indices)
+ tensor([[[[ 0., 0., 0., 0.],
+ [ 0., 6., 0., 8.],
+ [ 0., 0., 0., 0.],
+ [ 0., 14., 0., 16.]]]])
+
+ >>> # specify a different output size than input size
+ >>> unpool(output, indices, output_size=torch.Size([1, 1, 5, 5]))
+ tensor([[[[ 0., 0., 0., 0., 0.],
+ [ 6., 0., 8., 0., 0.],
+ [ 0., 0., 0., 14., 0.],
+ [ 16., 0., 0., 0., 0.],
+ [ 0., 0., 0., 0., 0.]]]])
+ """
+
+ def __init__(self, kernel_size, stride=None, padding=0):
+ super(MaxUnpool2d, self).__init__()
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride or kernel_size)
+ self.padding = _pair(padding)
+
+ def forward(self, input, indices, output_size=None):
+ return F.max_unpool2d(input, indices, self.kernel_size, self.stride,
+ self.padding, output_size)
+
+
+[docs]@weak_module
+class MaxUnpool3d(_MaxUnpoolNd):
+ r"""Computes a partial inverse of :class:`MaxPool3d`.
+
+ :class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost.
+ :class:`MaxUnpool3d` takes in as input the output of :class:`MaxPool3d`
+ including the indices of the maximal values and computes a partial inverse
+ in which all non-maximal values are set to zero.
+
+ .. note:: :class:`MaxPool3d` can map several input sizes to the same output
+ sizes. Hence, the inversion process can get ambiguous.
+ To accommodate this, you can provide the needed output size
+ as an additional argument :attr:`output_size` in the forward call.
+ See the Inputs section below.
+
+ Args:
+ kernel_size (int or tuple): Size of the max pooling window.
+ stride (int or tuple): Stride of the max pooling window.
+ It is set to :attr:`kernel_size` by default.
+ padding (int or tuple): Padding that was added to the input
+
+ Inputs:
+ - `input`: the input Tensor to invert
+ - `indices`: the indices given out by :class:`~torch.nn.MaxPool3d`
+ - `output_size` (optional): the targeted output size
+
+ Shape:
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where
+
+ .. math::
+ D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
+
+ .. math::
+ H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
+
+ .. math::
+ W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]}
+
+ or as given by :attr:`output_size` in the call operator
+
+ Example::
+
+ >>> # pool of square window of size=3, stride=2
+ >>> pool = nn.MaxPool3d(3, stride=2, return_indices=True)
+ >>> unpool = nn.MaxUnpool3d(3, stride=2)
+ >>> output, indices = pool(torch.randn(20, 16, 51, 33, 15))
+ >>> unpooled_output = unpool(output, indices)
+ >>> unpooled_output.size()
+ torch.Size([20, 16, 51, 33, 15])
+ """
+
+ def __init__(self, kernel_size, stride=None, padding=0):
+ super(MaxUnpool3d, self).__init__()
+ self.kernel_size = _triple(kernel_size)
+ self.stride = _triple(stride or kernel_size)
+ self.padding = _triple(padding)
+
+ def forward(self, input, indices, output_size=None):
+ return F.max_unpool3d(input, indices, self.kernel_size, self.stride,
+ self.padding, output_size)
+
+
+@weak_module
+class _AvgPoolNd(Module):
+ __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad']
+
+ def extra_repr(self):
+ return 'kernel_size={}, stride={}, padding={}'.format(
+ self.kernel_size, self.stride, self.padding
+ )
+
+
+[docs]@weak_module
+class AvgPool1d(_AvgPoolNd):
+ r"""Applies a 1D average pooling over an input signal composed of several
+ input planes.
+
+ In the simplest case, the output value of the layer with input size :math:`(N, C, L)`,
+ output :math:`(N, C, L_{out})` and :attr:`kernel_size` :math:`k`
+ can be precisely described as:
+
+ .. math::
+
+ \text{out}(N_i, C_j, l) = \frac{1}{k} \sum_{m=0}^{k-1}
+ \text{input}(N_i, C_j, \text{stride} \times l + m)
+
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
+ for :attr:`padding` number of points.
+
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be
+ an ``int`` or a one-element tuple.
+
+ Args:
+ kernel_size: the size of the window
+ stride: the stride of the window. Default value is :attr:`kernel_size`
+ padding: implicit zero padding to be added on both sides
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
+ count_include_pad: when True, will include the zero-padding in the averaging calculation
+
+ Shape:
+ - Input: :math:`(N, C, L_{in})`
+ - Output: :math:`(N, C, L_{out})`, where
+
+ .. math::
+ L_{out} = \left\lfloor \frac{L_{in} +
+ 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
+
+ Examples::
+
+ >>> # pool with window of size=3, stride=2
+ >>> m = nn.AvgPool1d(3, stride=2)
+ >>> m(torch.tensor([[[1.,2,3,4,5,6,7]]]))
+ tensor([[[ 2., 4., 6.]]])
+ """
+ def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
+ count_include_pad=True):
+ super(AvgPool1d, self).__init__()
+ self.kernel_size = _single(kernel_size)
+ self.stride = _single(stride if stride is not None else kernel_size)
+ self.padding = _single(padding)
+ self.ceil_mode = ceil_mode
+ self.count_include_pad = count_include_pad
+
+ @weak_script_method
+ def forward(self, input):
+ return F.avg_pool1d(
+ input, self.kernel_size, self.stride, self.padding, self.ceil_mode,
+ self.count_include_pad)
+
+
+[docs]@weak_module
+class AvgPool2d(_AvgPoolNd):
+ r"""Applies a 2D average pooling over an input signal composed of several input
+ planes.
+
+ In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
+ output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`
+ can be precisely described as:
+
+ .. math::
+
+ out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
+ input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n)
+
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
+ for :attr:`padding` number of points.
+
+ The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be:
+
+ - a single ``int`` -- in which case the same value is used for the height and width dimension
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
+ and the second `int` for the width dimension
+
+ Args:
+ kernel_size: the size of the window
+ stride: the stride of the window. Default value is :attr:`kernel_size`
+ padding: implicit zero padding to be added on both sides
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
+ count_include_pad: when True, will include the zero-padding in the averaging calculation
+
+ Shape:
+ - Input: :math:`(N, C, H_{in}, W_{in})`
+ - Output: :math:`(N, C, H_{out}, W_{out})`, where
+
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
+
+ Examples::
+
+ >>> # pool of square window of size=3, stride=2
+ >>> m = nn.AvgPool2d(3, stride=2)
+ >>> # pool of non-square window
+ >>> m = nn.AvgPool2d((3, 2), stride=(2, 1))
+ >>> input = torch.randn(20, 16, 50, 32)
+ >>> output = m(input)
+ """
+ def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
+ count_include_pad=True):
+ super(AvgPool2d, self).__init__()
+ self.kernel_size = kernel_size
+ self.stride = stride or kernel_size
+ self.padding = padding
+ self.ceil_mode = ceil_mode
+ self.count_include_pad = count_include_pad
+
+ @weak_script_method
+ def forward(self, input):
+ return F.avg_pool2d(input, self.kernel_size, self.stride,
+ self.padding, self.ceil_mode, self.count_include_pad)
+
+
+[docs]@weak_module
+class AvgPool3d(_AvgPoolNd):
+ r"""Applies a 3D average pooling over an input signal composed of several input
+ planes.
+
+ In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
+ output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
+ can be precisely described as:
+
+ .. math::
+ \begin{aligned}
+ \text{out}(N_i, C_j, d, h, w) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
+ & \frac{\text{input}(N_i, C_j, \text{stride}[0] \times d + k,
+ \text{stride}[1] \times h + m, \text{stride}[2] \times w + n)}
+ {kD \times kH \times kW}
+ \end{aligned}
+
+ If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
+ for :attr:`padding` number of points.
+
+ The parameters :attr:`kernel_size`, :attr:`stride` can either be:
+
+ - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
+ - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
+ the second `int` for the height dimension and the third `int` for the width dimension
+
+ Args:
+ kernel_size: the size of the window
+ stride: the stride of the window. Default value is :attr:`kernel_size`
+ padding: implicit zero padding to be added on all three sides
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
+ count_include_pad: when True, will include the zero-padding in the averaging calculation
+
+ Shape:
+ - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})`, where
+
+ .. math::
+ D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
+ \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
+
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
+ \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
+ \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
+
+ Examples::
+
+ >>> # pool of square window of size=3, stride=2
+ >>> m = nn.AvgPool3d(3, stride=2)
+ >>> # pool of non-square window
+ >>> m = nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2))
+ >>> input = torch.randn(20, 16, 50,44, 31)
+ >>> output = m(input)
+ """
+ def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
+ count_include_pad=True):
+ super(AvgPool3d, self).__init__()
+ self.kernel_size = kernel_size
+ self.stride = stride or kernel_size
+ self.padding = padding
+ self.ceil_mode = ceil_mode
+ self.count_include_pad = count_include_pad
+
+ @weak_script_method
+ def forward(self, input):
+ return F.avg_pool3d(input, self.kernel_size, self.stride,
+ self.padding, self.ceil_mode, self.count_include_pad)
+
+ def __setstate__(self, d):
+ super(AvgPool3d, self).__setstate__(d)
+ self.__dict__.setdefault('padding', 0)
+ self.__dict__.setdefault('ceil_mode', False)
+ self.__dict__.setdefault('count_include_pad', True)
+
+
+[docs]@weak_module
+class FractionalMaxPool2d(Module):
+ r"""Applies a 2D fractional max pooling over an input signal composed of several input planes.
+
+ Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
+
+ The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic
+ step size determined by the target output size.
+ The number of output features is equal to the number of input planes.
+
+ Args:
+ kernel_size: the size of the window to take a max over.
+ Can be a single number k (for a square kernel of k x k) or a tuple `(kh, kw)`
+ output_size: the target output size of the image of the form `oH x oW`.
+ Can be a tuple `(oH, oW)` or a single number oH for a square image `oH x oH`
+ output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
+ This has to be a number or tuple in the range (0, 1)
+ return_indices: if ``True``, will return the indices along with the outputs.
+ Useful to pass to :meth:`nn.MaxUnpool2d`. Default: ``False``
+
+ Examples:
+ >>> # pool of square window of size=3, and target output size 13x12
+ >>> m = nn.FractionalMaxPool2d(3, output_size=(13, 12))
+ >>> # pool of square window and target output size being half of input image size
+ >>> m = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5))
+ >>> input = torch.randn(20, 16, 50, 32)
+ >>> output = m(input)
+
+ .. _Fractional MaxPooling:
+ http://arxiv.org/abs/1412.6071
+ """
+ __constants__ = ['kernel_size', 'return_indices', 'output_size',
+ 'output_ratio']
+
+ def __init__(self, kernel_size, output_size=None, output_ratio=None,
+ return_indices=False, _random_samples=None):
+ super(FractionalMaxPool2d, self).__init__()
+ self.kernel_size = _pair(kernel_size)
+ self.return_indices = return_indices
+ self.register_buffer('_random_samples', _random_samples)
+ self.output_size = _pair(output_size) if output_size is not None else None
+ self.output_ratio = _pair(output_ratio) if output_ratio is not None else None
+ if output_size is None and output_ratio is None:
+ raise ValueError("FractionalMaxPool2d requires specifying either "
+ "an output size, or a pooling ratio")
+ if output_size is not None and output_ratio is not None:
+ raise ValueError("only one of output_size and output_ratio may be specified")
+ if self.output_ratio is not None:
+ if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1):
+ raise ValueError("output_ratio must be between 0 and 1 (got {})"
+ .format(output_ratio))
+
+ @weak_script_method
+ def forward(self, input):
+ return F.fractional_max_pool2d(
+ input, self.kernel_size, self.output_size, self.output_ratio,
+ self.return_indices,
+ _random_samples=self._random_samples)
+
+
+@weak_module
+class FractionalMaxPool3d(Module):
+ r"""Applies a 3D fractional max pooling over an input signal composed of several input planes.
+
+ Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
+
+ The max-pooling operation is applied in :math:`kTxkHxkW` regions by a stochastic
+ step size determined by the target output size.
+ The number of output features is equal to the number of input planes.
+
+ Args:
+ kernel_size: the size of the window to take a max over.
+ Can be a single number k (for a square kernel of k x k x k) or a tuple `(kt x kh x kw)`
+ output_size: the target output size of the image of the form `oT x oH x oW`.
+ Can be a tuple `(oT, oH, oW)` or a single number oH for a square image `oH x oH x oH`
+ output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
+ This has to be a number or tuple in the range (0, 1)
+ return_indices: if ``True``, will return the indices along with the outputs.
+ Useful to pass to :meth:`nn.MaxUnpool3d`. Default: ``False``
+
+ Examples:
+ >>> # pool of cubic window of size=3, and target output size 13x12x11
+ >>> m = nn.FractionalMaxPool3d(3, output_size=(13, 12, 11))
+ >>> # pool of cubic window and target output size being half of input size
+ >>> m = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5))
+ >>> input = torch.randn(20, 16, 50, 32, 16)
+ >>> output = m(input)
+
+ .. _Fractional MaxPooling:
+ http://arxiv.org/abs/1412.6071
+ """
+ __constants__ = ['kernel_size', 'return_indices', 'output_size',
+ 'output_ratio']
+
+ def __init__(self, kernel_size, output_size=None, output_ratio=None,
+ return_indices=False, _random_samples=None):
+ super(FractionalMaxPool3d, self).__init__()
+ self.kernel_size = _triple(kernel_size)
+ self.return_indices = return_indices
+ self.register_buffer('_random_samples', _random_samples)
+ self.output_size = _triple(output_size) if output_size is not None else None
+ self.output_ratio = _triple(output_ratio) if output_ratio is not None else None
+ if output_size is None and output_ratio is None:
+ raise ValueError("FractionalMaxPool3d requires specifying either "
+ "an output size, or a pooling ratio")
+ if output_size is not None and output_ratio is not None:
+ raise ValueError("only one of output_size and output_ratio may be specified")
+ if self.output_ratio is not None:
+ if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1 and 0 < self.output_ratio[2] < 1):
+ raise ValueError("output_ratio must be between 0 and 1 (got {})"
+ .format(output_ratio))
+
+ @weak_script_method
+ def forward(self, input):
+ return F.fractional_max_pool3d(
+ input, self.kernel_size, self.output_size, self.output_ratio,
+ self.return_indices,
+ _random_samples=self._random_samples)
+
+
+@weak_module
+class _LPPoolNd(Module):
+ __constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode']
+
+ def __init__(self, norm_type, kernel_size, stride=None, ceil_mode=False):
+ super(_LPPoolNd, self).__init__()
+ self.norm_type = norm_type
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.ceil_mode = ceil_mode
+
+ def extra_repr(self):
+ return 'norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, ' \
+ 'ceil_mode={ceil_mode}'.format(**self.__dict__)
+
+
+[docs]@weak_module
+class LPPool1d(_LPPoolNd):
+ r"""Applies a 1D power-average pooling over an input signal composed of several input
+ planes.
+
+ On each window, the function computed is:
+
+ .. math::
+ f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
+
+ - At p = :math:`\infty`, one gets Max Pooling
+ - At p = 1, one gets Sum Pooling (which is proportional to Average Pooling)
+
+ .. note:: If the sum to the power of `p` is zero, the gradient of this function is
+ not defined. This implementation will set the gradient to zero in this case.
+
+ Args:
+ kernel_size: a single int, the size of the window
+ stride: a single int, the stride of the window. Default value is :attr:`kernel_size`
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
+
+ Shape:
+ - Input: :math:`(N, C, L_{in})`
+ - Output: :math:`(N, C, L_{out})`, where
+
+ .. math::
+ L_{out} = \left\lfloor\frac{L_{in} +
+ 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
+
+ Examples::
+ >>> # power-2 pool of window of length 3, with stride 2.
+ >>> m = nn.LPPool1d(2, 3, stride=2)
+ >>> input = torch.randn(20, 16, 50)
+ >>> output = m(input)
+ """
+
+ @weak_script_method
+ @weak_script_method
+ def forward(self, input):
+ return F.lp_pool1d(input, float(self.norm_type), self.kernel_size,
+ self.stride, self.ceil_mode)
+
+
+[docs]@weak_module
+class LPPool2d(_LPPoolNd):
+ r"""Applies a 2D power-average pooling over an input signal composed of several input
+ planes.
+
+ On each window, the function computed is:
+
+ .. math::
+ f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
+
+ - At p = :math:`\infty`, one gets Max Pooling
+ - At p = 1, one gets Sum Pooling (which is proportional to average pooling)
+
+ The parameters :attr:`kernel_size`, :attr:`stride` can either be:
+
+ - a single ``int`` -- in which case the same value is used for the height and width dimension
+ - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
+ and the second `int` for the width dimension
+
+ .. note:: If the sum to the power of `p` is zero, the gradient of this function is
+ not defined. This implementation will set the gradient to zero in this case.
+
+ Args:
+ kernel_size: the size of the window
+ stride: the stride of the window. Default value is :attr:`kernel_size`
+ ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
+
+ Shape:
+ - Input: :math:`(N, C, H_{in}, W_{in})`
+ - Output: :math:`(N, C, H_{out}, W_{out})`, where
+
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
+ (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
+ (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
+
+ Examples::
+
+ >>> # power-2 pool of square window of size=3, stride=2
+ >>> m = nn.LPPool2d(2, 3, stride=2)
+ >>> # pool of non-square window of power 1.2
+ >>> m = nn.LPPool2d(1.2, (3, 2), stride=(2, 1))
+ >>> input = torch.randn(20, 16, 50, 32)
+ >>> output = m(input)
+
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.lp_pool2d(input, float(self.norm_type), self.kernel_size,
+ self.stride, self.ceil_mode)
+
+
+@weak_module
+class _AdaptiveMaxPoolNd(Module):
+ __constants__ = ['output_size', 'return_indices']
+
+ def __init__(self, output_size, return_indices=False):
+ super(_AdaptiveMaxPoolNd, self).__init__()
+ self.output_size = output_size
+ self.return_indices = return_indices
+
+ def extra_repr(self):
+ return 'output_size={}'.format(self.output_size)
+
+# FIXME (by @ssnl): Improve adaptive pooling docs: specify what the input and
+# output shapes are, and how the operation computes output.
+
+
+[docs]@weak_module
+class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
+ r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
+
+ The output size is H, for any input size.
+ The number of output features is equal to the number of input planes.
+
+ Args:
+ output_size: the target output size H
+ return_indices: if ``True``, will return the indices along with the outputs.
+ Useful to pass to nn.MaxUnpool1d. Default: ``False``
+
+ Examples:
+ >>> # target output size of 5
+ >>> m = nn.AdaptiveMaxPool1d(5)
+ >>> input = torch.randn(1, 64, 8)
+ >>> output = m(input)
+
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
+
+
+[docs]@weak_module
+class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
+ r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
+
+ The output is of size H x W, for any input size.
+ The number of output features is equal to the number of input planes.
+
+ Args:
+ output_size: the target output size of the image of the form H x W.
+ Can be a tuple (H, W) or a single H for a square image H x H.
+ H and W can be either a ``int``, or ``None`` which means the size will
+ be the same as that of the input.
+ return_indices: if ``True``, will return the indices along with the outputs.
+ Useful to pass to nn.MaxUnpool2d. Default: ``False``
+
+ Examples:
+ >>> # target output size of 5x7
+ >>> m = nn.AdaptiveMaxPool2d((5,7))
+ >>> input = torch.randn(1, 64, 8, 9)
+ >>> output = m(input)
+ >>> # target output size of 7x7 (square)
+ >>> m = nn.AdaptiveMaxPool2d(7)
+ >>> input = torch.randn(1, 64, 10, 9)
+ >>> output = m(input)
+ >>> # target output size of 10x7
+ >>> m = nn.AdaptiveMaxPool2d((None, 7))
+ >>> input = torch.randn(1, 64, 10, 9)
+ >>> output = m(input)
+
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
+
+
+[docs]@weak_module
+class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
+ r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
+
+ The output is of size D x H x W, for any input size.
+ The number of output features is equal to the number of input planes.
+
+ Args:
+ output_size: the target output size of the image of the form D x H x W.
+ Can be a tuple (D, H, W) or a single D for a cube D x D x D.
+ D, H and W can be either a ``int``, or ``None`` which means the size will
+ be the same as that of the input.
+
+ return_indices: if ``True``, will return the indices along with the outputs.
+ Useful to pass to nn.MaxUnpool3d. Default: ``False``
+
+ Examples:
+ >>> # target output size of 5x7x9
+ >>> m = nn.AdaptiveMaxPool3d((5,7,9))
+ >>> input = torch.randn(1, 64, 8, 9, 10)
+ >>> output = m(input)
+ >>> # target output size of 7x7x7 (cube)
+ >>> m = nn.AdaptiveMaxPool3d(7)
+ >>> input = torch.randn(1, 64, 10, 9, 8)
+ >>> output = m(input)
+ >>> # target output size of 7x9x8
+ >>> m = nn.AdaptiveMaxPool3d((7, None, None))
+ >>> input = torch.randn(1, 64, 10, 9, 8)
+ >>> output = m(input)
+
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
+
+
+@weak_module
+class _AdaptiveAvgPoolNd(Module):
+ __constants__ = ['output_size']
+
+ def __init__(self, output_size):
+ super(_AdaptiveAvgPoolNd, self).__init__()
+ self.output_size = output_size
+
+ def extra_repr(self):
+ return 'output_size={}'.format(self.output_size)
+
+
+[docs]@weak_module
+class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
+ r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
+
+ The output size is H, for any input size.
+ The number of output features is equal to the number of input planes.
+
+ Args:
+ output_size: the target output size H
+
+ Examples:
+ >>> # target output size of 5
+ >>> m = nn.AdaptiveAvgPool1d(5)
+ >>> input = torch.randn(1, 64, 8)
+ >>> output = m(input)
+
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.adaptive_avg_pool1d(input, self.output_size)
+
+
+[docs]@weak_module
+class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
+ r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
+
+ The output is of size H x W, for any input size.
+ The number of output features is equal to the number of input planes.
+
+ Args:
+ output_size: the target output size of the image of the form H x W.
+ Can be a tuple (H, W) or a single H for a square image H x H.
+ H and W can be either a ``int``, or ``None`` which means the size will
+ be the same as that of the input.
+
+ Examples:
+ >>> # target output size of 5x7
+ >>> m = nn.AdaptiveAvgPool2d((5,7))
+ >>> input = torch.randn(1, 64, 8, 9)
+ >>> output = m(input)
+ >>> # target output size of 7x7 (square)
+ >>> m = nn.AdaptiveAvgPool2d(7)
+ >>> input = torch.randn(1, 64, 10, 9)
+ >>> output = m(input)
+ >>> # target output size of 10x7
+ >>> m = nn.AdaptiveMaxPool2d((None, 7))
+ >>> input = torch.randn(1, 64, 10, 9)
+ >>> output = m(input)
+
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.adaptive_avg_pool2d(input, self.output_size)
+
+
+[docs]@weak_module
+class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
+ r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
+
+ The output is of size D x H x W, for any input size.
+ The number of output features is equal to the number of input planes.
+
+ Args:
+ output_size: the target output size of the form D x H x W.
+ Can be a tuple (D, H, W) or a single number D for a cube D x D x D.
+ D, H and W can be either a ``int``, or ``None`` which means the size will
+ be the same as that of the input.
+
+ Examples:
+ >>> # target output size of 5x7x9
+ >>> m = nn.AdaptiveAvgPool3d((5,7,9))
+ >>> input = torch.randn(1, 64, 8, 9, 10)
+ >>> output = m(input)
+ >>> # target output size of 7x7x7 (cube)
+ >>> m = nn.AdaptiveAvgPool3d(7)
+ >>> input = torch.randn(1, 64, 10, 9, 8)
+ >>> output = m(input)
+ >>> # target output size of 7x9x8
+ >>> m = nn.AdaptiveMaxPool3d((7, None, None))
+ >>> input = torch.randn(1, 64, 10, 9, 8)
+ >>> output = m(input)
+
+ """
+
+ @weak_script_method
+ def forward(self, input):
+ return F.adaptive_avg_pool3d(input, self.output_size)
+
+import math
+import torch
+import warnings
+import numbers
+
+from .module import Module
+from ..parameter import Parameter
+from ..utils.rnn import PackedSequence, get_packed_sequence
+from .. import init
+from .. import _VF
+from ..._jit_internal import weak_module, weak_script_method, weak_script, \
+ _parameter_list
+
+_rnn_impls = {
+ 'GRU': _VF.gru,
+ 'RNN_TANH': _VF.rnn_tanh,
+ 'RNN_RELU': _VF.rnn_relu,
+}
+
+
+@weak_script
+def apply_permutation(tensor, permutation, dim=1):
+ # type: (Tensor, Tensor, int) -> Tensor
+ return tensor.index_select(dim, permutation)
+
+
+class RNNBase(Module):
+ __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
+ 'batch_first', 'dropout', 'bidirectional', '_flat_parameters']
+
+ def __init__(self, mode, input_size, hidden_size,
+ num_layers=1, bias=True, batch_first=False,
+ dropout=0., bidirectional=False):
+ super(RNNBase, self).__init__()
+ self.mode = mode
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.bias = bias
+ self.batch_first = batch_first
+ self.dropout = dropout
+ self.bidirectional = bidirectional
+ num_directions = 2 if bidirectional else 1
+
+ if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
+ isinstance(dropout, bool):
+ raise ValueError("dropout should be a number in range [0, 1] "
+ "representing the probability of an element being "
+ "zeroed")
+ if dropout > 0 and num_layers == 1:
+ warnings.warn("dropout option adds dropout after all but last "
+ "recurrent layer, so non-zero dropout expects "
+ "num_layers greater than 1, but got dropout={} and "
+ "num_layers={}".format(dropout, num_layers))
+
+ if mode == 'LSTM':
+ gate_size = 4 * hidden_size
+ elif mode == 'GRU':
+ gate_size = 3 * hidden_size
+ elif mode == 'RNN_TANH':
+ gate_size = hidden_size
+ elif mode == 'RNN_RELU':
+ gate_size = hidden_size
+ else:
+ raise ValueError("Unrecognized RNN mode: " + mode)
+
+ self._all_weights = []
+ for layer in range(num_layers):
+ for direction in range(num_directions):
+ layer_input_size = input_size if layer == 0 else hidden_size * num_directions
+
+ w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
+ w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
+ b_ih = Parameter(torch.Tensor(gate_size))
+ # Second bias vector included for CuDNN compatibility. Only one
+ # bias vector is needed in standard definition.
+ b_hh = Parameter(torch.Tensor(gate_size))
+ layer_params = (w_ih, w_hh, b_ih, b_hh)
+
+ suffix = '_reverse' if direction == 1 else ''
+ param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
+ if bias:
+ param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
+ param_names = [x.format(layer, suffix) for x in param_names]
+
+ for name, param in zip(param_names, layer_params):
+ setattr(self, name, param)
+ self._all_weights.append(param_names)
+
+ self.flatten_parameters()
+ self.reset_parameters()
+
+ def flatten_parameters(self):
+ """Resets parameter data pointer so that they can use faster code paths.
+
+ Right now, this works only if the module is on the GPU and cuDNN is enabled.
+ Otherwise, it's a no-op.
+ """
+ any_param = next(self.parameters()).data
+ if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param):
+ return
+
+ # If any parameters alias, we fall back to the slower, copying code path. This is
+ # a sufficient check, because overlapping parameter buffers that don't completely
+ # alias would break the assumptions of the uniqueness check in
+ # Module.named_parameters().
+ all_weights = self._flat_weights
+ unique_data_ptrs = set(p.data_ptr() for p in all_weights)
+ if len(unique_data_ptrs) != len(all_weights):
+ return
+
+ with torch.cuda.device_of(any_param):
+ import torch.backends.cudnn.rnn as rnn
+
+ # NB: This is a temporary hack while we still don't have Tensor
+ # bindings for ATen functions
+ with torch.no_grad():
+ # NB: this is an INPLACE function on all_weights, that's why the
+ # no_grad() is necessary.
+ torch._cudnn_rnn_flatten_weight(
+ all_weights, (4 if self.bias else 2),
+ self.input_size, rnn.get_cudnn_mode(self.mode), self.hidden_size, self.num_layers,
+ self.batch_first, bool(self.bidirectional))
+
+ def _apply(self, fn):
+ ret = super(RNNBase, self)._apply(fn)
+ self.flatten_parameters()
+ return ret
+
+ def reset_parameters(self):
+ stdv = 1.0 / math.sqrt(self.hidden_size)
+ for weight in self.parameters():
+ init.uniform_(weight, -stdv, stdv)
+
+ def _get_flat_weights_names(self):
+ return [weight for weights in self._all_weights for weight in weights]
+
+ @_parameter_list(_get_flat_weights_names)
+ def _get_flat_weights(self):
+ return self._flat_weights
+
+ @weak_script_method
+ def check_input(self, input, batch_sizes):
+ # type: (Tensor, Optional[Tensor]) -> None
+ expected_input_dim = 2 if batch_sizes is not None else 3
+ if input.dim() != expected_input_dim:
+ raise RuntimeError(
+ 'input must have {} dimensions, got {}'.format(
+ expected_input_dim, input.dim()))
+ if self.input_size != input.size(-1):
+ raise RuntimeError(
+ 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
+ self.input_size, input.size(-1)))
+
+ @weak_script_method
+ def get_expected_hidden_size(self, input, batch_sizes):
+ # type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
+ if batch_sizes is not None:
+ mini_batch = batch_sizes[0]
+ mini_batch = int(mini_batch)
+ else:
+ mini_batch = input.size(0) if self.batch_first else input.size(1)
+ num_directions = 2 if self.bidirectional else 1
+ expected_hidden_size = (self.num_layers * num_directions,
+ mini_batch, self.hidden_size)
+ return expected_hidden_size
+
+ @weak_script_method
+ def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
+ # type: (Tensor, Tuple[int, int, int], str) -> None
+ if hx.size() != expected_hidden_size:
+ raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
+
+ def check_forward_args(self, input, hidden, batch_sizes):
+ self.check_input(input, batch_sizes)
+ expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
+
+ self.check_hidden_size(hidden, expected_hidden_size)
+
+ def permute_hidden(self, hx, permutation):
+ if permutation is None:
+ return hx
+ return apply_permutation(hx, permutation)
+
+ def forward(self, input, hx=None):
+ is_packed = isinstance(input, PackedSequence)
+ if is_packed:
+ input, batch_sizes, sorted_indices, unsorted_indices = input
+ max_batch_size = batch_sizes[0]
+ max_batch_size = int(max_batch_size)
+ else:
+ batch_sizes = None
+ max_batch_size = input.size(0) if self.batch_first else input.size(1)
+ sorted_indices = None
+ unsorted_indices = None
+
+ if hx is None:
+ num_directions = 2 if self.bidirectional else 1
+ hx = torch.zeros(self.num_layers * num_directions,
+ max_batch_size, self.hidden_size,
+ dtype=input.dtype, device=input.device)
+ else:
+ # Each batch of the hidden state should match the input sequence that
+ # the user believes he/she is passing in.
+ hx = self.permute_hidden(hx, sorted_indices)
+
+ self.check_forward_args(input, hx, batch_sizes)
+ _impl = _rnn_impls[self.mode]
+ if batch_sizes is None:
+ result = _impl(input, hx, self._get_flat_weights(), self.bias, self.num_layers,
+ self.dropout, self.training, self.bidirectional, self.batch_first)
+ else:
+ result = _impl(input, batch_sizes, hx, self._get_flat_weights(), self.bias,
+ self.num_layers, self.dropout, self.training, self.bidirectional)
+ output = result[0]
+ hidden = result[1]
+
+ if is_packed:
+ output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
+ return output, self.permute_hidden(hidden, unsorted_indices)
+
+ def extra_repr(self):
+ s = '{input_size}, {hidden_size}'
+ if self.num_layers != 1:
+ s += ', num_layers={num_layers}'
+ if self.bias is not True:
+ s += ', bias={bias}'
+ if self.batch_first is not False:
+ s += ', batch_first={batch_first}'
+ if self.dropout != 0:
+ s += ', dropout={dropout}'
+ if self.bidirectional is not False:
+ s += ', bidirectional={bidirectional}'
+ return s.format(**self.__dict__)
+
+ def __setstate__(self, d):
+ super(RNNBase, self).__setstate__(d)
+ if 'all_weights' in d:
+ self._all_weights = d['all_weights']
+ if isinstance(self._all_weights[0][0], str):
+ return
+ num_layers = self.num_layers
+ num_directions = 2 if self.bidirectional else 1
+ self._all_weights = []
+ for layer in range(num_layers):
+ for direction in range(num_directions):
+ suffix = '_reverse' if direction == 1 else ''
+ weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
+ weights = [x.format(layer, suffix) for x in weights]
+ if self.bias:
+ self._all_weights += [weights]
+ else:
+ self._all_weights += [weights[:2]]
+
+ @property
+ def _flat_weights(self):
+ return [p for layerparams in self.all_weights for p in layerparams]
+
+ @property
+ def all_weights(self):
+ return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
+
+
+[docs]class RNN(RNNBase):
+ r"""Applies a multi-layer Elman RNN with :math:`tanh` or :math:`ReLU` non-linearity to an
+ input sequence.
+
+
+ For each element in the input sequence, each layer computes the following
+ function:
+
+ .. math::
+ h_t = \text{tanh}(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
+
+ where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
+ the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
+ previous layer at time `t-1` or the initial hidden state at time `0`.
+ If :attr:`nonlinearity` is ``'relu'``, then `ReLU` is used instead of `tanh`.
+
+ Args:
+ input_size: The number of expected features in the input `x`
+ hidden_size: The number of features in the hidden state `h`
+ num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
+ would mean stacking two RNNs together to form a `stacked RNN`,
+ with the second RNN taking in outputs of the first RNN and
+ computing the final results. Default: 1
+ nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
+ bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
+ Default: ``True``
+ batch_first: If ``True``, then the input and output tensors are provided
+ as `(batch, seq, feature)`. Default: ``False``
+ dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
+ RNN layer except the last layer, with dropout probability equal to
+ :attr:`dropout`. Default: 0
+ bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
+
+ Inputs: input, h_0
+ - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
+ of the input sequence. The input can also be a packed variable length
+ sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
+ or :func:`torch.nn.utils.rnn.pack_sequence`
+ for details.
+ - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
+ containing the initial hidden state for each element in the batch.
+ Defaults to zero if not provided. If the RNN is bidirectional,
+ num_directions should be 2, else it should be 1.
+
+ Outputs: output, h_n
+ - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
+ containing the output features (`h_t`) from the last layer of the RNN,
+ for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has
+ been given as the input, the output will also be a packed sequence.
+
+ For the unpacked case, the directions can be separated
+ using ``output.view(seq_len, batch, num_directions, hidden_size)``,
+ with forward and backward being direction `0` and `1` respectively.
+ Similarly, the directions can be separated in the packed case.
+ - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
+ containing the hidden state for `t = seq_len`.
+
+ Like *output*, the layers can be separated using
+ ``h_n.view(num_layers, num_directions, batch, hidden_size)``.
+
+ Shape:
+ - Input1: :math:`(L, N, H_{in})` tensor containing input features where
+ :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length.
+ - Input2: :math:`(S, N, H_{out})` tensor
+ containing the initial hidden state for each element in the batch.
+ :math:`H_{out}=\text{hidden\_size}`
+ Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}`
+ If the RNN is bidirectional, num_directions should be 2, else it should be 1.
+ - Output1: :math:`(L, N, H_{all})` where :math:`H_all=\text{num\_directions} * \text{hidden\_size}`
+ - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state
+ for each element in the batch
+
+ Attributes:
+ weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
+ of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is
+ `(hidden_size, num_directions * hidden_size)`
+ weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
+ of shape `(hidden_size, hidden_size)`
+ bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
+ of shape `(hidden_size)`
+ bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
+ of shape `(hidden_size)`
+
+ .. note::
+ All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
+ where :math:`k = \frac{1}{\text{hidden\_size}}`
+
+ .. include:: cudnn_persistent_rnn.rst
+
+ Examples::
+
+ >>> rnn = nn.RNN(10, 20, 2)
+ >>> input = torch.randn(5, 3, 10)
+ >>> h0 = torch.randn(2, 3, 20)
+ >>> output, hn = rnn(input, h0)
+ """
+
+ def __init__(self, *args, **kwargs):
+ if 'nonlinearity' in kwargs:
+ if kwargs['nonlinearity'] == 'tanh':
+ mode = 'RNN_TANH'
+ elif kwargs['nonlinearity'] == 'relu':
+ mode = 'RNN_RELU'
+ else:
+ raise ValueError("Unknown nonlinearity '{}'".format(
+ kwargs['nonlinearity']))
+ del kwargs['nonlinearity']
+ else:
+ mode = 'RNN_TANH'
+
+ super(RNN, self).__init__(mode, *args, **kwargs)
+
+
+[docs]@weak_module
+class LSTM(RNNBase):
+ r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
+ sequence.
+
+
+ For each element in the input sequence, each layer computes the following
+ function:
+
+ .. math::
+ \begin{array}{ll} \\
+ i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
+ f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
+ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{(t-1)} + b_{hg}) \\
+ o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
+ c_t = f_t * c_{(t-1)} + i_t * g_t \\
+ h_t = o_t * \tanh(c_t) \\
+ \end{array}
+
+ where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
+ state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{(t-1)}`
+ is the hidden state of the layer at time `t-1` or the initial hidden
+ state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
+ :math:`o_t` are the input, forget, cell, and output gates, respectively.
+ :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
+
+ In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
+ (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
+ dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
+ variable which is :math:`0` with probability :attr:`dropout`.
+
+ Args:
+ input_size: The number of expected features in the input `x`
+ hidden_size: The number of features in the hidden state `h`
+ num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
+ would mean stacking two LSTMs together to form a `stacked LSTM`,
+ with the second LSTM taking in outputs of the first LSTM and
+ computing the final results. Default: 1
+ bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
+ Default: ``True``
+ batch_first: If ``True``, then the input and output tensors are provided
+ as (batch, seq, feature). Default: ``False``
+ dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
+ LSTM layer except the last layer, with dropout probability equal to
+ :attr:`dropout`. Default: 0
+ bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
+
+ Inputs: input, (h_0, c_0)
+ - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
+ of the input sequence.
+ The input can also be a packed variable length sequence.
+ See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
+ :func:`torch.nn.utils.rnn.pack_sequence` for details.
+ - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
+ containing the initial hidden state for each element in the batch.
+ If the LSTM is bidirectional, num_directions should be 2, else it should be 1.
+ - **c_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
+ containing the initial cell state for each element in the batch.
+
+ If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
+
+
+ Outputs: output, (h_n, c_n)
+ - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
+ containing the output features `(h_t)` from the last layer of the LSTM,
+ for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
+ given as the input, the output will also be a packed sequence.
+
+ For the unpacked case, the directions can be separated
+ using ``output.view(seq_len, batch, num_directions, hidden_size)``,
+ with forward and backward being direction `0` and `1` respectively.
+ Similarly, the directions can be separated in the packed case.
+ - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
+ containing the hidden state for `t = seq_len`.
+
+ Like *output*, the layers can be separated using
+ ``h_n.view(num_layers, num_directions, batch, hidden_size)`` and similarly for *c_n*.
+ - **c_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
+ containing the cell state for `t = seq_len`.
+
+ Attributes:
+ weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
+ `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.
+ Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`
+ weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
+ `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`
+ bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
+ `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
+ bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
+ `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`
+
+ .. note::
+ All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
+ where :math:`k = \frac{1}{\text{hidden\_size}}`
+
+ .. include:: cudnn_persistent_rnn.rst
+
+ Examples::
+
+ >>> rnn = nn.LSTM(10, 20, 2)
+ >>> input = torch.randn(5, 3, 10)
+ >>> h0 = torch.randn(2, 3, 20)
+ >>> c0 = torch.randn(2, 3, 20)
+ >>> output, (hn, cn) = rnn(input, (h0, c0))
+ """
+ __overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
+
+ def __init__(self, *args, **kwargs):
+ super(LSTM, self).__init__('LSTM', *args, **kwargs)
+
+ @weak_script_method
+ def check_forward_args(self, input, hidden, batch_sizes):
+ # type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None
+ self.check_input(input, batch_sizes)
+ expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
+
+ self.check_hidden_size(hidden[0], expected_hidden_size,
+ 'Expected hidden[0] size {}, got {}')
+ self.check_hidden_size(hidden[1], expected_hidden_size,
+ 'Expected hidden[1] size {}, got {}')
+
+ @weak_script_method
+ def permute_hidden(self, hx, permutation):
+ # type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
+ if permutation is None:
+ return hx
+ return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
+
+ @weak_script_method
+ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
+ # type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
+ if hx is None:
+ num_directions = 2 if self.bidirectional else 1
+ zeros = torch.zeros(self.num_layers * num_directions,
+ max_batch_size, self.hidden_size,
+ dtype=input.dtype, device=input.device)
+ hx = (zeros, zeros)
+ else:
+ # Each batch of the hidden state should match the input sequence that
+ # the user believes he/she is passing in.
+ hx = self.permute_hidden(hx, sorted_indices)
+
+ self.check_forward_args(input, hx, batch_sizes)
+ if batch_sizes is None:
+ result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers,
+ self.dropout, self.training, self.bidirectional, self.batch_first)
+ else:
+ result = _VF.lstm(input, batch_sizes, hx, self._get_flat_weights(), self.bias,
+ self.num_layers, self.dropout, self.training, self.bidirectional)
+ output = result[0]
+ hidden = result[1:]
+
+ return output, hidden
+
+ @weak_script_method
+ def forward_tensor(self, input, hx=None):
+ # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
+ batch_sizes = None
+ max_batch_size = input.size(0) if self.batch_first else input.size(1)
+ sorted_indices = None
+ unsorted_indices = None
+
+ output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
+
+ return output, self.permute_hidden(hidden, unsorted_indices)
+
+ @weak_script_method
+ def forward_packed(self, input, hx=None):
+ # type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa
+ input, batch_sizes, sorted_indices, unsorted_indices = input
+ max_batch_size = batch_sizes[0]
+ max_batch_size = int(max_batch_size)
+
+ output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
+
+ output = get_packed_sequence(output, batch_sizes, sorted_indices, unsorted_indices)
+ return output, self.permute_hidden(hidden, unsorted_indices)
+
+ def forward(self, input, hx=None):
+ if isinstance(input, PackedSequence):
+ return self.forward_packed(input, hx)
+ else:
+ return self.forward_tensor(input, hx)
+
+
+[docs]class GRU(RNNBase):
+ r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
+
+
+ For each element in the input sequence, each layer computes the following
+ function:
+
+ .. math::
+ \begin{array}{ll}
+ r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
+ z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
+ n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
+ h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
+ \end{array}
+
+ where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
+ at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
+ at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
+ :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
+ :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
+
+ In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
+ (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
+ dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
+ variable which is :math:`0` with probability :attr:`dropout`.
+
+ Args:
+ input_size: The number of expected features in the input `x`
+ hidden_size: The number of features in the hidden state `h`
+ num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
+ would mean stacking two GRUs together to form a `stacked GRU`,
+ with the second GRU taking in outputs of the first GRU and
+ computing the final results. Default: 1
+ bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
+ Default: ``True``
+ batch_first: If ``True``, then the input and output tensors are provided
+ as (batch, seq, feature). Default: ``False``
+ dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
+ GRU layer except the last layer, with dropout probability equal to
+ :attr:`dropout`. Default: 0
+ bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
+
+ Inputs: input, h_0
+ - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
+ of the input sequence. The input can also be a packed variable length
+ sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
+ for details.
+ - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
+ containing the initial hidden state for each element in the batch.
+ Defaults to zero if not provided. If the RNN is bidirectional,
+ num_directions should be 2, else it should be 1.
+
+ Outputs: output, h_n
+ - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
+ containing the output features h_t from the last layer of the GRU,
+ for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
+ given as the input, the output will also be a packed sequence.
+ For the unpacked case, the directions can be separated
+ using ``output.view(seq_len, batch, num_directions, hidden_size)``,
+ with forward and backward being direction `0` and `1` respectively.
+
+ Similarly, the directions can be separated in the packed case.
+ - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
+ containing the hidden state for `t = seq_len`
+
+ Like *output*, the layers can be separated using
+ ``h_n.view(num_layers, num_directions, batch, hidden_size)``.
+
+ Shape:
+ - Input1: :math:`(L, N, H_{in})` tensor containing input features where
+ :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length.
+ - Input2: :math:`(S, N, H_{out})` tensor
+ containing the initial hidden state for each element in the batch.
+ :math:`H_{out}=\text{hidden\_size}`
+ Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}`
+ If the RNN is bidirectional, num_directions should be 2, else it should be 1.
+ - Output1: :math:`(L, N, H_{all})` where :math:`H_all=\text{num\_directions} * \text{hidden\_size}`
+ - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state
+ for each element in the batch
+
+ Attributes:
+ weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
+ (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
+ Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
+ weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
+ (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
+ bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
+ (b_ir|b_iz|b_in), of shape `(3*hidden_size)`
+ bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
+ (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
+
+ .. note::
+ All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
+ where :math:`k = \frac{1}{\text{hidden\_size}}`
+
+ .. include:: cudnn_persistent_rnn.rst
+
+ Examples::
+
+ >>> rnn = nn.GRU(10, 20, 2)
+ >>> input = torch.randn(5, 3, 10)
+ >>> h0 = torch.randn(2, 3, 20)
+ >>> output, hn = rnn(input, h0)
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(GRU, self).__init__('GRU', *args, **kwargs)
+
+
+class RNNCellBase(Module):
+ __constants__ = ['input_size', 'hidden_size', 'bias']
+
+ def __init__(self, input_size, hidden_size, bias, num_chunks):
+ super(RNNCellBase, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.bias = bias
+ self.weight_ih = Parameter(torch.Tensor(num_chunks * hidden_size, input_size))
+ self.weight_hh = Parameter(torch.Tensor(num_chunks * hidden_size, hidden_size))
+ if bias:
+ self.bias_ih = Parameter(torch.Tensor(num_chunks * hidden_size))
+ self.bias_hh = Parameter(torch.Tensor(num_chunks * hidden_size))
+ else:
+ self.register_parameter('bias_ih', None)
+ self.register_parameter('bias_hh', None)
+ self.reset_parameters()
+
+ def extra_repr(self):
+ s = '{input_size}, {hidden_size}'
+ if 'bias' in self.__dict__ and self.bias is not True:
+ s += ', bias={bias}'
+ if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
+ s += ', nonlinearity={nonlinearity}'
+ return s.format(**self.__dict__)
+
+ @weak_script_method
+ def check_forward_input(self, input):
+ if input.size(1) != self.input_size:
+ raise RuntimeError(
+ "input has inconsistent input_size: got {}, expected {}".format(
+ input.size(1), self.input_size))
+
+ @weak_script_method
+ def check_forward_hidden(self, input, hx, hidden_label=''):
+ # type: (Tensor, Tensor, str) -> None
+ if input.size(0) != hx.size(0):
+ raise RuntimeError(
+ "Input batch size {} doesn't match hidden{} batch size {}".format(
+ input.size(0), hidden_label, hx.size(0)))
+
+ if hx.size(1) != self.hidden_size:
+ raise RuntimeError(
+ "hidden{} has inconsistent hidden_size: got {}, expected {}".format(
+ hidden_label, hx.size(1), self.hidden_size))
+
+ def reset_parameters(self):
+ stdv = 1.0 / math.sqrt(self.hidden_size)
+ for weight in self.parameters():
+ init.uniform_(weight, -stdv, stdv)
+
+
+[docs]@weak_module
+class RNNCell(RNNCellBase):
+ r"""An Elman RNN cell with tanh or ReLU non-linearity.
+
+ .. math::
+
+ h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
+
+ If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
+
+ Args:
+ input_size: The number of expected features in the input `x`
+ hidden_size: The number of features in the hidden state `h`
+ bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
+ Default: ``True``
+ nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
+
+ Inputs: input, hidden
+ - **input** of shape `(batch, input_size)`: tensor containing input features
+ - **hidden** of shape `(batch, hidden_size)`: tensor containing the initial hidden
+ state for each element in the batch.
+ Defaults to zero if not provided.
+
+ Outputs: h'
+ - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
+ for each element in the batch
+
+ Shape:
+ - Input1: :math:`(N, H_{in})` tensor containing input features where
+ :math:`H_{in}` = `input_size`
+ - Input2: :math:`(N, H_{out})` tensor containing the initial hidden
+ state for each element in the batch where :math:`H_{out}` = `hidden_size`
+ Defaults to zero if not provided.
+ - Output: :math:`(N, H_{out})` tensor containing the next hidden state
+ for each element in the batch
+
+ Attributes:
+ weight_ih: the learnable input-hidden weights, of shape
+ `(hidden_size, input_size)`
+ weight_hh: the learnable hidden-hidden weights, of shape
+ `(hidden_size, hidden_size)`
+ bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
+ bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`
+
+ .. note::
+ All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
+ where :math:`k = \frac{1}{\text{hidden\_size}}`
+
+ Examples::
+
+ >>> rnn = nn.RNNCell(10, 20)
+ >>> input = torch.randn(6, 3, 10)
+ >>> hx = torch.randn(3, 20)
+ >>> output = []
+ >>> for i in range(6):
+ hx = rnn(input[i], hx)
+ output.append(hx)
+ """
+ __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
+
+ def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"):
+ super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
+ self.nonlinearity = nonlinearity
+
+ @weak_script_method
+ def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tensor]) -> Tensor
+ self.check_forward_input(input)
+ if hx is None:
+ hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ self.check_forward_hidden(input, hx, '')
+ if self.nonlinearity == "tanh":
+ ret = _VF.rnn_tanh_cell(
+ input, hx,
+ self.weight_ih, self.weight_hh,
+ self.bias_ih, self.bias_hh,
+ )
+ elif self.nonlinearity == "relu":
+ ret = _VF.rnn_relu_cell(
+ input, hx,
+ self.weight_ih, self.weight_hh,
+ self.bias_ih, self.bias_hh,
+ )
+ else:
+ ret = input # TODO: remove when jit supports exception flow
+ raise RuntimeError(
+ "Unknown nonlinearity: {}".format(self.nonlinearity))
+ return ret
+
+
+[docs]@weak_module
+class LSTMCell(RNNCellBase):
+ r"""A long short-term memory (LSTM) cell.
+
+ .. math::
+
+ \begin{array}{ll}
+ i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
+ f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
+ g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
+ o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
+ c' = f * c + i * g \\
+ h' = o * \tanh(c') \\
+ \end{array}
+
+ where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
+
+ Args:
+ input_size: The number of expected features in the input `x`
+ hidden_size: The number of features in the hidden state `h`
+ bias: If ``False``, then the layer does not use bias weights `b_ih` and
+ `b_hh`. Default: ``True``
+
+ Inputs: input, (h_0, c_0)
+ - **input** of shape `(batch, input_size)`: tensor containing input features
+ - **h_0** of shape `(batch, hidden_size)`: tensor containing the initial hidden
+ state for each element in the batch.
+ - **c_0** of shape `(batch, hidden_size)`: tensor containing the initial cell state
+ for each element in the batch.
+
+ If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
+
+ Outputs: (h_1, c_1)
+ - **h_1** of shape `(batch, hidden_size)`: tensor containing the next hidden state
+ for each element in the batch
+ - **c_1** of shape `(batch, hidden_size)`: tensor containing the next cell state
+ for each element in the batch
+
+ Attributes:
+ weight_ih: the learnable input-hidden weights, of shape
+ `(4*hidden_size, input_size)`
+ weight_hh: the learnable hidden-hidden weights, of shape
+ `(4*hidden_size, hidden_size)`
+ bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
+ bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
+
+ .. note::
+ All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
+ where :math:`k = \frac{1}{\text{hidden\_size}}`
+
+ Examples::
+
+ >>> rnn = nn.LSTMCell(10, 20)
+ >>> input = torch.randn(6, 3, 10)
+ >>> hx = torch.randn(3, 20)
+ >>> cx = torch.randn(3, 20)
+ >>> output = []
+ >>> for i in range(6):
+ hx, cx = rnn(input[i], (hx, cx))
+ output.append(hx)
+ """
+
+ def __init__(self, input_size, hidden_size, bias=True):
+ super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
+
+ @weak_script_method
+ def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
+ self.check_forward_input(input)
+ if hx is None:
+ zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ hx = (zeros, zeros)
+ self.check_forward_hidden(input, hx[0], '[0]')
+ self.check_forward_hidden(input, hx[1], '[1]')
+ return _VF.lstm_cell(
+ input, hx,
+ self.weight_ih, self.weight_hh,
+ self.bias_ih, self.bias_hh,
+ )
+
+
+[docs]@weak_module
+class GRUCell(RNNCellBase):
+ r"""A gated recurrent unit (GRU) cell
+
+ .. math::
+
+ \begin{array}{ll}
+ r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
+ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
+ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
+ h' = (1 - z) * n + z * h
+ \end{array}
+
+ where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
+
+ Args:
+ input_size: The number of expected features in the input `x`
+ hidden_size: The number of features in the hidden state `h`
+ bias: If ``False``, then the layer does not use bias weights `b_ih` and
+ `b_hh`. Default: ``True``
+
+ Inputs: input, hidden
+ - **input** of shape `(batch, input_size)`: tensor containing input features
+ - **hidden** of shape `(batch, hidden_size)`: tensor containing the initial hidden
+ state for each element in the batch.
+ Defaults to zero if not provided.
+
+ Outputs: h'
+ - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
+ for each element in the batch
+
+ Shape:
+ - Input1: :math:`(N, H_{in})` tensor containing input features where
+ :math:`H_{in}` = `input_size`
+ - Input2: :math:`(N, H_{out})` tensor containing the initial hidden
+ state for each element in the batch where :math:`H_{out}` = `hidden_size`
+ Defaults to zero if not provided.
+ - Output: :math:`(N, H_{out})` tensor containing the next hidden state
+ for each element in the batch
+
+ Attributes:
+ weight_ih: the learnable input-hidden weights, of shape
+ `(3*hidden_size, input_size)`
+ weight_hh: the learnable hidden-hidden weights, of shape
+ `(3*hidden_size, hidden_size)`
+ bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
+ bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
+
+ .. note::
+ All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
+ where :math:`k = \frac{1}{\text{hidden\_size}}`
+
+ Examples::
+
+ >>> rnn = nn.GRUCell(10, 20)
+ >>> input = torch.randn(6, 3, 10)
+ >>> hx = torch.randn(3, 20)
+ >>> output = []
+ >>> for i in range(6):
+ hx = rnn(input[i], hx)
+ output.append(hx)
+ """
+
+ def __init__(self, input_size, hidden_size, bias=True):
+ super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
+
+ @weak_script_method
+ def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tensor]) -> Tensor
+ self.check_forward_input(input)
+ if hx is None:
+ hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ self.check_forward_hidden(input, hx, '')
+ return _VF.gru_cell(
+ input, hx,
+ self.weight_ih, self.weight_hh,
+ self.bias_ih, self.bias_hh,
+ )
+
+import torch
+from torch.nn.parameter import Parameter
+
+from .module import Module
+from .. import functional as F
+from .. import init
+from torch._jit_internal import weak_module, weak_script_method
+
+
+[docs]@weak_module
+class Embedding(Module):
+ r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
+
+ This module is often used to store word embeddings and retrieve them using indices.
+ The input to the module is a list of indices, and the output is the corresponding
+ word embeddings.
+
+ Args:
+ num_embeddings (int): size of the dictionary of embeddings
+ embedding_dim (int): the size of each embedding vector
+ padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
+ (initialized to zeros) whenever it encounters the index.
+ max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
+ is renormalized to have norm :attr:`max_norm`.
+ norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
+ scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
+ the words in the mini-batch. Default ``False``.
+ sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
+ See Notes for more details regarding sparse gradients.
+
+ Attributes:
+ weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
+ initialized from :math:`\mathcal{N}(0, 1)`
+
+ Shape:
+ - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
+ - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
+
+ .. note::
+ Keep in mind that only a limited number of optimizers support
+ sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
+ :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
+
+ .. note::
+ With :attr:`padding_idx` set, the embedding vector at
+ :attr:`padding_idx` is initialized to all zeros. However, note that this
+ vector can be modified afterwards, e.g., using a customized
+ initialization method, and thus changing the vector used to pad the
+ output. The gradient for this vector from :class:`~torch.nn.Embedding`
+ is always zero.
+
+ Examples::
+
+ >>> # an Embedding module containing 10 tensors of size 3
+ >>> embedding = nn.Embedding(10, 3)
+ >>> # a batch of 2 samples of 4 indices each
+ >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
+ >>> embedding(input)
+ tensor([[[-0.0251, -1.6902, 0.7172],
+ [-0.6431, 0.0748, 0.6969],
+ [ 1.4970, 1.3448, -0.9685],
+ [-0.3677, -2.7265, -0.1685]],
+
+ [[ 1.4970, 1.3448, -0.9685],
+ [ 0.4362, -0.4004, 0.9400],
+ [-0.6431, 0.0748, 0.6969],
+ [ 0.9124, -2.3616, 1.1151]]])
+
+
+ >>> # example with padding_idx
+ >>> embedding = nn.Embedding(10, 3, padding_idx=0)
+ >>> input = torch.LongTensor([[0,2,0,5]])
+ >>> embedding(input)
+ tensor([[[ 0.0000, 0.0000, 0.0000],
+ [ 0.1535, -2.0309, 0.9315],
+ [ 0.0000, 0.0000, 0.0000],
+ [-0.1655, 0.9897, 0.0635]]])
+ """
+ __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
+ 'norm_type', 'scale_grad_by_freq', 'sparse', '_weight']
+
+ def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
+ max_norm=None, norm_type=2., scale_grad_by_freq=False,
+ sparse=False, _weight=None):
+ super(Embedding, self).__init__()
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ if padding_idx is not None:
+ if padding_idx > 0:
+ assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
+ elif padding_idx < 0:
+ assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
+ padding_idx = self.num_embeddings + padding_idx
+ self.padding_idx = padding_idx
+ self.max_norm = max_norm
+ self.norm_type = norm_type
+ self.scale_grad_by_freq = scale_grad_by_freq
+ if _weight is None:
+ self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
+ self.reset_parameters()
+ else:
+ assert list(_weight.shape) == [num_embeddings, embedding_dim], \
+ 'Shape of weight does not match num_embeddings and embedding_dim'
+ self.weight = Parameter(_weight)
+ self.sparse = sparse
+
+ def reset_parameters(self):
+ init.normal_(self.weight)
+ if self.padding_idx is not None:
+ with torch.no_grad():
+ self.weight[self.padding_idx].fill_(0)
+
+ @weak_script_method
+ def forward(self, input):
+ return F.embedding(
+ input, self.weight, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
+
+ def extra_repr(self):
+ s = '{num_embeddings}, {embedding_dim}'
+ if self.padding_idx is not None:
+ s += ', padding_idx={padding_idx}'
+ if self.max_norm is not None:
+ s += ', max_norm={max_norm}'
+ if self.norm_type != 2:
+ s += ', norm_type={norm_type}'
+ if self.scale_grad_by_freq is not False:
+ s += ', scale_grad_by_freq={scale_grad_by_freq}'
+ if self.sparse is not False:
+ s += ', sparse=True'
+ return s.format(**self.__dict__)
+
+[docs] @classmethod
+ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
+ max_norm=None, norm_type=2., scale_grad_by_freq=False,
+ sparse=False):
+ r"""Creates Embedding instance from given 2-dimensional FloatTensor.
+
+ Args:
+ embeddings (Tensor): FloatTensor containing weights for the Embedding.
+ First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
+ freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
+ Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
+ padding_idx (int, optional): See module initialization documentation.
+ max_norm (float, optional): See module initialization documentation.
+ norm_type (float, optional): See module initialization documentation. Default ``2``.
+ scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
+ sparse (bool, optional): See module initialization documentation.
+
+ Examples::
+
+ >>> # FloatTensor containing pretrained weights
+ >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
+ >>> embedding = nn.Embedding.from_pretrained(weight)
+ >>> # Get embeddings for index 1
+ >>> input = torch.LongTensor([1])
+ >>> embedding(input)
+ tensor([[ 4.0000, 5.1000, 6.3000]])
+ """
+ assert embeddings.dim() == 2, \
+ 'Embeddings parameter is expected to be 2-dimensional'
+ rows, cols = embeddings.shape
+ embedding = cls(
+ num_embeddings=rows,
+ embedding_dim=cols,
+ _weight=embeddings,
+ padding_idx=padding_idx,
+ max_norm=max_norm,
+ norm_type=norm_type,
+ scale_grad_by_freq=scale_grad_by_freq,
+ sparse=sparse)
+ embedding.weight.requires_grad = not freeze
+ return embedding
+
+
+[docs]@weak_module
+class EmbeddingBag(Module):
+ r"""Computes sums or means of 'bags' of embeddings, without instantiating the
+ intermediate embeddings.
+
+ For bags of constant length and no :attr:`per_sample_weights`, this class
+
+ * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=0)``,
+ * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=0)``,
+ * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=0)``.
+
+ However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
+ operations.
+
+ EmbeddingBag also supports per-sample weights as an argument to the forward
+ pass. This scales the output of the Embedding before performing a weighted
+ reduction as specified by ``mode``. If :attr:`per_sample_weights`` is passed, the
+ only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
+ :attr:`per_sample_weights`.
+
+ Args:
+ num_embeddings (int): size of the dictionary of embeddings
+ embedding_dim (int): the size of each embedding vector
+ max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
+ is renormalized to have norm :attr:`max_norm`.
+ norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
+ scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of
+ the words in the mini-batch. Default ``False``.
+ Note: this option is not supported when ``mode="max"``.
+ mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
+ ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
+ into consideration. ``"mean"`` computes the average of the values
+ in the bag, ``"max"`` computes the max value over each bag.
+ Default: ``"mean"``
+ sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
+ Notes for more details regarding sparse gradients. Note: this option is not
+ supported when ``mode="max"``.
+
+ Attributes:
+ weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
+ initialized from :math:`\mathcal{N}(0, 1)`.
+
+ Inputs: :attr:`input` (LongTensor), :attr:`offsets` (LongTensor, optional), and
+ :attr:`per_index_weights` (Tensor, optional)
+
+ - If :attr:`input` is 2D of shape `(B, N)`,
+
+ it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and
+ this will return ``B`` values aggregated in a way depending on the :attr:`mode`.
+ :attr:`offsets` is ignored and required to be ``None`` in this case.
+
+ - If :attr:`input` is 1D of shape `(N)`,
+
+ it will be treated as a concatenation of multiple bags (sequences).
+ :attr:`offsets` is required to be a 1D tensor containing the
+ starting index positions of each bag in :attr:`input`. Therefore,
+ for :attr:`offsets` of shape `(B)`, :attr:`input` will be viewed as
+ having ``B`` bags. Empty bags (i.e., having 0-length) will have
+ returned vectors filled by zeros.
+
+ per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
+ to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
+ must have exactly the same shape as input and is treated as having the same
+ :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
+
+
+ Output shape: `(B, embedding_dim)`
+
+ Examples::
+
+ >>> # an Embedding module containing 10 tensors of size 3
+ >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
+ >>> # a batch of 2 samples of 4 indices each
+ >>> input = torch.LongTensor([1,2,4,5,4,3,2,9])
+ >>> offsets = torch.LongTensor([0,4])
+ >>> embedding_sum(input, offsets)
+ tensor([[-0.8861, -5.4350, -0.0523],
+ [ 1.1306, -2.5798, -1.0044]])
+ """
+ __constants__ = ['num_embeddings, embedding_dim', 'max_norm', 'norm_type',
+ 'scale_grad_by_freq', 'mode', 'sparse', '_weight']
+
+ def __init__(self, num_embeddings, embedding_dim,
+ max_norm=None, norm_type=2., scale_grad_by_freq=False,
+ mode='mean', sparse=False, _weight=None):
+ super(EmbeddingBag, self).__init__()
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.max_norm = max_norm
+ self.norm_type = norm_type
+ self.scale_grad_by_freq = scale_grad_by_freq
+ if _weight is None:
+ self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
+ self.reset_parameters()
+ else:
+ assert list(_weight.shape) == [num_embeddings, embedding_dim], \
+ 'Shape of weight does not match num_embeddings and embedding_dim'
+ self.weight = Parameter(_weight)
+ self.mode = mode
+ self.sparse = sparse
+
+ def reset_parameters(self):
+ init.normal_(self.weight)
+
+ @weak_script_method
+ def forward(self, input, offsets=None, per_sample_weights=None):
+ # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
+ return F.embedding_bag(input, self.weight, offsets,
+ self.max_norm, self.norm_type,
+ self.scale_grad_by_freq, self.mode, self.sparse,
+ per_sample_weights)
+
+ def extra_repr(self):
+ s = '{num_embeddings}, {embedding_dim}'
+ if self.max_norm is not None:
+ s += ', max_norm={max_norm}'
+ if self.norm_type != 2:
+ s += ', norm_type={norm_type}'
+ if self.scale_grad_by_freq is not False:
+ s += ', scale_grad_by_freq={scale_grad_by_freq}'
+ s += ', mode={mode}'
+ return s.format(**self.__dict__)
+
+[docs] @classmethod
+ def from_pretrained(cls, embeddings, freeze=True, max_norm=None,
+ norm_type=2., scale_grad_by_freq=False,
+ mode='mean', sparse=False):
+ r"""Creates EmbeddingBag instance from given 2-dimensional FloatTensor.
+
+ Args:
+ embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag.
+ First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'.
+ freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
+ Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True``
+ max_norm (float, optional): See module initialization documentation. Default: ``None``
+ norm_type (float, optional): See module initialization documentation. Default ``2``.
+ scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
+ mode (string, optional): See module initialization documentation. Default: ``"mean"``
+ sparse (bool, optional): See module initialization documentation. Default: ``False``.
+
+ Examples::
+
+ >>> # FloatTensor containing pretrained weights
+ >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
+ >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
+ >>> # Get embeddings for index 1
+ >>> input = torch.LongTensor([[1, 0]])
+ >>> embeddingbag(input)
+ tensor([[ 2.5000, 3.7000, 4.6500]])
+ """
+ assert embeddings.dim() == 2, \
+ 'Embeddings parameter is expected to be 2-dimensional'
+ rows, cols = embeddings.shape
+ embeddingbag = cls(
+ num_embeddings=rows,
+ embedding_dim=cols,
+ _weight=embeddings,
+ max_norm=max_norm,
+ norm_type=norm_type,
+ scale_grad_by_freq=scale_grad_by_freq,
+ mode=mode,
+ sparse=sparse)
+ embeddingbag.weight.requires_grad = not freeze
+ return embeddingbag
+
+from .module import Module
+from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
+
+
+[docs]@weak_module
+class Upsample(Module):
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
+
+ The input data is assumed to be of the form
+ `minibatch x channels x [optional depth] x [optional height] x width`.
+ Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.
+
+ The algorithms available for upsampling are nearest neighbor and linear,
+ bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor,
+ respectively.
+
+ One can either give a :attr:`scale_factor` or the target output :attr:`size` to
+ calculate the output size. (You cannot give both, as it is ambiguous)
+
+ Args:
+ size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional):
+ output spatial sizes
+ scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional):
+ multiplier for spatial size. Has to match input size if it is a tuple.
+ mode (str, optional): the upsampling algorithm: one of ``'nearest'``,
+ ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``.
+ Default: ``'nearest'``
+ align_corners (bool, optional): if ``True``, the corner pixels of the input
+ and output tensors are aligned, and thus preserving the values at
+ those pixels. This only has effect when :attr:`mode` is
+ ``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False``
+
+ Shape:
+ - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})`
+ or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where
+
+ .. math::
+ D_{out} = \left\lfloor D_{in} \times \text{scale\_factor} \right\rfloor
+
+ .. math::
+ H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
+
+ .. warning::
+ With ``align_corners = True``, the linearly interpolating modes
+ (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally
+ align the output and input pixels, and thus the output values can depend
+ on the input size. This was the default behavior for these modes up to
+ version 0.3.1. Since then, the default behavior is
+ ``align_corners = False``. See below for concrete examples on how this
+ affects the outputs.
+
+ .. note::
+ If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`.
+
+ Examples::
+
+ >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
+ >>> input
+ tensor([[[[ 1., 2.],
+ [ 3., 4.]]]])
+
+ >>> m = nn.Upsample(scale_factor=2, mode='nearest')
+ >>> m(input)
+ tensor([[[[ 1., 1., 2., 2.],
+ [ 1., 1., 2., 2.],
+ [ 3., 3., 4., 4.],
+ [ 3., 3., 4., 4.]]]])
+
+ >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False
+ >>> m(input)
+ tensor([[[[ 1.0000, 1.2500, 1.7500, 2.0000],
+ [ 1.5000, 1.7500, 2.2500, 2.5000],
+ [ 2.5000, 2.7500, 3.2500, 3.5000],
+ [ 3.0000, 3.2500, 3.7500, 4.0000]]]])
+
+ >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+ >>> m(input)
+ tensor([[[[ 1.0000, 1.3333, 1.6667, 2.0000],
+ [ 1.6667, 2.0000, 2.3333, 2.6667],
+ [ 2.3333, 2.6667, 3.0000, 3.3333],
+ [ 3.0000, 3.3333, 3.6667, 4.0000]]]])
+
+ >>> # Try scaling the same data in a larger tensor
+ >>>
+ >>> input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3)
+ >>> input_3x3[:, :, :2, :2].copy_(input)
+ tensor([[[[ 1., 2.],
+ [ 3., 4.]]]])
+ >>> input_3x3
+ tensor([[[[ 1., 2., 0.],
+ [ 3., 4., 0.],
+ [ 0., 0., 0.]]]])
+
+ >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False
+ >>> # Notice that values in top left corner are the same with the small input (except at boundary)
+ >>> m(input_3x3)
+ tensor([[[[ 1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000],
+ [ 1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000],
+ [ 2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000],
+ [ 2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000],
+ [ 0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000],
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
+
+ >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+ >>> # Notice that values in top left corner are now changed
+ >>> m(input_3x3)
+ tensor([[[[ 1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000],
+ [ 1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000],
+ [ 2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000],
+ [ 2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000],
+ [ 1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000],
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
+ """
+ __constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name']
+
+ def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
+ super(Upsample, self).__init__()
+ self.name = type(self).__name__
+ self.size = size
+ self.scale_factor = float(scale_factor) if scale_factor else None
+ self.mode = mode
+ self.align_corners = align_corners
+
+ @weak_script_method
+ def forward(self, input):
+ return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
+
+ def extra_repr(self):
+ if self.scale_factor is not None:
+ info = 'scale_factor=' + str(self.scale_factor)
+ else:
+ info = 'size=' + str(self.size)
+ info += ', mode=' + self.mode
+ return info
+
+
+[docs]@weak_module
+class UpsamplingNearest2d(Upsample):
+ r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input
+ channels.
+
+ To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`
+ as it's constructor argument.
+
+ When :attr:`size` is given, it is the output size of the image `(h, w)`.
+
+ Args:
+ size (int or Tuple[int, int], optional): output spatial sizes
+ scale_factor (float or Tuple[float, float], optional): multiplier for
+ spatial size.
+
+ .. warning::
+ This class is deprecated in favor of :func:`~nn.functional.interpolate`.
+
+ Shape:
+ - Input: :math:`(N, C, H_{in}, W_{in})`
+ - Output: :math:`(N, C, H_{out}, W_{out})` where
+
+ .. math::
+ H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
+
+ Examples::
+
+ >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
+ >>> input
+ tensor([[[[ 1., 2.],
+ [ 3., 4.]]]])
+
+ >>> m = nn.UpsamplingNearest2d(scale_factor=2)
+ >>> m(input)
+ tensor([[[[ 1., 1., 2., 2.],
+ [ 1., 1., 2., 2.],
+ [ 3., 3., 4., 4.],
+ [ 3., 3., 4., 4.]]]])
+ """
+ def __init__(self, size=None, scale_factor=None):
+ super(UpsamplingNearest2d, self).__init__(size, scale_factor, mode='nearest')
+
+
+[docs]@weak_module
+class UpsamplingBilinear2d(Upsample):
+ r"""Applies a 2D bilinear upsampling to an input signal composed of several input
+ channels.
+
+ To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`
+ as it's constructor argument.
+
+ When :attr:`size` is given, it is the output size of the image `(h, w)`.
+
+ Args:
+ size (int or Tuple[int, int], optional): output spatial sizes
+ scale_factor (float or Tuple[float, float], optional): multiplier for
+ spatial size.
+
+ .. warning::
+ This class is deprecated in favor of :func:`~nn.functional.interpolate`. It is
+ equivalent to ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``.
+
+ Shape:
+ - Input: :math:`(N, C, H_{in}, W_{in})`
+ - Output: :math:`(N, C, H_{out}, W_{out})` where
+
+ .. math::
+ H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
+
+ Examples::
+
+ >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
+ >>> input
+ tensor([[[[ 1., 2.],
+ [ 3., 4.]]]])
+
+ >>> m = nn.UpsamplingBilinear2d(scale_factor=2)
+ >>> m(input)
+ tensor([[[[ 1.0000, 1.3333, 1.6667, 2.0000],
+ [ 1.6667, 2.0000, 2.3333, 2.6667],
+ [ 2.3333, 2.6667, 3.0000, 3.3333],
+ [ 3.0000, 3.3333, 3.6667, 4.0000]]]])
+ """
+ def __init__(self, size=None, scale_factor=None):
+ super(UpsamplingBilinear2d, self).__init__(size, scale_factor, mode='bilinear', align_corners=True)
+
+import operator
+import torch
+import warnings
+from itertools import chain
+from ..modules import Module
+from .scatter_gather import scatter_kwargs, gather
+from .replicate import replicate
+from .parallel_apply import parallel_apply
+from torch.cuda._utils import _get_device_index
+
+
+def _check_balance(device_ids):
+ imbalance_warn = """
+ There is an imbalance between your GPUs. You may want to exclude GPU {} which
+ has less than 75% of the memory or cores of GPU {}. You can do so by setting
+ the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
+ environment variable."""
+ device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
+ dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
+
+ def warn_imbalance(get_prop):
+ values = [get_prop(props) for props in dev_props]
+ min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
+ max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
+ if min_val / max_val < 0.75:
+ warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
+ return True
+ return False
+
+ if warn_imbalance(lambda props: props.total_memory):
+ return
+ if warn_imbalance(lambda props: props.multi_processor_count):
+ return
+
+
+[docs]class DataParallel(Module):
+ r"""Implements data parallelism at the module level.
+
+ This container parallelizes the application of the given :attr:`module` by
+ splitting the input across the specified devices by chunking in the batch
+ dimension (other objects will be copied once per device). In the forward
+ pass, the module is replicated on each device, and each replica handles a
+ portion of the input. During the backwards pass, gradients from each replica
+ are summed into the original module.
+
+ The batch size should be larger than the number of GPUs used.
+
+ See also: :ref:`cuda-nn-dataparallel-instead`
+
+ Arbitrary positional and keyword inputs are allowed to be passed into
+ DataParallel but some types are specially handled. tensors will be
+ **scattered** on dim specified (default 0). tuple, list and dict types will
+ be shallow copied. The other types will be shared among different threads
+ and can be corrupted if written to in the model's forward pass.
+
+ The parallelized :attr:`module` must have its parameters and buffers on
+ ``device_ids[0]`` before running this :class:`~torch.nn.DataParallel`
+ module.
+
+ .. warning::
+ In each forward, :attr:`module` is **replicated** on each device, so any
+ updates to the running module in ``forward`` will be lost. For example,
+ if :attr:`module` has a counter attribute that is incremented in each
+ ``forward``, it will always stay at the initial value because the update
+ is done on the replicas which are destroyed after ``forward``. However,
+ :class:`~torch.nn.DataParallel` guarantees that the replica on
+ ``device[0]`` will have its parameters and buffers sharing storage with
+ the base parallelized :attr:`module`. So **in-place** updates to the
+ parameters or buffers on ``device[0]`` will be recorded. E.g.,
+ :class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm`
+ rely on this behavior to update the buffers.
+
+ .. warning::
+ Forward and backward hooks defined on :attr:`module` and its submodules
+ will be invoked ``len(device_ids)`` times, each with inputs located on
+ a particular device. Particularly, the hooks are only guaranteed to be
+ executed in correct order with respect to operations on corresponding
+ devices. For example, it is not guaranteed that hooks set via
+ :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
+ `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
+ that each such hook be executed before the corresponding
+ :meth:`~torch.nn.Module.forward` call of that device.
+
+ .. warning::
+ When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
+ :func:`forward`, this wrapper will return a vector of length equal to
+ number of devices used in data parallelism, containing the result from
+ each device.
+
+ .. note::
+ There is a subtlety in using the
+ ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
+ :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
+ See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
+ details.
+
+
+ Args:
+ module (Module): module to be parallelized
+ device_ids (list of int or torch.device): CUDA devices (default: all devices)
+ output_device (int or torch.device): device location of output (default: device_ids[0])
+
+ Attributes:
+ module (Module): the module to be parallelized
+
+ Example::
+
+ >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
+ >>> output = net(input_var) # input_var can be on any device, including CPU
+ """
+
+ # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
+
+ def __init__(self, module, device_ids=None, output_device=None, dim=0):
+ super(DataParallel, self).__init__()
+
+ if not torch.cuda.is_available():
+ self.module = module
+ self.device_ids = []
+ return
+
+ if device_ids is None:
+ device_ids = list(range(torch.cuda.device_count()))
+ if output_device is None:
+ output_device = device_ids[0]
+
+ self.dim = dim
+ self.module = module
+ self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
+ self.output_device = _get_device_index(output_device, True)
+ self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0]))
+
+ _check_balance(self.device_ids)
+
+ if len(self.device_ids) == 1:
+ self.module.cuda(device_ids[0])
+
+ def forward(self, *inputs, **kwargs):
+ if not self.device_ids:
+ return self.module(*inputs, **kwargs)
+
+ for t in chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError("module must have its parameters and buffers "
+ "on device {} (device_ids[0]) but found one of "
+ "them on device: {}".format(self.src_device_obj, t.device))
+
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ return self.module(*inputs[0], **kwargs[0])
+ replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
+ outputs = self.parallel_apply(replicas, inputs, kwargs)
+ return self.gather(outputs, self.output_device)
+
+ def replicate(self, module, device_ids):
+ return replicate(module, device_ids)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def parallel_apply(self, replicas, inputs, kwargs):
+ return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
+
+ def gather(self, outputs, output_device):
+ return gather(outputs, output_device, dim=self.dim)
+
+
+[docs]def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
+ r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
+
+ This is the functional version of the DataParallel module.
+
+ Args:
+ module (Module): the module to evaluate in parallel
+ inputs (Tensor): inputs to the module
+ device_ids (list of int or torch.device): GPU ids on which to replicate module
+ output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
+ (default: device_ids[0])
+ Returns:
+ a Tensor containing the result of module(input) located on
+ output_device
+ """
+ if not isinstance(inputs, tuple):
+ inputs = (inputs,)
+
+ if device_ids is None:
+ device_ids = list(range(torch.cuda.device_count()))
+
+ if output_device is None:
+ output_device = device_ids[0]
+
+ device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
+ output_device = _get_device_index(output_device, True)
+ src_device_obj = torch.device("cuda:{}".format(device_ids[0]))
+
+ for t in chain(module.parameters(), module.buffers()):
+ if t.device != src_device_obj:
+ raise RuntimeError("module must have its parameters and buffers "
+ "on device {} (device_ids[0]) but found one of "
+ "them on device: {}".format(src_device_obj, t.device))
+
+ inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
+ if len(device_ids) == 1:
+ return module(*inputs[0], **module_kwargs[0])
+ used_device_ids = device_ids[:len(inputs)]
+ replicas = replicate(module, used_device_ids)
+ outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
+ return gather(outputs, output_device, dim)
+
+import copy
+import itertools
+
+import torch
+
+from torch.cuda.comm import broadcast_coalesced
+import torch.distributed as dist
+
+if dist.is_available():
+ from torch.distributed.distributed_c10d import _get_default_group
+
+from ..modules import Module
+from .replicate import replicate
+from .scatter_gather import scatter_kwargs, gather
+from .parallel_apply import parallel_apply
+from torch.cuda._utils import _get_device_index
+
+
+def _find_tensors(obj):
+ r"""
+ Recursively find all tensors contained in the specified object.
+ """
+ if isinstance(obj, torch.Tensor):
+ return [obj]
+ if isinstance(obj, (list, tuple)):
+ return itertools.chain(*map(_find_tensors, obj))
+ if isinstance(obj, dict):
+ return itertools.chain(*map(_find_tensors, obj.values()))
+ return []
+
+
+[docs]class DistributedDataParallel(Module):
+ r"""Implements distributed data parallelism that is based on
+ ``torch.distributed`` package at the module level.
+
+ This container parallelizes the application of the given module by
+ splitting the input across the specified devices by chunking in the batch
+ dimension. The module is replicated on each machine and each device, and
+ each such replica handles a portion of the input. During the backwards
+ pass, gradients from each node are averaged.
+
+ The batch size should be larger than the number of GPUs used locally.
+
+ See also: :ref:`distributed-basics` and :ref:`cuda-nn-dataparallel-instead`.
+ The same constraints on input as in :class:`torch.nn.DataParallel` apply.
+
+ Creation of this class requires that ``torch.distributed`` to be already
+ initialized, by calling :func:`torch.distributed.init_process_group`.
+
+ ``DistributedDataParallel`` can be used in the following two ways:
+
+ (1) Single-Process Multi-GPU
+
+ In this case, a single process will be
+ spawned on each host/node and each process will operate on all the GPUs
+ of the node where it's running. To use ``DistributedDataParallel`` in
+ this way, you can simply construct the model as the following:
+
+ >>> torch.distributed.init_process_group(backend="nccl")
+ >>> model = DistributedDataParallel(model) # device_ids will include all GPU devices by default
+
+ (2) Multi-Process Single-GPU
+
+ This is the highly recommended way to use ``DistributedDataParallel``, with
+ multiple processes, each of which operates on a single GPU. This is
+ currently the fastest approach to do data parallel training using PyTorch
+ and applies to both single-node(multi-GPU) and multi-node data
+ parallel training. It is proven to be significantly faster than
+ :class:`torch.nn.DataParallel` for single-node multi-GPU data
+ parallel training.
+
+ Here is how to use it: on each host with N GPUs, you should spawn up N
+ processes, while ensuring that each process individually works on a single GPU
+ from 0 to N-1. Therefore, it is your job to ensure that your training script
+ operates on a single given GPU by calling:
+
+ >>> torch.cuda.set_device(i)
+
+ where i is from 0 to N-1. In each process, you should refer the following
+ to construct this module:
+
+ >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
+ >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
+
+ In order to spawn up multiple processes per node, you can use either
+ ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``
+
+ .. note:: ``nccl`` backend is currently the fastest and
+ highly recommended backend to be used with Multi-Process Single-GPU
+ distributed training and this applies to both single-node and multi-node
+ distributed training
+
+ .. note:: This module also supports mixed-precision distributed training.
+ This means that your model can have different types of parameters such
+ as mixed types of fp16 and fp32, the gradient reduction on these
+ mixed types of parameters will just work fine.
+ Also note that ``nccl`` backend is currently the fastest and highly
+ recommended backend for fp16/fp32 mixed-precision training.
+
+ .. note:: If you use ``torch.save`` on one process to checkpoint the module,
+ and ``torch.load`` on some other processes to recover it, make sure that
+ ``map_location`` is configured properly for every process. Without
+ ``map_location``, ``torch.load`` would recover the module to devices
+ where the module was saved from.
+
+ .. warning::
+ This module works only with the ``gloo`` and ``nccl`` backends.
+
+ .. warning::
+ Constructor, forward method, and differentiation of the output (or a
+ function of the output of this module) is a distributed synchronization
+ point. Take that into account in case different processes might be
+ executing different code.
+
+ .. warning::
+ This module assumes all parameters are registered in the model by the
+ time it is created. No parameters should be added nor removed later.
+ Same applies to buffers.
+
+ .. warning::
+ This module assumes all parameters are registered in the model of each
+ distributed processes are in the same order. The module itself will
+ conduct gradient all-reduction following the reverse order of the
+ registered parameters of the model. In other words, it is users'
+ responsibility to ensure that each distributed process has the exact
+ same model and thus the exact same parameter registration order.
+
+ .. warning::
+ This module assumes all buffers and gradients are dense.
+
+ .. warning::
+ This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
+ only work if gradients are to be accumulated in ``.grad`` attributes of
+ parameters).
+
+ .. warning::
+
+ If you plan on using this module with a ``nccl`` backend or a ``gloo``
+ backend (that uses Infiniband), together with a DataLoader that uses
+ multiple workers, please change the multiprocessing start method to
+ ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
+ Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
+ likely experience deadlocks if you don't change this setting.
+
+ .. warning::
+ Forward and backward hooks defined on :attr:`module` and its submodules
+ won't be invoked anymore, unless the hooks are initialized in the
+ :meth:`forward` method.
+
+ .. warning::
+ You should never try to change your model's parameters after wrapping
+ up your model with DistributedDataParallel. In other words, when
+ wrapping up your model with DistributedDataParallel, the constructor of
+ DistributedDataParallel will register the additional gradient
+ reduction functions on all the parameters of the model itself at the
+ time of construction. If you change the model's parameters after
+ the DistributedDataParallel construction, this is not supported and
+ unexpected behaviors can happen, since some parameters' gradient
+ reduction functions might not get called.
+
+ .. note::
+ Parameters are never broadcast between processes. The module performs
+ an all-reduce step on gradients and assumes that they will be modified
+ by the optimizer in all processes in the same way. Buffers
+ (e.g. BatchNorm stats) are broadcast from the module in process of rank
+ 0, to all other replicas in the system in every iteration.
+
+ Args:
+ module (Module): module to be parallelized
+ device_ids (list of int or torch.device): CUDA devices. This should
+ only be provided when the input module resides on a single
+ CUDA device. For single-device modules, the ``i``th
+ :attr:`module` replica is placed on ``device_ids[i]``. For
+ multi-device modules and CPU modules, device_ids must be None
+ or an empty list, and input data for the forward pass must be
+ placed on the correct device. (default: all devices for
+ single-device modules)
+ output_device (int or torch.device): device location of output for
+ single-device CUDA modules. For multi-device modules and
+ CPU modules, it must be None, and the module itself
+ dictates the output location. (default: device_ids[0] for
+ single-device modules)
+ broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
+ the module at beginning of the forward function.
+ (default: ``True``)
+ process_group: the process group to be used for distributed data
+ all-reduction. If ``None``, the default process group, which
+ is created by ```torch.distributed.init_process_group```,
+ will be used. (default: ``None``)
+ bucket_cap_mb: DistributedDataParallel will bucket parameters into
+ multiple buckets so that gradient reduction of each
+ bucket can potentially overlap with backward computation.
+ :attr:`bucket_cap_mb` controls the bucket size in MegaBytes (MB)
+ (default: 25)
+ find_unused_parameters (bool): Traverse the autograd graph of all tensors
+ contained in the return value of the wrapped
+ module's ``forward`` function.
+ Parameters that don't receive gradients as
+ part of this graph are preemptively marked
+ as being ready to be reduced.
+ (default: ``False``)
+ check_reduction: when setting to ``True``, it enables DistributedDataParallel
+ to automatically check if the previous iteration's
+ backward reductions were successfully issued at the
+ beginning of every iteration's forward function.
+ You normally don't need this option enabled unless you
+ are observing weird behaviors such as different ranks
+ are getting different gradients, which should not
+ happen if DistributedDataParallel is correctly used.
+ (default: ``False``)
+
+ Attributes:
+ module (Module): the module to be parallelized
+
+ Example::
+
+ >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
+ >>> net = torch.nn.DistributedDataParallel(model, pg)
+ """
+ def __init__(self, module, device_ids=None,
+ output_device=None, dim=0, broadcast_buffers=True,
+ process_group=None, bucket_cap_mb=25,
+ find_unused_parameters=False,
+ check_reduction=False):
+
+ super(DistributedDataParallel, self).__init__()
+
+ self.is_multi_device_module = len({p.device for p in module.parameters()}) > 1
+ self.is_cuda = all([p.device.type == 'cuda' for p in module.parameters()])
+
+ if not self.is_cuda or self.is_multi_device_module:
+ assert not device_ids and not output_device, (
+ "DistributedDataParallel device_ids and output_device arguments "
+ "only work with single-device CUDA modules, but got "
+ "device_ids {}, output_device {}, and module parameters {}."
+ ).format(device_ids, output_device, {p.device for p in module.parameters()})
+
+ self.device_ids = None
+ self.output_device = None
+ else:
+ # Use all devices by default for single-device CUDA modules
+ if device_ids is None:
+ device_ids = list(range(torch.cuda.device_count()))
+
+ self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
+
+ if output_device is None:
+ output_device = device_ids[0]
+
+ self.output_device = _get_device_index(output_device, True)
+
+ if self.is_multi_device_module:
+ assert self.is_cuda, (
+ "DistributedDataParallel with multi-device module only works "
+ "with CUDA devices, but module parameters locate in {}."
+ ).format({p.device for p in module.parameters()})
+
+ if process_group is None:
+ self.process_group = _get_default_group()
+ else:
+ self.process_group = process_group
+
+ self.dim = dim
+ self.module = module
+ self.broadcast_buffers = broadcast_buffers
+ self.find_unused_parameters = find_unused_parameters
+
+ if check_reduction:
+ # This argument is no longer used since the reducer
+ # will ensure reduction completes even if some parameters
+ # do not receive gradients.
+ pass
+
+ MB = 1024 * 1024
+
+ # used for intra-node param sync and inter-node sync as well
+ self.broadcast_bucket_size = int(250 * MB)
+
+ # reduction bucket size
+ self.bucket_bytes_cap = int(bucket_cap_mb * MB)
+
+ # Sync params and buffers
+ module_states = list(self.module.state_dict().values())
+ if len(module_states) > 0:
+ self._dist_broadcast_coalesced(module_states,
+ self.broadcast_bucket_size)
+
+ self._ddp_init_helper()
+
+ def _ddp_init_helper(self):
+ """
+ Initialization helper function that does the following:
+
+ (1) replicating the module from device[0] to the other devices
+ (2) bucketing the parameters for reductions
+ (3) resetting the bucketing states
+ (4) registering the grad hooks
+ (5) passing a handle of DDP to SyncBatchNorm Layer
+ """
+ if self.device_ids and len(self.device_ids) > 1:
+ # only create replicas for single-device CUDA modules
+ #
+ # TODO: we don't need to replicate params in here. they're always going to
+ # be broadcasted using larger blocks in broadcast_coalesced, so it might be
+ # better to not pollute the caches with these small blocks
+ self._module_copies = replicate(self.module, self.device_ids, detach=True)
+ self._module_copies[0] = self.module
+
+ for module_copy in self._module_copies[1:]:
+ for param, copy_param in zip(self.module.parameters(), module_copy.parameters()):
+ copy_param.requires_grad = param.requires_grad
+
+ else:
+ self._module_copies = [self.module]
+
+ self.modules_params = [list(m.parameters()) for m in self._module_copies]
+ self.modules_buffers = [list(m.buffers()) for m in self._module_copies]
+
+ param_list = [
+ list(filter(lambda p: p.requires_grad, module.parameters()))
+ for module in self._module_copies]
+
+ # The bucket size limit is specified in the constructor.
+ # Additionally, we allow for a single small bucket for parameters
+ # that are defined first, such that their gradients don't spill into
+ # a much larger bucket, adding unnecessary latency after gradient
+ # computation finishes. Experiments showed 1MB is a reasonable value.
+ bucket_indices = dist._compute_bucket_assignment_by_size(
+ param_list[0],
+ [1024 * 1024, self.bucket_bytes_cap])
+
+ # Note: reverse list of buckets because we want to approximate the
+ # order in which their gradients are produced, and assume they
+ # are used in the forward pass in the order they are defined.
+ self.reducer = dist.Reducer(
+ param_list,
+ list(reversed(bucket_indices)),
+ self.process_group)
+
+ # passing a handle to torch.nn.SyncBatchNorm layer
+ self._passing_sync_batchnorm_handle(self._module_copies)
+
+ def __getstate__(self):
+ self._check_default_group()
+ attrs = copy.copy(self.__dict__)
+ del attrs['process_group']
+ del attrs['reducer']
+ return attrs
+
+ def __setstate__(self, state):
+ # If serializable, then the process group should be the default one
+ self.process_group = _get_default_group()
+ super(DistributedDataParallel, self).__setstate__(state)
+ self._ddp_init_helper()
+
+ def _check_default_group(self):
+ pickle_not_supported = False
+ try:
+ if self.process_group != _get_default_group():
+ pickle_not_supported = True
+ except RuntimeError:
+ pickle_not_supported = True
+
+ if pickle_not_supported:
+ raise RuntimeError("DDP Pickling/Unpickling are only supported "
+ "when using DDP with the default process "
+ "group. That is, when you have called "
+ "init_process_group and have not passed "
+ "process_group argument to DDP constructor")
+
+ def forward(self, *inputs, **kwargs):
+ self._sync_params()
+ if self.device_ids:
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ output = self.module(*inputs[0], **kwargs[0])
+ else:
+ outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
+ output = self.gather(outputs, self.output_device)
+ else:
+ output = self.module(*inputs, **kwargs)
+
+ if torch.is_grad_enabled():
+ # We'll return the output object verbatim since it is a freeform
+ # object. We need to find any tensors in this object, though,
+ # because we need to figure out which parameters were used during
+ # this forward pass, to ensure we short circuit reduction for any
+ # unused parameters. Only if `find_unused_parameters` is set.
+ if self.find_unused_parameters:
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ return output
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def parallel_apply(self, replicas, inputs, kwargs):
+ return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
+
+ def gather(self, outputs, output_device):
+ return gather(outputs, output_device, dim=self.dim)
+
+ def train(self, mode=True):
+ super(DistributedDataParallel, self).train(mode)
+ for module in self._module_copies[1:]:
+ module.train(mode)
+
+ def _dist_broadcast_coalesced(self, tensors, buffer_size):
+ dist._dist_broadcast_coalesced(self.process_group, tensors, buffer_size, False)
+
+ def _sync_params(self):
+ with torch.no_grad():
+ # only do intra-node parameters sync for replicated single-device
+ # CUDA modules
+ if self.device_ids and len(self.device_ids) > 1:
+ # intra-node parameter sync
+ result = broadcast_coalesced(self.modules_params[0],
+ self.device_ids,
+ self.broadcast_bucket_size)
+ for tensors, module_params in zip(result[1:],
+ self.modules_params[1:]):
+ for tensor, param in zip(tensors, module_params):
+ param.set_(tensor)
+ # Assume we have just run the optimizer and zeroed the
+ # grads of the parameters on the root model. We need
+ # to zero the grads on all model replicas as well.
+ # This snippet is copied from torch.optim.Optimizer.
+ if param.grad is not None:
+ param.grad.detach_()
+ param.grad.zero_()
+
+ # module buffer sync
+ if self.broadcast_buffers and len(self.modules_buffers[0]) > 0:
+ # cross-node buffer sync
+ self._dist_broadcast_coalesced(self.modules_buffers[0],
+ self.broadcast_bucket_size)
+ # only do intra-node buffer sync for replicated single-device
+ # CUDA modules
+ if self.device_ids and len(self.device_ids) > 1:
+ # intra-node buffer sync
+ result = broadcast_coalesced(self.modules_buffers[0],
+ self.device_ids,
+ self.broadcast_bucket_size)
+ for tensors, module_buffers in zip(result[1:],
+ self.modules_buffers[1:]):
+ for tensor, buffer in zip(tensors, module_buffers):
+ buffer.set_(tensor)
+
+ def _passing_sync_batchnorm_handle(self, module_copies):
+ for dev_idx, module in enumerate(module_copies):
+ for layer in module.modules():
+ if isinstance(layer, torch.nn.modules.SyncBatchNorm):
+ assert self.is_cuda, "SyncBatchNorm layers only work with CUDA modules"
+ layer._specify_ddp_gpu_num(
+ len(self.device_ids) if self.device_ids else 1)
+
+import torch
+from collections import OrderedDict
+
+
+[docs]class Parameter(torch.Tensor):
+ r"""A kind of Tensor that is to be considered a module parameter.
+
+ Parameters are :class:`~torch.Tensor` subclasses, that have a
+ very special property when used with :class:`Module` s - when they're
+ assigned as Module attributes they are automatically added to the list of
+ its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator.
+ Assigning a Tensor doesn't have such effect. This is because one might
+ want to cache some temporary state, like last hidden state of the RNN, in
+ the model. If there was no such class as :class:`Parameter`, these
+ temporaries would get registered too.
+
+ Arguments:
+ data (Tensor): parameter tensor.
+ requires_grad (bool, optional): if the parameter requires gradient. See
+ :ref:`excluding-subgraphs` for more details. Default: `True`
+ """
+
+ def __new__(cls, data=None, requires_grad=True):
+ if data is None:
+ data = torch.Tensor()
+ return torch.Tensor._make_subclass(cls, data, requires_grad)
+
+ def __deepcopy__(self, memo):
+ if id(self) in memo:
+ return memo[id(self)]
+ else:
+ result = type(self)(self.data.clone(), self.requires_grad)
+ memo[id(self)] = result
+ return result
+
+ def __repr__(self):
+ return 'Parameter containing:\n' + super(Parameter, self).__repr__()
+
+ def __reduce_ex__(self, proto):
+ # See Note [Don't serialize hooks]
+ return (
+ torch._utils._rebuild_parameter,
+ (self.data, self.requires_grad, OrderedDict())
+ )
+
+import warnings
+import torch
+from torch._six import inf
+
+
+[docs]def clip_grad_norm_(parameters, max_norm, norm_type=2):
+ r"""Clips gradient norm of an iterable of parameters.
+
+ The norm is computed over all gradients together, as if they were
+ concatenated into a single vector. Gradients are modified in-place.
+
+ Arguments:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
+ infinity norm.
+
+ Returns:
+ Total norm of the parameters (viewed as a single vector).
+ """
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+ if norm_type == inf:
+ total_norm = max(p.grad.data.abs().max() for p in parameters)
+ else:
+ total_norm = 0
+ for p in parameters:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ total_norm = total_norm ** (1. / norm_type)
+ clip_coef = max_norm / (total_norm + 1e-6)
+ if clip_coef < 1:
+ for p in parameters:
+ p.grad.data.mul_(clip_coef)
+ return total_norm
+
+
+def clip_grad_norm(parameters, max_norm, norm_type=2):
+ r"""Clips gradient norm of an iterable of parameters.
+
+ .. warning::
+ This method is now deprecated in favor of
+ :func:`torch.nn.utils.clip_grad_norm_`.
+ """
+ warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor "
+ "of torch.nn.utils.clip_grad_norm_.", stacklevel=2)
+ return clip_grad_norm_(parameters, max_norm, norm_type)
+
+
+[docs]def clip_grad_value_(parameters, clip_value):
+ r"""Clips gradient of an iterable of parameters at specified value.
+
+ Gradients are modified in-place.
+
+ Arguments:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ clip_value (float or int): maximum allowed value of the gradients.
+ The gradients are clipped in the range
+ :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
+ """
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ clip_value = float(clip_value)
+ for p in filter(lambda p: p.grad is not None, parameters):
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
+
+import torch
+
+
+[docs]def parameters_to_vector(parameters):
+ r"""Convert parameters to one vector
+
+ Arguments:
+ parameters (Iterable[Tensor]): an iterator of Tensors that are the
+ parameters of a model.
+
+ Returns:
+ The parameters represented by a single vector
+ """
+ # Flag for the device where the parameter is located
+ param_device = None
+
+ vec = []
+ for param in parameters:
+ # Ensure the parameters are located in the same device
+ param_device = _check_param_device(param, param_device)
+
+ vec.append(param.view(-1))
+ return torch.cat(vec)
+
+
+[docs]def vector_to_parameters(vec, parameters):
+ r"""Convert one vector to the parameters
+
+ Arguments:
+ vec (Tensor): a single vector represents the parameters of a model.
+ parameters (Iterable[Tensor]): an iterator of Tensors that are the
+ parameters of a model.
+ """
+ # Ensure vec of type Tensor
+ if not isinstance(vec, torch.Tensor):
+ raise TypeError('expected torch.Tensor, but got: {}'
+ .format(torch.typename(vec)))
+ # Flag for the device where the parameter is located
+ param_device = None
+
+ # Pointer for slicing the vector for each parameter
+ pointer = 0
+ for param in parameters:
+ # Ensure the parameters are located in the same device
+ param_device = _check_param_device(param, param_device)
+
+ # The length of the parameter
+ num_param = param.numel()
+ # Slice the vector, reshape it, and replace the old data of the parameter
+ param.data = vec[pointer:pointer + num_param].view_as(param).data
+
+ # Increment the pointer
+ pointer += num_param
+
+
+def _check_param_device(param, old_param_device):
+ r"""This helper function is to check if the parameters are located
+ in the same device. Currently, the conversion between model parameters
+ and single vector form is not supported for multiple allocations,
+ e.g. parameters in different GPUs, or mixture of CPU/GPU.
+
+ Arguments:
+ param ([Tensor]): a Tensor of a parameter of a model
+ old_param_device (int): the device where the first parameter of a
+ model is allocated.
+
+ Returns:
+ old_param_device (int): report device for the first time
+ """
+
+ # Meet the first parameter
+ if old_param_device is None:
+ old_param_device = param.get_device() if param.is_cuda else -1
+ else:
+ warn = False
+ if param.is_cuda: # Check if in same GPU
+ warn = (param.get_device() != old_param_device)
+ else: # Check if in CPU
+ warn = (old_param_device != -1)
+ if warn:
+ raise TypeError('Found two parameters on different devices, '
+ 'this is currently not supported.')
+ return old_param_device
+
+from collections import namedtuple
+import warnings
+
+import torch
+
+
+PackedSequence_ = namedtuple('PackedSequence',
+ ['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])
+
+
+def bind(optional, fn):
+ if optional is None:
+ return None
+ return fn(optional)
+
+
+[docs]class PackedSequence(PackedSequence_):
+ r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence.
+
+ All RNN modules accept packed sequences as inputs.
+
+ Note:
+ Instances of this class should never be created manually. They are meant
+ to be instantiated by functions like :func:`pack_padded_sequence`.
+
+ Batch sizes represent the number elements at each sequence step in
+ the batch, not the varying sequence lengths passed to
+ :func:`pack_padded_sequence`. For instance, given data ``abc`` and ``x``
+ the :class:`PackedSequence` would contain data ``axbc`` with
+ ``batch_sizes=[2,1,1]``.
+
+ Attributes:
+ data (Tensor): Tensor containing packed sequence
+ batch_sizes (Tensor): Tensor of integers holding
+ information about the batch size at each sequence step
+ sorted_indices (Tensor, optional): Tensor of integers holding how this
+ :class:`PackedSequence` is constructed from sequences.
+ unsorted_indices (Tensor, optional): Tensor of integers holding how this
+ to recover the original sequences with correct order.
+
+ .. note::
+ :attr:`data` can be on arbitrary device and of arbitrary dtype.
+ :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64``
+ tensors on the same device as :attr:`data`.
+
+ However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor.
+
+ This invariant is maintained throughout :class:`PackedSequence` class,
+ and all functions that construct a `:class:PackedSequence` in PyTorch
+ (i.e., they only pass in tensors conforming to this constraint).
+
+ """
+
+ # NOTE [ device and dtype of a PackedSequence ]
+ #
+ # See the note above in doc string (starting with ":attr:`data` can be on
+ # arbitrary device...").
+
+ def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
+ # PackedSequence used to only have __init__(self, data, batch_sizes)
+ # without a __new__ like this. So to preserve BC for calling in keyword
+ # arg style (e.g., `PackedSequence(data=..., batch_sizes=...)`), we have
+ # to provide two arguments with exact names `data` and `batch_sizes`.
+
+ # NB: if unsorted_indices is provided, it should be the inverse permutation
+ # to sorted_indices. Don't assert it here because the PackedSequence ctor
+ # should only be used internally.
+ if unsorted_indices is None:
+ unsorted_indices = invert_permutation(sorted_indices)
+
+ # support being called as `PackedSequence(data, batch_sizes, sorted_indices)`
+ if batch_sizes is not None:
+ return super(PackedSequence, cls).__new__(
+ cls, data, batch_sizes, sorted_indices, unsorted_indices)
+
+ # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
+ else:
+ assert isinstance(data, (list, tuple)) and len(data) == 2
+ return super(PackedSequence, cls).__new__(
+ cls, data[0], data[1], sorted_indices)
+
+ def pin_memory(self):
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ return type(self)(self.data.pin_memory(), self.batch_sizes,
+ bind(self.sorted_indices, lambda t: t.pin_memory()),
+ bind(self.unsorted_indices, lambda t: t.pin_memory()))
+
+ def cuda(self, *args, **kwargs):
+ """Returns a GPU copy if `self.data` not already on the GPU"""
+ if self.is_cuda:
+ return self
+ else:
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ return type(self)(self.data.cuda(*args, **kwargs), self.batch_sizes,
+ bind(self.sorted_indices, lambda t: t.cuda(*args, **kwargs)),
+ bind(self.unsorted_indices, lambda t: t.cuda(*args, **kwargs)))
+
+ def cpu(self):
+ """Returns a CPU copy if `self.data` not already on the CPU"""
+ if self.is_cuda:
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ return type(self)(self.data.cpu(), self.batch_sizes,
+ bind(self.sorted_indices, lambda t: t.cpu()),
+ bind(self.unsorted_indices, lambda t: t.cpu()))
+ else:
+ return self
+
+ def double(self):
+ r"""Returns copy with `self.data` cast to double type"""
+
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ return type(self)(self.data.double(), self.batch_sizes,
+ self.sorted_indices, self.unsorted_indices)
+
+ def float(self):
+ r"""Returns copy with `self.data` cast to float type"""
+
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ return type(self)(self.data.float(), self.batch_sizes,
+ self.sorted_indices, self.unsorted_indices)
+
+ def half(self):
+ r"""Returns copy with `self.data` cast to half type"""
+
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ return type(self)(self.data.half(), self.batch_sizes,
+ self.sorted_indices, self.unsorted_indices)
+
+ def long(self):
+ r"""Returns copy with `self.data` cast to long type"""
+
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ return type(self)(self.data.long(), self.batch_sizes,
+ self.sorted_indices, self.unsorted_indices)
+
+ def int(self):
+ r"""Returns copy with `self.data` cast to int type"""
+
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ return type(self)(self.data.int(), self.batch_sizes,
+ self.sorted_indices, self.unsorted_indices)
+
+ def short(self):
+ r"""Returns copy with `self.data` cast to short type"""
+
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ return type(self)(self.data.short(), self.batch_sizes,
+ self.sorted_indices, self.unsorted_indices)
+
+ def char(self):
+ r"""Returns copy with `self.data` cast to char type"""
+
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ return type(self)(self.data.char(), self.batch_sizes,
+ self.sorted_indices, self.unsorted_indices)
+
+ def byte(self):
+ r"""Returns copy with `self.data` cast to byte type"""
+
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ return type(self)(self.data.byte(), self.batch_sizes,
+ self.sorted_indices, self.unsorted_indices)
+
+ def to(self, *args, **kwargs):
+ r"""Performs dtype and/or device conversion on `self.data`.
+
+ It has similar signature as :meth:`torch.Tensor.to`.
+
+ .. note::
+
+ If the ``self.data`` Tensor already has the correct :class:`torch.dtype`
+ and :class:`torch.device`, then ``self`` is returned.
+ Otherwise, returns a copy with the desired configuration.
+ """
+
+ # Why not convert `batch_sizes`?
+ # See NOTE [ device and dtype of a PackedSequence ]
+ data = self.data.to(*args, **kwargs)
+ sorted_indices = self.sorted_indices
+ unsorted_indices = self.unsorted_indices
+ device_kw = 'device'
+ if device_kw in kwargs:
+ sorted_indices = bind(sorted_indices, lambda t: t.to(kwargs[device_kw]))
+ unsorted_indices = bind(unsorted_indices, lambda t: t.to(kwargs[device_kw]))
+ if data is self.data:
+ return self
+ else:
+ return type(self)(data, self.batch_sizes,
+ sorted_indices, unsorted_indices)
+
+ @property
+ def is_cuda(self):
+ r"""Returns true if `self.data` stored on a gpu"""
+ return self.data.is_cuda
+
+ def is_pinned(self):
+ r"""Returns true if `self.data` stored on in pinned memory"""
+ return self.data.is_pinned()
+
+
+def invert_permutation(permutation):
+ if permutation is None:
+ return None
+ output = torch.empty_like(permutation)
+ output.scatter_(0, permutation,
+ torch.arange(0, permutation.numel(), device=permutation.device))
+ return output
+
+
+[docs]def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True):
+ r"""Packs a Tensor containing padded sequences of variable length.
+
+ :attr:`input` can be of size ``T x B x *`` where `T` is the length of the
+ longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and
+ ``*`` is any number of dimensions (including 0). If ``batch_first`` is
+ ``True``, ``B x T x *`` :attr:`input` is expected.
+
+ For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is
+ ``True``, the sequences should be sorted by length in a decreasing order, i.e.
+ ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
+ one. `enforce_sorted = True` is only necessary for ONNX export.
+
+ Note:
+ This function accepts any input that has at least two dimensions. You
+ can apply it to pack the labels, and use the output of the RNN with
+ them to compute the loss directly. A Tensor can be retrieved from
+ a :class:`PackedSequence` object by accessing its ``.data`` attribute.
+
+ Arguments:
+ input (Tensor): padded batch of variable length sequences.
+ lengths (Tensor): list of sequences lengths of each batch element.
+ batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *``
+ format.
+ enforce_sorted (bool, optional): if ``True``, the input is expected to
+ contain sequences sorted by length in a decreasing order. If
+ ``False``, this condition is not checked. Default: ``True``.
+
+ Returns:
+ a :class:`PackedSequence` object
+ """
+ if torch._C._get_tracing_state() and not isinstance(lengths, torch.Tensor):
+ warnings.warn('pack_padded_sequence has been called with a Python list of '
+ 'sequence lengths. The tracer cannot track the data flow of Python '
+ 'values, and it will treat them as constants, likely rendering '
+ 'the trace incorrect for any other combination of lengths.',
+ category=torch.jit.TracerWarning, stacklevel=2)
+ lengths = torch.as_tensor(lengths, dtype=torch.int64)
+ if enforce_sorted:
+ sorted_indices = None
+ else:
+ lengths, sorted_indices = torch.sort(lengths, descending=True)
+ sorted_indices = sorted_indices.to(input.device)
+ batch_dim = 0 if batch_first else 1
+ input = input.index_select(batch_dim, sorted_indices)
+
+ data, batch_sizes = \
+ torch._C._VariableFunctions._pack_padded_sequence(input, lengths, batch_first)
+ return PackedSequence(data, batch_sizes, sorted_indices)
+
+
+[docs]def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):
+ r"""Pads a packed batch of variable length sequences.
+
+ It is an inverse operation to :func:`pack_padded_sequence`.
+
+ The returned Tensor's data will be of size ``T x B x *``, where `T` is the length
+ of the longest sequence and `B` is the batch size. If ``batch_first`` is True,
+ the data will be transposed into ``B x T x *`` format.
+
+ Batch elements will be ordered decreasingly by their length.
+
+ .. note::
+ :attr:`total_length` is useful to implement the
+ ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
+ :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
+ See :ref:`this FAQ section <pack-rnn-unpack-with-data-parallelism>` for
+ details.
+
+ Arguments:
+ sequence (PackedSequence): batch to pad
+ batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
+ format.
+ padding_value (float, optional): values for padded elements.
+ total_length (int, optional): if not ``None``, the output will be padded to
+ have length :attr:`total_length`. This method will throw :class:`ValueError`
+ if :attr:`total_length` is less than the max sequence length in
+ :attr:`sequence`.
+
+ Returns:
+ Tuple of Tensor containing the padded sequence, and a Tensor
+ containing the list of lengths of each sequence in the batch.
+
+ """
+ max_seq_length = sequence.batch_sizes.size(0)
+ if total_length is not None:
+ if total_length < max_seq_length:
+ raise ValueError("Expected total_length to be at least the length "
+ "of the longest sequence in input, but got "
+ "total_length={} and max sequence length being {}"
+ .format(total_length, max_seq_length))
+ max_seq_length = total_length
+ padded_output, lengths = torch._C._VariableFunctions._pad_packed_sequence(
+ sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length)
+ if sequence.unsorted_indices is not None:
+ batch_dim = 0 if batch_first else 1
+ return padded_output.index_select(batch_dim, sequence.unsorted_indices), \
+ lengths[sequence.unsorted_indices]
+ return padded_output, lengths
+
+
+[docs]def pad_sequence(sequences, batch_first=False, padding_value=0):
+ r"""Pad a list of variable length Tensors with ``padding_value``
+
+ ``pad_sequence`` stacks a list of Tensors along a new dimension,
+ and pads them to equal length. For example, if the input is list of
+ sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
+ otherwise.
+
+ `B` is batch size. It is equal to the number of elements in ``sequences``.
+ `T` is length of the longest sequence.
+ `L` is length of the sequence.
+ `*` is any number of trailing dimensions, including none.
+
+ Example:
+ >>> from torch.nn.utils.rnn import pad_sequence
+ >>> a = torch.ones(25, 300)
+ >>> b = torch.ones(22, 300)
+ >>> c = torch.ones(15, 300)
+ >>> pad_sequence([a, b, c]).size()
+ torch.Size([25, 3, 300])
+
+ Note:
+ This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
+ where `T` is the length of the longest sequence. This function assumes
+ trailing dimensions and type of all the Tensors in sequences are same.
+
+ Arguments:
+ sequences (list[Tensor]): list of variable length sequences.
+ batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
+ ``T x B x *`` otherwise
+ padding_value (float, optional): value for padded elements. Default: 0.
+
+ Returns:
+ Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
+ Tensor of size ``B x T x *`` otherwise
+ """
+
+ # assuming trailing dimensions and type of all the Tensors
+ # in sequences are same and fetching those from sequences[0]
+ max_size = sequences[0].size()
+ trailing_dims = max_size[1:]
+ max_len = max([s.size(0) for s in sequences])
+ if batch_first:
+ out_dims = (len(sequences), max_len) + trailing_dims
+ else:
+ out_dims = (max_len, len(sequences)) + trailing_dims
+
+ out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
+ for i, tensor in enumerate(sequences):
+ length = tensor.size(0)
+ # use index notation to prevent duplicate references to the tensor
+ if batch_first:
+ out_tensor[i, :length, ...] = tensor
+ else:
+ out_tensor[:length, i, ...] = tensor
+
+ return out_tensor
+
+
+[docs]def pack_sequence(sequences, enforce_sorted=True):
+ r"""Packs a list of variable length Tensors
+
+ ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is
+ the length of a sequence and `*` is any number of trailing dimensions,
+ including zero.
+
+ For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted``
+ is ``True``, the sequences should be sorted in the order of decreasing length.
+ ``enforce_sorted = True`` is only necessary for ONNX export.
+
+
+ Example:
+ >>> from torch.nn.utils.rnn import pack_sequence
+ >>> a = torch.tensor([1,2,3])
+ >>> b = torch.tensor([4,5])
+ >>> c = torch.tensor([6])
+ >>> pack_sequence([a, b, c])
+ PackedSequence(data=tensor([ 1, 4, 6, 2, 5, 3]), batch_sizes=tensor([ 3, 2, 1]))
+
+
+ Arguments:
+ sequences (list[Tensor]): A list of sequences of decreasing length.
+ enforce_sorted (bool, optional): if ``True``, checks that the input
+ contains sequences sorted by length in a decreasing order. If
+ ``False``, this condition is not checked. Default: ``True``.
+
+ Returns:
+ a :class:`PackedSequence` object
+ """
+ lengths = [v.size(0) for v in sequences]
+ return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted)
+
+
+def get_packed_sequence(data, batch_sizes, sorted_indices, unsorted_indices):
+ return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
+
+"""
+Spectral Normalization from https://arxiv.org/abs/1802.05957
+"""
+import torch
+from torch.nn.functional import normalize
+
+
+class SpectralNorm(object):
+ # Invariant before and after each forward call:
+ # u = normalize(W @ v)
+ # NB: At initialization, this invariant is not enforced
+
+ _version = 1
+ # At version 1:
+ # made `W` not a buffer,
+ # added `v` as a buffer, and
+ # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
+
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
+ self.name = name
+ self.dim = dim
+ if n_power_iterations <= 0:
+ raise ValueError('Expected n_power_iterations to be positive, but '
+ 'got n_power_iterations={}'.format(n_power_iterations))
+ self.n_power_iterations = n_power_iterations
+ self.eps = eps
+
+ def reshape_weight_to_matrix(self, weight):
+ weight_mat = weight
+ if self.dim != 0:
+ # permute dim to front
+ weight_mat = weight_mat.permute(self.dim,
+ *[d for d in range(weight_mat.dim()) if d != self.dim])
+ height = weight_mat.size(0)
+ return weight_mat.reshape(height, -1)
+
+ def compute_weight(self, module, do_power_iteration):
+ # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
+ # updated in power iteration **in-place**. This is very important
+ # because in `DataParallel` forward, the vectors (being buffers) are
+ # broadcast from the parallelized module to each module replica,
+ # which is a new module object created on the fly. And each replica
+ # runs its own spectral norm power iteration. So simply assigning
+ # the updated vectors to the module this function runs on will cause
+ # the update to be lost forever. And the next time the parallelized
+ # module is replicated, the same randomly initialized vectors are
+ # broadcast and used!
+ #
+ # Therefore, to make the change propagate back, we rely on two
+ # important behaviors (also enforced via tests):
+ # 1. `DataParallel` doesn't clone storage if the broadcast tensor
+ # is already on correct device; and it makes sure that the
+ # parallelized module is already on `device[0]`.
+ # 2. If the out tensor in `out=` kwarg has correct shape, it will
+ # just fill in the values.
+ # Therefore, since the same power iteration is performed on all
+ # devices, simply updating the tensors in-place will make sure that
+ # the module replica on `device[0]` will update the _u vector on the
+ # parallized module (by shared storage).
+ #
+ # However, after we update `u` and `v` in-place, we need to **clone**
+ # them before using them to normalize the weight. This is to support
+ # backproping through two forward passes, e.g., the common pattern in
+ # GAN training: loss = D(real) - D(fake). Otherwise, engine will
+ # complain that variables needed to do backward for the first forward
+ # (i.e., the `u` and `v` vectors) are changed in the second forward.
+ weight = getattr(module, self.name + '_orig')
+ u = getattr(module, self.name + '_u')
+ v = getattr(module, self.name + '_v')
+ weight_mat = self.reshape_weight_to_matrix(weight)
+
+ if do_power_iteration:
+ with torch.no_grad():
+ for _ in range(self.n_power_iterations):
+ # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
+ # are the first left and right singular vectors.
+ # This power iteration produces approximations of `u` and `v`.
+ v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v)
+ u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u)
+ if self.n_power_iterations > 0:
+ # See above on why we need to clone
+ u = u.clone()
+ v = v.clone()
+
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
+ weight = weight / sigma
+ return weight
+
+ def remove(self, module):
+ with torch.no_grad():
+ weight = self.compute_weight(module, do_power_iteration=False)
+ delattr(module, self.name)
+ delattr(module, self.name + '_u')
+ delattr(module, self.name + '_v')
+ delattr(module, self.name + '_orig')
+ module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))
+
+ def __call__(self, module, inputs):
+ setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training))
+
+ def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
+ # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
+ # (the invariant at top of this class) and `u @ W @ v = sigma`.
+ # This uses pinverse in case W^T W is not invertible.
+ v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1)
+ return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
+
+ @staticmethod
+ def apply(module, name, n_power_iterations, dim, eps):
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ raise RuntimeError("Cannot register two spectral_norm hooks on "
+ "the same parameter {}".format(name))
+
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
+ weight = module._parameters[name]
+
+ with torch.no_grad():
+ weight_mat = fn.reshape_weight_to_matrix(weight)
+
+ h, w = weight_mat.size()
+ # randomly initialize `u` and `v`
+ u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
+ v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
+
+ delattr(module, fn.name)
+ module.register_parameter(fn.name + "_orig", weight)
+ # We still need to assign weight back as fn.name because all sorts of
+ # things may assume that it exists, e.g., when initializing weights.
+ # However, we can't directly assign as it could be an nn.Parameter and
+ # gets added as a parameter. Instead, we register weight.data as a plain
+ # attribute.
+ setattr(module, fn.name, weight.data)
+ module.register_buffer(fn.name + "_u", u)
+ module.register_buffer(fn.name + "_v", v)
+
+ module.register_forward_pre_hook(fn)
+
+ module._register_state_dict_hook(SpectralNormStateDictHook(fn))
+ module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
+ return fn
+
+
+# This is a top level class because Py2 pickle doesn't like inner class nor an
+# instancemethod.
+class SpectralNormLoadStateDictPreHook(object):
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
+ def __init__(self, fn):
+ self.fn = fn
+
+ # For state_dict with version None, (assuming that it has gone through at
+ # least one training forward), we have
+ #
+ # u = normalize(W_orig @ v)
+ # W = W_orig / sigma, where sigma = u @ W_orig @ v
+ #
+ # To compute `v`, we solve `W_orig @ x = u`, and let
+ # v = x / (u @ W_orig @ x) * (W / W_orig).
+ def __call__(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ fn = self.fn
+ version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None)
+ if version is None or version < 1:
+ with torch.no_grad():
+ weight_orig = state_dict[prefix + fn.name + '_orig']
+ weight = state_dict.pop(prefix + fn.name)
+ sigma = (weight_orig / weight).mean()
+ weight_mat = fn.reshape_weight_to_matrix(weight_orig)
+ u = state_dict[prefix + fn.name + '_u']
+ v = fn._solve_v_and_rescale(weight_mat, u, sigma)
+ state_dict[prefix + fn.name + '_v'] = v
+
+
+# This is a top level class because Py2 pickle doesn't like inner class nor an
+# instancemethod.
+class SpectralNormStateDictHook(object):
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
+ def __init__(self, fn):
+ self.fn = fn
+
+ def __call__(self, module, state_dict, prefix, local_metadata):
+ if 'spectral_norm' not in local_metadata:
+ local_metadata['spectral_norm'] = {}
+ key = self.fn.name + '.version'
+ if key in local_metadata['spectral_norm']:
+ raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key))
+ local_metadata['spectral_norm'][key] = self.fn._version
+
+
+[docs]def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
+ r"""Applies spectral normalization to a parameter in the given module.
+
+ .. math::
+ \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
+ \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
+
+ Spectral normalization stabilizes the training of discriminators (critics)
+ in Generative Adversarial Networks (GANs) by rescaling the weight tensor
+ with spectral norm :math:`\sigma` of the weight matrix calculated using
+ power iteration method. If the dimension of the weight tensor is greater
+ than 2, it is reshaped to 2D in power iteration method to get spectral
+ norm. This is implemented via a hook that calculates spectral norm and
+ rescales weight before every :meth:`~Module.forward` call.
+
+ See `Spectral Normalization for Generative Adversarial Networks`_ .
+
+ .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
+
+ Args:
+ module (nn.Module): containing module
+ name (str, optional): name of weight parameter
+ n_power_iterations (int, optional): number of power iterations to
+ calculate spectral norm
+ eps (float, optional): epsilon for numerical stability in
+ calculating norms
+ dim (int, optional): dimension corresponding to number of outputs,
+ the default is ``0``, except for modules that are instances of
+ ConvTranspose{1,2,3}d, when it is ``1``
+
+ Returns:
+ The original module with the spectral norm hook
+
+ Example::
+
+ >>> m = spectral_norm(nn.Linear(20, 40))
+ >>> m
+ Linear(in_features=20, out_features=40, bias=True)
+ >>> m.weight_u.size()
+ torch.Size([40])
+
+ """
+ if dim is None:
+ if isinstance(module, (torch.nn.ConvTranspose1d,
+ torch.nn.ConvTranspose2d,
+ torch.nn.ConvTranspose3d)):
+ dim = 1
+ else:
+ dim = 0
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
+ return module
+
+
+[docs]def remove_spectral_norm(module, name='weight'):
+ r"""Removes the spectral normalization reparameterization from a module.
+
+ Args:
+ module (Module): containing module
+ name (str, optional): name of weight parameter
+
+ Example:
+ >>> m = spectral_norm(nn.Linear(40, 10))
+ >>> remove_spectral_norm(m)
+ """
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ hook.remove(module)
+ del module._forward_pre_hooks[k]
+ return module
+
+ raise ValueError("spectral_norm of '{}' not found in {}".format(
+ name, module))
+
+r"""
+Weight Normalization from https://arxiv.org/abs/1602.07868
+"""
+from torch.nn.parameter import Parameter
+from torch import _weight_norm, norm_except_dim
+
+
+class WeightNorm(object):
+ def __init__(self, name, dim):
+ if dim is None:
+ dim = -1
+ self.name = name
+ self.dim = dim
+
+ def compute_weight(self, module):
+ g = getattr(module, self.name + '_g')
+ v = getattr(module, self.name + '_v')
+ return _weight_norm(v, g, self.dim)
+
+ @staticmethod
+ def apply(module, name, dim):
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, WeightNorm) and hook.name == name:
+ raise RuntimeError("Cannot register two weight_norm hooks on "
+ "the same parameter {}".format(name))
+
+ if dim is None:
+ dim = -1
+
+ fn = WeightNorm(name, dim)
+
+ weight = getattr(module, name)
+
+ # remove w from parameter list
+ del module._parameters[name]
+
+ # add g and v as new parameters and express w as g/||v|| * v
+ module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
+ module.register_parameter(name + '_v', Parameter(weight.data))
+ setattr(module, name, fn.compute_weight(module))
+
+ # recompute weight before every forward()
+ module.register_forward_pre_hook(fn)
+
+ return fn
+
+ def remove(self, module):
+ weight = self.compute_weight(module)
+ delattr(module, self.name)
+ del module._parameters[self.name + '_g']
+ del module._parameters[self.name + '_v']
+ module.register_parameter(self.name, Parameter(weight.data))
+
+ def __call__(self, module, inputs):
+ setattr(module, self.name, self.compute_weight(module))
+
+
+[docs]def weight_norm(module, name='weight', dim=0):
+ r"""Applies weight normalization to a parameter in the given module.
+
+ .. math::
+ \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
+
+ Weight normalization is a reparameterization that decouples the magnitude
+ of a weight tensor from its direction. This replaces the parameter specified
+ by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
+ (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
+ Weight normalization is implemented via a hook that recomputes the weight
+ tensor from the magnitude and direction before every :meth:`~Module.forward`
+ call.
+
+ By default, with ``dim=0``, the norm is computed independently per output
+ channel/plane. To compute a norm over the entire weight tensor, use
+ ``dim=None``.
+
+ See https://arxiv.org/abs/1602.07868
+
+ Args:
+ module (Module): containing module
+ name (str, optional): name of weight parameter
+ dim (int, optional): dimension over which to compute the norm
+
+ Returns:
+ The original module with the weight norm hook
+
+ Example::
+
+ >>> m = weight_norm(nn.Linear(20, 40), name='weight')
+ >>> m
+ Linear(in_features=20, out_features=40, bias=True)
+ >>> m.weight_g.size()
+ torch.Size([40, 1])
+ >>> m.weight_v.size()
+ torch.Size([40, 20])
+
+ """
+ WeightNorm.apply(module, name, dim)
+ return module
+
+
+[docs]def remove_weight_norm(module, name='weight'):
+ r"""Removes the weight normalization reparameterization from a module.
+
+ Args:
+ module (Module): containing module
+ name (str, optional): name of weight parameter
+
+ Example:
+ >>> m = weight_norm(nn.Linear(20, 40))
+ >>> remove_weight_norm(m)
+ """
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, WeightNorm) and hook.name == name:
+ hook.remove(module)
+ del module._forward_pre_hooks[k]
+ return module
+
+ raise ValueError("weight_norm of '{}' not found in {}"
+ .format(name, module))
+
+import torch._C as _C
+
+TensorProtoDataType = _C._onnx.TensorProtoDataType
+OperatorExportTypes = _C._onnx.OperatorExportTypes
+PYTORCH_ONNX_CAFFE2_BUNDLE = _C._onnx.PYTORCH_ONNX_CAFFE2_BUNDLE
+
+ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
+
+
+class ExportTypes:
+ PROTOBUF_FILE = 1
+ ZIP_ARCHIVE = 2
+ COMPRESSED_ZIP_ARCHIVE = 3
+ DIRECTORY = 4
+
+
+def _export(*args, **kwargs):
+ from torch.onnx import utils
+ result = utils._export(*args, **kwargs)
+ return result
+
+
+[docs]def export(*args, **kwargs):
+ from torch.onnx import utils
+ return utils.export(*args, **kwargs)
+
+
+def export_to_pretty_string(*args, **kwargs):
+ from torch.onnx import utils
+ return utils.export_to_pretty_string(*args, **kwargs)
+
+
+def _export_to_pretty_string(*args, **kwargs):
+ from torch.onnx import utils
+ return utils._export_to_pretty_string(*args, **kwargs)
+
+
+def _optimize_trace(trace, operator_export_type):
+ from torch.onnx import utils
+ trace.set_graph(utils._optimize_graph(trace.graph(), operator_export_type))
+
+
+def set_training(*args, **kwargs):
+ from torch.onnx import utils
+ return utils.set_training(*args, **kwargs)
+
+
+def _run_symbolic_function(*args, **kwargs):
+ from torch.onnx import utils
+ return utils._run_symbolic_function(*args, **kwargs)
+
+
+def _run_symbolic_method(*args, **kwargs):
+ from torch.onnx import utils
+ return utils._run_symbolic_method(*args, **kwargs)
+
+
+def is_in_onnx_export():
+ from torch.onnx import utils
+ return utils.is_in_onnx_export()
+
+import torch
+
+from .optimizer import Optimizer
+
+
+[docs]class Adadelta(Optimizer):
+ """Implements Adadelta algorithm.
+
+ It has been proposed in `ADADELTA: An Adaptive Learning Rate Method`__.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ rho (float, optional): coefficient used for computing a running average
+ of squared gradients (default: 0.9)
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-6)
+ lr (float, optional): coefficient that scale delta before it is applied
+ to the parameters (default: 1.0)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+
+ __ https://arxiv.org/abs/1212.5701
+ """
+
+ def __init__(self, params, lr=1.0, rho=0.9, eps=1e-6, weight_decay=0):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= rho <= 1.0:
+ raise ValueError("Invalid rho value: {}".format(rho))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+
+ defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay)
+ super(Adadelta, self).__init__(params, defaults)
+
+[docs] def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError('Adadelta does not support sparse gradients')
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ state['square_avg'] = torch.zeros_like(p.data)
+ state['acc_delta'] = torch.zeros_like(p.data)
+
+ square_avg, acc_delta = state['square_avg'], state['acc_delta']
+ rho, eps = group['rho'], group['eps']
+
+ state['step'] += 1
+
+ if group['weight_decay'] != 0:
+ grad = grad.add(group['weight_decay'], p.data)
+
+ square_avg.mul_(rho).addcmul_(1 - rho, grad, grad)
+ std = square_avg.add(eps).sqrt_()
+ delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)
+ p.data.add_(-group['lr'], delta)
+ acc_delta.mul_(rho).addcmul_(1 - rho, delta, delta)
+
+ return loss
+
+import torch
+from .optimizer import Optimizer
+
+
+[docs]class Adagrad(Optimizer):
+ """Implements Adagrad algorithm.
+
+ It has been proposed in `Adaptive Subgradient Methods for Online Learning
+ and Stochastic Optimization`_.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-2)
+ lr_decay (float, optional): learning rate decay (default: 0)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+
+ .. _Adaptive Subgradient Methods for Online Learning and Stochastic
+ Optimization: http://jmlr.org/papers/v12/duchi11a.html
+ """
+
+ def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= lr_decay:
+ raise ValueError("Invalid lr_decay value: {}".format(lr_decay))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= initial_accumulator_value:
+ raise ValueError("Invalid initial_accumulator_value value: {}".format(initial_accumulator_value))
+
+ defaults = dict(lr=lr, lr_decay=lr_decay, weight_decay=weight_decay,
+ initial_accumulator_value=initial_accumulator_value)
+ super(Adagrad, self).__init__(params, defaults)
+
+ for group in self.param_groups:
+ for p in group['params']:
+ state = self.state[p]
+ state['step'] = 0
+ state['sum'] = torch.full_like(p.data, initial_accumulator_value)
+
+ def share_memory(self):
+ for group in self.param_groups:
+ for p in group['params']:
+ state = self.state[p]
+ state['sum'].share_memory_()
+
+[docs] def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ grad = p.grad.data
+ state = self.state[p]
+
+ state['step'] += 1
+
+ if group['weight_decay'] != 0:
+ if p.grad.data.is_sparse:
+ raise RuntimeError("weight_decay option is not compatible with sparse gradients")
+ grad = grad.add(group['weight_decay'], p.data)
+
+ clr = group['lr'] / (1 + (state['step'] - 1) * group['lr_decay'])
+
+ if grad.is_sparse:
+ grad = grad.coalesce() # the update is non-linear so indices must be unique
+ grad_indices = grad._indices()
+ grad_values = grad._values()
+ size = grad.size()
+
+ def make_sparse(values):
+ constructor = grad.new
+ if grad_indices.dim() == 0 or values.dim() == 0:
+ return constructor().resize_as_(grad)
+ return constructor(grad_indices, values, size)
+ state['sum'].add_(make_sparse(grad_values.pow(2)))
+ std = state['sum'].sparse_mask(grad)
+ std_values = std._values().sqrt_().add_(1e-10)
+ p.data.add_(-clr, make_sparse(grad_values / std_values))
+ else:
+ state['sum'].addcmul_(1, grad, grad)
+ std = state['sum'].sqrt().add_(1e-10)
+ p.data.addcdiv_(-clr, grad, std)
+
+ return loss
+
+import math
+import torch
+from .optimizer import Optimizer
+
+
+[docs]class Adam(Optimizer):
+ r"""Implements Adam algorithm.
+
+ It has been proposed in `Adam: A Method for Stochastic Optimization`_.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
+ (default: False)
+
+ .. _Adam\: A Method for Stochastic Optimization:
+ https://arxiv.org/abs/1412.6980
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay, amsgrad=amsgrad)
+ super(Adam, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(Adam, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+[docs] def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
+ amsgrad = group['amsgrad']
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p.data)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p.data)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ if amsgrad:
+ max_exp_avg_sq = state['max_exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ state['step'] += 1
+
+ if group['weight_decay'] != 0:
+ grad.add_(group['weight_decay'], p.data)
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ if amsgrad:
+ # Maintains the maximum of all 2nd moment running avg. till now
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+ # Use the max. for normalizing running avg. of gradient
+ denom = max_exp_avg_sq.sqrt().add_(group['eps'])
+ else:
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+
+ bias_correction1 = 1 - beta1 ** state['step']
+ bias_correction2 = 1 - beta2 ** state['step']
+ step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
+
+ p.data.addcdiv_(-step_size, exp_avg, denom)
+
+ return loss
+
+import torch
+from .optimizer import Optimizer
+
+
+[docs]class Adamax(Optimizer):
+ """Implements Adamax algorithm (a variant of Adam based on infinity norm).
+
+ It has been proposed in `Adam: A Method for Stochastic Optimization`__.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 2e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+
+ __ https://arxiv.org/abs/1412.6980
+ """
+
+ def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
+ super(Adamax, self).__init__(params, defaults)
+
+[docs] def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError('Adamax does not support sparse gradients')
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ state['exp_avg'] = torch.zeros_like(p.data)
+ state['exp_inf'] = torch.zeros_like(p.data)
+
+ exp_avg, exp_inf = state['exp_avg'], state['exp_inf']
+ beta1, beta2 = group['betas']
+ eps = group['eps']
+
+ state['step'] += 1
+
+ if group['weight_decay'] != 0:
+ grad = grad.add(group['weight_decay'], p.data)
+
+ # Update biased first moment estimate.
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
+ # Update the exponentially weighted infinity norm.
+ norm_buf = torch.cat([
+ exp_inf.mul_(beta2).unsqueeze(0),
+ grad.abs().add_(eps).unsqueeze_(0)
+ ], 0)
+ torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long()))
+
+ bias_correction = 1 - beta1 ** state['step']
+ clr = group['lr'] / bias_correction
+
+ p.data.addcdiv_(-clr, exp_avg, exp_inf)
+
+ return loss
+
+import math
+import torch
+from .optimizer import Optimizer
+
+
+[docs]class ASGD(Optimizer):
+ """Implements Averaged Stochastic Gradient Descent.
+
+ It has been proposed in `Acceleration of stochastic approximation by
+ averaging`_.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-2)
+ lambd (float, optional): decay term (default: 1e-4)
+ alpha (float, optional): power for eta update (default: 0.75)
+ t0 (float, optional): point at which to start averaging (default: 1e6)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+
+ .. _Acceleration of stochastic approximation by averaging:
+ http://dl.acm.org/citation.cfm?id=131098
+ """
+
+ def __init__(self, params, lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+
+ defaults = dict(lr=lr, lambd=lambd, alpha=alpha, t0=t0,
+ weight_decay=weight_decay)
+ super(ASGD, self).__init__(params, defaults)
+
+[docs] def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError('ASGD does not support sparse gradients')
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ state['eta'] = group['lr']
+ state['mu'] = 1
+ state['ax'] = torch.zeros_like(p.data)
+
+ state['step'] += 1
+
+ if group['weight_decay'] != 0:
+ grad = grad.add(group['weight_decay'], p.data)
+
+ # decay term
+ p.data.mul_(1 - group['lambd'] * state['eta'])
+
+ # update parameter
+ p.data.add_(-state['eta'], grad)
+
+ # averaging
+ if state['mu'] != 1:
+ state['ax'].add_(p.data.sub(state['ax']).mul(state['mu']))
+ else:
+ state['ax'].copy_(p.data)
+
+ # update eta and mu
+ state['eta'] = (group['lr'] /
+ math.pow((1 + group['lambd'] * group['lr'] * state['step']), group['alpha']))
+ state['mu'] = 1 / max(1, state['step'] - group['t0'])
+
+ return loss
+
+import torch
+from functools import reduce
+from .optimizer import Optimizer
+
+
+[docs]class LBFGS(Optimizer):
+ """Implements L-BFGS algorithm.
+
+ .. warning::
+ This optimizer doesn't support per-parameter options and parameter
+ groups (there can be only one).
+
+ .. warning::
+ Right now all parameters have to be on a single device. This will be
+ improved in the future.
+
+ .. note::
+ This is a very memory intensive optimizer (it requires additional
+ ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
+ try reducing the history size, or use a different algorithm.
+
+ Arguments:
+ lr (float): learning rate (default: 1)
+ max_iter (int): maximal number of iterations per optimization step
+ (default: 20)
+ max_eval (int): maximal number of function evaluations per optimization
+ step (default: max_iter * 1.25).
+ tolerance_grad (float): termination tolerance on first order optimality
+ (default: 1e-5).
+ tolerance_change (float): termination tolerance on function
+ value/parameter changes (default: 1e-9).
+ history_size (int): update history size (default: 100).
+ """
+
+ def __init__(self, params, lr=1, max_iter=20, max_eval=None,
+ tolerance_grad=1e-5, tolerance_change=1e-9, history_size=100,
+ line_search_fn=None):
+ if max_eval is None:
+ max_eval = max_iter * 5 // 4
+ defaults = dict(lr=lr, max_iter=max_iter, max_eval=max_eval,
+ tolerance_grad=tolerance_grad, tolerance_change=tolerance_change,
+ history_size=history_size, line_search_fn=line_search_fn)
+ super(LBFGS, self).__init__(params, defaults)
+
+ if len(self.param_groups) != 1:
+ raise ValueError("LBFGS doesn't support per-parameter options "
+ "(parameter groups)")
+
+ self._params = self.param_groups[0]['params']
+ self._numel_cache = None
+
+ def _numel(self):
+ if self._numel_cache is None:
+ self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
+ return self._numel_cache
+
+ def _gather_flat_grad(self):
+ views = []
+ for p in self._params:
+ if p.grad is None:
+ view = p.data.new(p.data.numel()).zero_()
+ elif p.grad.data.is_sparse:
+ view = p.grad.data.to_dense().view(-1)
+ else:
+ view = p.grad.data.view(-1)
+ views.append(view)
+ return torch.cat(views, 0)
+
+ def _add_grad(self, step_size, update):
+ offset = 0
+ for p in self._params:
+ numel = p.numel()
+ # view as to avoid deprecated pointwise semantics
+ p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
+ offset += numel
+ assert offset == self._numel()
+
+[docs] def step(self, closure):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable): A closure that reevaluates the model
+ and returns the loss.
+ """
+ assert len(self.param_groups) == 1
+
+ group = self.param_groups[0]
+ lr = group['lr']
+ max_iter = group['max_iter']
+ max_eval = group['max_eval']
+ tolerance_grad = group['tolerance_grad']
+ tolerance_change = group['tolerance_change']
+ line_search_fn = group['line_search_fn']
+ history_size = group['history_size']
+
+ # NOTE: LBFGS has only global state, but we register it as state for
+ # the first param, because this helps with casting in load_state_dict
+ state = self.state[self._params[0]]
+ state.setdefault('func_evals', 0)
+ state.setdefault('n_iter', 0)
+
+ # evaluate initial f(x) and df/dx
+ orig_loss = closure()
+ loss = float(orig_loss)
+ current_evals = 1
+ state['func_evals'] += 1
+
+ flat_grad = self._gather_flat_grad()
+ abs_grad_sum = flat_grad.abs().sum()
+
+ if abs_grad_sum <= tolerance_grad:
+ return orig_loss
+
+ # tensors cached in state (for tracing)
+ d = state.get('d')
+ t = state.get('t')
+ old_dirs = state.get('old_dirs')
+ old_stps = state.get('old_stps')
+ H_diag = state.get('H_diag')
+ prev_flat_grad = state.get('prev_flat_grad')
+ prev_loss = state.get('prev_loss')
+
+ n_iter = 0
+ # optimize for a max of max_iter iterations
+ while n_iter < max_iter:
+ # keep track of nb of iterations
+ n_iter += 1
+ state['n_iter'] += 1
+
+ ############################################################
+ # compute gradient descent direction
+ ############################################################
+ if state['n_iter'] == 1:
+ d = flat_grad.neg()
+ old_dirs = []
+ old_stps = []
+ H_diag = 1
+ else:
+ # do lbfgs update (update memory)
+ y = flat_grad.sub(prev_flat_grad)
+ s = d.mul(t)
+ ys = y.dot(s) # y*s
+ if ys > 1e-10:
+ # updating memory
+ if len(old_dirs) == history_size:
+ # shift history by one (limited-memory)
+ old_dirs.pop(0)
+ old_stps.pop(0)
+
+ # store new direction/step
+ old_dirs.append(y)
+ old_stps.append(s)
+
+ # update scale of initial Hessian approximation
+ H_diag = ys / y.dot(y) # (y*y)
+
+ # compute the approximate (L-BFGS) inverse Hessian
+ # multiplied by the gradient
+ num_old = len(old_dirs)
+
+ if 'ro' not in state:
+ state['ro'] = [None] * history_size
+ state['al'] = [None] * history_size
+ ro = state['ro']
+ al = state['al']
+
+ for i in range(num_old):
+ ro[i] = 1. / old_dirs[i].dot(old_stps[i])
+
+ # iteration in L-BFGS loop collapsed to use just one buffer
+ q = flat_grad.neg()
+ for i in range(num_old - 1, -1, -1):
+ al[i] = old_stps[i].dot(q) * ro[i]
+ q.add_(-al[i], old_dirs[i])
+
+ # multiply by initial Hessian
+ # r/d is the final direction
+ d = r = torch.mul(q, H_diag)
+ for i in range(num_old):
+ be_i = old_dirs[i].dot(r) * ro[i]
+ r.add_(al[i] - be_i, old_stps[i])
+
+ if prev_flat_grad is None:
+ prev_flat_grad = flat_grad.clone()
+ else:
+ prev_flat_grad.copy_(flat_grad)
+ prev_loss = loss
+
+ ############################################################
+ # compute step length
+ ############################################################
+ # reset initial guess for step size
+ if state['n_iter'] == 1:
+ t = min(1., 1. / abs_grad_sum) * lr
+ else:
+ t = lr
+
+ # directional derivative
+ gtd = flat_grad.dot(d) # g * d
+
+ # optional line search: user function
+ ls_func_evals = 0
+ if line_search_fn is not None:
+ # perform line search, using user function
+ raise RuntimeError("line search function is not supported yet")
+ else:
+ # no line search, simply move with fixed-step
+ self._add_grad(t, d)
+ if n_iter != max_iter:
+ # re-evaluate function only if not in last iteration
+ # the reason we do this: in a stochastic setting,
+ # no use to re-evaluate that function here
+ loss = float(closure())
+ flat_grad = self._gather_flat_grad()
+ abs_grad_sum = flat_grad.abs().sum()
+ ls_func_evals = 1
+
+ # update func eval
+ current_evals += ls_func_evals
+ state['func_evals'] += ls_func_evals
+
+ ############################################################
+ # check conditions
+ ############################################################
+ if n_iter == max_iter:
+ break
+
+ if current_evals >= max_eval:
+ break
+
+ if abs_grad_sum <= tolerance_grad:
+ break
+
+ if gtd > -tolerance_change:
+ break
+
+ if d.mul(t).abs_().sum() <= tolerance_change:
+ break
+
+ if abs(loss - prev_loss) < tolerance_change:
+ break
+
+ state['d'] = d
+ state['t'] = t
+ state['old_dirs'] = old_dirs
+ state['old_stps'] = old_stps
+ state['H_diag'] = H_diag
+ state['prev_flat_grad'] = prev_flat_grad
+ state['prev_loss'] = prev_loss
+
+ return orig_loss
+
+import types
+import math
+from torch._six import inf
+from collections import Counter
+from functools import partial
+
+from .optimizer import Optimizer
+
+
+class _LRScheduler(object):
+ def __init__(self, optimizer, last_epoch=-1):
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError('{} is not an Optimizer'.format(
+ type(optimizer).__name__))
+ self.optimizer = optimizer
+ if last_epoch == -1:
+ for group in optimizer.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ last_epoch = 0
+ else:
+ for i, group in enumerate(optimizer.param_groups):
+ if 'initial_lr' not in group:
+ raise KeyError("param 'initial_lr' is not specified "
+ "in param_groups[{}] when resuming an optimizer".format(i))
+ self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
+ self.step(last_epoch)
+
+ def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ """
+ return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
+
+ def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+
+ Arguments:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ self.__dict__.update(state_dict)
+
+ def get_lr(self):
+ raise NotImplementedError
+
+ def step(self, epoch=None):
+ if epoch is None:
+ epoch = self.last_epoch + 1
+ self.last_epoch = epoch
+ for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
+ param_group['lr'] = lr
+
+
+[docs]class LambdaLR(_LRScheduler):
+ """Sets the learning rate of each parameter group to the initial lr
+ times a given function. When last_epoch=-1, sets initial lr as lr.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ lr_lambda (function or list): A function which computes a multiplicative
+ factor given an integer parameter epoch, or a list of such
+ functions, one for each group in optimizer.param_groups.
+ last_epoch (int): The index of last epoch. Default: -1.
+
+ Example:
+ >>> # Assuming optimizer has two groups.
+ >>> lambda1 = lambda epoch: epoch // 30
+ >>> lambda2 = lambda epoch: 0.95 ** epoch
+ >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
+ >>> for epoch in range(100):
+ >>> train(...)
+ >>> validate(...)
+ >>> scheduler.step()
+ """
+
+ def __init__(self, optimizer, lr_lambda, last_epoch=-1):
+ self.optimizer = optimizer
+ if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
+ self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
+ else:
+ if len(lr_lambda) != len(optimizer.param_groups):
+ raise ValueError("Expected {} lr_lambdas, but got {}".format(
+ len(optimizer.param_groups), len(lr_lambda)))
+ self.lr_lambdas = list(lr_lambda)
+ self.last_epoch = last_epoch
+ super(LambdaLR, self).__init__(optimizer, last_epoch)
+
+[docs] def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ The learning rate lambda functions will only be saved if they are callable objects
+ and not if they are functions or lambdas.
+ """
+ state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
+ state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
+
+ for idx, fn in enumerate(self.lr_lambdas):
+ if not isinstance(fn, types.FunctionType):
+ state_dict['lr_lambdas'][idx] = fn.__dict__.copy()
+
+ return state_dict
+
+[docs] def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+
+ Arguments:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ lr_lambdas = state_dict.pop('lr_lambdas')
+ self.__dict__.update(state_dict)
+
+ for idx, fn in enumerate(lr_lambdas):
+ if fn is not None:
+ self.lr_lambdas[idx].__dict__.update(fn)
+
+ def get_lr(self):
+ return [base_lr * lmbda(self.last_epoch)
+ for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
+
+
+[docs]class StepLR(_LRScheduler):
+ """Decays the learning rate of each parameter group by gamma every
+ step_size epochs. Notice that such decay can happen simultaneously with
+ other changes to the learning rate from outside this scheduler. When
+ last_epoch=-1, sets initial lr as lr.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ step_size (int): Period of learning rate decay.
+ gamma (float): Multiplicative factor of learning rate decay.
+ Default: 0.1.
+ last_epoch (int): The index of last epoch. Default: -1.
+
+ Example:
+ >>> # Assuming optimizer uses lr = 0.05 for all groups
+ >>> # lr = 0.05 if epoch < 30
+ >>> # lr = 0.005 if 30 <= epoch < 60
+ >>> # lr = 0.0005 if 60 <= epoch < 90
+ >>> # ...
+ >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
+ >>> for epoch in range(100):
+ >>> train(...)
+ >>> validate(...)
+ >>> scheduler.step()
+ """
+
+ def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
+ self.step_size = step_size
+ self.gamma = gamma
+ super(StepLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
+ return [group['lr'] for group in self.optimizer.param_groups]
+ return [group['lr'] * self.gamma
+ for group in self.optimizer.param_groups]
+
+
+[docs]class MultiStepLR(_LRScheduler):
+ """Decays the learning rate of each parameter group by gamma once the
+ number of epoch reaches one of the milestones. Notice that such decay can
+ happen simultaneously with other changes to the learning rate from outside
+ this scheduler. When last_epoch=-1, sets initial lr as lr.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ milestones (list): List of epoch indices. Must be increasing.
+ gamma (float): Multiplicative factor of learning rate decay.
+ Default: 0.1.
+ last_epoch (int): The index of last epoch. Default: -1.
+
+ Example:
+ >>> # Assuming optimizer uses lr = 0.05 for all groups
+ >>> # lr = 0.05 if epoch < 30
+ >>> # lr = 0.005 if 30 <= epoch < 80
+ >>> # lr = 0.0005 if epoch >= 80
+ >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
+ >>> for epoch in range(100):
+ >>> train(...)
+ >>> validate(...)
+ >>> scheduler.step()
+ """
+
+ def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
+ self.milestones = Counter(milestones)
+ self.gamma = gamma
+ super(MultiStepLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch not in self.milestones:
+ return [group['lr'] for group in self.optimizer.param_groups]
+ return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
+ for group in self.optimizer.param_groups]
+
+
+[docs]class ExponentialLR(_LRScheduler):
+ """Decays the learning rate of each parameter group by gamma every epoch.
+ When last_epoch=-1, sets initial lr as lr.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ gamma (float): Multiplicative factor of learning rate decay.
+ last_epoch (int): The index of last epoch. Default: -1.
+ """
+
+ def __init__(self, optimizer, gamma, last_epoch=-1):
+ self.gamma = gamma
+ super(ExponentialLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch == 0:
+ return self.base_lrs
+ return [group['lr'] * self.gamma
+ for group in self.optimizer.param_groups]
+
+
+[docs]class CosineAnnealingLR(_LRScheduler):
+ r"""Set the learning rate of each parameter group using a cosine annealing
+ schedule, where :math:`\eta_{max}` is set to the initial lr and
+ :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
+
+ .. math::
+ \eta_{t+1} = \eta_{min} + (\eta_t - \eta_{min})\frac{1 +
+ \cos(\frac{T_{cur+1}}{T_{max}}\pi)}{1 + \cos(\frac{T_{cur}}{T_{max}}\pi)},
+ T_{cur} \neq (2k+1)T_{max};\\
+ \eta_{t+1} = \eta_{t} + (\eta_{max} - \eta_{min})\frac{1 -
+ \cos(\frac{1}{T_{max}}\pi)}{2},
+ T_{cur} = (2k+1)T_{max}.\\
+
+ When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
+ is defined recursively, the learning rate can be simultaneously modified
+ outside this scheduler by other operators. If the learning rate is set
+ solely by this scheduler, the learning rate at each step becomes:
+
+ .. math::
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
+ \cos(\frac{T_{cur}}{T_{max}}\pi))
+
+ It has been proposed in
+ `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
+ implements the cosine annealing part of SGDR, and not the restarts.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ T_max (int): Maximum number of iterations.
+ eta_min (float): Minimum learning rate. Default: 0.
+ last_epoch (int): The index of last epoch. Default: -1.
+
+ .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
+ https://arxiv.org/abs/1608.03983
+ """
+
+ def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
+ self.T_max = T_max
+ self.eta_min = eta_min
+ super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch == 0:
+ return self.base_lrs
+ elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
+ return [group['lr'] + (base_lr - self.eta_min) *
+ (1 - math.cos(math.pi / self.T_max)) / 2
+ for base_lr, group in
+ zip(self.base_lrs, self.optimizer.param_groups)]
+ return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
+ (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
+ (group['lr'] - self.eta_min) + self.eta_min
+ for group in self.optimizer.param_groups]
+
+
+[docs]class ReduceLROnPlateau(object):
+ """Reduce learning rate when a metric has stopped improving.
+ Models often benefit from reducing the learning rate by a factor
+ of 2-10 once learning stagnates. This scheduler reads a metrics
+ quantity and if no improvement is seen for a 'patience' number
+ of epochs, the learning rate is reduced.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ mode (str): One of `min`, `max`. In `min` mode, lr will
+ be reduced when the quantity monitored has stopped
+ decreasing; in `max` mode it will be reduced when the
+ quantity monitored has stopped increasing. Default: 'min'.
+ factor (float): Factor by which the learning rate will be
+ reduced. new_lr = lr * factor. Default: 0.1.
+ patience (int): Number of epochs with no improvement after
+ which learning rate will be reduced. For example, if
+ `patience = 2`, then we will ignore the first 2 epochs
+ with no improvement, and will only decrease the LR after the
+ 3rd epoch if the loss still hasn't improved then.
+ Default: 10.
+ verbose (bool): If ``True``, prints a message to stdout for
+ each update. Default: ``False``.
+ threshold (float): Threshold for measuring the new optimum,
+ to only focus on significant changes. Default: 1e-4.
+ threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
+ dynamic_threshold = best * ( 1 + threshold ) in 'max'
+ mode or best * ( 1 - threshold ) in `min` mode.
+ In `abs` mode, dynamic_threshold = best + threshold in
+ `max` mode or best - threshold in `min` mode. Default: 'rel'.
+ cooldown (int): Number of epochs to wait before resuming
+ normal operation after lr has been reduced. Default: 0.
+ min_lr (float or list): A scalar or a list of scalars. A
+ lower bound on the learning rate of all param groups
+ or each group respectively. Default: 0.
+ eps (float): Minimal decay applied to lr. If the difference
+ between new and old lr is smaller than eps, the update is
+ ignored. Default: 1e-8.
+
+ Example:
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+ >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
+ >>> for epoch in range(10):
+ >>> train(...)
+ >>> val_loss = validate(...)
+ >>> # Note that step should be called after validate()
+ >>> scheduler.step(val_loss)
+ """
+
+ def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
+ verbose=False, threshold=1e-4, threshold_mode='rel',
+ cooldown=0, min_lr=0, eps=1e-8):
+
+ if factor >= 1.0:
+ raise ValueError('Factor should be < 1.0.')
+ self.factor = factor
+
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError('{} is not an Optimizer'.format(
+ type(optimizer).__name__))
+ self.optimizer = optimizer
+
+ if isinstance(min_lr, list) or isinstance(min_lr, tuple):
+ if len(min_lr) != len(optimizer.param_groups):
+ raise ValueError("expected {} min_lrs, got {}".format(
+ len(optimizer.param_groups), len(min_lr)))
+ self.min_lrs = list(min_lr)
+ else:
+ self.min_lrs = [min_lr] * len(optimizer.param_groups)
+
+ self.patience = patience
+ self.verbose = verbose
+ self.cooldown = cooldown
+ self.cooldown_counter = 0
+ self.mode = mode
+ self.threshold = threshold
+ self.threshold_mode = threshold_mode
+ self.best = None
+ self.num_bad_epochs = None
+ self.mode_worse = None # the worse value for the chosen mode
+ self.is_better = None
+ self.eps = eps
+ self.last_epoch = -1
+ self._init_is_better(mode=mode, threshold=threshold,
+ threshold_mode=threshold_mode)
+ self._reset()
+
+ def _reset(self):
+ """Resets num_bad_epochs counter and cooldown counter."""
+ self.best = self.mode_worse
+ self.cooldown_counter = 0
+ self.num_bad_epochs = 0
+
+ def step(self, metrics, epoch=None):
+ # convert `metrics` to float, in case it's a zero-dim Tensor
+ current = float(metrics)
+ if epoch is None:
+ epoch = self.last_epoch = self.last_epoch + 1
+ self.last_epoch = epoch
+
+ if self.is_better(current, self.best):
+ self.best = current
+ self.num_bad_epochs = 0
+ else:
+ self.num_bad_epochs += 1
+
+ if self.in_cooldown:
+ self.cooldown_counter -= 1
+ self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
+
+ if self.num_bad_epochs > self.patience:
+ self._reduce_lr(epoch)
+ self.cooldown_counter = self.cooldown
+ self.num_bad_epochs = 0
+
+ def _reduce_lr(self, epoch):
+ for i, param_group in enumerate(self.optimizer.param_groups):
+ old_lr = float(param_group['lr'])
+ new_lr = max(old_lr * self.factor, self.min_lrs[i])
+ if old_lr - new_lr > self.eps:
+ param_group['lr'] = new_lr
+ if self.verbose:
+ print('Epoch {:5d}: reducing learning rate'
+ ' of group {} to {:.4e}.'.format(epoch, i, new_lr))
+
+ @property
+ def in_cooldown(self):
+ return self.cooldown_counter > 0
+
+ def _cmp(self, mode, threshold_mode, threshold, a, best):
+ if mode == 'min' and threshold_mode == 'rel':
+ rel_epsilon = 1. - threshold
+ return a < best * rel_epsilon
+
+ elif mode == 'min' and threshold_mode == 'abs':
+ return a < best - threshold
+
+ elif mode == 'max' and threshold_mode == 'rel':
+ rel_epsilon = threshold + 1.
+ return a > best * rel_epsilon
+
+ else: # mode == 'max' and epsilon_mode == 'abs':
+ return a > best + threshold
+
+ def _init_is_better(self, mode, threshold, threshold_mode):
+ if mode not in {'min', 'max'}:
+ raise ValueError('mode ' + mode + ' is unknown!')
+ if threshold_mode not in {'rel', 'abs'}:
+ raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
+
+ if mode == 'min':
+ self.mode_worse = inf
+ else: # mode == 'max':
+ self.mode_worse = -inf
+
+ self.is_better = partial(self._cmp, mode, threshold_mode, threshold)
+
+ def state_dict(self):
+ return {key: value for key, value in self.__dict__.items() if key not in {'optimizer', 'is_better'}}
+
+ def load_state_dict(self, state_dict):
+ self.__dict__.update(state_dict)
+ self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
+
+
+[docs]class CyclicLR(_LRScheduler):
+ """Sets the learning rate of each parameter group according to
+ cyclical learning rate policy (CLR). The policy cycles the learning
+ rate between two boundaries with a constant frequency, as detailed in
+ the paper `Cyclical Learning Rates for Training Neural Networks`_.
+ The distance between the two boundaries can be scaled on a per-iteration
+ or per-cycle basis.
+
+ Cyclical learning rate policy changes the learning rate after every batch.
+ `step` should be called after a batch has been used for training.
+
+ This class has three built-in policies, as put forth in the paper:
+ "triangular":
+ A basic triangular cycle w/ no amplitude scaling.
+ "triangular2":
+ A basic triangular cycle that scales initial amplitude by half each cycle.
+ "exp_range":
+ A cycle that scales initial amplitude by gamma**(cycle iterations) at each
+ cycle iteration.
+
+ This implementation was adapted from the github repo: `bckenstler/CLR`_
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ base_lr (float or list): Initial learning rate which is the
+ lower boundary in the cycle for each parameter group.
+ max_lr (float or list): Upper learning rate boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (max_lr - base_lr).
+ The lr at any cycle is the sum of base_lr
+ and some scaling of the amplitude; therefore
+ max_lr may not actually be reached depending on
+ scaling function.
+ step_size_up (int): Number of training iterations in the
+ increasing half of a cycle. Default: 2000
+ step_size_down (int): Number of training iterations in the
+ decreasing half of a cycle. If step_size_down is None,
+ it is set to step_size_up. Default: None
+ mode (str): One of {triangular, triangular2, exp_range}.
+ Values correspond to policies detailed above.
+ If scale_fn is not None, this argument is ignored.
+ Default: 'triangular'
+ gamma (float): Constant in 'exp_range' scaling function:
+ gamma**(cycle iterations)
+ Default: 1.0
+ scale_fn (function): Custom scaling policy defined by a single
+ argument lambda function, where
+ 0 <= scale_fn(x) <= 1 for all x >= 0.
+ If specified, then 'mode' is ignored.
+ Default: None
+ scale_mode (str): {'cycle', 'iterations'}.
+ Defines whether scale_fn is evaluated on
+ cycle number or cycle iterations (training
+ iterations since start of cycle).
+ Default: 'cycle'
+ cycle_momentum (bool): If ``True``, momentum is cycled inversely
+ to learning rate between 'base_momentum' and 'max_momentum'.
+ Default: True
+ base_momentum (float or list): Initial momentum which is the
+ lower boundary in the cycle for each parameter group.
+ Default: 0.8
+ max_momentum (float or list): Upper momentum boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (max_momentum - base_momentum).
+ The momentum at any cycle is the difference of max_momentum
+ and some scaling of the amplitude; therefore
+ base_momentum may not actually be reached depending on
+ scaling function. Default: 0.9
+ last_epoch (int): The index of the last batch. This parameter is used when
+ resuming a training job. Since `step()` should be invoked after each
+ batch instead of after each epoch, this number represents the total
+ number of *batches* computed, not the total number of epochs computed.
+ When last_epoch=-1, the schedule is started from the beginning.
+ Default: -1
+
+ Example:
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+ >>> scheduler = torch.optim.CyclicLR(optimizer)
+ >>> data_loader = torch.utils.data.DataLoader(...)
+ >>> for epoch in range(10):
+ >>> for batch in data_loader:
+ >>> train_batch(...)
+ >>> scheduler.step()
+
+
+ .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
+ .. _bckenstler/CLR: https://github.com/bckenstler/CLR
+ """
+
+ def __init__(self,
+ optimizer,
+ base_lr,
+ max_lr,
+ step_size_up=2000,
+ step_size_down=None,
+ mode='triangular',
+ gamma=1.,
+ scale_fn=None,
+ scale_mode='cycle',
+ cycle_momentum=True,
+ base_momentum=0.8,
+ max_momentum=0.9,
+ last_epoch=-1):
+
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError('{} is not an Optimizer'.format(
+ type(optimizer).__name__))
+ self.optimizer = optimizer
+
+ base_lrs = self._format_param('base_lr', optimizer, base_lr)
+ if last_epoch == -1:
+ for lr, group in zip(base_lrs, optimizer.param_groups):
+ group['lr'] = lr
+
+ self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
+
+ step_size_up = float(step_size_up)
+ step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
+ self.total_size = step_size_up + step_size_down
+ self.step_ratio = step_size_up / self.total_size
+
+ if mode not in ['triangular', 'triangular2', 'exp_range'] \
+ and scale_fn is None:
+ raise ValueError('mode is invalid and scale_fn is None')
+
+ self.mode = mode
+ self.gamma = gamma
+
+ if scale_fn is None:
+ if self.mode == 'triangular':
+ self.scale_fn = self._triangular_scale_fn
+ self.scale_mode = 'cycle'
+ elif self.mode == 'triangular2':
+ self.scale_fn = self._triangular2_scale_fn
+ self.scale_mode = 'cycle'
+ elif self.mode == 'exp_range':
+ self.scale_fn = self._exp_range_scale_fn
+ self.scale_mode = 'iterations'
+ else:
+ self.scale_fn = scale_fn
+ self.scale_mode = scale_mode
+
+ self.cycle_momentum = cycle_momentum
+ if cycle_momentum:
+ if 'momentum' not in optimizer.defaults:
+ raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
+
+ base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
+ if last_epoch == -1:
+ for momentum, group in zip(base_momentums, optimizer.param_groups):
+ group['momentum'] = momentum
+ self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups))
+ self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
+
+ super(CyclicLR, self).__init__(optimizer, last_epoch)
+
+ def _format_param(self, name, optimizer, param):
+ """Return correctly formatted lr/momentum for each param group."""
+ if isinstance(param, (list, tuple)):
+ if len(param) != len(optimizer.param_groups):
+ raise ValueError("expected {} values for {}, got {}".format(
+ len(optimizer.param_groups), name, len(param)))
+ return param
+ else:
+ return [param] * len(optimizer.param_groups)
+
+ def _triangular_scale_fn(self, x):
+ return 1.
+
+ def _triangular2_scale_fn(self, x):
+ return 1 / (2. ** (x - 1))
+
+ def _exp_range_scale_fn(self, x):
+ return self.gamma**(x)
+
+[docs] def get_lr(self):
+ """Calculates the learning rate at batch index. This function treats
+ `self.last_epoch` as the last batch index.
+
+ If `self.cycle_momentum` is ``True``, this function has a side effect of
+ updating the optimizer's momentum.
+ """
+ cycle = math.floor(1 + self.last_epoch / self.total_size)
+ x = 1. + self.last_epoch / self.total_size - cycle
+ if x <= self.step_ratio:
+ scale_factor = x / self.step_ratio
+ else:
+ scale_factor = (x - 1) / (self.step_ratio - 1)
+
+ lrs = []
+ for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
+ base_height = (max_lr - base_lr) * scale_factor
+ if self.scale_mode == 'cycle':
+ lr = base_lr + base_height * self.scale_fn(cycle)
+ else:
+ lr = base_lr + base_height * self.scale_fn(self.last_epoch)
+ lrs.append(lr)
+
+ if self.cycle_momentum:
+ momentums = []
+ for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
+ base_height = (max_momentum - base_momentum) * scale_factor
+ if self.scale_mode == 'cycle':
+ momentum = max_momentum - base_height * self.scale_fn(cycle)
+ else:
+ momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
+ momentums.append(momentum)
+ for param_group, momentum in zip(self.optimizer.param_groups, momentums):
+ param_group['momentum'] = momentum
+
+ return lrs
+
+
+class CosineAnnealingWarmRestarts(_LRScheduler):
+ r"""Set the learning rate of each parameter group using a cosine annealing
+ schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
+ is the number of epochs since the last restart and :math:`T_{i}` is the number
+ of epochs between two warm restarts in SGDR:
+
+ .. math::
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
+ \cos(\frac{T_{cur}}{T_{i}}\pi))
+
+ When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
+ When :math:`T_{cur}=0`(after restart), set :math:`\eta_t=\eta_{max}`.
+
+ It has been proposed in
+ `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ T_0 (int): Number of iterations for the first restart.
+ T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
+ eta_min (float, optional): Minimum learning rate. Default: 0.
+ last_epoch (int, optional): The index of last epoch. Default: -1.
+
+ .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
+ https://arxiv.org/abs/1608.03983
+ """
+
+ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
+ if T_0 <= 0 or not isinstance(T_0, int):
+ raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
+ if T_mult < 1 or not isinstance(T_mult, int):
+ raise ValueError("Expected integer T_mul >= 1, but got {}".format(T_mul))
+ self.T_0 = T_0
+ self.T_i = T_0
+ self.T_mult = T_mult
+ self.eta_min = eta_min
+ super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
+ self.T_cur = last_epoch
+
+ def get_lr(self):
+ return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
+ for base_lr in self.base_lrs]
+
+ def step(self, epoch=None):
+ """Step could be called after every update, i.e. if one epoch has 10 iterations
+ (number_of_train_examples / batch_size), we should call SGDR.step(0.1), SGDR.step(0.2), etc.
+
+ This function can be called in an interleaved way.
+
+ Example:
+ >>> scheduler = SGDR(optimizer, T_0, T_mult)
+ >>> for epoch in range(20):
+ >>> scheduler.step()
+ >>> scheduler.step(26)
+ >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
+ """
+ if epoch is None:
+ epoch = self.last_epoch + 1
+ self.T_cur = self.T_cur + 1
+ if self.T_cur >= self.T_i:
+ self.T_cur = self.T_cur - self.T_i
+ self.T_i = self.T_i * self.T_mult
+ else:
+ if epoch >= self.T_0:
+ if self.T_mult == 1:
+ self.T_cur = epoch % self.T_0
+ else:
+ n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
+ self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
+ self.T_i = self.T_0 * self.T_mult ** (n)
+ else:
+ self.T_i = self.T_0
+ self.T_cur = epoch
+ self.last_epoch = math.floor(epoch)
+ for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
+ param_group['lr'] = lr
+
+from collections import defaultdict
+from torch._six import container_abcs
+
+import torch
+from copy import deepcopy
+from itertools import chain
+
+
+class _RequiredParameter(object):
+ """Singleton class representing a required parameter for an Optimizer."""
+ def __repr__(self):
+ return "<required parameter>"
+
+required = _RequiredParameter()
+
+
+[docs]class Optimizer(object):
+ r"""Base class for all optimizers.
+
+ .. warning::
+ Parameters need to be specified as collections that have a deterministic
+ ordering that is consistent between runs. Examples of objects that don't
+ satisfy those properties are sets and iterators over values of dictionaries.
+
+ Arguments:
+ params (iterable): an iterable of :class:`torch.Tensor` s or
+ :class:`dict` s. Specifies what Tensors should be optimized.
+ defaults: (dict): a dict containing default values of optimization
+ options (used when a parameter group doesn't specify them).
+ """
+
+ def __init__(self, params, defaults):
+ self.defaults = defaults
+
+ if isinstance(params, torch.Tensor):
+ raise TypeError("params argument given to the optimizer should be "
+ "an iterable of Tensors or dicts, but got " +
+ torch.typename(params))
+
+ self.state = defaultdict(dict)
+ self.param_groups = []
+
+ param_groups = list(params)
+ if len(param_groups) == 0:
+ raise ValueError("optimizer got an empty parameter list")
+ if not isinstance(param_groups[0], dict):
+ param_groups = [{'params': param_groups}]
+
+ for param_group in param_groups:
+ self.add_param_group(param_group)
+
+ def __getstate__(self):
+ return {
+ 'defaults': self.defaults,
+ 'state': self.state,
+ 'param_groups': self.param_groups,
+ }
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + ' ('
+ for i, group in enumerate(self.param_groups):
+ format_string += '\n'
+ format_string += 'Parameter Group {0}\n'.format(i)
+ for key in sorted(group.keys()):
+ if key != 'params':
+ format_string += ' {0}: {1}\n'.format(key, group[key])
+ format_string += ')'
+ return format_string
+
+[docs] def state_dict(self):
+ r"""Returns the state of the optimizer as a :class:`dict`.
+
+ It contains two entries:
+
+ * state - a dict holding current optimization state. Its content
+ differs between optimizer classes.
+ * param_groups - a dict containing all parameter groups
+ """
+ # Save ids instead of Tensors
+ def pack_group(group):
+ packed = {k: v for k, v in group.items() if k != 'params'}
+ packed['params'] = [id(p) for p in group['params']]
+ return packed
+ param_groups = [pack_group(g) for g in self.param_groups]
+ # Remap state to use ids as keys
+ packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
+ for k, v in self.state.items()}
+ return {
+ 'state': packed_state,
+ 'param_groups': param_groups,
+ }
+
+[docs] def load_state_dict(self, state_dict):
+ r"""Loads the optimizer state.
+
+ Arguments:
+ state_dict (dict): optimizer state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+ state_dict = deepcopy(state_dict)
+ # Validate the state_dict
+ groups = self.param_groups
+ saved_groups = state_dict['param_groups']
+
+ if len(groups) != len(saved_groups):
+ raise ValueError("loaded state dict has a different number of "
+ "parameter groups")
+ param_lens = (len(g['params']) for g in groups)
+ saved_lens = (len(g['params']) for g in saved_groups)
+ if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
+ raise ValueError("loaded state dict contains a parameter group "
+ "that doesn't match the size of optimizer's group")
+
+ # Update the state
+ id_map = {old_id: p for old_id, p in
+ zip(chain(*(g['params'] for g in saved_groups)),
+ chain(*(g['params'] for g in groups)))}
+
+ def cast(param, value):
+ r"""Make a deep copy of value, casting all tensors to device of param."""
+ if isinstance(value, torch.Tensor):
+ # Floating-point types are a bit special here. They are the only ones
+ # that are assumed to always match the type of params.
+ if param.is_floating_point():
+ value = value.to(param.dtype)
+ value = value.to(param.device)
+ return value
+ elif isinstance(value, dict):
+ return {k: cast(param, v) for k, v in value.items()}
+ elif isinstance(value, container_abcs.Iterable):
+ return type(value)(cast(param, v) for v in value)
+ else:
+ return value
+
+ # Copy state assigned to params (and cast tensors to appropriate types).
+ # State that is not assigned to params is copied as is (needed for
+ # backward compatibility).
+ state = defaultdict(dict)
+ for k, v in state_dict['state'].items():
+ if k in id_map:
+ param = id_map[k]
+ state[param] = cast(param, v)
+ else:
+ state[k] = v
+
+ # Update parameter groups, setting their 'params' value
+ def update_group(group, new_group):
+ new_group['params'] = group['params']
+ return new_group
+ param_groups = [
+ update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ self.__setstate__({'state': state, 'param_groups': param_groups})
+
+[docs] def zero_grad(self):
+ r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is not None:
+ p.grad.detach_()
+ p.grad.zero_()
+
+[docs] def step(self, closure):
+ r"""Performs a single optimization step (parameter update).
+
+ Arguments:
+ closure (callable): A closure that reevaluates the model and
+ returns the loss. Optional for most optimizers.
+ """
+ raise NotImplementedError
+
+[docs] def add_param_group(self, param_group):
+ r"""Add a param group to the :class:`Optimizer` s `param_groups`.
+
+ This can be useful when fine tuning a pre-trained network as frozen layers can be made
+ trainable and added to the :class:`Optimizer` as training progresses.
+
+ Arguments:
+ param_group (dict): Specifies what Tensors should be optimized along with group
+ specific optimization options.
+ """
+ assert isinstance(param_group, dict), "param group must be a dict"
+
+ params = param_group['params']
+ if isinstance(params, torch.Tensor):
+ param_group['params'] = [params]
+ elif isinstance(params, set):
+ raise TypeError('optimizer parameters need to be organized in ordered collections, but '
+ 'the ordering of tensors in sets will change between runs. Please use a list instead.')
+ else:
+ param_group['params'] = list(params)
+
+ for param in param_group['params']:
+ if not isinstance(param, torch.Tensor):
+ raise TypeError("optimizer can only optimize Tensors, "
+ "but one of the params is " + torch.typename(param))
+ if not param.is_leaf:
+ raise ValueError("can't optimize a non-leaf Tensor")
+
+ for name, default in self.defaults.items():
+ if default is required and name not in param_group:
+ raise ValueError("parameter group didn't specify a value of required optimization parameter " +
+ name)
+ else:
+ param_group.setdefault(name, default)
+
+ param_set = set()
+ for group in self.param_groups:
+ param_set.update(set(group['params']))
+
+ if not param_set.isdisjoint(set(param_group['params'])):
+ raise ValueError("some parameters appear in more than one parameter group")
+
+ self.param_groups.append(param_group)
+
+import torch
+from .optimizer import Optimizer
+
+
+[docs]class RMSprop(Optimizer):
+ """Implements RMSprop algorithm.
+
+ Proposed by G. Hinton in his
+ `course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
+
+ The centered version first appears in `Generating Sequences
+ With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-2)
+ momentum (float, optional): momentum factor (default: 0)
+ alpha (float, optional): smoothing constant (default: 0.99)
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ centered (bool, optional) : if ``True``, compute the centered RMSProp,
+ the gradient is normalized by an estimation of its variance
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+
+ """
+
+ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= momentum:
+ raise ValueError("Invalid momentum value: {}".format(momentum))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= alpha:
+ raise ValueError("Invalid alpha value: {}".format(alpha))
+
+ defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
+ super(RMSprop, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(RMSprop, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('momentum', 0)
+ group.setdefault('centered', False)
+
+[docs] def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError('RMSprop does not support sparse gradients')
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ state['square_avg'] = torch.zeros_like(p.data)
+ if group['momentum'] > 0:
+ state['momentum_buffer'] = torch.zeros_like(p.data)
+ if group['centered']:
+ state['grad_avg'] = torch.zeros_like(p.data)
+
+ square_avg = state['square_avg']
+ alpha = group['alpha']
+
+ state['step'] += 1
+
+ if group['weight_decay'] != 0:
+ grad = grad.add(group['weight_decay'], p.data)
+
+ square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)
+
+ if group['centered']:
+ grad_avg = state['grad_avg']
+ grad_avg.mul_(alpha).add_(1 - alpha, grad)
+ avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt().add_(group['eps'])
+ else:
+ avg = square_avg.sqrt().add_(group['eps'])
+
+ if group['momentum'] > 0:
+ buf = state['momentum_buffer']
+ buf.mul_(group['momentum']).addcdiv_(grad, avg)
+ p.data.add_(-group['lr'], buf)
+ else:
+ p.data.addcdiv_(-group['lr'], grad, avg)
+
+ return loss
+
+import torch
+from .optimizer import Optimizer
+
+
+[docs]class Rprop(Optimizer):
+ """Implements the resilient backpropagation algorithm.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-2)
+ etas (Tuple[float, float], optional): pair of (etaminus, etaplis), that
+ are multiplicative increase and decrease factors
+ (default: (0.5, 1.2))
+ step_sizes (Tuple[float, float], optional): a pair of minimal and
+ maximal allowed step sizes (default: (1e-6, 50))
+ """
+
+ def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50)):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 < etas[0] < 1.0 < etas[1]:
+ raise ValueError("Invalid eta values: {}, {}".format(etas[0], etas[1]))
+
+ defaults = dict(lr=lr, etas=etas, step_sizes=step_sizes)
+ super(Rprop, self).__init__(params, defaults)
+
+[docs] def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError('Rprop does not support sparse gradients')
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ state['prev'] = torch.zeros_like(p.data)
+ state['step_size'] = grad.new().resize_as_(grad).fill_(group['lr'])
+
+ etaminus, etaplus = group['etas']
+ step_size_min, step_size_max = group['step_sizes']
+ step_size = state['step_size']
+
+ state['step'] += 1
+
+ sign = grad.mul(state['prev']).sign()
+ sign[sign.gt(0)] = etaplus
+ sign[sign.lt(0)] = etaminus
+ sign[sign.eq(0)] = 1
+
+ # update stepsizes with step size updates
+ step_size.mul_(sign).clamp_(step_size_min, step_size_max)
+
+ # for dir<0, dfdx=0
+ # for dir>=0 dfdx=dfdx
+ grad = grad.clone()
+ grad[sign.eq(etaminus)] = 0
+
+ # update parameters
+ p.data.addcmul_(-1, grad.sign(), step_size)
+
+ state['prev'].copy_(grad)
+
+ return loss
+
+import torch
+from .optimizer import Optimizer, required
+
+
+[docs]class SGD(Optimizer):
+ r"""Implements stochastic gradient descent (optionally with momentum).
+
+ Nesterov momentum is based on the formula from
+ `On the importance of initialization and momentum in deep learning`__.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float): learning rate
+ momentum (float, optional): momentum factor (default: 0)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ dampening (float, optional): dampening for momentum (default: 0)
+ nesterov (bool, optional): enables Nesterov momentum (default: False)
+
+ Example:
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+ >>> optimizer.zero_grad()
+ >>> loss_fn(model(input), target).backward()
+ >>> optimizer.step()
+
+ __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
+
+ .. note::
+ The implementation of SGD with Momentum/Nesterov subtly differs from
+ Sutskever et. al. and implementations in some other frameworks.
+
+ Considering the specific case of Momentum, the update can be written as
+
+ .. math::
+ v = \rho * v + g \\
+ p = p - lr * v
+
+ where p, g, v and :math:`\rho` denote the parameters, gradient,
+ velocity, and momentum respectively.
+
+ This is in contrast to Sutskever et. al. and
+ other frameworks which employ an update of the form
+
+ .. math::
+ v = \rho * v + lr * g \\
+ p = p - v
+
+ The Nesterov version is analogously modified.
+ """
+
+ def __init__(self, params, lr=required, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False):
+ if lr is not required and lr < 0.0:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if momentum < 0.0:
+ raise ValueError("Invalid momentum value: {}".format(momentum))
+ if weight_decay < 0.0:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+
+ defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
+ weight_decay=weight_decay, nesterov=nesterov)
+ if nesterov and (momentum <= 0 or dampening != 0):
+ raise ValueError("Nesterov momentum requires a momentum and zero dampening")
+ super(SGD, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(SGD, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('nesterov', False)
+
+[docs] def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ weight_decay = group['weight_decay']
+ momentum = group['momentum']
+ dampening = group['dampening']
+ nesterov = group['nesterov']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ d_p = p.grad.data
+ if weight_decay != 0:
+ d_p.add_(weight_decay, p.data)
+ if momentum != 0:
+ param_state = self.state[p]
+ if 'momentum_buffer' not in param_state:
+ buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
+ else:
+ buf = param_state['momentum_buffer']
+ buf.mul_(momentum).add_(1 - dampening, d_p)
+ if nesterov:
+ d_p = d_p.add(momentum, buf)
+ else:
+ d_p = buf
+
+ p.data.add_(-group['lr'], d_p)
+
+ return loss
+
+import math
+import torch
+from .optimizer import Optimizer
+
+
+[docs]class SparseAdam(Optimizer):
+ r"""Implements lazy version of Adam algorithm suitable for sparse tensors.
+
+ In this variant, only moments that show up in the gradient get updated, and
+ only those portions of the gradient get applied to the parameters.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+
+ .. _Adam\: A Method for Stochastic Optimization:
+ https://arxiv.org/abs/1412.6980
+ """
+
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
+ if not 0.0 < lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 < eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ defaults = dict(lr=lr, betas=betas, eps=eps)
+ super(SparseAdam, self).__init__(params, defaults)
+
+[docs] def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if not grad.is_sparse:
+ raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead')
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p.data)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
+
+ state['step'] += 1
+
+ grad = grad.coalesce() # the update is non-linear so indices must be unique
+ grad_indices = grad._indices()
+ grad_values = grad._values()
+ size = grad.size()
+
+ def make_sparse(values):
+ constructor = grad.new
+ if grad_indices.dim() == 0 or values.dim() == 0:
+ return constructor().resize_as_(grad)
+ return constructor(grad_indices, values, size)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ # Decay the first and second moment running average coefficient
+ # old <- b * old + (1 - b) * new
+ # <==> old += (1 - b) * (new - old)
+ old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
+ exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
+ exp_avg.add_(make_sparse(exp_avg_update_values))
+ old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
+ exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
+ exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
+
+ # Dense addition again is intended, avoiding another sparse_mask
+ numer = exp_avg_update_values.add_(old_exp_avg_values)
+ exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
+ denom = exp_avg_sq_update_values.sqrt_().add_(group['eps'])
+ del exp_avg_update_values, exp_avg_sq_update_values
+
+ bias_correction1 = 1 - beta1 ** state['step']
+ bias_correction2 = 1 - beta2 ** state['step']
+ step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
+
+ p.data.add_(make_sparse(-step_size * numer.div_(denom)))
+
+ return loss
+
+import torch
+
+
+[docs]class SobolEngine(object):
+ r"""
+ The :class:`torch.quasirandom.SobolEngine` is an engine for generating
+ (scrambled) Sobol sequences. Sobol sequences are an example of low
+ discrepancy quasi-random sequences.
+
+ This implementation of an engine for Sobol sequences is capable of
+ sampling sequences up to a maximum dimension of 1111. It uses direction
+ numbers to generate these sequences, and these numbers have been adapted
+ from `here <http://web.maths.unsw.edu.au/~fkuo/sobol/joe-kuo-old.1111>`_.
+
+ References:
+ - Art B. Owen. Scrambling Sobol and Niederreiter-Xing points.
+ Journal of Complexity, 14(4):466-489, December 1998.
+
+ - I. M. Sobol. The distribution of points in a cube and the accurate
+ evaluation of integrals.
+ Zh. Vychisl. Mat. i Mat. Phys., 7:784-802, 1967.
+
+ Args:
+ dimension (Int): The dimensionality of the sequence to be drawn
+ scramble (bool, optional): Setting this to ``True`` will produce
+ scrambled Sobol sequences. Scrambling is
+ capable of producing better Sobol
+ sequences. Default: ``False``.
+ seed (Int, optional): This is the seed for the scrambling. The seed
+ of the random number generator is set to this,
+ if specified. Default: ``None``
+
+ Examples::
+
+ >>> soboleng = torch.quasirandom.SobolEngine(dimension=5)
+ >>> soboleng.draw(3)
+ tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
+ [0.7500, 0.2500, 0.7500, 0.2500, 0.7500],
+ [0.2500, 0.7500, 0.2500, 0.7500, 0.2500]])
+ """
+ MAXBIT = 30
+ MAXDIM = 1111
+
+ def __init__(self, dimension, scramble=False, seed=None):
+ if dimension > self.MAXDIM or dimension < 1:
+ raise ValueError("Supported range of dimensionality "
+ "for SobolEngine is [1, {}]".format(self.MAXDIM))
+
+ self.seed = seed
+ self.scramble = scramble
+ self.dimension = dimension
+
+ self.sobolstate = torch.zeros(dimension, self.MAXBIT, dtype=torch.long)
+ torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
+
+ if self.scramble:
+ g = torch.Generator()
+ if self.seed is not None:
+ g.manual_seed(self.seed)
+
+ self.shift = torch.mv(torch.randint(2, (self.dimension, self.MAXBIT), generator=g),
+ torch.pow(2, torch.arange(0, self.MAXBIT)))
+
+ ltm = torch.randint(2, (self.dimension, self.MAXBIT, self.MAXBIT), generator=g).tril()
+
+ torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
+ else:
+ self.shift = torch.zeros(self.dimension, dtype=torch.long)
+
+ self.quasi = self.shift.clone()
+ self.num_generated = 0
+
+[docs] def draw(self, n=1, out=None, dtype=torch.float32):
+ r"""
+ Function to draw a sequence of :attr:`n` points from a Sobol sequence.
+ Note that the samples are dependent on the previous samples. The size
+ of the result is :math:`(n, dimension)`.
+
+ Args:
+ n (Int, optional): The length of sequence of points to draw.
+ Default: 1
+ out (Tensor, optional): The output tensor
+ dtype (:class:`torch.dtype`, optional): the desired data type of the
+ returned tensor.
+ Default: ``torch.float32``
+ """
+ result, self.quasi = torch._sobol_engine_draw(self.quasi, n, self.sobolstate,
+ self.dimension, self.num_generated, dtype=dtype)
+ self.num_generated += n
+ if out is not None:
+ out.resize_as_(result).copy_(result)
+ return out
+ return result
+
+[docs] def reset(self):
+ r"""
+ Function to reset the ``SobolEngine`` to base state.
+ """
+ self.quasi.copy_(self.shift)
+ self.num_generated = 0
+ return self
+
+[docs] def fast_forward(self, n):
+ r"""
+ Function to fast-forward the state of the ``SobolEngine`` by
+ :attr:`n` steps. This is equivalent to drawing :attr:`n` samples
+ without using the samples.
+
+ Args:
+ n (Int): The number of steps to fast-forward by.
+ """
+ torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated)
+ self.num_generated += n
+ return self
+
+ def __repr__(self):
+ fmt_string = ['dimension={}'.format(self.dimension)]
+ if self.scramble:
+ fmt_string += ['scramble=True']
+ if self.seed is not None:
+ fmt_string += ['seed={}'.format(self.seed)]
+ return self.__class__.__name__ + '(' + ', '.join(fmt_string) + ')'
+
+import contextlib
+import warnings
+
+from torch._C import default_generator
+
+
+[docs]def set_rng_state(new_state):
+ r"""Sets the random number generator state.
+
+ Args:
+ new_state (torch.ByteTensor): The desired state
+ """
+ default_generator.set_state(new_state)
+
+
+[docs]def get_rng_state():
+ r"""Returns the random number generator state as a `torch.ByteTensor`."""
+ return default_generator.get_state()
+
+
+[docs]def manual_seed(seed):
+ r"""Sets the seed for generating random numbers. Returns a
+ `torch._C.Generator` object.
+
+ Args:
+ seed (int): The desired seed.
+ """
+ seed = int(seed)
+ import torch.cuda
+
+ if not torch.cuda._in_bad_fork:
+ torch.cuda.manual_seed_all(seed)
+
+ return default_generator.manual_seed(seed)
+
+
+[docs]def initial_seed():
+ r"""Returns the initial seed for generating random numbers as a
+ Python `long`.
+ """
+ return default_generator.initial_seed()
+
+
+_fork_rng_warned_already = False
+
+
+@contextlib.contextmanager
+def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"):
+ """
+ Forks the RNG, so that when you return, the RNG is reset
+ to the state that it was previously in.
+
+ Arguments:
+ devices (iterable of CUDA IDs): CUDA devices for which to fork
+ the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
+ on all devices, but will emit a warning if your machine has a lot
+ of devices, since this function will run very slowly in that case.
+ If you explicitly specify devices, this warning will be suppressed
+ enabled (bool): if ``False``, the RNG is not forked. This is a convenience
+ argument for easily disabling the context manager without having
+ to delete it and unindent your Python code under it.
+ """
+
+ import torch.cuda
+ global _fork_rng_warned_already
+
+ # Internal arguments:
+ # _caller: the function which called fork_rng, which the user used
+ # _devices_kw: the devices keyword of _caller
+
+ if not enabled:
+ yield
+ return
+
+ if devices is None:
+ num_devices = torch.cuda.device_count()
+ if num_devices > 1 and not _fork_rng_warned_already:
+ warnings.warn(
+ ("CUDA reports that you have {num_devices} available devices, and you "
+ "have used {caller} without explicitly specifying which devices are being used. "
+ "For safety, we initialize *every* CUDA device by default, which "
+ "can be quite slow if you have a lot of GPUs. If you know that you are only "
+ "making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES "
+ "or the '{devices_kw}' keyword argument of {caller} with the set of devices "
+ "you are actually using. For example, if you are using CPU only, "
+ "set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using "
+ "GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0]. To initialize "
+ "all devices and suppress this warning, set the '{devices_kw}' keyword argument "
+ "to `range(torch.cuda.device_count())`."
+ ).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw))
+ _fork_rng_warned_already = True
+ devices = list(range(num_devices))
+ else:
+ # Protect against user passing us a generator; we need to traverse this
+ # multiple times but a generator will be exhausted upon first traversal
+ devices = list(devices)
+
+ cpu_rng_state = torch.get_rng_state()
+ gpu_rng_states = []
+ for device in devices:
+ with torch.cuda.device(device):
+ gpu_rng_states.append(torch.cuda.get_rng_state())
+
+ try:
+ yield
+ finally:
+ torch.set_rng_state(cpu_rng_state)
+ for device, gpu_rng_state in zip(devices, gpu_rng_states):
+ with torch.cuda.device(device):
+ torch.cuda.set_rng_state(gpu_rng_state)
+
+import difflib
+import inspect
+import os
+import io
+import shutil
+import struct
+import sys
+import torch
+import tarfile
+import zipfile
+import tempfile
+import warnings
+from contextlib import closing, contextmanager
+from ._utils import _import_dotted_name
+from ._six import string_classes as _string_classes
+if sys.version_info[0] == 2:
+ import cPickle as pickle
+else:
+ import pickle
+ import pathlib
+
+DEFAULT_PROTOCOL = 2
+
+LONG_SIZE = struct.Struct('=l').size
+INT_SIZE = struct.Struct('=i').size
+SHORT_SIZE = struct.Struct('=h').size
+
+MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
+PROTOCOL_VERSION = 1001
+STORAGE_KEY_SEPARATOR = ','
+
+
+class SourceChangeWarning(Warning):
+ pass
+
+
+@contextmanager
+def mkdtemp():
+ path = tempfile.mkdtemp()
+ yield path
+ shutil.rmtree(path)
+
+
+_package_registry = []
+
+
+def register_package(priority, tagger, deserializer):
+ queue_elem = (priority, tagger, deserializer)
+ _package_registry.append(queue_elem)
+ _package_registry.sort()
+
+
+def _cpu_tag(obj):
+ if type(obj).__module__ == 'torch':
+ return 'cpu'
+
+
+def _cuda_tag(obj):
+ if type(obj).__module__ == 'torch.cuda':
+ return 'cuda:' + str(obj.get_device())
+
+
+def _cpu_deserialize(obj, location):
+ if location == 'cpu':
+ return obj
+
+
+def validate_cuda_device(location):
+ if isinstance(location, torch.device):
+ location = str(location)
+ if not isinstance(location, _string_classes):
+ raise ValueError("location should be a string or torch.device")
+ if location[5:] == '':
+ device = 0
+ else:
+ device = max(int(location[5:]), 0)
+
+ if not torch.cuda.is_available():
+ raise RuntimeError('Attempting to deserialize object on a CUDA '
+ 'device but torch.cuda.is_available() is False. '
+ 'If you are running on a CPU-only machine, '
+ 'please use torch.load with map_location=\'cpu\' '
+ 'to map your storages to the CPU.')
+ if device >= torch.cuda.device_count():
+ raise RuntimeError('Attempting to deserialize object on CUDA device '
+ '{} but torch.cuda.device_count() is {}. Please use '
+ 'torch.load with map_location to map your storages '
+ 'to an existing device.'.format(
+ device, torch.cuda.device_count()))
+ return device
+
+
+def _cuda_deserialize(obj, location):
+ if location.startswith('cuda'):
+ device = validate_cuda_device(location)
+ if getattr(obj, "_torch_load_uninitialized", False):
+ storage_type = getattr(torch.cuda, type(obj).__name__)
+ with torch.cuda.device(device):
+ return storage_type(obj.size())
+ else:
+ return obj.cuda(device)
+
+
+register_package(10, _cpu_tag, _cpu_deserialize)
+register_package(20, _cuda_tag, _cuda_deserialize)
+
+
+def location_tag(storage):
+ for _, tagger, _ in _package_registry:
+ location = tagger(storage)
+ if location:
+ return location
+ raise RuntimeError("don't know how to determine data location of " +
+ torch.typename(storage))
+
+
+def default_restore_location(storage, location):
+ for _, _, fn in _package_registry:
+ result = fn(storage, location)
+ if result is not None:
+ return result
+ raise RuntimeError("don't know how to restore data location of " +
+ torch.typename(storage) + " (tagged with " +
+ location + ")")
+
+
+def normalize_storage_type(storage_type):
+ return getattr(torch, storage_type.__name__)
+
+
+def storage_to_tensor_type(storage):
+ storage_type = type(storage)
+ module = _import_dotted_name(storage_type.__module__)
+ return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
+
+
+def _with_file_like(f, mode, body):
+ """
+ Executes a body function with a file object for f, opening
+ it in 'mode' if it is a string filename.
+ """
+ new_fd = False
+ if isinstance(f, str) or \
+ (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
+ (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
+ new_fd = True
+ f = open(f, mode)
+ try:
+ return body(f)
+ finally:
+ if new_fd:
+ f.close()
+
+
+def _is_compressed_file(f):
+ compress_modules = ['gzip']
+ try:
+ return f.__module__ in compress_modules
+ except AttributeError:
+ return False
+
+
+def _should_read_directly(f):
+ """
+ Checks if f is a file that should be read directly. It should be read
+ directly if it is backed by a real file (has a fileno) and is not a
+ a compressed file (e.g. gzip)
+ """
+ if _is_compressed_file(f):
+ return False
+ try:
+ return f.fileno() >= 0
+ except io.UnsupportedOperation:
+ return False
+ except AttributeError:
+ return False
+
+
+def _check_seekable(f):
+
+ def raise_err_msg(patterns, e):
+ for p in patterns:
+ if p in str(e):
+ msg = (str(e) + ". You can only torch.load from a file that is seekable." +
+ " Please pre-load the data into a buffer like io.BytesIO and" +
+ " try to load from it instead.")
+ raise type(e)(msg)
+ raise e
+
+ try:
+ f.seek(f.tell())
+ return True
+ except (io.UnsupportedOperation, AttributeError) as e:
+ raise_err_msg(["seek", "tell"], e)
+
+
+[docs]def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
+ """Saves an object to a disk file.
+
+ See also: :ref:`recommend-saving-models`
+
+ Args:
+ obj: saved object
+ f: a file-like object (has to implement write and flush) or a string
+ containing a file name
+ pickle_module: module used for pickling metadata and objects
+ pickle_protocol: can be specified to override the default protocol
+
+ .. warning::
+ If you are using Python 2, torch.save does NOT support StringIO.StringIO
+ as a valid file-like object. This is because the write method should return
+ the number of bytes written; StringIO.write() does not do this.
+
+ Please use something like io.BytesIO instead.
+
+ Example:
+ >>> # Save to file
+ >>> x = torch.tensor([0, 1, 2, 3, 4])
+ >>> torch.save(x, 'tensor.pt')
+ >>> # Save to io.BytesIO buffer
+ >>> buffer = io.BytesIO()
+ >>> torch.save(x, buffer)
+ """
+ return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
+
+
+def _save(obj, f, pickle_module, pickle_protocol):
+ if sys.version_info[0] == 2:
+ import StringIO
+ if isinstance(f, StringIO.StringIO):
+ msg = ('torch.save received unsupported StringIO.StringIO file object, whose '
+ 'write method does not return the number of bytes written. '
+ 'Please use something like io.BytesIO for torch.save instead.')
+ raise RuntimeError(msg)
+
+ import torch.nn as nn
+ serialized_container_types = {}
+ serialized_storages = {}
+
+ def persistent_id(obj):
+ # FIXME: the docs say that persistent_id should only return a string
+ # but torch store returns tuples. This works only in the binary protocol
+ # see
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
+ if isinstance(obj, type) and issubclass(obj, nn.Module):
+ if obj in serialized_container_types:
+ return None
+ serialized_container_types[obj] = True
+ source_file = source = None
+ try:
+ source_file = inspect.getsourcefile(obj)
+ source = inspect.getsource(obj)
+ except Exception: # saving the source is optional, so we can ignore any errors
+ warnings.warn("Couldn't retrieve source code for container of "
+ "type " + obj.__name__ + ". It won't be checked "
+ "for correctness upon loading.")
+ return ('module', obj, source_file, source)
+ elif torch.is_storage(obj):
+ storage_type = normalize_storage_type(type(obj))
+ # Offset is always 0, but we keep it for backwards compatibility
+ # with the old serialization format (which supported storage views)
+ offset = 0
+ obj_key = str(obj._cdata)
+ location = location_tag(obj)
+ serialized_storages[obj_key] = obj
+ is_view = obj._cdata != obj._cdata
+ if is_view:
+ view_metadata = (str(obj._cdata), offset, obj.size())
+ else:
+ view_metadata = None
+
+ return ('storage',
+ storage_type,
+ obj_key,
+ location,
+ obj.size(),
+ view_metadata)
+
+ return None
+
+ sys_info = dict(
+ protocol_version=PROTOCOL_VERSION,
+ little_endian=sys.byteorder == 'little',
+ type_sizes=dict(
+ short=SHORT_SIZE,
+ int=INT_SIZE,
+ long=LONG_SIZE,
+ ),
+ )
+
+ pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
+ pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
+ pickle_module.dump(sys_info, f, protocol=pickle_protocol)
+ pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
+ pickler.persistent_id = persistent_id
+ pickler.dump(obj)
+
+ serialized_storage_keys = sorted(serialized_storages.keys())
+ pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
+ f.flush()
+ for key in serialized_storage_keys:
+ serialized_storages[key]._write_file(f, _should_read_directly(f))
+
+
+[docs]def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
+ """Loads an object saved with :func:`torch.save` from a file.
+
+ :meth:`torch.load` uses Python's unpickling facilities but treats storages,
+ which underlie tensors, specially. They are first deserialized on the
+ CPU and are then moved to the device they were saved from. If this fails
+ (e.g. because the run time system doesn't have certain devices), an exception
+ is raised. However, storages can be dynamically remapped to an alternative
+ set of devices using the `map_location` argument.
+
+ If `map_location` is a callable, it will be called once for each serialized
+ storage with two arguments: storage and location. The storage argument
+ will be the initial deserialization of the storage, residing on the CPU.
+ Each serialized storage has a location tag associated with it which
+ identifies the device it was saved from, and this tag is the second
+ argument passed to map_location. The builtin location tags are `'cpu'` for
+ CPU tensors and `'cuda:device_id'` (e.g. `'cuda:2'`) for CUDA tensors.
+ `map_location` should return either None or a storage. If `map_location` returns
+ a storage, it will be used as the final deserialized object, already moved to
+ the right device. Otherwise, :math:`torch.load` will fall back to the default
+ behavior, as if `map_location` wasn't specified.
+
+ If `map_location` is a string, it should be a device tag, where all tensors
+ should be loaded.
+
+ Otherwise, if `map_location` is a dict, it will be used to remap location tags
+ appearing in the file (keys), to ones that specify where to put the
+ storages (values).
+
+ User extensions can register their own location tags and tagging and
+ deserialization methods using `register_package`.
+
+ Args:
+ f: a file-like object (has to implement read, readline, tell, and seek),
+ or a string containing a file name
+ map_location: a function, torch.device, string or a dict specifying how to remap storage
+ locations
+ pickle_module: module used for unpickling metadata and objects (has to
+ match the pickle_module used to serialize file)
+ pickle_load_args: optional keyword arguments passed over to
+ ``pickle_module.load`` and ``pickle_module.Unpickler``, e.g.,
+ ``encoding=...``.
+
+ .. note::
+ When you call :meth:`torch.load()` on a file which contains GPU tensors, those tensors
+ will be loaded to GPU by default. You can call `torch.load(.., map_location='cpu')`
+ and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
+
+ .. note::
+ In Python 3, when loading files saved by Python 2, you may encounter
+ ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``. This is
+ caused by the difference of handling in byte strings in Python2 and
+ Python 3. You may use extra ``encoding`` keyword argument to specify how
+ these objects should be loaded, e.g., ``encoding='latin1'`` decodes them
+ to strings using ``latin1`` encoding, and ``encoding='bytes'`` keeps them
+ as byte arrays which can be decoded later with ``byte_array.decode(...)``.
+
+ Example:
+ >>> torch.load('tensors.pt')
+ # Load all tensors onto the CPU
+ >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
+ # Load all tensors onto the CPU, using a function
+ >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
+ # Load all tensors onto GPU 1
+ >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
+ # Map tensors from GPU 1 to GPU 0
+ >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
+ # Load tensor from io.BytesIO object
+ >>> with open('tensor.pt', 'rb') as f:
+ buffer = io.BytesIO(f.read())
+ >>> torch.load(buffer)
+ """
+ new_fd = False
+ if isinstance(f, str) or \
+ (sys.version_info[0] == 2 and isinstance(f, unicode)):
+ new_fd = True
+ f = open(f, 'rb')
+ elif (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
+ new_fd = True
+ f = f.open('rb')
+ try:
+ return _load(f, map_location, pickle_module, **pickle_load_args)
+ finally:
+ if new_fd:
+ f.close()
+
+
+def _load(f, map_location, pickle_module, **pickle_load_args):
+ deserialized_objects = {}
+
+ if map_location is None:
+ restore_location = default_restore_location
+ elif isinstance(map_location, dict):
+ def restore_location(storage, location):
+ location = map_location.get(location, location)
+ return default_restore_location(storage, location)
+ elif isinstance(map_location, _string_classes):
+ def restore_location(storage, location):
+ return default_restore_location(storage, map_location)
+ elif isinstance(map_location, torch.device):
+ def restore_location(storage, location):
+ return default_restore_location(storage, str(map_location))
+ else:
+ def restore_location(storage, location):
+ result = map_location(storage, location)
+ if result is None:
+ result = default_restore_location(storage, location)
+ return result
+
+ def _check_container_source(container_type, source_file, original_source):
+ try:
+ current_source = inspect.getsource(container_type)
+ except Exception: # saving the source is optional, so we can ignore any errors
+ warnings.warn("Couldn't retrieve source code for container of "
+ "type " + container_type.__name__ + ". It won't be checked "
+ "for correctness upon loading.")
+ return
+ if original_source != current_source:
+ if container_type.dump_patches:
+ file_name = container_type.__name__ + '.patch'
+ diff = difflib.unified_diff(current_source.split('\n'),
+ original_source.split('\n'),
+ source_file,
+ source_file, lineterm="")
+ lines = '\n'.join(diff)
+ try:
+ with open(file_name, 'a+') as f:
+ file_size = f.seek(0, 2)
+ f.seek(0)
+ if file_size == 0:
+ f.write(lines)
+ elif file_size != len(lines) or f.read() != lines:
+ raise IOError
+ msg = ("Saved a reverse patch to " + file_name + ". "
+ "Run `patch -p0 < " + file_name + "` to revert your "
+ "changes.")
+ except IOError:
+ msg = ("Tried to save a patch, but couldn't create a "
+ "writable file " + file_name + ". Make sure it "
+ "doesn't exist and your working directory is "
+ "writable.")
+ else:
+ msg = ("you can retrieve the original source code by "
+ "accessing the object's source attribute or set "
+ "`torch.nn.Module.dump_patches = True` and use the "
+ "patch tool to revert the changes.")
+ msg = ("source code of class '{}' has changed. {}"
+ .format(torch.typename(container_type), msg))
+ warnings.warn(msg, SourceChangeWarning)
+
+ def legacy_load(f):
+ deserialized_objects = {}
+
+ def persistent_load(saved_id):
+ if isinstance(saved_id, tuple):
+ # Ignore containers that don't have any sources saved
+ if all(saved_id[1:]):
+ _check_container_source(*saved_id)
+ return saved_id[0]
+ return deserialized_objects[int(saved_id)]
+
+ with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
+ mkdtemp() as tmpdir:
+
+ tar.extract('storages', path=tmpdir)
+ with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
+ num_storages = pickle_module.load(f, **pickle_load_args)
+ for i in range(num_storages):
+ args = pickle_module.load(f, **pickle_load_args)
+ key, location, storage_type = args
+ obj = storage_type._new_with_file(f)
+ obj = restore_location(obj, location)
+ deserialized_objects[key] = obj
+
+ storage_views = pickle_module.load(f, **pickle_load_args)
+ for target_cdata, root_cdata, offset, size in storage_views:
+ root = deserialized_objects[root_cdata]
+ deserialized_objects[target_cdata] = root[offset:offset + size]
+
+ tar.extract('tensors', path=tmpdir)
+ with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
+ num_tensors = pickle_module.load(f, **pickle_load_args)
+ for _ in range(num_tensors):
+ args = pickle_module.load(f, **pickle_load_args)
+ key, storage_id, original_tensor_type = args
+ storage = deserialized_objects[storage_id]
+ tensor_type = storage_to_tensor_type(storage)
+ ndim, = struct.unpack('<i', f.read(4))
+ # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
+ f.read(4)
+ size = struct.unpack('<{}q'.format(ndim), f.read(8 * ndim))
+ stride = struct.unpack('<{}q'.format(ndim), f.read(8 * ndim))
+ storage_offset, = struct.unpack('<q', f.read(8))
+ tensor = tensor_type().set_(storage, storage_offset, size, stride)
+ deserialized_objects[key] = tensor
+
+ pickle_file = tar.extractfile('pickle')
+ unpickler = pickle_module.Unpickler(pickle_file, **pickle_load_args)
+ unpickler.persistent_load = persistent_load
+ result = unpickler.load()
+ return result
+
+ deserialized_objects = {}
+
+ def maybe_decode_ascii(bytes_str):
+ # When using encoding='bytes' in Py3, some **internal** keys stored as
+ # strings in Py2 are loaded as bytes. This function decodes them with
+ # ascii encoding, one that Py3 uses by default.
+ #
+ # NOTE: This should only be used on internal keys (e.g., `typename` and
+ # `location` in `persistent_load` below!
+ if isinstance(bytes_str, bytes):
+ return bytes_str.decode('ascii')
+ return bytes_str
+
+ def persistent_load(saved_id):
+ assert isinstance(saved_id, tuple)
+ typename = maybe_decode_ascii(saved_id[0])
+ data = saved_id[1:]
+
+ if typename == 'module':
+ # Ignore containers that don't have any sources saved
+ if all(data[1:]):
+ _check_container_source(*data)
+ return data[0]
+ elif typename == 'storage':
+ data_type, root_key, location, size, view_metadata = data
+ location = maybe_decode_ascii(location)
+ if root_key not in deserialized_objects:
+ obj = data_type(size)
+ obj._torch_load_uninitialized = True
+ deserialized_objects[root_key] = restore_location(obj, location)
+ storage = deserialized_objects[root_key]
+ if view_metadata is not None:
+ view_key, offset, view_size = view_metadata
+ if view_key not in deserialized_objects:
+ deserialized_objects[view_key] = storage[offset:offset + view_size]
+ return deserialized_objects[view_key]
+ else:
+ return storage
+ else:
+ raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
+
+ _check_seekable(f)
+ f_should_read_directly = _should_read_directly(f)
+
+ if f_should_read_directly and f.tell() == 0:
+ # legacy_load requires that f has fileno()
+ # only if offset is zero we can attempt the legacy tar file loader
+ try:
+ return legacy_load(f)
+ except tarfile.TarError:
+ if zipfile.is_zipfile(f):
+ # .zip is used for torch.jit.save and will throw an un-pickling error here
+ raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))
+ # if not a tarfile, reset file offset and proceed
+ f.seek(0)
+
+ magic_number = pickle_module.load(f, **pickle_load_args)
+ if magic_number != MAGIC_NUMBER:
+ raise RuntimeError("Invalid magic number; corrupt file?")
+ protocol_version = pickle_module.load(f, **pickle_load_args)
+ if protocol_version != PROTOCOL_VERSION:
+ raise RuntimeError("Invalid protocol version: %s" % protocol_version)
+
+ _sys_info = pickle_module.load(f, **pickle_load_args)
+ unpickler = pickle_module.Unpickler(f, **pickle_load_args)
+ unpickler.persistent_load = persistent_load
+ result = unpickler.load()
+
+ deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
+
+ offset = f.tell() if f_should_read_directly else None
+ for key in deserialized_storage_keys:
+ assert key in deserialized_objects
+ deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
+ offset = None
+
+ return result
+
+# The Tensor classes are added to this module by python_tensor.cpp
+import torch
+
+__all__ = [
+ 'addmm',
+ 'mm',
+ 'sum',
+]
+
+
+[docs]def addmm(mat, mat1, mat2, beta=1, alpha=1):
+ r"""
+ This function does exact same thing as :func:`torch.addmm` in the forward,
+ except that it supports backward for sparse matrix :attr:`mat1`. :attr:`mat1`
+ need to have `sparse_dim = 2`. Note that the gradients of :attr:`mat1` is a
+ coalesced sparse tensor.
+
+ Args:
+ mat (Tensor): a dense matrix to be added
+ mat1 (SparseTensor): a sparse matrix to be multiplied
+ mat2 (Tensor): a dense matrix be multiplied
+ beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
+ alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
+ """
+ return torch._sparse_addmm(mat, mat1, mat2, beta=beta, alpha=alpha)
+
+
+[docs]def mm(mat1, mat2):
+ r"""
+ Performs a matrix multiplication of the sparse matrix :attr:`mat1`
+ and dense matrix :attr:`mat2`. Similar to :func:`torch.mm`, If :attr:`mat1` is a
+ :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a
+ :math:`(n \times p)` dense tensor. :attr:`mat1` need to have `sparse_dim = 2`.
+ This function also supports backward for both matrices. Note that the gradients of
+ :attr:`mat1` is a coalesced sparse tensor.
+
+ Args:
+ mat1 (SparseTensor): the first sparse matrix to be multiplied
+ mat2 (Tensor): the second dense matrix to be multiplied
+
+ Example::
+
+ >>> a = torch.randn(2, 3).to_sparse().requires_grad_(True)
+ >>> a
+ tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
+ [0, 1, 2, 0, 1, 2]]),
+ values=tensor([ 1.5901, 0.0183, -0.6146, 1.8061, -0.0112, 0.6302]),
+ size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True)
+
+ >>> b = torch.randn(3, 2, requires_grad=True)
+ >>> b
+ tensor([[-0.6479, 0.7874],
+ [-1.2056, 0.5641],
+ [-1.1716, -0.9923]], requires_grad=True)
+
+ >>> y = torch.sparse.mm(a, b)
+ >>> y
+ tensor([[-0.3323, 1.8723],
+ [-1.8951, 0.7904]], grad_fn=<SparseAddmmBackward>)
+ >>> y.sum().backward()
+ >>> a.grad
+ tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
+ [0, 1, 2, 0, 1, 2]]),
+ values=tensor([ 0.1394, -0.6415, -2.1639, 0.1394, -0.6415, -2.1639]),
+ size=(2, 3), nnz=6, layout=torch.sparse_coo)
+ """
+ return torch._sparse_mm(mat1, mat2)
+
+
+[docs]def sum(input, dim=None, dtype=None):
+ r"""
+ Returns the sum of each row of SparseTensor :attr:`input` in the given
+ dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions,
+ reduce over all of them. When sum over all ``sparse_dim``, this method
+ returns a Tensor instead of SparseTensor.
+
+ All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output
+ tensor having :attr:`dim` fewer dimensions than :attr:`input`.
+
+ During backward, only gradients at ``nnz`` locations of :attr:`input`
+ will propagate back. Note that the gradients of :attr:`input` is coalesced.
+
+ Args:
+ input (Tensor): the input SparseTensor
+ dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce
+ over all dims.
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor.
+ Default: dtype of :attr:`input`.
+
+ Example::
+
+ >>> nnz = 3
+ >>> dims = [5, 5, 2, 3]
+ >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)),
+ torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz)
+ >>> V = torch.randn(nnz, dims[2], dims[3])
+ >>> size = torch.Size(dims)
+ >>> S = torch.sparse_coo_tensor(I, V, size)
+ >>> S
+ tensor(indices=tensor([[2, 0, 3],
+ [2, 4, 1]]),
+ values=tensor([[[-0.6438, -1.6467, 1.4004],
+ [ 0.3411, 0.0918, -0.2312]],
+
+ [[ 0.5348, 0.0634, -2.0494],
+ [-0.7125, -1.0646, 2.1844]],
+
+ [[ 0.1276, 0.1874, -0.6334],
+ [-1.9682, -0.5340, 0.7483]]]),
+ size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo)
+
+ # when sum over only part of sparse_dims, return a SparseTensor
+ >>> torch.sparse.sum(S, [1, 3])
+ tensor(indices=tensor([[0, 2, 3]]),
+ values=tensor([[-1.4512, 0.4073],
+ [-0.8901, 0.2017],
+ [-0.3183, -1.7539]]),
+ size=(5, 2), nnz=3, layout=torch.sparse_coo)
+
+ # when sum over all sparse dim, return a dense Tensor
+ # with summed dims squeezed
+ >>> torch.sparse.sum(S, [0, 1, 3])
+ tensor([-2.6596, -1.1450])
+ """
+ if dtype is None:
+ if dim is not None:
+ return torch._sparse_sum(input, dim)
+ else:
+ return torch._sparse_sum(input)
+ else:
+ if dim is not None:
+ return torch._sparse_sum(input, dim, dtype=dtype)
+ else:
+ return torch._sparse_sum(input, dtype=dtype)
+
+import io
+
+import torch
+from ._utils import _type, _cuda
+
+
+class _StorageBase(object):
+ is_cuda = False
+ is_sparse = False
+
+ def __str__(self):
+ content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))
+ return content + '\n[{} of size {}]'.format(torch.typename(self), len(self))
+
+ def __repr__(self):
+ return str(self)
+
+ def __iter__(self):
+ return iter(map(lambda i: self[i], range(self.size())))
+
+ def __copy__(self):
+ return self.clone()
+
+ def __deepcopy__(self, memo):
+ memo = memo.setdefault('torch', {})
+ if self._cdata in memo:
+ return memo[self._cdata]
+ new_storage = self.clone()
+ memo[self._cdata] = new_storage
+ return new_storage
+
+ def __reduce__(self):
+ b = io.BytesIO()
+ torch.save(self, b)
+ return (_load_from_bytes, (b.getvalue(),))
+
+ def __sizeof__(self):
+ return super(_StorageBase, self).__sizeof__() + self.element_size() * self.size()
+
+ def clone(self):
+ """Returns a copy of this storage"""
+ device = self.get_device() if self.is_cuda else -1
+ with torch.cuda.device(device):
+ return type(self)(self.size()).copy_(self)
+
+ def tolist(self):
+ """Returns a list containing the elements of this storage"""
+ return [v for v in self]
+
+ def cpu(self):
+ """Returns a CPU copy of this storage if it's not already on the CPU"""
+ return self.type(getattr(torch, self.__class__.__name__))
+
+ def double(self):
+ """Casts this storage to double type"""
+ return self.type(type(self).__module__ + '.DoubleStorage')
+
+ def float(self):
+ """Casts this storage to float type"""
+ return self.type(type(self).__module__ + '.FloatStorage')
+
+ def half(self):
+ """Casts this storage to half type"""
+ return self.type(type(self).__module__ + '.HalfStorage')
+
+ def long(self):
+ """Casts this storage to long type"""
+ return self.type(type(self).__module__ + '.LongStorage')
+
+ def int(self):
+ """Casts this storage to int type"""
+ return self.type(type(self).__module__ + '.IntStorage')
+
+ def short(self):
+ """Casts this storage to short type"""
+ return self.type(type(self).__module__ + '.ShortStorage')
+
+ def char(self):
+ """Casts this storage to char type"""
+ return self.type(type(self).__module__ + '.CharStorage')
+
+ def byte(self):
+ """Casts this storage to byte type"""
+ return self.type(type(self).__module__ + '.ByteStorage')
+
+ def bool(self):
+ """Casts this storage to bool type"""
+ return self.type(type(self).__module__ + '.BoolStorage')
+
+ def pin_memory(self):
+ """Copies the storage to pinned memory, if it's not already pinned."""
+ if self.is_cuda:
+ raise TypeError("cannot pin '{0}' only CPU memory can be pinned"
+ .format(self.type()))
+ import torch.cuda
+ allocator = torch.cuda._host_allocator()
+ return type(self)(self.size(), allocator=allocator).copy_(self)
+
+ def share_memory_(self):
+ """Moves the storage to shared memory.
+
+ This is a no-op for storages already in shared memory and for CUDA
+ storages, which do not need to be moved for sharing across processes.
+ Storages in shared memory cannot be resized.
+
+ Returns: self
+ """
+ from torch.multiprocessing import get_sharing_strategy
+ if self.is_cuda:
+ pass # CUDA doesn't use POSIX shared memory
+ elif get_sharing_strategy() == 'file_system':
+ self._share_filename_()
+ else:
+ self._share_fd_()
+ return self
+
+ @classmethod
+ def _new_shared(cls, size):
+ """Creates a new storage in shared memory with the same data type"""
+ from torch.multiprocessing import get_sharing_strategy
+ if cls.is_cuda:
+ return cls(size)
+ elif get_sharing_strategy() == 'file_system':
+ return cls._new_using_filename(size)
+ else:
+ return cls._new_using_fd(size)
+
+
+def _load_from_bytes(b):
+ return torch.load(io.BytesIO(b))
+
+
+_StorageBase.type = _type
+_StorageBase.cuda = _cuda
+
+import sys
+import torch
+import torch._C as _C
+from collections import OrderedDict
+import torch.utils.hooks as hooks
+import warnings
+import weakref
+from torch._six import imap
+from torch._C import _add_docstr
+from numbers import Number
+
+
+# NB: If you subclass Tensor, and want to share the subclassed class
+# across processes, you must also update torch/multiprocessing/reductions.py
+# to define a ForkingPickler serialization mode for the class.
+#
+# NB: If you add a new method to Tensor, you must update
+# torch/__init__.py.in to add a type annotation for your method;
+# otherwise, it will not show up in autocomplete.
+class Tensor(torch._C._TensorBase):
+ def __deepcopy__(self, memo):
+ if not self.is_leaf:
+ raise RuntimeError("Only Tensors created explicitly by the user "
+ "(graph leaves) support the deepcopy protocol at the moment")
+ if id(self) in memo:
+ return memo[id(self)]
+ with torch.no_grad():
+ if self.is_sparse:
+ new_tensor = self.clone()
+ else:
+ new_storage = self.storage().__deepcopy__(memo)
+ new_tensor = self.new()
+ new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
+ memo[id(self)] = new_tensor
+ new_tensor.requires_grad = self.requires_grad
+ return new_tensor
+
+ def __reduce_ex__(self, proto):
+ # See Note [Don't serialize hooks]
+ torch.utils.hooks.warn_if_has_hooks(self)
+ args = (self.storage(),
+ self.storage_offset(),
+ tuple(self.size()),
+ self.stride(),
+ self.requires_grad,
+ OrderedDict()) # previously was self._backward_hooks
+ return (torch._utils._rebuild_tensor_v2, args)
+
+ def __setstate__(self, state):
+ # Warning: this method is NOT called when you torch.load() a tensor;
+ # that is managed by _rebuild_tensor_v2
+ if not self.is_leaf:
+ raise RuntimeError('__setstate__ can be only called on leaf Tensors')
+ if len(state) == 4:
+ # legacy serialization of Tensor
+ self.set_(*state)
+ return
+ elif len(state) == 5:
+ # legacy serialization of Variable
+ self.data = state[0]
+ state = (state[3], state[4], state[2])
+ # The setting of _backward_hooks is expected to be a no-op.
+ # See Note [Don't serialize hooks]
+ self.requires_grad, _, self._backward_hooks = state
+
+ def __repr__(self):
+ # All strings are unicode in Python 3, while we have to encode unicode
+ # strings in Python2. If we can't, let python decide the best
+ # characters to replace unicode characters with.
+ if sys.version_info > (3,):
+ return torch._tensor_str._str(self)
+ else:
+ if hasattr(sys.stdout, 'encoding'):
+ return torch._tensor_str._str(self).encode(
+ sys.stdout.encoding or 'UTF-8', 'replace')
+ else:
+ return torch._tensor_str._str(self).encode('UTF-8', 'replace')
+
+[docs] def backward(self, gradient=None, retain_graph=None, create_graph=False):
+ r"""Computes the gradient of current tensor w.r.t. graph leaves.
+
+ The graph is differentiated using the chain rule. If the tensor is
+ non-scalar (i.e. its data has more than one element) and requires
+ gradient, the function additionally requires specifying ``gradient``.
+ It should be a tensor of matching type and location, that contains
+ the gradient of the differentiated function w.r.t. ``self``.
+
+ This function accumulates gradients in the leaves - you might need to
+ zero them before calling it.
+
+ Arguments:
+ gradient (Tensor or None): Gradient w.r.t. the
+ tensor. If it is a tensor, it will be automatically converted
+ to a Tensor that does not require grad unless ``create_graph`` is True.
+ None values can be specified for scalar Tensors or ones that
+ don't require grad. If a None value would be acceptable then
+ this argument is optional.
+ retain_graph (bool, optional): If ``False``, the graph used to compute
+ the grads will be freed. Note that in nearly all cases setting
+ this option to True is not needed and often can be worked around
+ in a much more efficient way. Defaults to the value of
+ ``create_graph``.
+ create_graph (bool, optional): If ``True``, graph of the derivative will
+ be constructed, allowing to compute higher order derivative
+ products. Defaults to ``False``.
+ """
+ torch.autograd.backward(self, gradient, retain_graph, create_graph)
+
+[docs] def register_hook(self, hook):
+ r"""Registers a backward hook.
+
+ The hook will be called every time a gradient with respect to the
+ Tensor is computed. The hook should have the following signature::
+
+ hook(grad) -> Tensor or None
+
+
+ The hook should not modify its argument, but it can optionally return
+ a new gradient which will be used in place of :attr:`grad`.
+
+ This function returns a handle with a method ``handle.remove()``
+ that removes the hook from the module.
+
+ Example::
+
+ >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
+ >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
+ >>> v.backward(torch.tensor([1., 2., 3.]))
+ >>> v.grad
+
+ 2
+ 4
+ 6
+ [torch.FloatTensor of size (3,)]
+
+ >>> h.remove() # removes the hook
+ """
+ if not self.requires_grad:
+ raise RuntimeError("cannot register a hook on a tensor that "
+ "doesn't require gradient")
+ if self._backward_hooks is None:
+ self._backward_hooks = OrderedDict()
+ if self.grad_fn is not None:
+ self.grad_fn._register_hook_dict(self)
+ handle = hooks.RemovableHandle(self._backward_hooks)
+ self._backward_hooks[handle.id] = hook
+ return handle
+
+ def reinforce(self, reward):
+ def trim(str):
+ return '\n'.join([line.strip() for line in str.split('\n')])
+
+ raise RuntimeError(trim(r"""reinforce() was removed.
+ Use torch.distributions instead.
+ See https://pytorch.org/docs/master/distributions.html
+
+ Instead of:
+
+ probs = policy_network(state)
+ action = probs.multinomial()
+ next_state, reward = env.step(action)
+ action.reinforce(reward)
+ action.backward()
+
+ Use:
+
+ probs = policy_network(state)
+ # NOTE: categorical is equivalent to what used to be called multinomial
+ m = torch.distributions.Categorical(probs)
+ action = m.sample()
+ next_state, reward = env.step(action)
+ loss = -m.log_prob(action) * reward
+ loss.backward()
+ """))
+
+ detach = _add_docstr(_C._TensorBase.detach, r"""
+ Returns a new Tensor, detached from the current graph.
+
+ The result will never require gradient.
+
+ .. note::
+
+ Returned Tensor shares the same storage with the original one.
+ In-place modifications on either of them will be seen, and may trigger
+ errors in correctness checks.
+ IMPORTANT NOTE: Previously, in-place size / stride / storage changes
+ (such as `resize_` / `resize_as_` / `set_` / `transpose_`) to the returned tensor
+ also update the original tensor. Now, these in-place changes will not update the
+ original tensor anymore, and will instead trigger an error.
+ For sparse tensors:
+ In-place indices / values changes (such as `zero_` / `copy_` / `add_`) to the
+ returned tensor will not update the original tensor anymore, and will instead
+ trigger an error.
+ """)
+
+ detach_ = _add_docstr(_C._TensorBase.detach_, r"""
+ Detaches the Tensor from the graph that created it, making it a leaf.
+ Views cannot be detached in-place.
+ """)
+
+[docs] def retain_grad(self):
+ r"""Enables .grad attribute for non-leaf Tensors."""
+ if self.grad_fn is None: # no-op for leaves
+ return
+ if not self.requires_grad:
+ raise RuntimeError("can't retain_grad on Tensor that has requires_grad=False")
+ if hasattr(self, 'retains_grad'):
+ return
+ weak_self = weakref.ref(self)
+
+ def retain_grad_hook(grad):
+ var = weak_self()
+ if var is None:
+ return
+ if var._grad is None:
+ var._grad = grad.clone()
+ else:
+ var._grad = var._grad + grad
+
+ self.register_hook(retain_grad_hook)
+ self.retains_grad = True
+
+[docs] def is_pinned(self):
+ r"""Returns true if this tensor resides in pinned memory"""
+ storage = self.storage()
+ return storage.is_pinned() if storage else False
+
+
+
+
+
+ def __reversed__(self):
+ r"""Reverses the tensor along dimension 0."""
+ if self.dim() == 0:
+ return self
+ else:
+ return self.flip(0)
+
+[docs] def norm(self, p="fro", dim=None, keepdim=False, dtype=None):
+ r"""See :func:`torch.norm`"""
+ return torch.norm(self, p, dim, keepdim, dtype=dtype)
+
+[docs] def pstrf(self, upper=True):
+ r"""See :func:`torch.pstrf`"""
+ warnings.warn("torch.pstrf is deprecated in favour of torch.cholesky and will be removed "
+ "in the next release.", stacklevel=2)
+ return super(Tensor, self).pstrf(upper=upper)
+
+[docs] def potrf(self, upper=True):
+ r"""See :func:`torch.cholesky`"""
+ warnings.warn("torch.potrf is deprecated in favour of torch.cholesky and will be removed "
+ "in the next release. Please use torch.cholesky instead and note that the "
+ ":attr:`upper` argument in torch.cholesky defaults to ``False``.", stacklevel=2)
+ return super(Tensor, self).cholesky(upper=upper)
+
+[docs] def potri(self, upper=True):
+ r"""See :func:`torch.cholesky_inverse`"""
+ warnings.warn("torch.potri is deprecated in favour of torch.cholesky_inverse and will be "
+ "removed in the next release. Please use torch.cholesky_inverse instead and "
+ "note that the :attr:`upper` argument in torch.cholesky_inverse defaults to "
+ "``False``.", stacklevel=2)
+ return super(Tensor, self).cholesky_inverse(upper=upper)
+
+[docs] def potrs(self, u, upper=True):
+ r"""See :func:`torch.cholesky_solve`"""
+ warnings.warn("torch.potrs is deprecated in favour of torch.cholesky_solve and "
+ "will be removed in the next release. Please use torch.cholesky_solve instead "
+ "and note that the :attr:`upper` argument in torch.cholesky_solve defaults "
+ "to ``False``.", stacklevel=2)
+ return super(Tensor, self).cholesky_solve(u, upper=upper)
+
+[docs] def gesv(self, A):
+ r"""See :func:`torch.solve`"""
+ warnings.warn("torch.gesv is deprecated in favour of torch.solve and will be removed in the "
+ "next release. Please use torch.solve instead.", stacklevel=2)
+ return super(Tensor, self).solve(A)
+
+[docs] def trtrs(self, A, upper=True, transpose=False, unitriangular=False):
+ r"""See :func:`torch.triangular_solve`"""
+ warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
+ "removed in the next release. Please use torch.triangular_solve instead.",
+ stacklevel=2)
+ return super(Tensor, self).triangular_solve(A, upper=upper,
+ transpose=transpose, unitriangular=unitriangular)
+
+[docs] def btrifact(self, pivot=True):
+ r"""See :func:`torch.lu`"""
+ warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be removed in "
+ "the next release. Please use torch.lu instead.", stacklevel=2)
+ return torch._lu_with_info(self, pivot=pivot, check_errors=True)
+
+[docs] def btrifact_with_info(self, pivot=True):
+ r"""See :func:`torch.lu`"""
+ warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu with the "
+ "get_infos argument and will be removed in the next release. Please use "
+ "torch.lu with the get_infos argument set to True instead.", stacklevel=2)
+ return torch._lu_with_info(self, pivot=pivot, check_errors=False)
+
+[docs] def btrisolve(self, LU_data, LU_pivots):
+ r"""See :func:`torch.lu_solve`"""
+ warnings.warn("torch.btrisolve is deprecated in favour of torch.lu_solve and will be "
+ "removed in the next release. Please use torch.lu_solve instead.",
+ stacklevel=2)
+ return super(Tensor, self).lu_solve(LU_data=LU_data, LU_pivots=LU_pivots)
+
+[docs] def lu(self, pivot=True, get_infos=False):
+ r"""See :func:`torch.lu`"""
+ # If get_infos is True, then we don't need to check for errors and vice versa
+ LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
+ if get_infos:
+ return LU, pivots, infos
+ else:
+ return LU, pivots
+
+[docs] def stft(self, n_fft, hop_length=None, win_length=None, window=None,
+ center=True, pad_mode='reflect', normalized=False, onesided=True):
+ r"""See :func:`torch.stft`
+
+ .. warning::
+ This function changed signature at version 0.4.1. Calling with
+ the previous signature may cause error or return incorrect result.
+ """
+ return torch.stft(self, n_fft, hop_length, win_length, window, center,
+ pad_mode, normalized, onesided)
+
+ def resize(self, *sizes):
+ warnings.warn("non-inplace resize is deprecated")
+ from torch.autograd._functions import Resize
+ return Resize.apply(self, sizes)
+
+ def resize_as(self, tensor):
+ warnings.warn("non-inplace resize_as is deprecated")
+ from torch.autograd._functions import Resize
+ return Resize.apply(self, tensor.size())
+
+[docs] def split(self, split_size, dim=0):
+ r"""See :func:`torch.split`
+ """
+ if isinstance(split_size, int):
+ return super(Tensor, self).split(split_size, dim)
+ else:
+ return super(Tensor, self).split_with_sizes(split_size, dim)
+
+[docs] def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
+ r"""Returns the unique elements of the input tensor.
+
+ See :func:`torch.unique`
+ """
+ return torch.unique(self, sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
+
+[docs] def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None):
+ r"""Eliminates all but the first element from every consecutive group of equivalent elements.
+
+ See :func:`torch.unique_consecutive`
+ """
+ return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
+
+ def __rsub__(self, other):
+ return _C._VariableFunctions.rsub(self, other)
+
+ def __rdiv__(self, other):
+ if self.dtype.is_floating_point:
+ return self.reciprocal() * other
+ else:
+ return (self.double().reciprocal() * other).type_as(self)
+
+ __rtruediv__ = __rdiv__
+ __itruediv__ = _C._TensorBase.__idiv__
+
+ __pow__ = _C._TensorBase.pow
+
+ def __format__(self, format_spec):
+ if self.dim() == 0:
+ return self.item().__format__(format_spec)
+ return object.__format__(self, format_spec)
+
+ def __ipow__(self, other):
+ raise NotImplementedError("in-place pow not implemented")
+
+ def __rpow__(self, other):
+ return self.new_tensor(other) ** self
+
+ def __floordiv__(self, other):
+ result = self / other
+ if result.dtype.is_floating_point:
+ result = result.trunc()
+ return result
+
+ def __rfloordiv__(self, other):
+ result = other / self
+ if result.dtype.is_floating_point:
+ result = result.trunc()
+ return result
+
+ __neg__ = _C._TensorBase.neg
+
+ __eq__ = _C._TensorBase.eq
+ __ne__ = _C._TensorBase.ne
+ __lt__ = _C._TensorBase.lt
+ __le__ = _C._TensorBase.le
+ __gt__ = _C._TensorBase.gt
+ __ge__ = _C._TensorBase.ge
+ __abs__ = _C._TensorBase.abs
+
+ def __len__(self):
+ if self.dim() == 0:
+ raise TypeError("len() of a 0-d tensor")
+ return self.shape[0]
+
+ def __iter__(self):
+ # NB: we use 'imap' and not 'map' here, so that in Python 2 we get a
+ # generator and don't eagerly perform all the indexes. This could
+ # save us work, and also helps keep trace ordering deterministic
+ # (e.g., if you zip(*hiddens), the eager map will force all the
+ # indexes of hiddens[0] before hiddens[1], while the generator
+ # map will interleave them.)
+ if self.dim() == 0:
+ raise TypeError('iteration over a 0-d tensor')
+ if torch._C._get_tracing_state():
+ warnings.warn('Iterating over a tensor might cause the trace to be incorrect. '
+ 'Passing a tensor of different shape won\'t change the number of '
+ 'iterations executed (and might lead to errors or silently give '
+ 'incorrect results).', category=RuntimeWarning)
+ return iter(imap(lambda i: self[i], range(self.size(0))))
+
+ def __hash__(self):
+ return id(self)
+
+ def __dir__(self):
+ tensor_methods = dir(self.__class__)
+ tensor_methods.remove('volatile') # deprecated
+ attrs = list(self.__dict__.keys())
+ keys = tensor_methods + attrs
+
+ # property only available dense, cuda tensors
+ if (not self.is_cuda) or self.is_sparse:
+ keys.remove("__cuda_array_interface__")
+
+ return sorted(keys)
+
+ # Numpy array interface, to support `numpy.asarray(tensor) -> ndarray`
+ __array_priority__ = 1000 # prefer Tensor ops over numpy ones
+
+ def __array__(self, dtype=None):
+ if dtype is None:
+ return self.numpy()
+ else:
+ return self.numpy().astype(dtype, copy=False)
+
+ # Wrap Numpy array again in a suitable tensor when done, to support e.g.
+ # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor`
+ def __array_wrap__(self, array):
+ if array.dtype == bool:
+ # Workaround, torch has no built-in bool tensor
+ array = array.astype('uint8')
+ return torch.from_numpy(array)
+
+ def __contains__(self, element):
+ r"""Check if `element` is present in tensor
+
+ Arguments:
+ element (Tensor or scalar): element to be checked
+ for presence in current tensor"
+ """
+ if isinstance(element, (torch.Tensor, Number)):
+ return (element == self).any().item()
+ return NotImplemented
+
+ @property
+ def __cuda_array_interface__(self):
+ """Array view description for cuda tensors.
+
+ See:
+ https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
+ """
+
+ # raise AttributeError for unsupported tensors, so that
+ # hasattr(cpu_tensor, "__cuda_array_interface__") is False.
+ if not self.is_cuda:
+ raise AttributeError(
+ "Can't get __cuda_array_interface__ on non-CUDA tensor type: %s "
+ "If CUDA data is required use tensor.cuda() to copy tensor to device memory." %
+ self.type()
+ )
+
+ if self.is_sparse:
+ raise AttributeError(
+ "Can't get __cuda_array_interface__ on sparse type: %s "
+ "Use Tensor.to_dense() to convert to a dense tensor first." %
+ self.type()
+ )
+
+ # RuntimeError, matching tensor.__array__() behavior.
+ if self.requires_grad:
+ raise RuntimeError(
+ "Can't get __cuda_array_interface__ on Variable that requires grad. "
+ "If gradients aren't required, use var.detach() to get Variable that doesn't require grad."
+ )
+
+ # CUDA devices are little-endian and tensors are stored in native byte
+ # order. 1-byte entries are endian-agnostic.
+ typestr = {
+ torch.float16: "<f2",
+ torch.float32: "<f4",
+ torch.float64: "<f8",
+ torch.uint8: "|u1",
+ torch.int8: "|i1",
+ torch.int16: "<i2",
+ torch.int32: "<i4",
+ torch.int64: "<i8",
+ }[self.dtype]
+
+ itemsize = self.storage().element_size()
+
+ shape = self.shape
+ strides = tuple(s * itemsize for s in self.stride())
+ data = (self.data_ptr(), False) # read-only is false
+
+ return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=0)
+
+ __module__ = 'torch'
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+import torch
+import warnings
+
+
+def detach_variable(inputs):
+ if isinstance(inputs, tuple):
+ out = []
+ for inp in inputs:
+ if not isinstance(inp, torch.Tensor):
+ out.append(inp)
+ continue
+
+ x = inp.detach()
+ x.requires_grad = inp.requires_grad
+ out.append(x)
+ return tuple(out)
+ else:
+ raise RuntimeError(
+ "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
+
+
+def check_backward_validity(inputs):
+ if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
+ warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
+
+
+# We can't know if the run_fn will internally move some args to different devices,
+# which would require logic to preserve rng states for those devices as well.
+# We could paranoically stash and restore ALL the rng states for all visible devices,
+# but that seems very wasteful for most cases. Compromise: Stash the RNG state for
+# the device of all Tensor args.
+#
+# To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
+def get_device_states(*args):
+ # This will not error out if "arg" is a CPU tensor or a non-tensor type because
+ # the conditionals short-circuit.
+ fwd_gpu_devices = list(set(arg.get_device() for arg in args
+ if isinstance(arg, torch.Tensor) and arg.is_cuda))
+
+ fwd_gpu_states = []
+ for device in fwd_gpu_devices:
+ with torch.cuda.device(device):
+ fwd_gpu_states.append(torch.cuda.get_rng_state())
+
+ return fwd_gpu_devices, fwd_gpu_states
+
+
+def set_device_states(devices, states):
+ for device, state in zip(devices, states):
+ with torch.cuda.device(device):
+ torch.cuda.set_rng_state(state)
+
+
+class CheckpointFunction(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, run_function, preserve_rng_state, *args):
+ check_backward_validity(args)
+ ctx.run_function = run_function
+ ctx.preserve_rng_state = preserve_rng_state
+ if preserve_rng_state:
+ ctx.fwd_cpu_state = torch.get_rng_state()
+ # Don't eagerly initialize the cuda context by accident.
+ # (If the user intends that the context is initialized later, within their
+ # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
+ # we have no way to anticipate this will happen before we run the function.)
+ ctx.had_cuda_in_fwd = False
+ if torch.cuda._initialized:
+ ctx.had_cuda_in_fwd = True
+ ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
+ ctx.save_for_backward(*args)
+ with torch.no_grad():
+ outputs = run_function(*args)
+ return outputs
+
+ @staticmethod
+ def backward(ctx, *args):
+ if not torch.autograd._is_checkpoint_valid():
+ raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
+ inputs = ctx.saved_tensors
+ # Stash the surrounding rng state, and mimic the state that was
+ # present at this time during forward. Restore the surrouding state
+ # when we're done.
+ rng_devices = []
+ if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
+ rng_devices = ctx.fwd_gpu_devices
+ with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
+ if ctx.preserve_rng_state:
+ torch.set_rng_state(ctx.fwd_cpu_state)
+ if ctx.had_cuda_in_fwd:
+ set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
+ detached_inputs = detach_variable(inputs)
+ with torch.enable_grad():
+ outputs = ctx.run_function(*detached_inputs)
+
+ if isinstance(outputs, torch.Tensor):
+ outputs = (outputs,)
+ torch.autograd.backward(outputs, args)
+ grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
+ for inp in detached_inputs)
+ return (None, None) + grads
+
+
+[docs]def checkpoint(function, *args, **kwargs):
+ r"""Checkpoint a model or part of the model
+
+ Checkpointing works by trading compute for memory. Rather than storing all
+ intermediate activations of the entire computation graph for computing
+ backward, the checkpointed part does **not** save intermediate activations,
+ and instead recomputes them in backward pass. It can be applied on any part
+ of a model.
+
+ Specifically, in the forward pass, :attr:`function` will run in
+ :func:`torch.no_grad` manner, i.e., not storing the intermediate
+ activations. Instead, the forward pass saves the inputs tuple and the
+ :attr:`function` parameter. In the backwards pass, the saved inputs and
+ :attr:`function` is retreived, and the forward pass is computed on
+ :attr:`function` again, now tracking the intermediate activations, and then
+ the gradients are calculated using these activation values.
+
+ .. warning::
+ Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
+ with :func:`torch.autograd.backward`.
+
+ .. warning::
+ If :attr:`function` invocation during backward does anything different
+ than the one during forward, e.g., due to some global variable, the
+ checkpointed version won't be equivalent, and unfortunately it can't be
+ detected.
+
+ .. warning:
+ At least one of the inputs needs to have :code:`requires_grad=True` if
+ grads are needed for model inputs, otherwise the checkpointed part of the
+ model won't have gradients.
+
+ Args:
+ function: describes what to run in the forward pass of the model or
+ part of the model. It should also know how to handle the inputs
+ passed as the tuple. For example, in LSTM, if user passes
+ ``(activation, hidden)``, :attr:`function` should correctly use the
+ first input as ``activation`` and the second input as ``hidden``
+ preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
+ the RNG state during each checkpoint.
+ args: tuple containing inputs to the :attr:`function`
+
+ Returns:
+ Output of running :attr:`function` on :attr:`*args`
+ """
+ # Hack to mix *args with **kwargs in a python 2.7-compliant way
+ preserve = kwargs.pop('preserve_rng_state', True)
+ if kwargs:
+ raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
+
+ return CheckpointFunction.apply(function, preserve, *args)
+
+
+[docs]def checkpoint_sequential(functions, segments, *inputs, **kwargs):
+ r"""A helper function for checkpointing sequential models.
+
+ Sequential models execute a list of modules/functions in order
+ (sequentially). Therefore, we can divide such a model in various segments
+ and checkpoint each segment. All segments except the last will run in
+ :func:`torch.no_grad` manner, i.e., not storing the intermediate
+ activations. The inputs of each checkpointed segment will be saved for
+ re-running the segment in the backward pass.
+
+ See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
+
+ .. warning::
+ Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
+ with :func:`torch.autograd.backward`.
+
+ .. warning:
+ At least one of the inputs needs to have :code:`requires_grad=True` if
+ grads are needed for model inputs, otherwise the checkpointed part of the
+ model won't have gradients.
+
+ Args:
+ functions: A :class:`torch.nn.Sequential` or the list of modules or
+ functions (comprising the model) to run sequentially.
+ segments: Number of chunks to create in the model
+ inputs: tuple of Tensors that are inputs to :attr:`functions`
+ preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
+ the RNG state during each checkpoint.
+
+ Returns:
+ Output of running :attr:`functions` sequentially on :attr:`*inputs`
+
+ Example:
+ >>> model = nn.Sequential(...)
+ >>> input_var = checkpoint_sequential(model, chunks, input_var)
+ """
+ # Hack to mix *args with **kwargs in a python 2.7-compliant way
+ preserve = kwargs.pop('preserve_rng_state', True)
+ if kwargs:
+ raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
+
+ def run_function(start, end, functions):
+ def forward(*inputs):
+ for j in range(start, end + 1):
+ if isinstance(inputs, tuple):
+ inputs = functions[j](*inputs)
+ else:
+ inputs = functions[j](inputs)
+ return inputs
+ return forward
+
+ if isinstance(functions, torch.nn.Sequential):
+ functions = list(functions.children())
+
+ segment_size = len(functions) // segments
+ # the last chunk has to be non-volatile
+ end = -1
+ for start in range(0, segment_size * (segments - 1), segment_size):
+ end = start + segment_size - 1
+ inputs = checkpoint(run_function(start, end, functions), *inputs,
+ preserve_rng_state=preserve)
+ if not isinstance(inputs, tuple):
+ inputs = (inputs,)
+ return run_function(end + 1, len(functions) - 1, functions)(*inputs)
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+import copy
+import glob
+import imp
+import os
+import re
+import setuptools
+import subprocess
+import sys
+import sysconfig
+import tempfile
+import warnings
+
+import torch
+from .file_baton import FileBaton
+from ._cpp_extension_versioner import ExtensionVersioner
+
+from setuptools.command.build_ext import build_ext
+
+
+IS_WINDOWS = sys.platform == 'win32'
+
+
+def _find_cuda_home():
+ '''Finds the CUDA install path.'''
+ # Guess #1
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
+ if cuda_home is None:
+ # Guess #2
+ if IS_WINDOWS:
+ cuda_homes = glob.glob(
+ 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
+ if len(cuda_homes) == 0:
+ cuda_home = ''
+ else:
+ cuda_home = cuda_homes[0]
+ else:
+ cuda_home = '/usr/local/cuda'
+ if not os.path.exists(cuda_home):
+ # Guess #3
+ try:
+ which = 'where' if IS_WINDOWS else 'which'
+ nvcc = subprocess.check_output(
+ [which, 'nvcc']).decode().rstrip('\r\n')
+ cuda_home = os.path.dirname(os.path.dirname(nvcc))
+ except Exception:
+ cuda_home = None
+ if cuda_home and not torch.cuda.is_available():
+ print("No CUDA runtime is found, using CUDA_HOME='{}'".format(cuda_home))
+ return cuda_home
+
+
+MINIMUM_GCC_VERSION = (4, 9, 0)
+MINIMUM_MSVC_VERSION = (19, 0, 24215)
+ABI_INCOMPATIBILITY_WARNING = '''
+
+ !! WARNING !!
+
+!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+Your compiler ({}) may be ABI-incompatible with PyTorch!
+Please use a compiler that is ABI-compatible with GCC 4.9 and above.
+See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html.
+
+See https://gist.github.com/goldsborough/d466f43e8ffc948ff92de7486c5216d6
+for instructions on how to install GCC 4.9 or higher.
+!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ !! WARNING !!
+'''
+WRONG_COMPILER_WARNING = '''
+
+ !! WARNING !!
+
+!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+Your compiler ({user_compiler}) is not compatible with the compiler Pytorch was
+built with for this platform, which is {pytorch_compiler} on {platform}. Please
+use {pytorch_compiler} to to compile your extension. Alternatively, you may
+compile PyTorch from source using {user_compiler}, and then you can also use
+{user_compiler} to compile your extension.
+
+See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
+with compiling PyTorch from source.
+!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ !! WARNING !!
+'''
+CUDA_HOME = _find_cuda_home()
+CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
+# PyTorch releases have the version pattern major.minor.patch, whereas when
+# PyTorch is built from source, we append the git commit hash, which gives
+# it the below pattern.
+BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')
+
+COMMON_NVCC_FLAGS = [
+ '-D__CUDA_NO_HALF_OPERATORS__',
+ '-D__CUDA_NO_HALF_CONVERSIONS__',
+ '-D__CUDA_NO_HALF2_OPERATORS__',
+]
+
+
+JIT_EXTENSION_VERSIONER = ExtensionVersioner()
+
+
+def _is_binary_build():
+ return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
+
+
+def _accepted_compilers_for_platform():
+ return ['clang++', 'clang'] if sys.platform.startswith('darwin') else ['g++', 'gcc']
+
+
+def get_default_build_root():
+ '''
+ Returns the path to the root folder under which extensions will built.
+
+ For each extension module built, there will be one folder underneath the
+ folder returned by this function. For example, if ``p`` is the path
+ returned by this function and ``ext`` the name of an extension, the build
+ folder for the extension will be ``p/ext``.
+ '''
+ # tempfile.gettempdir() will be /tmp on UNIX and \TEMP on Windows.
+ return os.path.realpath(os.path.join(tempfile.gettempdir(), 'torch_extensions'))
+
+
+def check_compiler_ok_for_platform(compiler):
+ '''
+ Verifies that the compiler is the expected one for the current platform.
+
+ Arguments:
+ compiler (str): The compiler executable to check.
+
+ Returns:
+ True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS,
+ and always True for Windows.
+ '''
+ if IS_WINDOWS:
+ return True
+ which = subprocess.check_output(['which', compiler], stderr=subprocess.STDOUT)
+ # Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'.
+ compiler_path = os.path.realpath(which.decode().strip())
+ return any(name in compiler_path for name in _accepted_compilers_for_platform())
+
+
+[docs]def check_compiler_abi_compatibility(compiler):
+ '''
+ Verifies that the given compiler is ABI-compatible with PyTorch.
+
+ Arguments:
+ compiler (str): The compiler executable name to check (e.g. ``g++``).
+ Must be executable in a shell process.
+
+ Returns:
+ False if the compiler is (likely) ABI-incompatible with PyTorch,
+ else True.
+ '''
+ if not _is_binary_build():
+ return True
+ if os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') in ['ON', '1', 'YES', 'TRUE', 'Y']:
+ return True
+
+ # First check if the compiler is one of the expected ones for the particular platform.
+ if not check_compiler_ok_for_platform(compiler):
+ warnings.warn(WRONG_COMPILER_WARNING.format(
+ user_compiler=compiler,
+ pytorch_compiler=_accepted_compilers_for_platform()[0],
+ platform=sys.platform))
+ return False
+
+ if sys.platform.startswith('darwin'):
+ # There is no particular minimum version we need for clang, so we're good here.
+ return True
+ try:
+ if sys.platform.startswith('linux'):
+ minimum_required_version = MINIMUM_GCC_VERSION
+ version = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
+ version = version.decode().strip().split('.')
+ else:
+ minimum_required_version = MINIMUM_MSVC_VERSION
+ compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT)
+ match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode().strip())
+ version = (0, 0, 0) if match is None else match.groups()
+ except Exception:
+ _, error, _ = sys.exc_info()
+ warnings.warn('Error checking compiler version for {}: {}'.format(compiler, error))
+ return False
+
+ if tuple(map(int, version)) >= minimum_required_version:
+ return True
+
+ compiler = '{} {}'.format(compiler, ".".join(version))
+ warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler))
+
+ return False
+
+
+# See below for why we inherit BuildExtension from object.
+# https://stackoverflow.com/questions/1713038/super-fails-with-error-typeerror-argument-1-must-be-type-not-classobj-when
+
+
+[docs]class BuildExtension(build_ext, object):
+ '''
+ A custom :mod:`setuptools` build extension .
+
+ This :class:`setuptools.build_ext` subclass takes care of passing the
+ minimum required compiler flags (e.g. ``-std=c++11``) as well as mixed
+ C++/CUDA compilation (and support for CUDA files in general).
+
+ When using :class:`BuildExtension`, it is allowed to supply a dictionary
+ for ``extra_compile_args`` (rather than the usual list) that maps from
+ languages (``cxx`` or ``cuda``) to a list of additional compiler flags to
+ supply to the compiler. This makes it possible to supply different flags to
+ the C++ and CUDA compiler during mixed compilation.
+ '''
+
+ @classmethod
+ def with_options(cls, **options):
+ '''
+ Returns an alternative constructor that extends any original keyword
+ arguments to the original constructor with the given options.
+ '''
+ def init_with_options(*args, **kwargs):
+ kwargs = kwargs.copy()
+ kwargs.update(options)
+ return cls(*args, **kwargs)
+ return init_with_options
+
+ def __init__(self, *args, **kwargs):
+ super(BuildExtension, self).__init__(*args, **kwargs)
+ self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False)
+
+ def build_extensions(self):
+ self._check_abi()
+ for extension in self.extensions:
+ self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H')
+ self._define_torch_extension_name(extension)
+ self._add_gnu_cpp_abi_flag(extension)
+
+ # Register .cu and .cuh as valid source extensions.
+ self.compiler.src_extensions += ['.cu', '.cuh']
+ # Save the original _compile method for later.
+ if self.compiler.compiler_type == 'msvc':
+ self.compiler._cpp_extensions += ['.cu', '.cuh']
+ original_compile = self.compiler.compile
+ original_spawn = self.compiler.spawn
+ else:
+ original_compile = self.compiler._compile
+
+ def unix_wrap_compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
+ # Copy before we make any modifications.
+ cflags = copy.deepcopy(extra_postargs)
+ try:
+ original_compiler = self.compiler.compiler_so
+ if _is_cuda_file(src):
+ nvcc = _join_cuda_home('bin', 'nvcc')
+ if not isinstance(nvcc, list):
+ nvcc = [nvcc]
+ self.compiler.set_executable('compiler_so', nvcc)
+ if isinstance(cflags, dict):
+ cflags = cflags['nvcc']
+ cflags = COMMON_NVCC_FLAGS + ['--compiler-options', "'-fPIC'"] + cflags
+ elif isinstance(cflags, dict):
+ cflags = cflags['cxx']
+ # NVCC does not allow multiple -std to be passed, so we avoid
+ # overriding the option if the user explicitly passed it.
+ if not any(flag.startswith('-std=') for flag in cflags):
+ cflags.append('-std=c++11')
+
+ original_compile(obj, src, ext, cc_args, cflags, pp_opts)
+ finally:
+ # Put the original compiler back in place.
+ self.compiler.set_executable('compiler_so', original_compiler)
+
+ def win_wrap_compile(sources,
+ output_dir=None,
+ macros=None,
+ include_dirs=None,
+ debug=0,
+ extra_preargs=None,
+ extra_postargs=None,
+ depends=None):
+
+ self.cflags = copy.deepcopy(extra_postargs)
+ extra_postargs = None
+
+ def spawn(cmd):
+ # Using regex to match src, obj and include files
+ src_regex = re.compile('/T(p|c)(.*)')
+ src_list = [
+ m.group(2) for m in (src_regex.match(elem) for elem in cmd)
+ if m
+ ]
+
+ obj_regex = re.compile('/Fo(.*)')
+ obj_list = [
+ m.group(1) for m in (obj_regex.match(elem) for elem in cmd)
+ if m
+ ]
+
+ include_regex = re.compile(r'((\-|\/)I.*)')
+ include_list = [
+ m.group(1)
+ for m in (include_regex.match(elem) for elem in cmd) if m
+ ]
+
+ if len(src_list) >= 1 and len(obj_list) >= 1:
+ src = src_list[0]
+ obj = obj_list[0]
+ if _is_cuda_file(src):
+ nvcc = _join_cuda_home('bin', 'nvcc')
+ if isinstance(self.cflags, dict):
+ cflags = self.cflags['nvcc']
+ elif isinstance(self.cflags, list):
+ cflags = self.cflags
+ else:
+ cflags = []
+ cmd = [
+ nvcc, '-c', src, '-o', obj, '-Xcompiler',
+ '/wd4819', '-Xcompiler', '/MD'
+ ] + include_list + cflags
+ elif isinstance(self.cflags, dict):
+ cflags = self.cflags['cxx'] + ['/MD']
+ cmd += cflags
+ elif isinstance(self.cflags, list):
+ cflags = self.cflags + ['/MD']
+ cmd += cflags
+
+ return original_spawn(cmd)
+
+ try:
+ self.compiler.spawn = spawn
+ return original_compile(sources, output_dir, macros,
+ include_dirs, debug, extra_preargs,
+ extra_postargs, depends)
+ finally:
+ self.compiler.spawn = original_spawn
+
+ # Monkey-patch the _compile method.
+ if self.compiler.compiler_type == 'msvc':
+ self.compiler.compile = win_wrap_compile
+ else:
+ self.compiler._compile = unix_wrap_compile
+
+ build_ext.build_extensions(self)
+
+ def get_ext_filename(self, ext_name):
+ # Get the original shared library name. For Python 3, this name will be
+ # suffixed with "<SOABI>.so", where <SOABI> will be something like
+ # cpython-37m-x86_64-linux-gnu. On Python 2, there is no such ABI name.
+ # The final extension, .so, would be .lib/.dll on Windows of course.
+ ext_filename = super(BuildExtension, self).get_ext_filename(ext_name)
+ # If `no_python_abi_suffix` is `True`, we omit the Python 3 ABI
+ # component. This makes building shared libraries with setuptools that
+ # aren't Python modules nicer.
+ if self.no_python_abi_suffix and sys.version_info >= (3, 0):
+ # The parts will be e.g. ["my_extension", "cpython-37m-x86_64-linux-gnu", "so"].
+ ext_filename_parts = ext_filename.split('.')
+ # Omit the second to last element.
+ without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:]
+ ext_filename = '.'.join(without_abi)
+ return ext_filename
+
+ def _check_abi(self):
+ # On some platforms, like Windows, compiler_cxx is not available.
+ if hasattr(self.compiler, 'compiler_cxx'):
+ compiler = self.compiler.compiler_cxx[0]
+ elif IS_WINDOWS:
+ compiler = os.environ.get('CXX', 'cl')
+ else:
+ compiler = os.environ.get('CXX', 'c++')
+ check_compiler_abi_compatibility(compiler)
+
+ def _add_compile_flag(self, extension, flag):
+ extension.extra_compile_args = copy.copy(extension.extra_compile_args)
+ if isinstance(extension.extra_compile_args, dict):
+ for args in extension.extra_compile_args.values():
+ args.append(flag)
+ else:
+ extension.extra_compile_args.append(flag)
+
+ def _define_torch_extension_name(self, extension):
+ # pybind11 doesn't support dots in the names
+ # so in order to support extensions in the packages
+ # like torch._C, we take the last part of the string
+ # as the library name
+ names = extension.name.split('.')
+ name = names[-1]
+ define = '-DTORCH_EXTENSION_NAME={}'.format(name)
+ self._add_compile_flag(extension, define)
+
+ def _add_gnu_cpp_abi_flag(self, extension):
+ # use the same CXX ABI as what PyTorch was compiled with
+ self._add_compile_flag(extension, '-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI)))
+
+
+[docs]def CppExtension(name, sources, *args, **kwargs):
+ '''
+ Creates a :class:`setuptools.Extension` for C++.
+
+ Convenience method that creates a :class:`setuptools.Extension` with the
+ bare minimum (but often sufficient) arguments to build a C++ extension.
+
+ All arguments are forwarded to the :class:`setuptools.Extension`
+ constructor.
+
+ Example:
+ >>> from setuptools import setup
+ >>> from torch.utils.cpp_extension import BuildExtension, CppExtension
+ >>> setup(
+ name='extension',
+ ext_modules=[
+ CppExtension(
+ name='extension',
+ sources=['extension.cpp'],
+ extra_compile_args=['-g']),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ })
+ '''
+ include_dirs = kwargs.get('include_dirs', [])
+ include_dirs += include_paths()
+ kwargs['include_dirs'] = include_dirs
+
+ if IS_WINDOWS:
+ library_dirs = kwargs.get('library_dirs', [])
+ library_dirs += library_paths()
+ kwargs['library_dirs'] = library_dirs
+
+ libraries = kwargs.get('libraries', [])
+ libraries.append('c10')
+ libraries.append('caffe2')
+ libraries.append('torch')
+ libraries.append('torch_python')
+ libraries.append('_C')
+ kwargs['libraries'] = libraries
+
+ kwargs['language'] = 'c++'
+ return setuptools.Extension(name, sources, *args, **kwargs)
+
+
+[docs]def CUDAExtension(name, sources, *args, **kwargs):
+ '''
+ Creates a :class:`setuptools.Extension` for CUDA/C++.
+
+ Convenience method that creates a :class:`setuptools.Extension` with the
+ bare minimum (but often sufficient) arguments to build a CUDA/C++
+ extension. This includes the CUDA include path, library path and runtime
+ library.
+
+ All arguments are forwarded to the :class:`setuptools.Extension`
+ constructor.
+
+ Example:
+ >>> from setuptools import setup
+ >>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+ >>> setup(
+ name='cuda_extension',
+ ext_modules=[
+ CUDAExtension(
+ name='cuda_extension',
+ sources=['extension.cpp', 'extension_kernel.cu'],
+ extra_compile_args={'cxx': ['-g'],
+ 'nvcc': ['-O2']})
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ })
+ '''
+ library_dirs = kwargs.get('library_dirs', [])
+ library_dirs += library_paths(cuda=True)
+ kwargs['library_dirs'] = library_dirs
+
+ libraries = kwargs.get('libraries', [])
+ libraries.append('cudart')
+ if IS_WINDOWS:
+ libraries.append('c10')
+ libraries.append('c10_cuda')
+ libraries.append('caffe2')
+ libraries.append('torch')
+ libraries.append('torch_python')
+ libraries.append('caffe2_gpu')
+ libraries.append('_C')
+ kwargs['libraries'] = libraries
+
+ include_dirs = kwargs.get('include_dirs', [])
+ include_dirs += include_paths(cuda=True)
+ kwargs['include_dirs'] = include_dirs
+
+ kwargs['language'] = 'c++'
+
+ return setuptools.Extension(name, sources, *args, **kwargs)
+
+
+[docs]def include_paths(cuda=False):
+ '''
+ Get the include paths required to build a C++ or CUDA extension.
+
+ Args:
+ cuda: If `True`, includes CUDA-specific include paths.
+
+ Returns:
+ A list of include path strings.
+ '''
+ here = os.path.abspath(__file__)
+ torch_path = os.path.dirname(os.path.dirname(here))
+ lib_include = os.path.join(torch_path, 'include')
+ paths = [
+ lib_include,
+ # Remove this once torch/torch.h is officially no longer supported for C++ extensions.
+ os.path.join(lib_include, 'torch', 'csrc', 'api', 'include'),
+ # Some internal (old) Torch headers don't properly prefix their includes,
+ # so we need to pass -Itorch/lib/include/TH as well.
+ os.path.join(lib_include, 'TH'),
+ os.path.join(lib_include, 'THC')
+ ]
+ if cuda:
+ cuda_home_include = _join_cuda_home('include')
+ # if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home.
+ # but gcc dosn't like having /usr/include passed explicitly
+ if cuda_home_include != '/usr/include':
+ paths.append(cuda_home_include)
+ if CUDNN_HOME is not None:
+ paths.append(os.path.join(CUDNN_HOME, 'include'))
+ return paths
+
+
+def library_paths(cuda=False):
+ '''
+ Get the library paths required to build a C++ or CUDA extension.
+
+ Args:
+ cuda: If `True`, includes CUDA-specific library paths.
+
+ Returns:
+ A list of library path strings.
+ '''
+ paths = []
+
+ if IS_WINDOWS:
+ here = os.path.abspath(__file__)
+ torch_path = os.path.dirname(os.path.dirname(here))
+ lib_path = os.path.join(torch_path, 'lib')
+
+ paths.append(lib_path)
+
+ if cuda:
+ lib_dir = 'lib/x64' if IS_WINDOWS else 'lib64'
+ paths.append(_join_cuda_home(lib_dir))
+ if CUDNN_HOME is not None:
+ paths.append(os.path.join(CUDNN_HOME, lib_dir))
+ return paths
+
+
+[docs]def load(name,
+ sources,
+ extra_cflags=None,
+ extra_cuda_cflags=None,
+ extra_ldflags=None,
+ extra_include_paths=None,
+ build_directory=None,
+ verbose=False,
+ with_cuda=None,
+ is_python_module=True):
+ '''
+ Loads a PyTorch C++ extension just-in-time (JIT).
+
+ To load an extension, a Ninja build file is emitted, which is used to
+ compile the given sources into a dynamic library. This library is
+ subsequently loaded into the current Python process as a module and
+ returned from this function, ready for use.
+
+ By default, the directory to which the build file is emitted and the
+ resulting library compiled to is ``<tmp>/torch_extensions/<name>``, where
+ ``<tmp>`` is the temporary folder on the current platform and ``<name>``
+ the name of the extension. This location can be overridden in two ways.
+ First, if the ``TORCH_EXTENSIONS_DIR`` environment variable is set, it
+ replaces ``<tmp>/torch_extensions`` and all extensions will be compiled
+ into subfolders of this directory. Second, if the ``build_directory``
+ argument to this function is supplied, it overrides the entire path, i.e.
+ the library will be compiled into that folder directly.
+
+ To compile the sources, the default system compiler (``c++``) is used,
+ which can be overridden by setting the ``CXX`` environment variable. To pass
+ additional arguments to the compilation process, ``extra_cflags`` or
+ ``extra_ldflags`` can be provided. For example, to compile your extension
+ with optimizations, pass ``extra_cflags=['-O3']``. You can also use
+ ``extra_cflags`` to pass further include directories.
+
+ CUDA support with mixed compilation is provided. Simply pass CUDA source
+ files (``.cu`` or ``.cuh``) along with other sources. Such files will be
+ detected and compiled with nvcc rather than the C++ compiler. This includes
+ passing the CUDA lib64 directory as a library directory, and linking
+ ``cudart``. You can pass additional flags to nvcc via
+ ``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various
+ heuristics for finding the CUDA install directory are used, which usually
+ work fine. If not, setting the ``CUDA_HOME`` environment variable is the
+ safest option.
+
+ Args:
+ name: The name of the extension to build. This MUST be the same as the
+ name of the pybind11 module!
+ sources: A list of relative or absolute paths to C++ source files.
+ extra_cflags: optional list of compiler flags to forward to the build.
+ extra_cuda_cflags: optional list of compiler flags to forward to nvcc
+ when building CUDA sources.
+ extra_ldflags: optional list of linker flags to forward to the build.
+ extra_include_paths: optional list of include directories to forward
+ to the build.
+ build_directory: optional path to use as build workspace.
+ verbose: If ``True``, turns on verbose logging of load steps.
+ with_cuda: Determines whether CUDA headers and libraries are added to
+ the build. If set to ``None`` (default), this value is
+ automatically determined based on the existence of ``.cu`` or
+ ``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers
+ and libraries to be included.
+ is_python_module: If ``True`` (default), imports the produced shared
+ library as a Python module. If ``False``, loads it into the process
+ as a plain dynamic library.
+
+ Returns:
+ If ``is_python_module`` is ``True``, returns the loaded PyTorch
+ extension as a Python module. If ``is_python_module`` is ``False``
+ returns nothing (the shared library is loaded into the process as a side
+ effect).
+
+ Example:
+ >>> from torch.utils.cpp_extension import load
+ >>> module = load(
+ name='extension',
+ sources=['extension.cpp', 'extension_kernel.cu'],
+ extra_cflags=['-O2'],
+ verbose=True)
+ '''
+ return _jit_compile(
+ name,
+ [sources] if isinstance(sources, str) else sources,
+ extra_cflags,
+ extra_cuda_cflags,
+ extra_ldflags,
+ extra_include_paths,
+ build_directory or _get_build_directory(name, verbose),
+ verbose,
+ with_cuda,
+ is_python_module)
+
+
+[docs]def load_inline(name,
+ cpp_sources,
+ cuda_sources=None,
+ functions=None,
+ extra_cflags=None,
+ extra_cuda_cflags=None,
+ extra_ldflags=None,
+ extra_include_paths=None,
+ build_directory=None,
+ verbose=False,
+ with_cuda=None,
+ is_python_module=True):
+ '''
+ Loads a PyTorch C++ extension just-in-time (JIT) from string sources.
+
+ This function behaves exactly like :func:`load`, but takes its sources as
+ strings rather than filenames. These strings are stored to files in the
+ build directory, after which the behavior of :func:`load_inline` is
+ identical to :func:`load`.
+
+ See `the
+ tests <https://github.com/pytorch/pytorch/blob/master/test/test_cpp_extensions.py>`_
+ for good examples of using this function.
+
+ Sources may omit two required parts of a typical non-inline C++ extension:
+ the necessary header includes, as well as the (pybind11) binding code. More
+ precisely, strings passed to ``cpp_sources`` are first concatenated into a
+ single ``.cpp`` file. This file is then prepended with ``#include
+ <torch/extension.h>``.
+
+ Furthermore, if the ``functions`` argument is supplied, bindings will be
+ automatically generated for each function specified. ``functions`` can
+ either be a list of function names, or a dictionary mapping from function
+ names to docstrings. If a list is given, the name of each function is used
+ as its docstring.
+
+ The sources in ``cuda_sources`` are concatenated into a separate ``.cu``
+ file and prepended with ``torch/types.h``, ``cuda.h`` and
+ ``cuda_runtime.h`` includes. The ``.cpp`` and ``.cu`` files are compiled
+ separately, but ultimately linked into a single library. Note that no
+ bindings are generated for functions in ``cuda_sources`` per se. To bind
+ to a CUDA kernel, you must create a C++ function that calls it, and either
+ declare or define this C++ function in one of the ``cpp_sources`` (and
+ include its name in ``functions``).
+
+ See :func:`load` for a description of arguments omitted below.
+
+ Args:
+ cpp_sources: A string, or list of strings, containing C++ source code.
+ cuda_sources: A string, or list of strings, containing CUDA source code.
+ functions: A list of function names for which to generate function
+ bindings. If a dictionary is given, it should map function names to
+ docstrings (which are otherwise just the function names).
+ with_cuda: Determines whether CUDA headers and libraries are added to
+ the build. If set to ``None`` (default), this value is
+ automatically determined based on whether ``cuda_sources`` is
+ provided. Set it to `True`` to force CUDA headers
+ and libraries to be included.
+
+ Example:
+ >>> from torch.utils.cpp_extension import load_inline
+ >>> source = \'\'\'
+ at::Tensor sin_add(at::Tensor x, at::Tensor y) {
+ return x.sin() + y.sin();
+ }
+ \'\'\'
+ >>> module = load_inline(name='inline_extension',
+ cpp_sources=[source],
+ functions=['sin_add'])
+ '''
+ build_directory = build_directory or _get_build_directory(name, verbose)
+
+ if isinstance(cpp_sources, str):
+ cpp_sources = [cpp_sources]
+ cuda_sources = cuda_sources or []
+ if isinstance(cuda_sources, str):
+ cuda_sources = [cuda_sources]
+
+ cpp_sources.insert(0, '#include <torch/extension.h>')
+
+ # If `functions` is supplied, we create the pybind11 bindings for the user.
+ # Here, `functions` is (or becomes, after some processing) a map from
+ # function names to function docstrings.
+ if functions is not None:
+ cpp_sources.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')
+ if isinstance(functions, str):
+ functions = [functions]
+ if isinstance(functions, list):
+ # Make the function docstring the same as the function name.
+ functions = dict((f, f) for f in functions)
+ elif not isinstance(functions, dict):
+ raise ValueError(
+ "Expected 'functions' to be a list or dict, but was {}".format(
+ type(functions)))
+ for function_name, docstring in functions.items():
+ cpp_sources.append('m.def("{0}", &{0}, "{1}");'.format(
+ function_name, docstring))
+ cpp_sources.append('}')
+
+ cpp_source_path = os.path.join(build_directory, 'main.cpp')
+ with open(cpp_source_path, 'w') as cpp_source_file:
+ cpp_source_file.write('\n'.join(cpp_sources))
+
+ sources = [cpp_source_path]
+
+ if cuda_sources:
+ cuda_sources.insert(0, '#include <torch/types.h>')
+ cuda_sources.insert(1, '#include <cuda.h>')
+ cuda_sources.insert(2, '#include <cuda_runtime.h>')
+
+ cuda_source_path = os.path.join(build_directory, 'cuda.cu')
+ with open(cuda_source_path, 'w') as cuda_source_file:
+ cuda_source_file.write('\n'.join(cuda_sources))
+
+ sources.append(cuda_source_path)
+
+ return _jit_compile(
+ name,
+ sources,
+ extra_cflags,
+ extra_cuda_cflags,
+ extra_ldflags,
+ extra_include_paths,
+ build_directory,
+ verbose,
+ with_cuda,
+ is_python_module)
+
+
+def _jit_compile(name,
+ sources,
+ extra_cflags,
+ extra_cuda_cflags,
+ extra_ldflags,
+ extra_include_paths,
+ build_directory,
+ verbose,
+ with_cuda,
+ is_python_module):
+ old_version = JIT_EXTENSION_VERSIONER.get_version(name)
+ version = JIT_EXTENSION_VERSIONER.bump_version_if_changed(
+ name,
+ sources,
+ build_arguments=[extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths],
+ build_directory=build_directory,
+ with_cuda=with_cuda
+ )
+ if version > 0:
+ if version != old_version and verbose:
+ print('The input conditions for extension module {} have changed. '.format(name) +
+ 'Bumping to version {0} and re-building as {1}_v{0}...'.format(version, name))
+ name = '{}_v{}'.format(name, version)
+
+ if version != old_version:
+ baton = FileBaton(os.path.join(build_directory, 'lock'))
+ if baton.try_acquire():
+ try:
+ _write_ninja_file_and_build(
+ name=name,
+ sources=sources,
+ extra_cflags=extra_cflags or [],
+ extra_cuda_cflags=extra_cuda_cflags or [],
+ extra_ldflags=extra_ldflags or [],
+ extra_include_paths=extra_include_paths or [],
+ build_directory=build_directory,
+ verbose=verbose,
+ with_cuda=with_cuda)
+ finally:
+ baton.release()
+ else:
+ baton.wait()
+ elif verbose:
+ print('No modifications detected for re-loaded extension '
+ 'module {}, skipping build step...'.format(name))
+
+ if verbose:
+ print('Loading extension module {}...'.format(name))
+ return _import_module_from_library(name, build_directory, is_python_module)
+
+
+def _write_ninja_file_and_build(name,
+ sources,
+ extra_cflags,
+ extra_cuda_cflags,
+ extra_ldflags,
+ extra_include_paths,
+ build_directory,
+ verbose,
+ with_cuda):
+ verify_ninja_availability()
+ if IS_WINDOWS:
+ compiler = os.environ.get('CXX', 'cl')
+ else:
+ compiler = os.environ.get('CXX', 'c++')
+ check_compiler_abi_compatibility(compiler)
+ if with_cuda is None:
+ with_cuda = any(map(_is_cuda_file, sources))
+ extra_ldflags = _prepare_ldflags(
+ extra_ldflags or [],
+ with_cuda,
+ verbose)
+ build_file_path = os.path.join(build_directory, 'build.ninja')
+ if verbose:
+ print(
+ 'Emitting ninja build file {}...'.format(build_file_path))
+ # NOTE: Emitting a new ninja build file does not cause re-compilation if
+ # the sources did not change, so it's ok to re-emit (and it's fast).
+ _write_ninja_file(
+ path=build_file_path,
+ name=name,
+ sources=sources,
+ extra_cflags=extra_cflags or [],
+ extra_cuda_cflags=extra_cuda_cflags or [],
+ extra_ldflags=extra_ldflags or [],
+ extra_include_paths=extra_include_paths or [],
+ with_cuda=with_cuda)
+
+ if verbose:
+ print('Building extension module {}...'.format(name))
+ _build_extension_module(name, build_directory, verbose)
+
+
+[docs]def verify_ninja_availability():
+ '''
+ Returns ``True`` if the `ninja <https://ninja-build.org/>`_ build system is
+ available on the system.
+ '''
+ with open(os.devnull, 'wb') as devnull:
+ try:
+ subprocess.check_call('ninja --version'.split(), stdout=devnull)
+ except OSError:
+ raise RuntimeError("Ninja is required to load C++ extensions")
+ else:
+ return True
+
+
+def _prepare_ldflags(extra_ldflags, with_cuda, verbose):
+ if IS_WINDOWS:
+ python_path = os.path.dirname(sys.executable)
+ python_lib_path = os.path.join(python_path, 'libs')
+
+ here = os.path.abspath(__file__)
+ torch_path = os.path.dirname(os.path.dirname(here))
+ lib_path = os.path.join(torch_path, 'lib')
+
+ extra_ldflags.append('c10.lib')
+ extra_ldflags.append('caffe2.lib')
+ extra_ldflags.append('torch.lib')
+ extra_ldflags.append('torch_python.lib')
+ if with_cuda:
+ extra_ldflags.append('caffe2_gpu.lib')
+ extra_ldflags.append('_C.lib')
+ extra_ldflags.append('/LIBPATH:{}'.format(python_lib_path))
+ extra_ldflags.append('/LIBPATH:{}'.format(lib_path))
+
+ if with_cuda:
+ if verbose:
+ print('Detected CUDA files, patching ldflags')
+ if IS_WINDOWS:
+ extra_ldflags.append('/LIBPATH:{}'.format(
+ _join_cuda_home('lib/x64')))
+ extra_ldflags.append('cudart.lib')
+ if CUDNN_HOME is not None:
+ extra_ldflags.append(os.path.join(CUDNN_HOME, 'lib/x64'))
+ else:
+ extra_ldflags.append('-L{}'.format(_join_cuda_home('lib64')))
+ extra_ldflags.append('-lcudart')
+ if CUDNN_HOME is not None:
+ extra_ldflags.append('-L{}'.format(os.path.join(CUDNN_HOME, 'lib64')))
+
+ return extra_ldflags
+
+
+def _get_build_directory(name, verbose):
+ root_extensions_directory = os.environ.get('TORCH_EXTENSIONS_DIR')
+ if root_extensions_directory is None:
+ root_extensions_directory = get_default_build_root()
+
+ if verbose:
+ print('Using {} as PyTorch extensions root...'.format(
+ root_extensions_directory))
+
+ build_directory = os.path.join(root_extensions_directory, name)
+ if not os.path.exists(build_directory):
+ if verbose:
+ print('Creating extension directory {}...'.format(build_directory))
+ # This is like mkdir -p, i.e. will also create parent directories.
+ os.makedirs(build_directory)
+
+ return build_directory
+
+
+def _build_extension_module(name, build_directory, verbose):
+ try:
+ sys.stdout.flush()
+ sys.stderr.flush()
+ if sys.version_info >= (3, 5):
+ subprocess.run(
+ ['ninja', '-v'],
+ stdout=None if verbose else subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ cwd=build_directory,
+ check=True)
+ else:
+ subprocess.check_output(
+ ['ninja', '-v'],
+ stderr=subprocess.STDOUT,
+ cwd=build_directory)
+ except subprocess.CalledProcessError:
+ # Python 2 and 3 compatible way of getting the error object.
+ _, error, _ = sys.exc_info()
+ # error.output contains the stdout and stderr of the build attempt.
+ message = "Error building extension '{}'".format(name)
+ if hasattr(error, 'output') and error.output:
+ message += ": {}".format(str(error.output))
+ raise RuntimeError(message)
+
+
+def _import_module_from_library(module_name, path, is_python_module):
+ # https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
+ file, path, description = imp.find_module(module_name, [path])
+ # Close the .so file after load.
+ with file:
+ if is_python_module:
+ return imp.load_module(module_name, file, path, description)
+ else:
+ torch.ops.load_library(path)
+
+
+def _write_ninja_file(path,
+ name,
+ sources,
+ extra_cflags,
+ extra_cuda_cflags,
+ extra_ldflags,
+ extra_include_paths,
+ with_cuda):
+ extra_cflags = [flag.strip() for flag in extra_cflags]
+ extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags]
+ extra_ldflags = [flag.strip() for flag in extra_ldflags]
+ extra_include_paths = [flag.strip() for flag in extra_include_paths]
+
+ if IS_WINDOWS:
+ compiler = os.environ.get('CXX', 'cl')
+ else:
+ compiler = os.environ.get('CXX', 'c++')
+
+ # Version 1.3 is required for the `deps` directive.
+ config = ['ninja_required_version = 1.3']
+ config.append('cxx = {}'.format(compiler))
+ if with_cuda:
+ config.append('nvcc = {}'.format(_join_cuda_home('bin', 'nvcc')))
+
+ # Turn into absolute paths so we can emit them into the ninja build
+ # file wherever it is.
+ sources = [os.path.abspath(file) for file in sources]
+ user_includes = [os.path.abspath(file) for file in extra_include_paths]
+
+ # include_paths() gives us the location of torch/extension.h
+ system_includes = include_paths(with_cuda)
+ # sysconfig.get_paths()['include'] gives us the location of Python.h
+ system_includes.append(sysconfig.get_paths()['include'])
+
+ # Windoze does not understand `-isystem`.
+ if IS_WINDOWS:
+ user_includes += system_includes
+ system_includes.clear()
+
+ common_cflags = ['-DTORCH_EXTENSION_NAME={}'.format(name)]
+ common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H')
+ common_cflags += ['-I{}'.format(include) for include in user_includes]
+ common_cflags += ['-isystem {}'.format(include) for include in system_includes]
+
+ common_cflags += ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
+
+ cflags = common_cflags + ['-fPIC', '-std=c++11'] + extra_cflags
+ if IS_WINDOWS:
+ from distutils.spawn import _nt_quote_args
+ cflags = _nt_quote_args(cflags)
+ flags = ['cflags = {}'.format(' '.join(cflags))]
+
+ if with_cuda:
+ cuda_flags = common_cflags + COMMON_NVCC_FLAGS
+ if IS_WINDOWS:
+ cuda_flags = _nt_quote_args(cuda_flags)
+ cuda_flags += _nt_quote_args(extra_cuda_cflags)
+ else:
+ cuda_flags += ['--compiler-options', "'-fPIC'"]
+ cuda_flags += extra_cuda_cflags
+ if not any(flag.startswith('-std=') for flag in cuda_flags):
+ cuda_flags.append('-std=c++11')
+
+ flags.append('cuda_flags = {}'.format(' '.join(cuda_flags)))
+
+ if IS_WINDOWS:
+ ldflags = ['/DLL'] + extra_ldflags
+ else:
+ ldflags = ['-shared'] + extra_ldflags
+ # The darwin linker needs explicit consent to ignore unresolved symbols.
+ if sys.platform.startswith('darwin'):
+ ldflags.append('-undefined dynamic_lookup')
+ elif IS_WINDOWS:
+ ldflags = _nt_quote_args(ldflags)
+ flags.append('ldflags = {}'.format(' '.join(ldflags)))
+
+ # See https://ninja-build.org/build.ninja.html for reference.
+ compile_rule = ['rule compile']
+ if IS_WINDOWS:
+ compile_rule.append(
+ ' command = cl /showIncludes $cflags -c $in /Fo$out')
+ compile_rule.append(' deps = msvc')
+ else:
+ compile_rule.append(
+ ' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out')
+ compile_rule.append(' depfile = $out.d')
+ compile_rule.append(' deps = gcc')
+
+ if with_cuda:
+ cuda_compile_rule = ['rule cuda_compile']
+ cuda_compile_rule.append(
+ ' command = $nvcc $cuda_flags -c $in -o $out')
+
+ link_rule = ['rule link']
+ if IS_WINDOWS:
+ cl_paths = subprocess.check_output(['where',
+ 'cl']).decode().split('\r\n')
+ if len(cl_paths) >= 1:
+ cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:')
+ else:
+ raise RuntimeError("MSVC is required to load C++ extensions")
+ link_rule.append(
+ ' command = "{}/link.exe" $in /nologo $ldflags /out:$out'.format(
+ cl_path))
+ else:
+ link_rule.append(' command = $cxx $in $ldflags -o $out')
+
+ # Emit one build rule per source to enable incremental build.
+ object_files = []
+ build = []
+ for source_file in sources:
+ # '/path/to/file.cpp' -> 'file'
+ file_name = os.path.splitext(os.path.basename(source_file))[0]
+ if _is_cuda_file(source_file) and with_cuda:
+ rule = 'cuda_compile'
+ # Use a different object filename in case a C++ and CUDA file have
+ # the same filename but different extension (.cpp vs. .cu).
+ target = '{}.cuda.o'.format(file_name)
+ else:
+ rule = 'compile'
+ target = '{}.o'.format(file_name)
+ object_files.append(target)
+ if IS_WINDOWS:
+ source_file = source_file.replace(':', '$:')
+ source_file = source_file.replace(" ", "$ ")
+ build.append('build {}: {} {}'.format(target, rule, source_file))
+
+ ext = 'pyd' if IS_WINDOWS else 'so'
+ library_target = '{}.{}'.format(name, ext)
+
+ link = ['build {}: link {}'.format(library_target, ' '.join(object_files))]
+
+ default = ['default {}'.format(library_target)]
+
+ # 'Blocks' should be separated by newlines, for visual benefit.
+ blocks = [config, flags, compile_rule]
+ if with_cuda:
+ blocks.append(cuda_compile_rule)
+ blocks += [link_rule, build, link, default]
+ with open(path, 'w') as build_file:
+ for block in blocks:
+ lines = '\n'.join(block)
+ build_file.write('{}\n\n'.format(lines))
+
+
+def _join_cuda_home(*paths):
+ '''
+ Joins paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set.
+
+ This is basically a lazy way of raising an error for missing $CUDA_HOME
+ only once we need to get any CUDA-specific path.
+ '''
+ if CUDA_HOME is None:
+ raise EnvironmentError('CUDA_HOME environment variable is not set. '
+ 'Please set it to your CUDA install root.')
+ return os.path.join(CUDA_HOME, *paths)
+
+
+def _is_cuda_file(path):
+ return os.path.splitext(path)[1] in ['.cu', '.cuh']
+
+r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes.
+
+To support these two classes, in `./_utils` we define many utility methods and
+functions to be run in multiprocessing. E.g., the data loading worker loop is
+in `./_utils/worker.py`.
+"""
+
+import torch
+import torch.multiprocessing as multiprocessing
+from . import SequentialSampler, RandomSampler, BatchSampler
+from . import _utils
+import threading
+from torch._six import queue
+
+
+# This function used to be defined in this file. However, it was moved to
+# _utils/collate.py. Although it is rather hard to access this from user land
+# (one has to explicitly directly `import torch.utils.data.dataloader`), there
+# probably is user code out there using it. This aliasing maintains BC in this
+# aspect.
+default_collate = _utils.collate.default_collate
+
+
+[docs]class DataLoader(object):
+ r"""
+ Data loader. Combines a dataset and a sampler, and provides
+ single- or multi-process iterators over the dataset.
+
+ Arguments:
+ dataset (Dataset): dataset from which to load the data.
+ batch_size (int, optional): how many samples per batch to load
+ (default: ``1``).
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
+ at every epoch (default: ``False``).
+ sampler (Sampler, optional): defines the strategy to draw samples from
+ the dataset. If specified, ``shuffle`` must be False.
+ batch_sampler (Sampler, optional): like sampler, but returns a batch of
+ indices at a time. Mutually exclusive with :attr:`batch_size`,
+ :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
+ num_workers (int, optional): how many subprocesses to use for data
+ loading. 0 means that the data will be loaded in the main process.
+ (default: ``0``)
+ collate_fn (callable, optional): merges a list of samples to form a mini-batch.
+ pin_memory (bool, optional): If ``True``, the data loader will copy tensors
+ into CUDA pinned memory before returning them. If your data elements
+ are a custom type, or your ``collate_fn`` returns a batch that is a custom type
+ see the example below.
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
+ if the dataset size is not divisible by the batch size. If ``False`` and
+ the size of dataset is not divisible by the batch size, then the last batch
+ will be smaller. (default: ``False``)
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
+ from workers. Should always be non-negative. (default: ``0``)
+ worker_init_fn (callable, optional): If not ``None``, this will be called on each
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
+ input, after seeding and before data loading. (default: ``None``)
+
+ .. note:: When ``num_workers != 0``, the corresponding worker processes are created each time
+ iterator for the DataLoader is obtained (as in when you call
+ ``enumerate(dataloader,0)``).
+ At this point, the dataset, ``collate_fn`` and ``worker_init_fn`` are passed to each
+ worker, where they are used to access and initialize data based on the indices
+ queued up from the main process. This means that dataset access together with
+ its internal IO, transforms and collation runs in the worker, while any
+ shuffle randomization is done in the main process which guides loading by assigning
+ indices to load. Workers are shut down once the end of the iteration is reached.
+
+ Since workers rely on Python multiprocessing, worker launch behavior is different
+ on Windows compared to Unix. On Unix fork() is used as the default
+ muliprocessing start method, so child workers typically can access the dataset and
+ Python argument functions directly through the cloned address space. On Windows, another
+ interpreter is launched which runs your main script, followed by the internal
+ worker function that receives the dataset, collate_fn and other arguments
+ through Pickle serialization.
+
+ This separate serialization means that you should take two steps to ensure you
+ are compatible with Windows while using workers
+ (this also works equally well on Unix):
+
+ - Wrap most of you main script's code within ``if __name__ == '__main__':`` block,
+ to make sure it doesn't run again (most likely generating error) when each worker
+ process is launched. You can place your dataset and DataLoader instance creation
+ logic here, as it doesn't need to be re-executed in workers.
+ - Make sure that ``collate_fn``, ``worker_init_fn`` or any custom dataset code
+ is declared as a top level def, outside of that ``__main__`` check. This ensures
+ they are available in workers as well
+ (this is needed since functions are pickled as references only, not bytecode).
+
+ By default, each worker will have its PyTorch seed set to
+ ``base_seed + worker_id``, where ``base_seed`` is a long generated
+ by main process using its RNG. However, seeds for other libraies
+ may be duplicated upon initializing workers (w.g., NumPy), causing
+ each worker to return identical random numbers. (See
+ :ref:`dataloader-workers-random-seed` section in FAQ.) You may
+ use :func:`torch.initial_seed()` to access the PyTorch seed for
+ each worker in :attr:`worker_init_fn`, and use it to set other
+ seeds before data loading.
+
+ .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
+ unpicklable object, e.g., a lambda function.
+
+ The default memory pinning logic only recognizes Tensors and maps and iterables
+ containg Tensors. By default, if the pinning logic sees a batch that is a custom type
+ (which will occur if you have a ``collate_fn`` that returns a custom batch type),
+ or if each element of your batch is a custom type, the pinning logic will not
+ recognize them, and it will return that batch (or those elements)
+ without pinning the memory. To enable memory pinning for custom batch or data types,
+ define a ``pin_memory`` method on your custom type(s).
+
+ Example::
+
+ class SimpleCustomBatch:
+ def __init__(self, data):
+ transposed_data = list(zip(*data))
+ self.inp = torch.stack(transposed_data[0], 0)
+ self.tgt = torch.stack(transposed_data[1], 0)
+
+ def pin_memory(self):
+ self.inp = self.inp.pin_memory()
+ self.tgt = self.tgt.pin_memory()
+ return self
+
+ def collate_wrapper(batch):
+ return SimpleCustomBatch(batch)
+
+ inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
+ tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
+ dataset = TensorDataset(inps, tgts)
+
+ loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
+ pin_memory=True)
+
+ for batch_ndx, sample in enumerate(loader):
+ print(sample.inp.is_pinned())
+ print(sample.tgt.is_pinned())
+
+ """
+
+ __initialized = False
+
+ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
+ batch_sampler=None, num_workers=0, collate_fn=default_collate,
+ pin_memory=False, drop_last=False, timeout=0,
+ worker_init_fn=None):
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.collate_fn = collate_fn
+ self.pin_memory = pin_memory
+ self.drop_last = drop_last
+ self.timeout = timeout
+ self.worker_init_fn = worker_init_fn
+
+ if timeout < 0:
+ raise ValueError('timeout option should be non-negative')
+
+ if batch_sampler is not None:
+ if batch_size > 1 or shuffle or sampler is not None or drop_last:
+ raise ValueError('batch_sampler option is mutually exclusive '
+ 'with batch_size, shuffle, sampler, and '
+ 'drop_last')
+ self.batch_size = None
+ self.drop_last = None
+
+ if sampler is not None and shuffle:
+ raise ValueError('sampler option is mutually exclusive with '
+ 'shuffle')
+
+ if self.num_workers < 0:
+ raise ValueError('num_workers option cannot be negative; '
+ 'use num_workers=0 to disable multiprocessing.')
+
+ if batch_sampler is None:
+ if sampler is None:
+ if shuffle:
+ sampler = RandomSampler(dataset)
+ else:
+ sampler = SequentialSampler(dataset)
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
+
+ self.sampler = sampler
+ self.batch_sampler = batch_sampler
+ self.__initialized = True
+
+ def __setattr__(self, attr, val):
+ if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
+ raise ValueError('{} attribute should not be set after {} is '
+ 'initialized'.format(attr, self.__class__.__name__))
+
+ super(DataLoader, self).__setattr__(attr, val)
+
+ def __iter__(self):
+ return _DataLoaderIter(self)
+
+ def __len__(self):
+ return len(self.batch_sampler)
+
+
+class _DataLoaderIter(object):
+ r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
+
+ # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
+ #
+ # Preliminary:
+ #
+ # Our data model looks like this (queues are indicated with curly brackets):
+ #
+ # main process ||
+ # | ||
+ # {index_queue} ||
+ # | ||
+ # worker processes || DATA
+ # | ||
+ # {worker_result_queue} || FLOW
+ # | ||
+ # pin_memory_thread of main process || DIRECTION
+ # | ||
+ # {data_queue} ||
+ # | ||
+ # data output \/
+ #
+ # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
+ # `pin_memory=False`.
+ #
+ #
+ # Terminating multiprocessing logic requires very careful design. In
+ # particular, we need to make sure that
+ #
+ # 1. The iterator gracefully exits the workers when its last reference is
+ # gone or it is depleted.
+ #
+ # In this case, the workers should be gracefully exited because the
+ # main process may still need to continue to run, and we want cleaning
+ # up code in the workers to be executed (e.g., releasing GPU memory).
+ # Naturally, we implement the shutdown logic in `__del__` of
+ # DataLoaderIterator.
+ #
+ # We delay the discussion on the logic in this case until later.
+ #
+ # 2. The iterator exits the workers when the loader process and/or worker
+ # processes exits normally or with error.
+ #
+ # We set all workers and `pin_memory_thread` to have `daemon=True`.
+ #
+ # You may ask, why can't we make the workers non-daemonic, and
+ # gracefully exit using the same logic as we have in `__del__` when the
+ # iterator gets deleted (see 1 above)?
+ #
+ # First of all, `__del__` is **not** guaranteed to be called when
+ # interpreter exits. Even if it is called, by the time it executes,
+ # many Python core library resources may alreay be freed, and even
+ # simple things like acquiring an internal lock of a queue may hang.
+ # Therefore, in this case, we actually need to prevent `__del__` from
+ # being executed, and rely on the automatic termination of daemonic
+ # children. Thus, we register an `atexit` hook that sets a global flag
+ # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
+ # reverse order of registration, we are guaranteed that this flag is
+ # set before library resources we use are freed. (Hooks freeing those
+ # resources are registered at importing the Python core libraries at
+ # the top of this file.) So in `__del__`, we check if
+ # `_utils.python_exit_status` is set or `None` (freed), and perform
+ # no-op if so.
+ #
+ # Another problem with `__del__` is also related to the library cleanup
+ # calls. When a process ends, it shuts the all its daemonic children
+ # down with a SIGTERM (instead of joining them without a timeout).
+ # Simiarly for threads, but by a different mechanism. This fact,
+ # together with a few implementation details of multiprocessing, forces
+ # us to make workers daemonic. All of our problems arise when a
+ # DataLoader is used in a subprocess, and are caused by multiprocessing
+ # code which looks more or less like this:
+ #
+ # try:
+ # your_function_using_a_dataloader()
+ # finally:
+ # multiprocessing.util._exit_function()
+ #
+ # The joining/termination mentioned above happens inside
+ # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
+ # throws, the stack trace stored in the exception will prevent the
+ # frame which uses `DataLoaderIter` to be freed. If the frame has any
+ # reference to the `DataLoaderIter` (e.g., in a method of the iter),
+ # its `__del__`, which starts the shutdown procedure, will not be
+ # called. That, in turn, means that workers aren't notified. Attempting
+ # to join in `_exit_function` will then result in a hang.
+ #
+ # For context, `_exit_function` is also registered as an `atexit` call.
+ # So it is unclear to me (@ssnl) why this is needed in a finally block.
+ # The code dates back to 2008 and there is no comment on the original
+ # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
+ # the finally block and the `atexit` registration) that explains this.
+ #
+ # Another choice is to just shutdown workers with logic in 1 above
+ # whenever we see an error in `next`. This isn't ideal because
+ # a. It prevents users from using try-catch to resume data loading.
+ # b. It doesn't prevent hanging if users have references to the
+ # iterator.
+ #
+ # 3. All processes exit if any of them die unexpectedly by fatal signals.
+ #
+ # As shown above, the workers are set as daemonic children of the main
+ # process. However, automatic cleaning-up of such child processes only
+ # happens if the parent process exits gracefully (e.g., not via fatal
+ # signals like SIGKILL). So we must ensure that each process will exit
+ # even the process that should send/receive data to/from it were
+ # killed, i.e.,
+ #
+ # a. A process won't hang when getting from a queue.
+ #
+ # Even with carefully designed data dependencies (i.e., a `put()`
+ # always corresponding to a `get()`), hanging on `get()` can still
+ # happen when data in queue is corrupted (e.g., due to
+ # `cancel_join_thread` or unexpected exit).
+ #
+ # For child exit, we set a timeout whenever we try to get data
+ # from `data_queue`, and check the workers' status on each timeout
+ # and error.
+ # See `_DataLoaderiter._get_batch()` and
+ # `_DataLoaderiter._try_get_batch()` for details.
+ #
+ # Additionally, for child exit on non-Windows platforms, we also
+ # register a SIGCHLD handler (which is supported on Windows) on
+ # the main process, which checks if any of the workers fail in the
+ # (Python) handler. This is more efficient and faster in detecting
+ # worker failures, compared to only using the above mechanism.
+ # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
+ #
+ # For `.get()` calls where the sender(s) is not the workers, we
+ # guard them with timeouts, and check the status of the sender
+ # when timeout happens:
+ # + in the workers, the `_utils.worker.ManagerWatchdog` class
+ # checks the status of the main process.
+ # + if `pin_memory=True`, when getting from `pin_memory_thread`,
+ # check `pin_memory_thread` status periodically until `.get()`
+ # returns or see that `pin_memory_thread` died.
+ #
+ # b. A process won't hang when putting into a queue;
+ #
+ # We use `mp.Queue` which has a separate background thread to put
+ # objects from an unbounded buffer array. The background thread is
+ # daemonic and usually automatically joined when the process
+ # exits.
+ #
+ # However, in case that the receiver has ended abruptly while
+ # reading from the pipe, the join will hang forever. Therefore,
+ # for both `worker_result_queue` (worker -> main process/pin_memory_thread)
+ # and each `index_queue` (main process -> worker), we use
+ # `q.cancel_join_thread()` in sender process before any `q.put` to
+ # prevent this automatic join.
+ #
+ # Moreover, having all queues called `cancel_join_thread` makes
+ # implementing graceful shutdown logic in `__del__` much easier.
+ # It won't need to get from any queue, which would also need to be
+ # guarded by periodic status checks.
+ #
+ # Note that this may leave corrupted data in the queue, but we
+ # don't care about the data anyways once we are shutting down.
+ #
+ #
+ # Now let's get back to 1:
+ # how we gracefully exit the workers when the last reference to the
+ # iterator is gone.
+ #
+ # To achieve this, we implement the following logic along with the design
+ # choices mentioned above:
+ #
+ # [worker processes]
+ # While loader process is alive:
+ # Get from index_queue.
+ # If got a `None`, exit.
+ # If get anything else,
+ # Check `done_event`.
+ # If set, continue to next iteration
+ # i.e., keep getting until see the `None`, then exit.
+ # Otherwise, process data.
+ # If timed out,
+ # No matter `done_event` is set (still need to see `None`) or not,
+ # must continue to next iteration .
+ #
+ # [pin_memory_thread]
+ # # No need to check main thread. If this thread is alive, the main loader
+ # # thread must be alive, because this thread is set as daemonic.
+ # While True:
+ # Get from index_queue.
+ # If got a `None`, exit.
+ # If get anything else,
+ # Check `done_event`.
+ # If set, continue to next iteration
+ # i.e., keep getting until see the `None`, then exit.
+ # Otherwise, process data.
+ #
+ # NOTE: we don't check the status of the main thread because
+ # 1. if the process is killed by fatal signal, `pin_memory_thread`
+ # ends.
+ # 2. in other cases, either the cleaning-up in __del__ or the
+ # automatic exit of daemonic thread will take care of it.
+ # This won't busy-wait either because `.get(timeout)` does not
+ # busy-wait.
+ #
+ # [main process]
+ # In the DataLoader Iter's `__del__`
+ # a. Set `done_event` (shared with `pin_memory_thread` and workers).
+ #
+ # Note: from here on, the workers & `pin_memory_thread` may exit at
+ # any time after they receive `None`.
+ #
+ # b. Exit `pin_memory_thread`
+ # i. Put `None` in `worker_result_queue`.
+ # ii. Join the `pin_memory_thread`.
+ #
+ # c. Exit the workers.
+ # i. Put `None` in each worker's `index_queue`.
+ # ii. Join the workers.
+ #
+ # NOTE: This has to be after (b) because it may leave corrupted data
+ # in `worker_result_queue`, which `pin_memory_thread` reads
+ # from.
+ #
+ # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
+ # can be omitted
+ #
+ # NB: `done_event`s isn't strictly needed. E.g., we can just check for
+ # `None` from `index_queue`, but it allows us to skip wasting resources
+ # processing indices already in `index_queue` if we are already shutting
+ # down.
+
+ def __init__(self, loader):
+ self.dataset = loader.dataset
+ self.collate_fn = loader.collate_fn
+ self.batch_sampler = loader.batch_sampler
+ self.num_workers = loader.num_workers
+ self.pin_memory = loader.pin_memory and torch.cuda.is_available()
+ self.timeout = loader.timeout
+
+ self.sample_iter = iter(self.batch_sampler)
+
+ base_seed = torch.LongTensor(1).random_().item()
+
+ if self.num_workers > 0:
+ self.worker_init_fn = loader.worker_init_fn
+ self.worker_queue_idx = 0
+ self.worker_result_queue = multiprocessing.Queue()
+ self.batches_outstanding = 0
+ self.worker_pids_set = False
+ self.shutdown = False
+ self.send_idx = 0
+ self.rcvd_idx = 0
+ self.reorder_dict = {}
+ self.done_event = multiprocessing.Event()
+
+ self.index_queues = []
+ self.workers = []
+ for i in range(self.num_workers):
+ index_queue = multiprocessing.Queue()
+ index_queue.cancel_join_thread()
+ w = multiprocessing.Process(
+ target=_utils.worker._worker_loop,
+ args=(self.dataset, index_queue,
+ self.worker_result_queue, self.done_event,
+ self.collate_fn, base_seed + i,
+ self.worker_init_fn, i))
+ w.daemon = True
+ # NB: Process.start() actually take some time as it needs to
+ # start a process and pass the arguments over via a pipe.
+ # Therefore, we only add a worker to self.workers list after
+ # it started, so that we do not call .join() if program dies
+ # before it starts, and __del__ tries to join but will get:
+ # AssertionError: can only join a started process.
+ w.start()
+ self.index_queues.append(index_queue)
+ self.workers.append(w)
+
+ if self.pin_memory:
+ self.data_queue = queue.Queue()
+ pin_memory_thread = threading.Thread(
+ target=_utils.pin_memory._pin_memory_loop,
+ args=(self.worker_result_queue, self.data_queue,
+ torch.cuda.current_device(), self.done_event))
+ pin_memory_thread.daemon = True
+ pin_memory_thread.start()
+ # Similar to workers (see comment above), we only register
+ # pin_memory_thread once it is started.
+ self.pin_memory_thread = pin_memory_thread
+ else:
+ self.data_queue = self.worker_result_queue
+
+ _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
+ _utils.signal_handling._set_SIGCHLD_handler()
+ self.worker_pids_set = True
+
+ # prime the prefetch loop
+ for _ in range(2 * self.num_workers):
+ self._put_indices()
+
+ def __len__(self):
+ return len(self.batch_sampler)
+
+ def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
+ # Tries to fetch data from `data_queue` for a given timeout. This can
+ # also be used as inner loop of fetching without timeout, with the
+ # sender status as the loop condition.
+ #
+ # This raises a `RuntimeError` if any worker died expectedly. This error
+ # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
+ # (only for non-Windows platforms), or the manual check below on errors
+ # and timeouts.
+ #
+ # Returns a 2-tuple:
+ # (bool: whether successfully get data, any: data if successful else None)
+ try:
+ data = self.data_queue.get(timeout=timeout)
+ return (True, data)
+ except Exception as e:
+ # At timeout and error, we manually check whether any worker has
+ # failed. Note that this is the only mechanism for Windows to detect
+ # worker failures.
+ if not all(w.is_alive() for w in self.workers):
+ pids_str = ', '.join(str(w.pid) for w in self.workers if not w.is_alive())
+ raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
+ if isinstance(e, queue.Empty):
+ return (False, None)
+ raise
+
+ def _get_batch(self):
+ # Fetches data from `self.data_queue`.
+ #
+ # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
+ # which we achieve by running `self._try_get_batch(timeout=MP_STATUS_CHECK_INTERVAL)`
+ # in a loop. This is the only mechanism to detect worker failures for
+ # Windows. For other platforms, a SIGCHLD handler is also used for
+ # worker failure detection.
+ #
+ # If `pin_memory=True`, we also need check if `pin_memory_thread` had
+ # died at timeouts.
+ if self.timeout > 0:
+ success, data = self._try_get_batch(self.timeout)
+ if success:
+ return data
+ else:
+ raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
+ elif self.pin_memory:
+ while self.pin_memory_thread.is_alive():
+ success, data = self._try_get_batch()
+ if success:
+ return data
+ else:
+ # while condition is false, i.e., pin_memory_thread died.
+ raise RuntimeError('Pin memory thread exited unexpectedly')
+ # In this case, `self.data_queue` is a `queue.Queue`,. But we don't
+ # need to call `.task_done()` because we don't use `.join()`.
+ else:
+ while True:
+ success, data = self._try_get_batch()
+ if success:
+ return data
+
+ def __next__(self):
+ if self.num_workers == 0: # same-process loading
+ indices = next(self.sample_iter) # may raise StopIteration
+ batch = self.collate_fn([self.dataset[i] for i in indices])
+ if self.pin_memory:
+ batch = _utils.pin_memory.pin_memory_batch(batch)
+ return batch
+
+ # check if the next sample has already been generated
+ if self.rcvd_idx in self.reorder_dict:
+ batch = self.reorder_dict.pop(self.rcvd_idx)
+ return self._process_next_batch(batch)
+
+ if self.batches_outstanding == 0:
+ self._shutdown_workers()
+ raise StopIteration
+
+ while True:
+ assert (not self.shutdown and self.batches_outstanding > 0)
+ idx, batch = self._get_batch()
+ self.batches_outstanding -= 1
+ if idx != self.rcvd_idx:
+ # store out-of-order samples
+ self.reorder_dict[idx] = batch
+ continue
+ return self._process_next_batch(batch)
+
+ next = __next__ # Python 2 compatibility
+
+ def __iter__(self):
+ return self
+
+ def _put_indices(self):
+ assert self.batches_outstanding < 2 * self.num_workers
+ indices = next(self.sample_iter, None)
+ if indices is None:
+ return
+ self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
+ self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
+ self.batches_outstanding += 1
+ self.send_idx += 1
+
+ def _process_next_batch(self, batch):
+ self.rcvd_idx += 1
+ self._put_indices()
+ if isinstance(batch, _utils.ExceptionWrapper):
+ # make multiline KeyError msg readable by working around
+ # a python bug https://bugs.python.org/issue2651
+ if batch.exc_type == KeyError and "\n" in batch.exc_msg:
+ raise Exception("KeyError:" + batch.exc_msg)
+ else:
+ raise batch.exc_type(batch.exc_msg)
+ return batch
+
+ def __getstate__(self):
+ # TODO: add limited pickling support for sharing an iterator
+ # across multiple threads for HOGWILD.
+ # Probably the best way to do this is by moving the sample pushing
+ # to a separate thread and then just sharing the data queue
+ # but signalling the end is tricky without a non-blocking API
+ raise NotImplementedError("_DataLoaderIter cannot be pickled")
+
+ def _shutdown_workers(self):
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
+ # the logic of this function.
+ python_exit_status = _utils.python_exit_status
+ if python_exit_status is True or python_exit_status is None:
+ # See (2) of the note. If Python is shutting down, do no-op.
+ return
+ # Normal exit when last reference is gone / iterator is depleted.
+ # See (1) and the second half of the note.
+ if not self.shutdown:
+ self.shutdown = True
+ try:
+ self.done_event.set()
+
+ # Exit `pin_memory_thread` first because exiting workers may leave
+ # corrupted data in `worker_result_queue` which `pin_memory_thread`
+ # reads from.
+ if hasattr(self, 'pin_memory_thread'):
+ # Use hasattr in case error happens before we set the attribute.
+ # First time do `worker_result_queue.put` in this process.
+
+ # `cancel_join_thread` in case that `pin_memory_thread` exited.
+ self.worker_result_queue.cancel_join_thread()
+ self.worker_result_queue.put(None)
+ self.pin_memory_thread.join()
+ # Indicate that no more data will be put on this queue by the
+ # current process. This **must** be called after
+ # `pin_memory_thread` is joined because that thread shares the
+ # same pipe handles with this loader thread. If the handle is
+ # closed, Py3 will error in this case, but Py2 will just time
+ # out even if there is data in the queue.
+ self.worker_result_queue.close()
+
+ # Exit workers now.
+ for q in self.index_queues:
+ q.put(None)
+ # Indicate that no more data will be put on this queue by the
+ # current process.
+ q.close()
+ for w in self.workers:
+ w.join()
+ finally:
+ # Even though all this function does is putting into queues that
+ # we have called `cancel_join_thread` on, weird things can
+ # happen when a worker is killed by a signal, e.g., hanging in
+ # `Event.set()`. So we need to guard this with SIGCHLD handler,
+ # and remove pids from the C side data structure only at the
+ # end.
+ #
+ # FIXME: Unfortunately, for Windows, we are missing a worker
+ # error detection mechanism here in this function, as it
+ # doesn't provide a SIGCHLD handler.
+ if self.worker_pids_set:
+ _utils.signal_handling._remove_worker_pids(id(self))
+ self.worker_pids_set = False
+
+ def __del__(self):
+ if self.num_workers > 0:
+ self._shutdown_workers()
+
+import bisect
+import warnings
+
+from torch._utils import _accumulate
+from torch import randperm
+
+
+[docs]class Dataset(object):
+ """An abstract class representing a Dataset.
+
+ All other datasets should subclass it. All subclasses should override
+ ``__len__``, that provides the size of the dataset, and ``__getitem__``,
+ supporting integer indexing in range from 0 to len(self) exclusive.
+ """
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+ def __add__(self, other):
+ return ConcatDataset([self, other])
+
+
+[docs]class TensorDataset(Dataset):
+ """Dataset wrapping tensors.
+
+ Each sample will be retrieved by indexing tensors along the first dimension.
+
+ Arguments:
+ *tensors (Tensor): tensors that have the same size of the first dimension.
+ """
+
+ def __init__(self, *tensors):
+ assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
+ self.tensors = tensors
+
+ def __getitem__(self, index):
+ return tuple(tensor[index] for tensor in self.tensors)
+
+ def __len__(self):
+ return self.tensors[0].size(0)
+
+
+[docs]class ConcatDataset(Dataset):
+ """
+ Dataset to concatenate multiple datasets.
+ Purpose: useful to assemble different existing datasets, possibly
+ large-scale datasets as the concatenation operation is done in an
+ on-the-fly manner.
+
+ Arguments:
+ datasets (sequence): List of datasets to be concatenated
+ """
+
+ @staticmethod
+ def cumsum(sequence):
+ r, s = [], 0
+ for e in sequence:
+ l = len(e)
+ r.append(l + s)
+ s += l
+ return r
+
+ def __init__(self, datasets):
+ super(ConcatDataset, self).__init__()
+ assert len(datasets) > 0, 'datasets should not be an empty iterable'
+ self.datasets = list(datasets)
+ self.cumulative_sizes = self.cumsum(self.datasets)
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError("absolute value of index should not exceed dataset length")
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx][sample_idx]
+
+ @property
+ def cummulative_sizes(self):
+ warnings.warn("cummulative_sizes attribute is renamed to "
+ "cumulative_sizes", DeprecationWarning, stacklevel=2)
+ return self.cumulative_sizes
+
+
+[docs]class Subset(Dataset):
+ """
+ Subset of a dataset at specified indices.
+
+ Arguments:
+ dataset (Dataset): The whole Dataset
+ indices (sequence): Indices in the whole set selected for subset
+ """
+ def __init__(self, dataset, indices):
+ self.dataset = dataset
+ self.indices = indices
+
+ def __getitem__(self, idx):
+ return self.dataset[self.indices[idx]]
+
+ def __len__(self):
+ return len(self.indices)
+
+
+[docs]def random_split(dataset, lengths):
+ """
+ Randomly split a dataset into non-overlapping new datasets of given lengths.
+
+ Arguments:
+ dataset (Dataset): Dataset to be split
+ lengths (sequence): lengths of splits to be produced
+ """
+ if sum(lengths) != len(dataset):
+ raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
+
+ indices = randperm(sum(lengths)).tolist()
+ return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
+
+import math
+import torch
+from . import Sampler
+import torch.distributed as dist
+
+
+[docs]class DistributedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+
+ .. note::
+ Dataset is assumed to be of constant size.
+
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+
+ # add extra samples to make it evenly divisible
+ indices += indices[:(self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+import torch
+from torch._six import int_classes as _int_classes
+
+
+[docs]class Sampler(object):
+ r"""Base class for all Samplers.
+
+ Every Sampler subclass has to provide an __iter__ method, providing a way
+ to iterate over indices of dataset elements, and a __len__ method that
+ returns the length of the returned iterators.
+ """
+
+ def __init__(self, data_source):
+ pass
+
+ def __iter__(self):
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+
+[docs]class SequentialSampler(Sampler):
+ r"""Samples elements sequentially, always in the same order.
+
+ Arguments:
+ data_source (Dataset): dataset to sample from
+ """
+
+ def __init__(self, data_source):
+ self.data_source = data_source
+
+ def __iter__(self):
+ return iter(range(len(self.data_source)))
+
+ def __len__(self):
+ return len(self.data_source)
+
+
+[docs]class RandomSampler(Sampler):
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
+ If with replacement, then user can specify ``num_samples`` to draw.
+
+ Arguments:
+ data_source (Dataset): dataset to sample from
+ replacement (bool): samples are drawn with replacement if ``True``, default=``False``
+ num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
+ is supposed to be specified only when `replacement` is ``True``.
+ """
+
+ def __init__(self, data_source, replacement=False, num_samples=None):
+ self.data_source = data_source
+ self.replacement = replacement
+ self._num_samples = num_samples
+
+ if not isinstance(self.replacement, bool):
+ raise ValueError("replacement should be a boolean value, but got "
+ "replacement={}".format(self.replacement))
+
+ if self._num_samples is not None and not replacement:
+ raise ValueError("With replacement=False, num_samples should not be specified, "
+ "since a random permute will be performed.")
+
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
+ raise ValueError("num_samples should be a positive integer "
+ "value, but got num_samples={}".format(self.num_samples))
+
+ @property
+ def num_samples(self):
+ # dataset size might change at runtime
+ if self._num_samples is None:
+ return len(self.data_source)
+ return self._num_samples
+
+ def __iter__(self):
+ n = len(self.data_source)
+ if self.replacement:
+ return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
+ return iter(torch.randperm(n).tolist())
+
+ def __len__(self):
+ return self.num_samples
+
+
+[docs]class SubsetRandomSampler(Sampler):
+ r"""Samples elements randomly from a given list of indices, without replacement.
+
+ Arguments:
+ indices (sequence): a sequence of indices
+ """
+
+ def __init__(self, indices):
+ self.indices = indices
+
+ def __iter__(self):
+ return (self.indices[i] for i in torch.randperm(len(self.indices)))
+
+ def __len__(self):
+ return len(self.indices)
+
+
+[docs]class WeightedRandomSampler(Sampler):
+ r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
+
+ Args:
+ weights (sequence) : a sequence of weights, not necessary summing up to one
+ num_samples (int): number of samples to draw
+ replacement (bool): if ``True``, samples are drawn with replacement.
+ If not, they are drawn without replacement, which means that when a
+ sample index is drawn for a row, it cannot be drawn again for that row.
+
+ Example:
+ >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
+ [0, 0, 0, 1, 0]
+ >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
+ [0, 1, 4, 3, 2]
+ """
+
+ def __init__(self, weights, num_samples, replacement=True):
+ if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
+ num_samples <= 0:
+ raise ValueError("num_samples should be a positive integer "
+ "value, but got num_samples={}".format(num_samples))
+ if not isinstance(replacement, bool):
+ raise ValueError("replacement should be a boolean value, but got "
+ "replacement={}".format(replacement))
+ self.weights = torch.as_tensor(weights, dtype=torch.double)
+ self.num_samples = num_samples
+ self.replacement = replacement
+
+ def __iter__(self):
+ return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
+
+ def __len__(self):
+ return self.num_samples
+
+
+[docs]class BatchSampler(Sampler):
+ r"""Wraps another sampler to yield a mini-batch of indices.
+
+ Args:
+ sampler (Sampler): Base sampler.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``
+
+ Example:
+ >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
+ >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
+ """
+
+ def __init__(self, sampler, batch_size, drop_last):
+ if not isinstance(sampler, Sampler):
+ raise ValueError("sampler should be an instance of "
+ "torch.utils.data.Sampler, but got sampler={}"
+ .format(sampler))
+ if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
+ batch_size <= 0:
+ raise ValueError("batch_size should be a positive integer value, "
+ "but got batch_size={}".format(batch_size))
+ if not isinstance(drop_last, bool):
+ raise ValueError("drop_last should be a boolean value, but got "
+ "drop_last={}".format(drop_last))
+ self.sampler = sampler
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ def __iter__(self):
+ batch = []
+ for idx in self.sampler:
+ batch.append(idx)
+ if len(batch) == self.batch_size:
+ yield batch
+ batch = []
+ if len(batch) > 0 and not self.drop_last:
+ yield batch
+
+ def __len__(self):
+ if self.drop_last:
+ return len(self.sampler) // self.batch_size
+ else:
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size
+
+"""Provides an API for writing protocol buffers to event files to be
+consumed by TensorBoard for visualization."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+import six
+import time
+
+from tensorboard.compat.proto.event_pb2 import SessionLog
+from tensorboard.compat.proto.event_pb2 import Event
+from tensorboard.compat.proto import event_pb2
+from tensorboard.summary.writer.event_file_writer import EventFileWriter
+
+from ._convert_np import make_np
+from ._embedding import make_mat, make_sprite, make_tsv, append_pbtxt
+from ._onnx_graph import load_onnx_graph
+from ._pytorch_graph import graph
+from ._utils import figure_to_image
+from .summary import (
+ scalar, histogram, histogram_raw, image, audio, text,
+ pr_curve, pr_curve_raw, video, custom_scalars, image_boxes
+)
+
+
+class FileWriter(object):
+ """Writes protocol buffers to event files to be consumed by TensorBoard.
+
+ The `FileWriter` class provides a mechanism to create an event file in a
+ given directory and add summaries and events to it. The class updates the
+ file contents asynchronously. This allows a training program to call methods
+ to add data to the file directly from the training loop, without slowing down
+ training.
+ """
+
+ def __init__(self,
+ logdir,
+ max_queue=10,
+ flush_secs=120,
+ filename_suffix=''):
+ """Creates a `FileWriter` and an event file.
+ On construction the writer creates a new event file in `logdir`.
+ The other arguments to the constructor control the asynchronous writes to
+ the event file.
+
+ Args:
+ logdir: A string. Directory where event file will be written.
+ max_queue: Integer. Size of the queue for pending events and
+ summaries before one of the 'add' calls forces a flush to disk.
+ flush_secs: Number. How often, in seconds, to flush the
+ pending events and summaries to disk.
+ filename_suffix: A string. Suffix added to all event filenames.
+ More details on event filename construction in
+ tensorboard.summary.writer.event_file_writer.EventFileWriter.
+ """
+ # Sometimes PosixPath is passed in and we need to coerce it to
+ # a string in all cases
+ # TODO: See if we can remove this in the future if we are
+ # actually the ones passing in a PosixPath
+ logdir = str(logdir)
+ self.event_writer = EventFileWriter(
+ logdir, max_queue, flush_secs, filename_suffix)
+
+ def get_logdir(self):
+ """Returns the directory where event file will be written."""
+ return self.event_writer.get_logdir()
+
+ def add_event(self, event, step=None, walltime=None):
+ """Adds an event to the event file.
+ Args:
+ event: An `Event` protocol buffer.
+ step: Number. Optional global step value for training process
+ to record with the event.
+ walltime: float. Optional walltime to override the default (current)
+ walltime (from time.time()) seconds after epoch
+ """
+ event.wall_time = time.time() if walltime is None else walltime
+ if step is not None:
+ # Make sure step is converted from numpy or other formats
+ # since protobuf might not convert depending on version
+ event.step = int(step)
+ self.event_writer.add_event(event)
+
+ def add_summary(self, summary, global_step=None, walltime=None):
+ """Adds a `Summary` protocol buffer to the event file.
+ This method wraps the provided summary in an `Event` protocol buffer
+ and adds it to the event file.
+
+ Args:
+ summary: A `Summary` protocol buffer.
+ global_step: Number. Optional global step value for training process
+ to record with the summary.
+ walltime: float. Optional walltime to override the default (current)
+ walltime (from time.time()) seconds after epoch
+ """
+ event = event_pb2.Event(summary=summary)
+ self.add_event(event, global_step, walltime)
+
+ def add_graph(self, graph_profile, walltime=None):
+ """Adds a `Graph` and step stats protocol buffer to the event file.
+
+ Args:
+ graph_profile: A `Graph` and step stats protocol buffer.
+ walltime: float. Optional walltime to override the default (current)
+ walltime (from time.time()) seconds after epoch
+ """
+ graph = graph_profile[0]
+ stepstats = graph_profile[1]
+ event = event_pb2.Event(graph_def=graph.SerializeToString())
+ self.add_event(event, None, walltime)
+
+ trm = event_pb2.TaggedRunMetadata(
+ tag='step1', run_metadata=stepstats.SerializeToString())
+ event = event_pb2.Event(tagged_run_metadata=trm)
+ self.add_event(event, None, walltime)
+
+ def add_onnx_graph(self, graph, walltime=None):
+ """Adds a `Graph` protocol buffer to the event file.
+
+ Args:
+ graph: A `Graph` protocol buffer.
+ walltime: float. Optional walltime to override the default (current)
+ _get_file_writerfrom time.time())
+ """
+ event = event_pb2.Event(graph_def=graph.SerializeToString())
+ self.add_event(event, None, walltime)
+
+ def flush(self):
+ """Flushes the event file to disk.
+ Call this method to make sure that all pending events have been written to
+ disk.
+ """
+ self.event_writer.flush()
+
+ def close(self):
+ """Flushes the event file to disk and close the file.
+ Call this method when you do not need the summary writer anymore.
+ """
+ self.event_writer.close()
+
+ def reopen(self):
+ """Reopens the EventFileWriter.
+ Can be called after `close()` to add more events in the same directory.
+ The events will go into a new events file.
+ Does nothing if the EventFileWriter was not closed.
+ """
+ self.event_writer.reopen()
+
+
+[docs]class SummaryWriter(object):
+ """Writes entries directly to event files in the log_dir to be
+ consumed by TensorBoard.
+
+ The `SummaryWriter` class provides a high-level API to create an event file
+ in a given directory and add summaries and events to it. The class updates the
+ file contents asynchronously. This allows a training program to call methods
+ to add data to the file directly from the training loop, without slowing down
+ training.
+ """
+
+ def __init__(self, log_dir=None, comment='', **kwargs):
+ """Creates a `SummaryWriter` that will write out events and summaries
+ to the event file.
+
+ Args:
+ log_dir (string): save location, default is: runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each
+ run. Use hierarchical folder structure to compare between runs easily. e.g. pass in
+ 'runs/exp1', 'runs/exp2', etc. for each new experiment to compare across. Defaults
+ to ``./runs/``.
+ comment (string): comment that appends to the default ``log_dir``. If ``log_dir`` is assigned,
+ this argument will no effect.
+ purge_step (int):
+ When logging crashes at step :math:`T+X` and restarts at step :math:`T`, any events
+ whose global_step larger or equal to :math:`T` will be purged and hidden from TensorBoard.
+ Note that the resumed experiment and crashed experiment should have the same ``log_dir``.
+ filename_suffix (string):
+ Every event file's name is suffixed with suffix. Example: ``SummaryWriter(filename_suffix='.123')``
+ More details on event filename construction in
+ tensorboard.summary.writer.event_file_writer.EventFileWriter.
+ kwargs: extra keyword arguments for FileWriter (e.g. 'flush_secs'
+ controls how often to flush pending events). For more arguments
+ please refer to docs for 'tf.summary.FileWriter'.
+ """
+ if not log_dir:
+ import socket
+ from datetime import datetime
+ current_time = datetime.now().strftime('%b%d_%H-%M-%S')
+ log_dir = os.path.join(
+ 'runs', current_time + '_' + socket.gethostname() + comment)
+ self.log_dir = log_dir
+ self.kwargs = kwargs
+
+ # Initialize the file writers, but they can be cleared out on close
+ # and recreated later as needed.
+ self.file_writer = self.all_writers = None
+ self._get_file_writer()
+
+ # Create default bins for histograms, see generate_testdata.py in tensorflow/tensorboard
+ v = 1E-12
+ buckets = []
+ neg_buckets = []
+ while v < 1E20:
+ buckets.append(v)
+ neg_buckets.append(-v)
+ v *= 1.1
+ self.default_bins = neg_buckets[::-1] + [0] + buckets
+ self.scalar_dict = {}
+
+ def _append_to_scalar_dict(self, tag, scalar_value, global_step,
+ timestamp):
+ """This adds an entry to the self.scalar_dict datastructure with format
+ {writer_id : [[timestamp, step, value], ...], ...}.
+ """
+ if tag not in self.scalar_dict.keys():
+ self.scalar_dict[tag] = []
+ self.scalar_dict[tag].append(
+ [timestamp, global_step, float(make_np(scalar_value))])
+
+ def _check_caffe2_blob(self, item):
+ """
+ Caffe2 users have the option of passing a string representing the name of
+ a blob in the workspace instead of passing the actual Tensor/array containing
+ the numeric values. Thus, we need to check if we received a string as input
+ instead of an actual Tensor/array, and if so, we need to fetch the Blob
+ from the workspace corresponding to that name. Fetching can be done with the
+ following:
+
+ from caffe2.python import workspace (if not already imported)
+ workspace.FetchBlob(blob_name)
+ workspace.FetchBlobs([blob_name1, blob_name2, ...])
+ """
+ return isinstance(item, six.string_types)
+
+ def _get_file_writer(self):
+ """Returns the default FileWriter instance. Recreates it if closed."""
+ if self.all_writers is None or self.file_writer is None:
+ if 'purge_step' in self.kwargs.keys():
+ most_recent_step = self.kwargs.pop('purge_step')
+ self.file_writer = FileWriter(logdir=self.log_dir, **self.kwargs)
+ self.file_writer.add_event(
+ Event(step=most_recent_step, file_version='brain.Event:2'))
+ self.file_writer.add_event(
+ Event(step=most_recent_step, session_log=SessionLog(status=SessionLog.START)))
+ else:
+ self.file_writer = FileWriter(logdir=self.log_dir, **self.kwargs)
+ self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
+ return self.file_writer
+
+[docs] def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
+ """Add scalar data to summary.
+
+ Args:
+ tag (string): Data identifier
+ scalar_value (float or string/blobname): Value to save
+ global_step (int): Global step value to record
+ walltime (float): Optional override default walltime (time.time())
+ with seconds after epoch of event
+ """
+ if self._check_caffe2_blob(scalar_value):
+ scalar_value = workspace.FetchBlob(scalar_value)
+ self._get_file_writer().add_summary(
+ scalar(tag, scalar_value), global_step, walltime)
+
+ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
+ """Adds many scalar data to summary.
+
+ Note that this function also keeps logged scalars in memory. In extreme case it explodes your RAM.
+
+ Args:
+ main_tag (string): The parent name for the tags
+ tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
+ global_step (int): Global step value to record
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+
+ Examples::
+
+ writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
+ 'xcosx':i*np.cos(i/r),
+ 'arctanx': numsteps*np.arctan(i/r)}, i)
+ # This call adds three values to the same scalar plot with the tag
+ # 'run_14h' in TensorBoard's scalar section.
+ """
+ walltime = time.time() if walltime is None else walltime
+ fw_logdir = self._get_file_writer().get_logdir()
+ for tag, scalar_value in tag_scalar_dict.items():
+ fw_tag = fw_logdir + "/" + main_tag + "/" + tag
+ if fw_tag in self.all_writers.keys():
+ fw = self.all_writers[fw_tag]
+ else:
+ fw = FileWriter(logdir=fw_tag)
+ self.all_writers[fw_tag] = fw
+ if self._check_caffe2_blob(scalar_value):
+ scalar_value = workspace.FetchBlob(scalar_value)
+ fw.add_summary(scalar(main_tag, scalar_value),
+ global_step, walltime)
+ self._append_to_scalar_dict(
+ fw_tag, scalar_value, global_step, walltime)
+
+ def export_scalars_to_json(self, path):
+ """Exports to the given path an ASCII file containing all the scalars written
+ so far by this instance, with the following format:
+ {writer_id : [[timestamp, step, value], ...], ...}
+
+ The scalars saved by ``add_scalars()`` will be flushed after export.
+ """
+ with open(path, "w") as f:
+ json.dump(self.scalar_dict, f)
+ self.scalar_dict = {}
+
+[docs] def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None):
+ """Add histogram to summary.
+
+ Args:
+ tag (string): Data identifier
+ values (torch.Tensor, numpy.array, or string/blobname): Values to build histogram
+ global_step (int): Global step value to record
+ bins (string): one of {'tensorflow','auto', 'fd', ...}, this determines how the bins are made. You can find
+ other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+ """
+ if self._check_caffe2_blob(values):
+ values = workspace.FetchBlob(values)
+ if isinstance(bins, six.string_types) and bins == 'tensorflow':
+ bins = self.default_bins
+ self._get_file_writer().add_summary(
+ histogram(tag, values, bins, max_bins=max_bins), global_step, walltime)
+
+ def add_histogram_raw(self, tag, min, max, num, sum, sum_squares,
+ bucket_limits, bucket_counts, global_step=None,
+ walltime=None):
+ """Adds histogram with raw data.
+
+ Args:
+ tag (string): Data identifier
+ min (float or int): Min value
+ max (float or int): Max value
+ num (int): Number of values
+ sum (float or int): Sum of all values
+ sum_squares (float or int): Sum of squares for all values
+ bucket_limits (torch.Tensor, numpy.array): Upper value per bucket
+ bucket_counts (torch.Tensor, numpy.array): Number of values per bucket
+ global_step (int): Global step value to record
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+ see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/README.md
+ """
+ self._get_file_writer().add_summary(
+ histogram_raw(tag,
+ min,
+ max,
+ num,
+ sum,
+ sum_squares,
+ bucket_limits,
+ bucket_counts),
+ global_step,
+ walltime)
+
+[docs] def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
+ """Add image data to summary.
+
+ Note that this requires the ``pillow`` package.
+
+ Args:
+ tag (string): Data identifier
+ img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
+ global_step (int): Global step value to record
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+ Shape:
+ img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
+ convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
+ Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitible as long as
+ corresponding ``dataformats`` argument is passed. e.g. CHW, HWC, HW.
+ """
+ if self._check_caffe2_blob(img_tensor):
+ img_tensor = workspace.FetchBlob(img_tensor)
+ self._get_file_writer().add_summary(
+ image(tag, img_tensor, dataformats=dataformats), global_step, walltime)
+
+ def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'):
+ """Add batched image data to summary.
+
+ Note that this requires the ``pillow`` package.
+
+ Args:
+ tag (string): Data identifier
+ img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
+ global_step (int): Global step value to record
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+ Shape:
+ img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be
+ accepted. e.g. NCHW or NHWC.
+ """
+ if self._check_caffe2_blob(img_tensor):
+ img_tensor = workspace.FetchBlob(img_tensor)
+ self._get_file_writer().add_summary(
+ image(tag, img_tensor, dataformats=dataformats), global_step, walltime)
+
+ def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None,
+ walltime=None, dataformats='CHW', **kwargs):
+ """Add image and draw bounding boxes on the image.
+
+ Args:
+ tag (string): Data identifier
+ img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
+ box_tensor (torch.Tensor, numpy.array, or string/blobname): Box data (for detected objects)
+ global_step (int): Global step value to record
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+ Shape:
+ img_tensor: Default is :math:`(3, H, W)`. It can be specified with ``dataformat`` agrument.
+ e.g. CHW or HWC
+
+ box_tensor: (torch.Tensor, numpy.array, or string/blobname): NX4, where N is the number of
+ boxes and each 4 elememts in a row represents (xmin, ymin, xmax, ymax).
+ """
+ if self._check_caffe2_blob(img_tensor):
+ img_tensor = workspace.FetchBlob(img_tensor)
+ if self._check_caffe2_blob(box_tensor):
+ box_tensor = workspace.FetchBlob(box_tensor)
+ self._get_file_writer().add_summary(image_boxes(
+ tag, img_tensor, box_tensor, dataformats=dataformats, **kwargs), global_step, walltime)
+
+[docs] def add_figure(self, tag, figure, global_step=None, close=True, walltime=None):
+ """Render matplotlib figure into an image and add it to summary.
+
+ Note that this requires the ``matplotlib`` package.
+
+ Args:
+ tag (string): Data identifier
+ figure (matplotlib.pyplot.figure) or list of figures: figure or a list of figures
+ global_step (int): Global step value to record
+ close (bool): Flag to automatically close the figure
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+ """
+ if isinstance(figure, list):
+ self.add_image(tag, figure_to_image(figure, close), global_step, walltime, dataformats='NCHW')
+ else:
+ self.add_image(tag, figure_to_image(figure, close), global_step, walltime, dataformats='CHW')
+
+[docs] def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):
+ """Add video data to summary.
+
+ Note that this requires the ``moviepy`` package.
+
+ Args:
+ tag (string): Data identifier
+ vid_tensor (torch.Tensor): Video data
+ global_step (int): Global step value to record
+ fps (float or int): Frames per second
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+ Shape:
+ vid_tensor: :math:`(N, T, C, H, W)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`.
+ """
+ self._get_file_writer().add_summary(
+ video(tag, vid_tensor, fps), global_step, walltime)
+
+[docs] def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None):
+ """Add audio data to summary.
+
+ Args:
+ tag (string): Data identifier
+ snd_tensor (torch.Tensor): Sound data
+ global_step (int): Global step value to record
+ sample_rate (int): sample rate in Hz
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+ Shape:
+ snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1].
+ """
+ if self._check_caffe2_blob(snd_tensor):
+ snd_tensor = workspace.FetchBlob(snd_tensor)
+ self._get_file_writer().add_summary(
+ audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime)
+
+[docs] def add_text(self, tag, text_string, global_step=None, walltime=None):
+ """Add text data to summary.
+
+ Args:
+ tag (string): Data identifier
+ text_string (string): String to save
+ global_step (int): Global step value to record
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+ Examples::
+
+ writer.add_text('lstm', 'This is an lstm', 0)
+ writer.add_text('rnn', 'This is an rnn', 10)
+ """
+ self._get_file_writer().add_summary(
+ text(tag, text_string), global_step, walltime)
+
+ def add_onnx_graph(self, prototxt):
+ self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt))
+
+[docs] def add_graph(self, model, input_to_model=None, verbose=False, **kwargs):
+ # prohibit second call?
+ # no, let tensorboard handle it and show its warning message.
+ """Add graph data to summary.
+
+ Args:
+ model (torch.nn.Module): model to draw.
+ input_to_model (torch.Tensor or list of torch.Tensor): a variable or a tuple of
+ variables to be fed.
+ verbose (bool): Whether to print graph structure in console.
+ omit_useless_nodes (bool): Default to ``true``, which eliminates unused nodes.
+ operator_export_type (string): One of: ``"ONNX"``, ``"RAW"``. This determines
+ the optimization level of the graph. If error happens during exporting
+ the graph, use ``"RAW"`` may help.
+
+ """
+ if hasattr(model, 'forward'):
+ # A valid PyTorch model should have a 'forward' method
+ import torch
+ from distutils.version import LooseVersion
+ if LooseVersion(torch.__version__) >= LooseVersion("0.3.1"):
+ pass
+ else:
+ if LooseVersion(torch.__version__) >= LooseVersion("0.3.0"):
+ print('You are using PyTorch==0.3.0, use add_onnx_graph()')
+ return
+ if not hasattr(torch.autograd.Variable, 'grad_fn'):
+ print('add_graph() only supports PyTorch v0.2.')
+ return
+ self._get_file_writer().add_graph(graph(model, input_to_model, verbose, **kwargs))
+ else:
+ # Caffe2 models do not have the 'forward' method
+ from caffe2.proto import caffe2_pb2
+ from caffe2.python import core
+ from ._caffe2_graph import (
+ model_to_graph_def, nets_to_graph_def, protos_to_graph_def
+ )
+ if isinstance(model, list):
+ if isinstance(model[0], core.Net):
+ current_graph = nets_to_graph_def(
+ model, **kwargs)
+ elif isinstance(model[0], caffe2_pb2.NetDef):
+ current_graph = protos_to_graph_def(
+ model, **kwargs)
+ else:
+ # Handles cnn.CNNModelHelper, model_helper.ModelHelper
+ current_graph = model_to_graph_def(
+ model, **kwargs)
+ event = event_pb2.Event(
+ graph_def=current_graph.SerializeToString())
+ self._get_file_writer().add_event(event)
+
+ @staticmethod
+ def _encode(rawstr):
+ # I'd use urllib but, I'm unsure about the differences from python3 to python2, etc.
+ retval = rawstr
+ retval = retval.replace("%", "%%%02x" % (ord("%")))
+ retval = retval.replace("/", "%%%02x" % (ord("/")))
+ retval = retval.replace("\\", "%%%02x" % (ord("\\")))
+ return retval
+
+[docs] def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None):
+ """Add embedding projector data to summary.
+
+ Args:
+ mat (torch.Tensor or numpy.array): A matrix which each row is the feature vector of the data point
+ metadata (list): A list of labels, each element will be convert to string
+ label_img (torch.Tensor): Images correspond to each data point
+ global_step (int): Global step value to record
+ tag (string): Name for the embedding
+ Shape:
+ mat: :math:`(N, D)`, where N is number of data and D is feature dimension
+
+ label_img: :math:`(N, C, H, W)`
+
+ Examples::
+
+ import keyword
+ import torch
+ meta = []
+ while len(meta)<100:
+ meta = meta+keyword.kwlist # get some strings
+ meta = meta[:100]
+
+ for i, v in enumerate(meta):
+ meta[i] = v+str(i)
+
+ label_img = torch.rand(100, 3, 10, 32)
+ for i in range(100):
+ label_img[i]*=i/100.0
+
+ writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img)
+ writer.add_embedding(torch.randn(100, 5), label_img=label_img)
+ writer.add_embedding(torch.randn(100, 5), metadata=meta)
+ """
+ mat = make_np(mat)
+ if global_step is None:
+ global_step = 0
+ # clear pbtxt?
+ # Maybe we should encode the tag so slashes don't trip us up?
+ # I don't think this will mess us up, but better safe than sorry.
+ subdir = "%s/%s" % (str(global_step).zfill(5), self._encode(tag))
+ save_path = os.path.join(self._get_file_writer().get_logdir(), subdir)
+ try:
+ os.makedirs(save_path)
+ except OSError:
+ print(
+ 'warning: Embedding dir exists, did you set global_step for add_embedding()?')
+ if metadata is not None:
+ assert mat.shape[0] == len(
+ metadata), '#labels should equal with #data points'
+ make_tsv(metadata, save_path, metadata_header=metadata_header)
+ if label_img is not None:
+ assert mat.shape[0] == label_img.shape[0], '#images should equal with #data points'
+ make_sprite(label_img, save_path)
+ assert mat.ndim == 2, 'mat should be 2D, where mat.size(0) is the number of data points'
+ make_mat(mat, save_path)
+ # new funcion to append to the config file a new embedding
+ append_pbtxt(metadata, label_img,
+ self._get_file_writer().get_logdir(), subdir, global_step, tag)
+
+[docs] def add_pr_curve(self, tag, labels, predictions, global_step=None,
+ num_thresholds=127, weights=None, walltime=None):
+ """Adds precision recall curve.
+
+ Args:
+ tag (string): Data identifier
+ labels (torch.Tensor, numpy.array, or string/blobname): Ground truth data. Binary label for each element.
+ predictions (torch.Tensor, numpy.array, or string/blobname):
+ The probability that an element be classified as true. Value should in [0, 1]
+ global_step (int): Global step value to record
+ num_thresholds (int): Number of thresholds used to draw the curve.
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+
+ """
+ labels, predictions = make_np(labels), make_np(predictions)
+ self._get_file_writer().add_summary(
+ pr_curve(tag, labels, predictions, num_thresholds, weights),
+ global_step, walltime)
+
+ def add_pr_curve_raw(self, tag, true_positive_counts,
+ false_positive_counts,
+ true_negative_counts,
+ false_negative_counts,
+ precision,
+ recall,
+ global_step=None,
+ num_thresholds=127,
+ weights=None,
+ walltime=None):
+ """Adds precision recall curve with raw data.
+
+ Args:
+ tag (string): Data identifier
+ true_positive_counts (torch.Tensor, numpy.array, or string/blobname): true positive counts
+ false_positive_counts (torch.Tensor, numpy.array, or string/blobname): false positive counts
+ true_negative_counts (torch.Tensor, numpy.array, or string/blobname): true negative counts
+ false_negative_counts (torch.Tensor, numpy.array, or string/blobname): false negative counts
+ precision (torch.Tensor, numpy.array, or string/blobname): precision
+ recall (torch.Tensor, numpy.array, or string/blobname): recall
+ global_step (int): Global step value to record
+ num_thresholds (int): Number of thresholds used to draw the curve.
+ walltime (float): Optional override default walltime (time.time())
+ seconds after epoch of event
+ see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md
+ """
+ self._get_file_writer().add_summary(
+ pr_curve_raw(tag,
+ true_positive_counts,
+ false_positive_counts,
+ true_negative_counts,
+ false_negative_counts,
+ precision,
+ recall,
+ num_thresholds,
+ weights),
+ global_step,
+ walltime)
+
+ def add_custom_scalars_multilinechart(self, tags, category='default', title='untitled'):
+ """Shorthand for creating multilinechart. Similar to ``add_custom_scalars()``, but the only necessary argument
+ is *tags*.
+
+ Args:
+ tags (list): list of tags that have been used in ``add_scalar()``
+
+ Examples::
+
+ writer.add_custom_scalars_multilinechart(['twse/0050', 'twse/2330'])
+ """
+ layout = {category: {title: ['Multiline', tags]}}
+ self._get_file_writer().add_summary(custom_scalars(layout))
+
+ def add_custom_scalars_marginchart(self, tags, category='default', title='untitled'):
+ """Shorthand for creating marginchart. Similar to ``add_custom_scalars()``, but the only necessary argument
+ is *tags*, which should have exactly 3 elements.
+
+ Args:
+ tags (list): list of tags that have been used in ``add_scalar()``
+
+ Examples::
+
+ writer.add_custom_scalars_marginchart(['twse/0050', 'twse/2330', 'twse/2006'])
+ """
+ assert len(tags) == 3
+ layout = {category: {title: ['Margin', tags]}}
+ self._get_file_writer().add_summary(custom_scalars(layout))
+
+[docs] def add_custom_scalars(self, layout):
+ """Create special chart by collecting charts tags in 'scalars'. Note that this function can only be called once
+ for each SummaryWriter() object. Because it only provides metadata to tensorboard, the function can be called
+ before or after the training loop.
+
+ Args:
+ layout (dict): {categoryName: *charts*}, where *charts* is also a dictionary
+ {chartName: *ListOfProperties*}. The first element in *ListOfProperties* is the chart's type
+ (one of **Multiline** or **Margin**) and the second element should be a list containing the tags
+ you have used in add_scalar function, which will be collected into the new chart.
+
+ Examples::
+
+ layout = {'Taiwan':{'twse':['Multiline',['twse/0050', 'twse/2330']]},
+ 'USA':{ 'dow':['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']],
+ 'nasdaq':['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}}
+
+ writer.add_custom_scalars(layout)
+ """
+ self._get_file_writer().add_summary(custom_scalars(layout))
+
+ def close(self):
+ if self.all_writers is None:
+ return # ignore double close
+ for writer in self.all_writers.values():
+ writer.flush()
+ writer.close()
+ self.file_writer = self.all_writers = None
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+from torchvision import models
+from torchvision import datasets
+from torchvision import transforms
+from torchvision import utils
+
+try:
+ from .version import __version__ # noqa: F401
+except ImportError:
+ pass
+
+_image_backend = 'PIL'
+
+
+[docs]def set_image_backend(backend):
+ """
+ Specifies the package used to load images.
+
+ Args:
+ backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
+ The :mod:`accimage` package uses the Intel IPP library. It is
+ generally faster than PIL, but does not support as many operations.
+ """
+ global _image_backend
+ if backend not in ['PIL', 'accimage']:
+ raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'"
+ .format(backend))
+ _image_backend = backend
+
+
+[docs]def get_image_backend():
+ """
+ Gets the name of the package used to load images
+ """
+ return _image_backend
+
+from __future__ import print_function
+from PIL import Image
+import os
+import os.path
+import numpy as np
+import sys
+
+if sys.version_info[0] == 2:
+ import cPickle as pickle
+else:
+ import pickle
+
+from .vision import VisionDataset
+from .utils import download_url, check_integrity
+
+
+[docs]class CIFAR10(VisionDataset):
+ """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
+
+ Args:
+ root (string): Root directory of dataset where directory
+ ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
+ train (bool, optional): If True, creates dataset from training set, otherwise
+ creates from test set.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+ base_folder = 'cifar-10-batches-py'
+ url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
+ filename = "cifar-10-python.tar.gz"
+ tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
+ train_list = [
+ ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
+ ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
+ ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
+ ['data_batch_4', '634d18415352ddfa80567beed471001a'],
+ ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
+ ]
+
+ test_list = [
+ ['test_batch', '40351d587109b95175f43aff81a1287e'],
+ ]
+ meta = {
+ 'filename': 'batches.meta',
+ 'key': 'label_names',
+ 'md5': '5ff9c542aee3614f3951f8cda6e48888',
+ }
+
+ def __init__(self, root, train=True,
+ transform=None, target_transform=None,
+ download=False):
+
+ super(CIFAR10, self).__init__(root)
+ self.transform = transform
+ self.target_transform = target_transform
+
+ self.train = train # training set or test set
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+
+ if self.train:
+ downloaded_list = self.train_list
+ else:
+ downloaded_list = self.test_list
+
+ self.data = []
+ self.targets = []
+
+ # now load the picked numpy arrays
+ for file_name, checksum in downloaded_list:
+ file_path = os.path.join(self.root, self.base_folder, file_name)
+ with open(file_path, 'rb') as f:
+ if sys.version_info[0] == 2:
+ entry = pickle.load(f)
+ else:
+ entry = pickle.load(f, encoding='latin1')
+ self.data.append(entry['data'])
+ if 'labels' in entry:
+ self.targets.extend(entry['labels'])
+ else:
+ self.targets.extend(entry['fine_labels'])
+
+ self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
+ self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
+
+ self._load_meta()
+
+ def _load_meta(self):
+ path = os.path.join(self.root, self.base_folder, self.meta['filename'])
+ if not check_integrity(path, self.meta['md5']):
+ raise RuntimeError('Dataset metadata file not found or corrupted.' +
+ ' You can use download=True to download it')
+ with open(path, 'rb') as infile:
+ if sys.version_info[0] == 2:
+ data = pickle.load(infile)
+ else:
+ data = pickle.load(infile, encoding='latin1')
+ self.classes = data[self.meta['key']]
+ self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], self.targets[index]
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(img)
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.data)
+
+ def _check_integrity(self):
+ root = self.root
+ for fentry in (self.train_list + self.test_list):
+ filename, md5 = fentry[0], fentry[1]
+ fpath = os.path.join(root, self.base_folder, filename)
+ if not check_integrity(fpath, md5):
+ return False
+ return True
+
+ def download(self):
+ import tarfile
+
+ if self._check_integrity():
+ print('Files already downloaded and verified')
+ return
+
+ download_url(self.url, self.root, self.filename, self.tgz_md5)
+
+ # extract file
+ with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
+ tar.extractall(path=self.root)
+
+ def extra_repr(self):
+ return "Split: {}".format("Train" if self.train is True else "Test")
+
+
+[docs]class CIFAR100(CIFAR10):
+ """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
+
+ This is a subclass of the `CIFAR10` Dataset.
+ """
+ base_folder = 'cifar-100-python'
+ url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
+ filename = "cifar-100-python.tar.gz"
+ tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
+ train_list = [
+ ['train', '16019d7e3df5f24257cddd939b257f8d'],
+ ]
+
+ test_list = [
+ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
+ ]
+ meta = {
+ 'filename': 'meta',
+ 'key': 'fine_label_names',
+ 'md5': '7973b15100ade9c7d40fb424638fde48',
+ }
+
+import json
+import os
+from collections import namedtuple
+
+from .vision import VisionDataset
+from PIL import Image
+
+
+[docs]class Cityscapes(VisionDataset):
+ """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
+
+ Args:
+ root (string): Root directory of dataset where directory ``leftImg8bit``
+ and ``gtFine`` or ``gtCoarse`` are located.
+ split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
+ otherwise ``train``, ``train_extra`` or ``val``
+ mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
+ target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
+ or ``color``. Can also be a list to output a tuple with all specified target types.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+
+ Examples:
+
+ Get semantic segmentation target
+
+ .. code-block:: python
+ dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+ target_type='semantic')
+
+ img, smnt = dataset[0]
+
+ Get multiple targets
+
+ .. code-block:: python
+ dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+ target_type=['instance', 'color', 'polygon'])
+
+ img, (inst, col, poly) = dataset[0]
+
+ Validate on the "coarse" set
+
+ .. code-block:: python
+ dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
+ target_type='semantic')
+
+ img, smnt = dataset[0]
+ """
+
+ # Based on https://github.com/mcordts/cityscapesScripts
+ CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
+ 'has_instances', 'ignore_in_eval', 'color'])
+
+ classes = [
+ CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
+ CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
+ CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
+ CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
+ CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
+ CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
+ CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
+ CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
+ CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
+ CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
+ CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
+ CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
+ CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
+ CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
+ CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
+ CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
+ CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
+ CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
+ CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
+ CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
+ CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
+ CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
+ CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
+ CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
+ CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
+ CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
+ CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
+ CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
+ CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
+ CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
+ CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
+ CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
+ CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
+ CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
+ CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
+ ]
+
+ def __init__(self, root, split='train', mode='fine', target_type='instance',
+ transform=None, target_transform=None):
+ super(Cityscapes, self).__init__(root)
+ self.transform = transform
+ self.target_transform = target_transform
+ self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
+ self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
+ self.targets_dir = os.path.join(self.root, self.mode, split)
+ self.target_type = target_type
+ self.split = split
+ self.images = []
+ self.targets = []
+
+ if mode not in ['fine', 'coarse']:
+ raise ValueError('Invalid mode! Please use mode="fine" or mode="coarse"')
+
+ if mode == 'fine' and split not in ['train', 'test', 'val']:
+ raise ValueError('Invalid split for mode "fine"! Please use split="train", split="test"'
+ ' or split="val"')
+ elif mode == 'coarse' and split not in ['train', 'train_extra', 'val']:
+ raise ValueError('Invalid split for mode "coarse"! Please use split="train", split="train_extra"'
+ ' or split="val"')
+
+ if not isinstance(target_type, list):
+ self.target_type = [target_type]
+
+ if not all(t in ['instance', 'semantic', 'polygon', 'color'] for t in self.target_type):
+ raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
+ ' or "color"')
+
+ if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
+ raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
+ ' specified "split" and "mode" are inside the "root" directory')
+
+ for city in os.listdir(self.images_dir):
+ img_dir = os.path.join(self.images_dir, city)
+ target_dir = os.path.join(self.targets_dir, city)
+ for file_name in os.listdir(img_dir):
+ target_types = []
+ for t in self.target_type:
+ target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
+ self._get_target_suffix(self.mode, t))
+ target_types.append(os.path.join(target_dir, target_name))
+
+ self.images.append(os.path.join(img_dir, file_name))
+ self.targets.append(target_types)
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
+ than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
+ """
+
+ image = Image.open(self.images[index]).convert('RGB')
+
+ targets = []
+ for i, t in enumerate(self.target_type):
+ if t == 'polygon':
+ target = self._load_json(self.targets[index][i])
+ else:
+ target = Image.open(self.targets[index][i])
+
+ targets.append(target)
+
+ target = tuple(targets) if len(targets) > 1 else targets[0]
+
+ if self.transform:
+ image = self.transform(image)
+
+ if self.target_transform:
+ target = self.target_transform(target)
+
+ return image, target
+
+ def __len__(self):
+ return len(self.images)
+
+ def extra_repr(self):
+ lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
+ return '\n'.join(lines).format(**self.__dict__)
+
+ def _load_json(self, path):
+ with open(path, 'r') as file:
+ data = json.load(file)
+ return data
+
+ def _get_target_suffix(self, mode, target_type):
+ if target_type == 'instance':
+ return '{}_instanceIds.png'.format(mode)
+ elif target_type == 'semantic':
+ return '{}_labelIds.png'.format(mode)
+ elif target_type == 'color':
+ return '{}_color.png'.format(mode)
+ else:
+ return '{}_polygons.json'.format(mode)
+
+from .vision import VisionDataset
+from PIL import Image
+import os
+import os.path
+
+
+[docs]class CocoCaptions(VisionDataset):
+ """`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.
+
+ Args:
+ root (string): Root directory where images are downloaded to.
+ annFile (string): Path to json annotation file.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.ToTensor``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+
+ Example:
+
+ .. code:: python
+
+ import torchvision.datasets as dset
+ import torchvision.transforms as transforms
+ cap = dset.CocoCaptions(root = 'dir where images are',
+ annFile = 'json annotation file',
+ transform=transforms.ToTensor())
+
+ print('Number of samples: ', len(cap))
+ img, target = cap[3] # load 4th sample
+
+ print("Image Size: ", img.size())
+ print(target)
+
+ Output: ::
+
+ Number of samples: 82783
+ Image Size: (3L, 427L, 640L)
+ [u'A plane emitting smoke stream flying over a mountain.',
+ u'A plane darts across a bright blue sky behind a mountain covered in snow',
+ u'A plane leaves a contrail above the snowy mountain top.',
+ u'A mountain that has a plane flying overheard in the distance.',
+ u'A mountain view with a plume of smoke in the background']
+
+ """
+
+ def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
+ super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
+ from pycocotools.coco import COCO
+ self.coco = COCO(annFile)
+ self.ids = list(sorted(self.coco.imgs.keys()))
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target). target is a list of captions for the image.
+ """
+ coco = self.coco
+ img_id = self.ids[index]
+ ann_ids = coco.getAnnIds(imgIds=img_id)
+ anns = coco.loadAnns(ann_ids)
+ target = [ann['caption'] for ann in anns]
+
+ path = coco.loadImgs(img_id)[0]['file_name']
+
+ img = Image.open(os.path.join(self.root, path)).convert('RGB')
+
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.ids)
+
+
+[docs]class CocoDetection(VisionDataset):
+ """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
+
+ Args:
+ root (string): Root directory where images are downloaded to.
+ annFile (string): Path to json annotation file.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.ToTensor``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
+ super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
+ from pycocotools.coco import COCO
+ self.coco = COCO(annFile)
+ self.ids = list(sorted(self.coco.imgs.keys()))
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
+ """
+ coco = self.coco
+ img_id = self.ids[index]
+ ann_ids = coco.getAnnIds(imgIds=img_id)
+ target = coco.loadAnns(ann_ids)
+
+ path = coco.loadImgs(img_id)[0]['file_name']
+
+ img = Image.open(os.path.join(self.root, path)).convert('RGB')
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.ids)
+
+import torch
+from .vision import VisionDataset
+from .. import transforms
+
+
+[docs]class FakeData(VisionDataset):
+ """A fake dataset that returns randomly generated images and returns them as PIL images
+
+ Args:
+ size (int, optional): Size of the dataset. Default: 1000 images
+ image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
+ num_classes(int, optional): Number of classes in the datset. Default: 10
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ random_offset (int): Offsets the index-based random seed used to
+ generate each image. Default: 0
+
+ """
+
+ def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10,
+ transform=None, target_transform=None, random_offset=0):
+ super(FakeData, self).__init__(None)
+ self.transform = transform
+ self.target_transform = target_transform
+ self.size = size
+ self.num_classes = num_classes
+ self.image_size = image_size
+ self.transform = transform
+ self.target_transform = target_transform
+ self.random_offset = random_offset
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is class_index of the target class.
+ """
+ # create random image that is consistent with the index id
+ if index >= len(self):
+ raise IndexError("{} index out of range".format(self.__class__.__name__))
+ rng_state = torch.get_rng_state()
+ torch.manual_seed(index + self.random_offset)
+ img = torch.randn(*self.image_size)
+ target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
+ torch.set_rng_state(rng_state)
+
+ # convert to PIL Image
+ img = transforms.ToPILImage()(img)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return self.size
+
+from collections import defaultdict
+from PIL import Image
+from six.moves import html_parser
+
+import glob
+import os
+from .vision import VisionDataset
+
+
+class Flickr8kParser(html_parser.HTMLParser):
+ """Parser for extracting captions from the Flickr8k dataset web page."""
+
+ def __init__(self, root):
+ super(Flickr8kParser, self).__init__()
+
+ self.root = root
+
+ # Data structure to store captions
+ self.annotations = {}
+
+ # State variables
+ self.in_table = False
+ self.current_tag = None
+ self.current_img = None
+
+ def handle_starttag(self, tag, attrs):
+ self.current_tag = tag
+
+ if tag == 'table':
+ self.in_table = True
+
+ def handle_endtag(self, tag):
+ self.current_tag = None
+
+ if tag == 'table':
+ self.in_table = False
+
+ def handle_data(self, data):
+ if self.in_table:
+ if data == 'Image Not Found':
+ self.current_img = None
+ elif self.current_tag == 'a':
+ img_id = data.split('/')[-2]
+ img_id = os.path.join(self.root, img_id + '_*.jpg')
+ img_id = glob.glob(img_id)[0]
+ self.current_img = img_id
+ self.annotations[img_id] = []
+ elif self.current_tag == 'li' and self.current_img:
+ img_id = self.current_img
+ self.annotations[img_id].append(data.strip())
+
+
+[docs]class Flickr8k(VisionDataset):
+ """`Flickr8k Entities <http://nlp.cs.illinois.edu/HockenmaierGroup/8k-pictures.html>`_ Dataset.
+
+ Args:
+ root (string): Root directory where images are downloaded to.
+ ann_file (string): Path to annotation file.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.ToTensor``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ def __init__(self, root, ann_file, transform=None, target_transform=None):
+ super(Flickr8k, self).__init__(root)
+ self.transform = transform
+ self.target_transform = target_transform
+ self.ann_file = os.path.expanduser(ann_file)
+
+ # Read annotations and store in a dict
+ parser = Flickr8kParser(self.root)
+ with open(self.ann_file) as fh:
+ parser.feed(fh.read())
+ self.annotations = parser.annotations
+
+ self.ids = list(sorted(self.annotations.keys()))
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target). target is a list of captions for the image.
+ """
+ img_id = self.ids[index]
+
+ # Image
+ img = Image.open(img_id).convert('RGB')
+ if self.transform is not None:
+ img = self.transform(img)
+
+ # Captions
+ target = self.annotations[img_id]
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.ids)
+
+
+[docs]class Flickr30k(VisionDataset):
+ """`Flickr30k Entities <http://web.engr.illinois.edu/~bplumme2/Flickr30kEntities/>`_ Dataset.
+
+ Args:
+ root (string): Root directory where images are downloaded to.
+ ann_file (string): Path to annotation file.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.ToTensor``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ def __init__(self, root, ann_file, transform=None, target_transform=None):
+ super(Flickr30k, self).__init__(root)
+ self.transform = transform
+ self.target_transform = target_transform
+ self.ann_file = os.path.expanduser(ann_file)
+
+ # Read annotations and store in a dict
+ self.annotations = defaultdict(list)
+ with open(self.ann_file) as fh:
+ for line in fh:
+ img_id, caption = line.strip().split('\t')
+ self.annotations[img_id[:-2]].append(caption)
+
+ self.ids = list(sorted(self.annotations.keys()))
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target). target is a list of captions for the image.
+ """
+ img_id = self.ids[index]
+
+ # Image
+ filename = os.path.join(self.root, img_id)
+ img = Image.open(filename).convert('RGB')
+ if self.transform is not None:
+ img = self.transform(img)
+
+ # Captions
+ target = self.annotations[img_id]
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.ids)
+
+from .vision import VisionDataset
+
+from PIL import Image
+
+import os
+import os.path
+import sys
+
+
+def has_file_allowed_extension(filename, extensions):
+ """Checks if a file is an allowed extension.
+
+ Args:
+ filename (string): path to a file
+ extensions (tuple of strings): extensions to consider (lowercase)
+
+ Returns:
+ bool: True if the filename ends with one of given extensions
+ """
+ return filename.lower().endswith(extensions)
+
+
+def is_image_file(filename):
+ """Checks if a file is an allowed image extension.
+
+ Args:
+ filename (string): path to a file
+
+ Returns:
+ bool: True if the filename ends with a known image extension
+ """
+ return has_file_allowed_extension(filename, IMG_EXTENSIONS)
+
+
+def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
+ images = []
+ dir = os.path.expanduser(dir)
+ if not ((extensions is None) ^ (is_valid_file is None)):
+ raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
+ if extensions is not None:
+ def is_valid_file(x):
+ return has_file_allowed_extension(x, extensions)
+ for target in sorted(class_to_idx.keys()):
+ d = os.path.join(dir, target)
+ if not os.path.isdir(d):
+ continue
+ for root, _, fnames in sorted(os.walk(d)):
+ for fname in sorted(fnames):
+ path = os.path.join(root, fname)
+ if is_valid_file(path):
+ item = (path, class_to_idx[target])
+ images.append(item)
+
+ return images
+
+
+[docs]class DatasetFolder(VisionDataset):
+ """A generic data loader where the samples are arranged in this way: ::
+
+ root/class_x/xxx.ext
+ root/class_x/xxy.ext
+ root/class_x/xxz.ext
+
+ root/class_y/123.ext
+ root/class_y/nsdf3.ext
+ root/class_y/asd932_.ext
+
+ Args:
+ root (string): Root directory path.
+ loader (callable): A function to load a sample given its path.
+ extensions (tuple[string]): A list of allowed extensions.
+ both extensions and is_valid_file should not be passed.
+ transform (callable, optional): A function/transform that takes in
+ a sample and returns a transformed version.
+ E.g, ``transforms.RandomCrop`` for images.
+ target_transform (callable, optional): A function/transform that takes
+ in the target and transforms it.
+ is_valid_file (callable, optional): A function that takes path of an Image file
+ and check if the file is a valid_file (used to check of corrupt files)
+ both extensions and is_valid_file should not be passed.
+
+ Attributes:
+ classes (list): List of the class names.
+ class_to_idx (dict): Dict with items (class_name, class_index).
+ samples (list): List of (sample path, class_index) tuples
+ targets (list): The class_index value for each image in the dataset
+ """
+
+ def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
+ super(DatasetFolder, self).__init__(root)
+ self.transform = transform
+ self.target_transform = target_transform
+ classes, class_to_idx = self._find_classes(self.root)
+ samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
+ if len(samples) == 0:
+ raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
+ "Supported extensions are: " + ",".join(extensions)))
+
+ self.loader = loader
+ self.extensions = extensions
+
+ self.classes = classes
+ self.class_to_idx = class_to_idx
+ self.samples = samples
+ self.targets = [s[1] for s in samples]
+
+ def _find_classes(self, dir):
+ """
+ Finds the class folders in a dataset.
+
+ Args:
+ dir (string): Root directory path.
+
+ Returns:
+ tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
+
+ Ensures:
+ No class is a subdirectory of another.
+ """
+ if sys.version_info >= (3, 5):
+ # Faster and available in Python 3.5 and above
+ classes = [d.name for d in os.scandir(dir) if d.is_dir()]
+ else:
+ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
+ classes.sort()
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
+ return classes, class_to_idx
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (sample, target) where target is class_index of the target class.
+ """
+ path, target = self.samples[index]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return sample, target
+
+ def __len__(self):
+ return len(self.samples)
+
+
+IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
+
+
+def pil_loader(path):
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+ with open(path, 'rb') as f:
+ img = Image.open(f)
+ return img.convert('RGB')
+
+
+def accimage_loader(path):
+ import accimage
+ try:
+ return accimage.Image(path)
+ except IOError:
+ # Potentially a decoding problem, fall back to PIL.Image
+ return pil_loader(path)
+
+
+def default_loader(path):
+ from torchvision import get_image_backend
+ if get_image_backend() == 'accimage':
+ return accimage_loader(path)
+ else:
+ return pil_loader(path)
+
+
+[docs]class ImageFolder(DatasetFolder):
+ """A generic data loader where the images are arranged in this way: ::
+
+ root/dog/xxx.png
+ root/dog/xxy.png
+ root/dog/xxz.png
+
+ root/cat/123.png
+ root/cat/nsdf3.png
+ root/cat/asd932_.png
+
+ Args:
+ root (string): Root directory path.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+ is_valid_file (callable, optional): A function that takes path of an Image file
+ and check if the file is a valid_file (used to check of corrupt files)
+
+ Attributes:
+ classes (list): List of the class names.
+ class_to_idx (dict): Dict with items (class_name, class_index).
+ imgs (list): List of (image path, class_index) tuples
+ """
+
+ def __init__(self, root, transform=None, target_transform=None,
+ loader=default_loader, is_valid_file=None):
+ super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
+ transform=transform,
+ target_transform=target_transform,
+ is_valid_file=is_valid_file)
+ self.imgs = self.samples
+
+from __future__ import print_function
+import os
+import shutil
+import torch
+from .folder import ImageFolder
+from .utils import check_integrity, download_url
+
+ARCHIVE_DICT = {
+ 'train': {
+ 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
+ 'md5': '1d675b47d978889d74fa0da5fadfb00e',
+ },
+ 'val': {
+ 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
+ 'md5': '29b22e2961454d5413ddabcf34fc5622',
+ },
+ 'devkit': {
+ 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
+ 'md5': 'fa75699e90414af021442c21a62c3abf',
+ }
+}
+
+
+[docs]class ImageNet(ImageFolder):
+ """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
+
+ Args:
+ root (string): Root directory of the ImageNet Dataset.
+ split (string, optional): The dataset split, supports ``train``, or ``val``.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+
+ Attributes:
+ classes (list): List of the class names.
+ class_to_idx (dict): Dict with items (class_name, class_index).
+ wnids (list): List of the WordNet IDs.
+ wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
+ imgs (list): List of (image path, class_index) tuples
+ targets (list): The class_index value for each image in the dataset
+ """
+
+ def __init__(self, root, split='train', download=False, **kwargs):
+ root = self.root = os.path.expanduser(root)
+ self.split = self._verify_split(split)
+
+ if download:
+ self.download()
+ wnid_to_classes = self._load_meta_file()[0]
+
+ super(ImageNet, self).__init__(self.split_folder, **kwargs)
+ self.root = root
+
+ idcs = [idx for _, idx in self.imgs]
+ self.wnids = self.classes
+ self.wnid_to_idx = {wnid: idx for idx, wnid in zip(idcs, self.wnids)}
+ self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
+ self.class_to_idx = {cls: idx
+ for clss, idx in zip(self.classes, idcs)
+ for cls in clss}
+
+ def download(self):
+ if not check_integrity(self.meta_file):
+ tmpdir = os.path.join(self.root, 'tmp')
+
+ archive_dict = ARCHIVE_DICT['devkit']
+ download_and_extract_tar(archive_dict['url'], self.root,
+ extract_root=tmpdir,
+ md5=archive_dict['md5'])
+ devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
+ meta = parse_devkit(os.path.join(tmpdir, devkit_folder))
+ self._save_meta_file(*meta)
+
+ shutil.rmtree(tmpdir)
+
+ if not os.path.isdir(self.split_folder):
+ archive_dict = ARCHIVE_DICT[self.split]
+ download_and_extract_tar(archive_dict['url'], self.root,
+ extract_root=self.split_folder,
+ md5=archive_dict['md5'])
+
+ if self.split == 'train':
+ prepare_train_folder(self.split_folder)
+ elif self.split == 'val':
+ val_wnids = self._load_meta_file()[1]
+ prepare_val_folder(self.split_folder, val_wnids)
+ else:
+ msg = ("You set download=True, but a folder '{}' already exist in "
+ "the root directory. If you want to re-download or re-extract the "
+ "archive, delete the folder.")
+ print(msg.format(self.split))
+
+ @property
+ def meta_file(self):
+ return os.path.join(self.root, 'meta.bin')
+
+ def _load_meta_file(self):
+ if check_integrity(self.meta_file):
+ return torch.load(self.meta_file)
+ else:
+ raise RuntimeError("Meta file not found or corrupted.",
+ "You can use download=True to create it.")
+
+ def _save_meta_file(self, wnid_to_class, val_wnids):
+ torch.save((wnid_to_class, val_wnids), self.meta_file)
+
+ def _verify_split(self, split):
+ if split not in self.valid_splits:
+ msg = "Unknown split {} .".format(split)
+ msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
+ raise ValueError(msg)
+ return split
+
+ @property
+ def valid_splits(self):
+ return 'train', 'val'
+
+ @property
+ def split_folder(self):
+ return os.path.join(self.root, self.split)
+
+ def extra_repr(self):
+ return "Split: {split}".format(**self.__dict__)
+
+
+def extract_tar(src, dest=None, gzip=None, delete=False):
+ import tarfile
+
+ if dest is None:
+ dest = os.path.dirname(src)
+ if gzip is None:
+ gzip = src.lower().endswith('.gz')
+
+ mode = 'r:gz' if gzip else 'r'
+ with tarfile.open(src, mode) as tarfh:
+ tarfh.extractall(path=dest)
+
+ if delete:
+ os.remove(src)
+
+
+def download_and_extract_tar(url, download_root, extract_root=None, filename=None,
+ md5=None, **kwargs):
+ download_root = os.path.expanduser(download_root)
+ if extract_root is None:
+ extract_root = download_root
+ if filename is None:
+ filename = os.path.basename(url)
+
+ if not check_integrity(os.path.join(download_root, filename), md5):
+ download_url(url, download_root, filename=filename, md5=md5)
+
+ extract_tar(os.path.join(download_root, filename), extract_root, **kwargs)
+
+
+def parse_devkit(root):
+ idx_to_wnid, wnid_to_classes = parse_meta(root)
+ val_idcs = parse_val_groundtruth(root)
+ val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
+ return wnid_to_classes, val_wnids
+
+
+def parse_meta(devkit_root, path='data', filename='meta.mat'):
+ import scipy.io as sio
+
+ metafile = os.path.join(devkit_root, path, filename)
+ meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
+ nums_children = list(zip(*meta))[4]
+ meta = [meta[idx] for idx, num_children in enumerate(nums_children)
+ if num_children == 0]
+ idcs, wnids, classes = list(zip(*meta))[:3]
+ classes = [tuple(clss.split(', ')) for clss in classes]
+ idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
+ wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
+ return idx_to_wnid, wnid_to_classes
+
+
+def parse_val_groundtruth(devkit_root, path='data',
+ filename='ILSVRC2012_validation_ground_truth.txt'):
+ with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
+ val_idcs = txtfh.readlines()
+ return [int(val_idx) for val_idx in val_idcs]
+
+
+def prepare_train_folder(folder):
+ for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]:
+ extract_tar(archive, os.path.splitext(archive)[0], delete=True)
+
+
+def prepare_val_folder(folder, wnids):
+ img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)])
+
+ for wnid in set(wnids):
+ os.mkdir(os.path.join(folder, wnid))
+
+ for wnid, img_file in zip(wnids, img_files):
+ shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file)))
+
+
+def _splitexts(root):
+ exts = []
+ ext = '.'
+ while ext:
+ root, ext = os.path.splitext(root)
+ exts.append(ext)
+ return root, ''.join(reversed(exts))
+
+from .vision import VisionDataset
+from PIL import Image
+import os
+import os.path
+import six
+import string
+import sys
+
+if sys.version_info[0] == 2:
+ import cPickle as pickle
+else:
+ import pickle
+
+
+class LSUNClass(VisionDataset):
+ def __init__(self, root, transform=None, target_transform=None):
+ import lmdb
+ super(LSUNClass, self).__init__(root)
+ self.transform = transform
+ self.target_transform = target_transform
+
+ self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
+ readahead=False, meminit=False)
+ with self.env.begin(write=False) as txn:
+ self.length = txn.stat()['entries']
+ cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters)
+ if os.path.isfile(cache_file):
+ self.keys = pickle.load(open(cache_file, "rb"))
+ else:
+ with self.env.begin(write=False) as txn:
+ self.keys = [key for key, _ in txn.cursor()]
+ pickle.dump(self.keys, open(cache_file, "wb"))
+
+ def __getitem__(self, index):
+ img, target = None, None
+ env = self.env
+ with env.begin(write=False) as txn:
+ imgbuf = txn.get(self.keys[index])
+
+ buf = six.BytesIO()
+ buf.write(imgbuf)
+ buf.seek(0)
+ img = Image.open(buf).convert('RGB')
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return self.length
+
+
+[docs]class LSUN(VisionDataset):
+ """
+ `LSUN <http://lsun.cs.princeton.edu>`_ dataset.
+
+ Args:
+ root (string): Root directory for the database files.
+ classes (string or list): One of {'train', 'val', 'test'} or a list of
+ categories to load. e,g. ['bedroom_train', 'church_train'].
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ def __init__(self, root, classes='train',
+ transform=None, target_transform=None):
+ super(LSUN, self).__init__(root)
+ self.transform = transform
+ self.target_transform = target_transform
+ categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
+ 'conference_room', 'dining_room', 'kitchen',
+ 'living_room', 'restaurant', 'tower']
+ dset_opts = ['train', 'val', 'test']
+
+ if type(classes) == str and classes in dset_opts:
+ if classes == 'test':
+ classes = [classes]
+ else:
+ classes = [c + '_' + classes for c in categories]
+ elif type(classes) == list:
+ for c in classes:
+ c_short = c.split('_')
+ c_short.pop(len(c_short) - 1)
+ c_short = '_'.join(c_short)
+ if c_short not in categories:
+ raise (ValueError('Unknown LSUN class: ' + c_short + '.'
+ 'Options are: ' + str(categories)))
+ c_short = c.split('_')
+ c_short = c_short.pop(len(c_short) - 1)
+ if c_short not in dset_opts:
+ raise (ValueError('Unknown postfix: ' + c_short + '.'
+ 'Options are: ' + str(dset_opts)))
+ else:
+ raise (ValueError('Unknown option for classes'))
+ self.classes = classes
+
+ # for each class, create an LSUNClassDataset
+ self.dbs = []
+ for c in self.classes:
+ self.dbs.append(LSUNClass(
+ root=root + '/' + c + '_lmdb',
+ transform=transform))
+
+ self.indices = []
+ count = 0
+ for db in self.dbs:
+ count += len(db)
+ self.indices.append(count)
+
+ self.length = count
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target) where target is the index of the target category.
+ """
+ target = 0
+ sub = 0
+ for ind in self.indices:
+ if index < ind:
+ break
+ target += 1
+ sub = ind
+
+ db = self.dbs[target]
+ index = index - sub
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ img, _ = db[index]
+ return img, target
+
+ def __len__(self):
+ return self.length
+
+ def extra_repr(self):
+ return "Classes: {classes}".format(**self.__dict__)
+
+from __future__ import print_function
+from .vision import VisionDataset
+import warnings
+from PIL import Image
+import os
+import os.path
+import gzip
+import numpy as np
+import torch
+import codecs
+from .utils import download_url, makedir_exist_ok
+
+
+[docs]class MNIST(VisionDataset):
+ """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
+
+ Args:
+ root (string): Root directory of dataset where ``MNIST/processed/training.pt``
+ and ``MNIST/processed/test.pt`` exist.
+ train (bool, optional): If True, creates dataset from ``training.pt``,
+ otherwise from ``test.pt``.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+ urls = [
+ 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
+ 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
+ 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
+ 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
+ ]
+ training_file = 'training.pt'
+ test_file = 'test.pt'
+ classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
+ '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
+
+ @property
+ def train_labels(self):
+ warnings.warn("train_labels has been renamed targets")
+ return self.targets
+
+ @property
+ def test_labels(self):
+ warnings.warn("test_labels has been renamed targets")
+ return self.targets
+
+ @property
+ def train_data(self):
+ warnings.warn("train_data has been renamed data")
+ return self.data
+
+ @property
+ def test_data(self):
+ warnings.warn("test_data has been renamed data")
+ return self.data
+
+ def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
+ super(MNIST, self).__init__(root)
+ self.transform = transform
+ self.target_transform = target_transform
+ self.train = train # training set or test set
+
+ if download:
+ self.download()
+
+ if not self._check_exists():
+ raise RuntimeError('Dataset not found.' +
+ ' You can use download=True to download it')
+
+ if self.train:
+ data_file = self.training_file
+ else:
+ data_file = self.test_file
+ self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], int(self.targets[index])
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(img.numpy(), mode='L')
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.data)
+
+ @property
+ def raw_folder(self):
+ return os.path.join(self.root, self.__class__.__name__, 'raw')
+
+ @property
+ def processed_folder(self):
+ return os.path.join(self.root, self.__class__.__name__, 'processed')
+
+ @property
+ def class_to_idx(self):
+ return {_class: i for i, _class in enumerate(self.classes)}
+
+ def _check_exists(self):
+ return (os.path.exists(os.path.join(self.processed_folder,
+ self.training_file)) and
+ os.path.exists(os.path.join(self.processed_folder,
+ self.test_file)))
+
+ @staticmethod
+ def extract_gzip(gzip_path, remove_finished=False):
+ print('Extracting {}'.format(gzip_path))
+ with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \
+ gzip.GzipFile(gzip_path) as zip_f:
+ out_f.write(zip_f.read())
+ if remove_finished:
+ os.unlink(gzip_path)
+
+ def download(self):
+ """Download the MNIST data if it doesn't exist in processed_folder already."""
+
+ if self._check_exists():
+ return
+
+ makedir_exist_ok(self.raw_folder)
+ makedir_exist_ok(self.processed_folder)
+
+ # download files
+ for url in self.urls:
+ filename = url.rpartition('/')[2]
+ file_path = os.path.join(self.raw_folder, filename)
+ download_url(url, root=self.raw_folder, filename=filename, md5=None)
+ self.extract_gzip(gzip_path=file_path, remove_finished=True)
+
+ # process and save as torch files
+ print('Processing...')
+
+ training_set = (
+ read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
+ read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
+ )
+ test_set = (
+ read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
+ read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
+ )
+ with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
+ torch.save(training_set, f)
+ with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
+ torch.save(test_set, f)
+
+ print('Done!')
+
+ def extra_repr(self):
+ return "Split: {}".format("Train" if self.train is True else "Test")
+
+
+[docs]class FashionMNIST(MNIST):
+ """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
+
+ Args:
+ root (string): Root directory of dataset where ``Fashion-MNIST/processed/training.pt``
+ and ``Fashion-MNIST/processed/test.pt`` exist.
+ train (bool, optional): If True, creates dataset from ``training.pt``,
+ otherwise from ``test.pt``.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+ urls = [
+ 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
+ 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
+ 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
+ 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
+ ]
+ classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
+ 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
+
+
+[docs]class KMNIST(MNIST):
+ """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.
+
+ Args:
+ root (string): Root directory of dataset where ``KMNIST/processed/training.pt``
+ and ``KMNIST/processed/test.pt`` exist.
+ train (bool, optional): If True, creates dataset from ``training.pt``,
+ otherwise from ``test.pt``.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+ urls = [
+ 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz',
+ 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz',
+ 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz',
+ 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz',
+ ]
+ classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo']
+
+
+[docs]class EMNIST(MNIST):
+ """`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
+
+ Args:
+ root (string): Root directory of dataset where ``EMNIST/processed/training.pt``
+ and ``EMNIST/processed/test.pt`` exist.
+ split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
+ ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
+ which one to use.
+ train (bool, optional): If True, creates dataset from ``training.pt``,
+ otherwise from ``test.pt``.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+ # Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
+ url = 'https://cloudstor.aarnet.edu.au/plus/index.php/s/54h3OuGJhFLwAlQ/download'
+ splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
+
+ def __init__(self, root, split, **kwargs):
+ if split not in self.splits:
+ raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
+ split, ', '.join(self.splits),
+ ))
+ self.split = split
+ self.training_file = self._training_file(split)
+ self.test_file = self._test_file(split)
+ super(EMNIST, self).__init__(root, **kwargs)
+
+ @staticmethod
+ def _training_file(split):
+ return 'training_{}.pt'.format(split)
+
+ @staticmethod
+ def _test_file(split):
+ return 'test_{}.pt'.format(split)
+
+ def download(self):
+ """Download the EMNIST data if it doesn't exist in processed_folder already."""
+ import shutil
+ import zipfile
+
+ if self._check_exists():
+ return
+
+ makedir_exist_ok(self.raw_folder)
+ makedir_exist_ok(self.processed_folder)
+
+ # download files
+ filename = self.url.rpartition('/')[2]
+ file_path = os.path.join(self.raw_folder, filename)
+ download_url(self.url, root=self.raw_folder, filename=filename, md5=None)
+
+ print('Extracting zip archive')
+ with zipfile.ZipFile(file_path) as zip_f:
+ zip_f.extractall(self.raw_folder)
+ os.unlink(file_path)
+ gzip_folder = os.path.join(self.raw_folder, 'gzip')
+ for gzip_file in os.listdir(gzip_folder):
+ if gzip_file.endswith('.gz'):
+ self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file))
+
+ # process and save as torch files
+ for split in self.splits:
+ print('Processing ' + split)
+ training_set = (
+ read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
+ read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
+ )
+ test_set = (
+ read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
+ read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
+ )
+ with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
+ torch.save(training_set, f)
+ with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
+ torch.save(test_set, f)
+ shutil.rmtree(gzip_folder)
+
+ print('Done!')
+
+
+def get_int(b):
+ return int(codecs.encode(b, 'hex'), 16)
+
+
+def read_label_file(path):
+ with open(path, 'rb') as f:
+ data = f.read()
+ assert get_int(data[:4]) == 2049
+ length = get_int(data[4:8])
+ parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
+ return torch.from_numpy(parsed).view(length).long()
+
+
+def read_image_file(path):
+ with open(path, 'rb') as f:
+ data = f.read()
+ assert get_int(data[:4]) == 2051
+ length = get_int(data[4:8])
+ num_rows = get_int(data[8:12])
+ num_cols = get_int(data[12:16])
+ parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
+ return torch.from_numpy(parsed).view(length, num_rows, num_cols)
+
+import os
+import numpy as np
+from PIL import Image
+
+import torch
+from .vision import VisionDataset
+
+from .utils import download_url
+
+
+[docs]class PhotoTour(VisionDataset):
+ """`Learning Local Image Descriptors Data <http://phototour.cs.washington.edu/patches/default.htm>`_ Dataset.
+
+
+ Args:
+ root (string): Root directory where images are.
+ name (string): Name of the dataset to load.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+ urls = {
+ 'notredame_harris': [
+ 'http://matthewalunbrown.com/patchdata/notredame_harris.zip',
+ 'notredame_harris.zip',
+ '69f8c90f78e171349abdf0307afefe4d'
+ ],
+ 'yosemite_harris': [
+ 'http://matthewalunbrown.com/patchdata/yosemite_harris.zip',
+ 'yosemite_harris.zip',
+ 'a73253d1c6fbd3ba2613c45065c00d46'
+ ],
+ 'liberty_harris': [
+ 'http://matthewalunbrown.com/patchdata/liberty_harris.zip',
+ 'liberty_harris.zip',
+ 'c731fcfb3abb4091110d0ae8c7ba182c'
+ ],
+ 'notredame': [
+ 'http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip',
+ 'notredame.zip',
+ '509eda8535847b8c0a90bbb210c83484'
+ ],
+ 'yosemite': [
+ 'http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip',
+ 'yosemite.zip',
+ '533b2e8eb7ede31be40abc317b2fd4f0'
+ ],
+ 'liberty': [
+ 'http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip',
+ 'liberty.zip',
+ 'fdd9152f138ea5ef2091746689176414'
+ ],
+ }
+ mean = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437,
+ 'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437}
+ std = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019,
+ 'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019}
+ lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092,
+ 'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295}
+ image_ext = 'bmp'
+ info_file = 'info.txt'
+ matches_files = 'm50_100000_100000_0.txt'
+
+ def __init__(self, root, name, train=True, transform=None, download=False):
+ super(PhotoTour, self).__init__(root)
+ self.transform = transform
+ self.name = name
+ self.data_dir = os.path.join(self.root, name)
+ self.data_down = os.path.join(self.root, '{}.zip'.format(name))
+ self.data_file = os.path.join(self.root, '{}.pt'.format(name))
+
+ self.train = train
+ self.mean = self.mean[name]
+ self.std = self.std[name]
+
+ if download:
+ self.download()
+
+ if not self._check_datafile_exists():
+ raise RuntimeError('Dataset not found.' +
+ ' You can use download=True to download it')
+
+ # load the serialized data
+ self.data, self.labels, self.matches = torch.load(self.data_file)
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (data1, data2, matches)
+ """
+ if self.train:
+ data = self.data[index]
+ if self.transform is not None:
+ data = self.transform(data)
+ return data
+ m = self.matches[index]
+ data1, data2 = self.data[m[0]], self.data[m[1]]
+ if self.transform is not None:
+ data1 = self.transform(data1)
+ data2 = self.transform(data2)
+ return data1, data2, m[2]
+
+ def __len__(self):
+ if self.train:
+ return self.lens[self.name]
+ return len(self.matches)
+
+ def _check_datafile_exists(self):
+ return os.path.exists(self.data_file)
+
+ def _check_downloaded(self):
+ return os.path.exists(self.data_dir)
+
+ def download(self):
+ if self._check_datafile_exists():
+ print('# Found cached data {}'.format(self.data_file))
+ return
+
+ if not self._check_downloaded():
+ # download files
+ url = self.urls[self.name][0]
+ filename = self.urls[self.name][1]
+ md5 = self.urls[self.name][2]
+ fpath = os.path.join(self.root, filename)
+
+ download_url(url, self.root, filename, md5)
+
+ print('# Extracting data {}\n'.format(self.data_down))
+
+ import zipfile
+ with zipfile.ZipFile(fpath, 'r') as z:
+ z.extractall(self.data_dir)
+
+ os.unlink(fpath)
+
+ # process and save as torch files
+ print('# Caching data {}'.format(self.data_file))
+
+ dataset = (
+ read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
+ read_info_file(self.data_dir, self.info_file),
+ read_matches_files(self.data_dir, self.matches_files)
+ )
+
+ with open(self.data_file, 'wb') as f:
+ torch.save(dataset, f)
+
+ def extra_repr(self):
+ return "Split: {}".format("Train" if self.train is True else "Test")
+
+
+def read_image_file(data_dir, image_ext, n):
+ """Return a Tensor containing the patches
+ """
+
+ def PIL2array(_img):
+ """Convert PIL image type to numpy 2D array
+ """
+ return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)
+
+ def find_files(_data_dir, _image_ext):
+ """Return a list with the file names of the images containing the patches
+ """
+ files = []
+ # find those files with the specified extension
+ for file_dir in os.listdir(_data_dir):
+ if file_dir.endswith(_image_ext):
+ files.append(os.path.join(_data_dir, file_dir))
+ return sorted(files) # sort files in ascend order to keep relations
+
+ patches = []
+ list_files = find_files(data_dir, image_ext)
+
+ for fpath in list_files:
+ img = Image.open(fpath)
+ for y in range(0, 1024, 64):
+ for x in range(0, 1024, 64):
+ patch = img.crop((x, y, x + 64, y + 64))
+ patches.append(PIL2array(patch))
+ return torch.ByteTensor(np.array(patches[:n]))
+
+
+def read_info_file(data_dir, info_file):
+ """Return a Tensor containing the list of labels
+ Read the file and keep only the ID of the 3D point.
+ """
+ labels = []
+ with open(os.path.join(data_dir, info_file), 'r') as f:
+ labels = [int(line.split()[0]) for line in f]
+ return torch.LongTensor(labels)
+
+
+def read_matches_files(data_dir, matches_file):
+ """Return a Tensor containing the ground truth matches
+ Read the file and keep only 3D point ID.
+ Matches are represented with a 1, non matches with a 0.
+ """
+ matches = []
+ with open(os.path.join(data_dir, matches_file), 'r') as f:
+ for line in f:
+ line_split = line.split()
+ matches.append([int(line_split[0]), int(line_split[3]),
+ int(line_split[1] == line_split[4])])
+ return torch.LongTensor(matches)
+
+import os
+import shutil
+from .vision import VisionDataset
+
+import numpy as np
+
+from PIL import Image
+from .utils import download_url
+from .voc import download_extract
+
+
+[docs]class SBDataset(VisionDataset):
+ """`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_
+
+ The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
+
+ .. note ::
+
+ Please note that the train and val splits included with this dataset are different from
+ the splits in the PASCAL VOC dataset. In particular some "train" images might be part of
+ VOC2012 val.
+ If you are interested in testing on VOC 2012 val, then use `image_set='train_noval'`,
+ which excludes all val images.
+
+ .. warning::
+
+ This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+ Args:
+ root (string): Root directory of the Semantic Boundaries Dataset
+ image_set (string, optional): Select the image_set to use, ``train``, ``val`` or ``train_noval``.
+ Image set ``train_noval`` excludes VOC 2012 val images.
+ mode (string, optional): Select target type. Possible values 'boundaries' or 'segmentation'.
+ In case of 'boundaries', the target is an array of shape `[num_classes, H, W]`,
+ where `num_classes=20`.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ xy_transform (callable, optional): A function/transform that takes input sample and its target as entry
+ and returns a transformed version. Input sample is PIL image and target is a numpy array
+ if `mode='boundaries'` or PIL image if `mode='segmentation'`.
+ """
+
+ url = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
+ md5 = "82b4d87ceb2ed10f6038a1cba92111cb"
+ filename = "benchmark.tgz"
+
+ voc_train_url = "http://home.bharathh.info/pubs/codes/SBD/train_noval.txt"
+ voc_split_filename = "train_noval.txt"
+ voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
+
+ def __init__(self,
+ root,
+ image_set='train',
+ mode='boundaries',
+ download=False,
+ transforms=None):
+
+ try:
+ from scipy.io import loadmat
+ self._loadmat = loadmat
+ except ImportError:
+ raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: "
+ "pip install scipy")
+
+ super(SBDataset, self).__init__(root, transforms)
+
+ if mode not in ("segmentation", "boundaries"):
+ raise ValueError("Argument mode should be 'segmentation' or 'boundaries'")
+
+ self.image_set = image_set
+ self.mode = mode
+ self.num_classes = 20
+
+ sbd_root = self.root
+ image_dir = os.path.join(sbd_root, 'img')
+ mask_dir = os.path.join(sbd_root, 'cls')
+
+ if download:
+ download_extract(self.url, self.root, self.filename, self.md5)
+ extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset")
+ for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
+ old_path = os.path.join(extracted_ds_root, f)
+ shutil.move(old_path, sbd_root)
+ download_url(self.voc_train_url, sbd_root, self.voc_split_filename,
+ self.voc_split_md5)
+
+ if not os.path.isdir(sbd_root):
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+
+ split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt')
+
+ if not os.path.exists(split_f):
+ raise ValueError(
+ 'Wrong image_set entered! Please use image_set="train" '
+ 'or image_set="val" or image_set="train_noval"')
+
+ with open(os.path.join(split_f), "r") as f:
+ file_names = [x.strip() for x in f.readlines()]
+
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+ self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
+ assert (len(self.images) == len(self.masks))
+
+ self._get_target = self._get_segmentation_target \
+ if self.mode == "segmentation" else self._get_boundaries_target
+
+ def _get_segmentation_target(self, filepath):
+ mat = self._loadmat(filepath)
+ return Image.fromarray(mat['GTcls'][0]['Segmentation'][0])
+
+ def _get_boundaries_target(self, filepath):
+ mat = self._loadmat(filepath)
+ return np.concatenate([np.expand_dims(mat['GTcls'][0]['Boundaries'][0][i][0].toarray(), axis=0)
+ for i in range(self.num_classes)], axis=0)
+
+ def __getitem__(self, index):
+ img = Image.open(self.images[index]).convert('RGB')
+ target = self._get_target(self.masks[index])
+
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.images)
+
+ def extra_repr(self):
+ lines = ["Image set: {image_set}", "Mode: {mode}"]
+ return '\n'.join(lines).format(**self.__dict__)
+
+from PIL import Image
+from six.moves import zip
+from .utils import download_url, check_integrity
+
+import os
+from .vision import VisionDataset
+
+
+[docs]class SBU(VisionDataset):
+ """`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
+
+ Args:
+ root (string): Root directory of dataset where tarball
+ ``SBUCaptionedPhotoDataset.tar.gz`` exists.
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If True, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+ url = "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
+ filename = "SBUCaptionedPhotoDataset.tar.gz"
+ md5_checksum = '9aec147b3488753cf758b4d493422285'
+
+ def __init__(self, root, transform=None, target_transform=None,
+ download=True):
+ super(SBU, self).__init__(root)
+ self.transform = transform
+ self.target_transform = target_transform
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+
+ # Read the caption for each photo
+ self.photos = []
+ self.captions = []
+
+ file1 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')
+ file2 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_captions.txt')
+
+ for line1, line2 in zip(open(file1), open(file2)):
+ url = line1.rstrip()
+ photo = os.path.basename(url)
+ filename = os.path.join(self.root, 'dataset', photo)
+ if os.path.exists(filename):
+ caption = line2.rstrip()
+ self.photos.append(photo)
+ self.captions.append(caption)
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is a caption for the photo.
+ """
+ filename = os.path.join(self.root, 'dataset', self.photos[index])
+ img = Image.open(filename).convert('RGB')
+ if self.transform is not None:
+ img = self.transform(img)
+
+ target = self.captions[index]
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ """The number of photos in the dataset."""
+ return len(self.photos)
+
+ def _check_integrity(self):
+ """Check the md5 checksum of the downloaded tarball."""
+ root = self.root
+ fpath = os.path.join(root, self.filename)
+ if not check_integrity(fpath, self.md5_checksum):
+ return False
+ return True
+
+ def download(self):
+ """Download and extract the tarball, and download each individual photo."""
+ import tarfile
+
+ if self._check_integrity():
+ print('Files already downloaded and verified')
+ return
+
+ download_url(self.url, self.root, self.filename, self.md5_checksum)
+
+ # Extract file
+ with tarfile.open(os.path.join(self.root, self.filename), 'r:gz') as tar:
+ tar.extractall(path=self.root)
+
+ # Download individual photos
+ with open(os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')) as fh:
+ for line in fh:
+ url = line.rstrip()
+ try:
+ download_url(url, os.path.join(self.root, 'dataset'))
+ except OSError:
+ # The images point to public images on Flickr.
+ # Note: Images might be removed by users at anytime.
+ pass
+
+from __future__ import print_function
+from PIL import Image
+import os
+import os.path
+import numpy as np
+from .cifar import CIFAR10
+
+
+[docs]class STL10(CIFAR10):
+ """`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.
+
+ Args:
+ root (string): Root directory of dataset where directory
+ ``stl10_binary`` exists.
+ split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
+ Accordingly dataset is selected.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+ base_folder = 'stl10_binary'
+ url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
+ filename = "stl10_binary.tar.gz"
+ tgz_md5 = '91f7769df0f17e558f3565bffb0c7dfb'
+ class_names_file = 'class_names.txt'
+ train_list = [
+ ['train_X.bin', '918c2871b30a85fa023e0c44e0bee87f'],
+ ['train_y.bin', '5a34089d4802c674881badbb80307741'],
+ ['unlabeled_X.bin', '5242ba1fed5e4be9e1e742405eb56ca4']
+ ]
+
+ test_list = [
+ ['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'],
+ ['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e']
+ ]
+ splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
+
+ def __init__(self, root, split='train',
+ transform=None, target_transform=None, download=False):
+ if split not in self.splits:
+ raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
+ split, ', '.join(self.splits),
+ ))
+ self.root = os.path.expanduser(root)
+ self.transform = transform
+ self.target_transform = target_transform
+ self.split = split # train/test/unlabeled set
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError(
+ 'Dataset not found or corrupted. '
+ 'You can use download=True to download it')
+
+ # now load the picked numpy arrays
+ if self.split == 'train':
+ self.data, self.labels = self.__loadfile(
+ self.train_list[0][0], self.train_list[1][0])
+ elif self.split == 'train+unlabeled':
+ self.data, self.labels = self.__loadfile(
+ self.train_list[0][0], self.train_list[1][0])
+ unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
+ self.data = np.concatenate((self.data, unlabeled_data))
+ self.labels = np.concatenate(
+ (self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
+
+ elif self.split == 'unlabeled':
+ self.data, _ = self.__loadfile(self.train_list[2][0])
+ self.labels = np.asarray([-1] * self.data.shape[0])
+ else: # self.split == 'test':
+ self.data, self.labels = self.__loadfile(
+ self.test_list[0][0], self.test_list[1][0])
+
+ class_file = os.path.join(
+ self.root, self.base_folder, self.class_names_file)
+ if os.path.isfile(class_file):
+ with open(class_file) as f:
+ self.classes = f.read().splitlines()
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ if self.labels is not None:
+ img, target = self.data[index], int(self.labels[index])
+ else:
+ img, target = self.data[index], None
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(np.transpose(img, (1, 2, 0)))
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return self.data.shape[0]
+
+ def __loadfile(self, data_file, labels_file=None):
+ labels = None
+ if labels_file:
+ path_to_labels = os.path.join(
+ self.root, self.base_folder, labels_file)
+ with open(path_to_labels, 'rb') as f:
+ labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based
+
+ path_to_data = os.path.join(self.root, self.base_folder, data_file)
+ with open(path_to_data, 'rb') as f:
+ # read whole file in uint8 chunks
+ everything = np.fromfile(f, dtype=np.uint8)
+ images = np.reshape(everything, (-1, 3, 96, 96))
+ images = np.transpose(images, (0, 1, 3, 2))
+
+ return images, labels
+
+ def extra_repr(self):
+ return "Split: {split}".format(**self.__dict__)
+
+from __future__ import print_function
+from .vision import VisionDataset
+from PIL import Image
+import os
+import os.path
+import numpy as np
+from .utils import download_url, check_integrity
+
+
+[docs]class SVHN(VisionDataset):
+ """`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
+ Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
+ we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
+ expect the class labels to be in the range `[0, C-1]`
+
+ Args:
+ root (string): Root directory of dataset where directory
+ ``SVHN`` exists.
+ split (string): One of {'train', 'test', 'extra'}.
+ Accordingly dataset is selected. 'extra' is Extra training set.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+ url = ""
+ filename = ""
+ file_md5 = ""
+
+ split_list = {
+ 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
+ "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
+ 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
+ "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
+ 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
+ "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
+
+ def __init__(self, root, split='train',
+ transform=None, target_transform=None, download=False):
+ super(SVHN, self).__init__(root)
+ self.transform = transform
+ self.target_transform = target_transform
+ self.split = split # training set or test set or extra set
+
+ if self.split not in self.split_list:
+ raise ValueError('Wrong split entered! Please use split="train" '
+ 'or split="extra" or split="test"')
+
+ self.url = self.split_list[split][0]
+ self.filename = self.split_list[split][1]
+ self.file_md5 = self.split_list[split][2]
+
+ if download:
+ self.download()
+
+ if not self._check_integrity():
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+
+ # import here rather than at top of file because this is
+ # an optional dependency for torchvision
+ import scipy.io as sio
+
+ # reading(loading) mat file as array
+ loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
+
+ self.data = loaded_mat['X']
+ # loading from the .mat file gives an np array of type np.uint8
+ # converting to np.int64, so that we have a LongTensor after
+ # the conversion from the numpy array
+ # the squeeze is needed to obtain a 1D tensor
+ self.labels = loaded_mat['y'].astype(np.int64).squeeze()
+
+ # the svhn dataset assigns the class label "10" to the digit 0
+ # this makes it inconsistent with several loss functions
+ # which expect the class labels to be in the range [0, C-1]
+ np.place(self.labels, self.labels == 10, 0)
+ self.data = np.transpose(self.data, (3, 2, 0, 1))
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], int(self.labels[index])
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(np.transpose(img, (1, 2, 0)))
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.data)
+
+ def _check_integrity(self):
+ root = self.root
+ md5 = self.split_list[self.split][2]
+ fpath = os.path.join(root, self.filename)
+ return check_integrity(fpath, md5)
+
+ def download(self):
+ md5 = self.split_list[self.split][2]
+ download_url(self.url, self.root, self.filename, md5)
+
+ def extra_repr(self):
+ return "Split: {split}".format(**self.__dict__)
+
+from __future__ import print_function
+from PIL import Image
+import os
+import numpy as np
+
+from .utils import download_url
+from .vision import VisionDataset
+
+
+[docs]class USPS(VisionDataset):
+ """`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.
+ The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``.
+ The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
+ and make pixel values in ``[0, 255]``.
+
+ Args:
+ root (string): Root directory of dataset to store``USPS`` data files.
+ train (bool, optional): If True, creates dataset from ``usps.bz2``,
+ otherwise from ``usps.t.bz2``.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+
+ """
+ split_list = {
+ 'train': [
+ "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
+ "usps.bz2", 'ec16c51db3855ca6c91edd34d0e9b197'
+ ],
+ 'test': [
+ "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
+ "usps.t.bz2", '8ea070ee2aca1ac39742fdd1ef5ed118'
+ ],
+ }
+
+ def __init__(self, root, train=True, transform=None, target_transform=None,
+ download=False):
+ super(USPS, self).__init__(root, transform=transform,
+ target_transform=target_transform)
+ split = 'train' if train else 'test'
+ url, filename, checksum = self.split_list[split]
+ full_path = os.path.join(self.root, filename)
+
+ if download and not os.path.exists(full_path):
+ download_url(url, self.root, filename, md5=checksum)
+
+ import bz2
+ with bz2.open(full_path) as fp:
+ raw_data = [l.decode().split() for l in fp.readlines()]
+ imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
+ imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
+ imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
+ targets = [int(d[0]) - 1 for d in raw_data]
+
+ self.data = imgs
+ self.targets = targets
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], int(self.targets[index])
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(img, mode='L')
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.data)
+
+import os
+import sys
+import tarfile
+import collections
+from .vision import VisionDataset
+
+if sys.version_info[0] == 2:
+ import xml.etree.cElementTree as ET
+else:
+ import xml.etree.ElementTree as ET
+
+from PIL import Image
+from .utils import download_url, check_integrity
+
+DATASET_YEAR_DICT = {
+ '2012': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
+ 'filename': 'VOCtrainval_11-May-2012.tar',
+ 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
+ 'base_dir': 'VOCdevkit/VOC2012'
+ },
+ '2011': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
+ 'filename': 'VOCtrainval_25-May-2011.tar',
+ 'md5': '6c3384ef61512963050cb5d687e5bf1e',
+ 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
+ },
+ '2010': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
+ 'filename': 'VOCtrainval_03-May-2010.tar',
+ 'md5': 'da459979d0c395079b5c75ee67908abb',
+ 'base_dir': 'VOCdevkit/VOC2010'
+ },
+ '2009': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
+ 'filename': 'VOCtrainval_11-May-2009.tar',
+ 'md5': '59065e4b188729180974ef6572f6a212',
+ 'base_dir': 'VOCdevkit/VOC2009'
+ },
+ '2008': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
+ 'filename': 'VOCtrainval_11-May-2012.tar',
+ 'md5': '2629fa636546599198acfcfbfcf1904a',
+ 'base_dir': 'VOCdevkit/VOC2008'
+ },
+ '2007': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
+ 'filename': 'VOCtrainval_06-Nov-2007.tar',
+ 'md5': 'c52e279531787c972589f7e41ab4ae64',
+ 'base_dir': 'VOCdevkit/VOC2007'
+ }
+}
+
+
+[docs]class VOCSegmentation(VisionDataset):
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
+
+ Args:
+ root (string): Root directory of the VOC Dataset.
+ year (string, optional): The dataset year, supports years 2007 to 2012.
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ def __init__(self,
+ root,
+ year='2012',
+ image_set='train',
+ download=False,
+ transform=None,
+ target_transform=None,
+ transforms=None):
+ super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
+ self.year = year
+ self.url = DATASET_YEAR_DICT[year]['url']
+ self.filename = DATASET_YEAR_DICT[year]['filename']
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
+ self.image_set = image_set
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
+ voc_root = os.path.join(self.root, base_dir)
+ image_dir = os.path.join(voc_root, 'JPEGImages')
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
+
+ if download:
+ download_extract(self.url, self.root, self.filename, self.md5)
+
+ if not os.path.isdir(voc_root):
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
+
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
+
+ if not os.path.exists(split_f):
+ raise ValueError(
+ 'Wrong image_set entered! Please use image_set="train" '
+ 'or image_set="trainval" or image_set="val"')
+
+ with open(os.path.join(split_f), "r") as f:
+ file_names = [x.strip() for x in f.readlines()]
+
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
+ assert (len(self.images) == len(self.masks))
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is the image segmentation.
+ """
+ img = Image.open(self.images[index]).convert('RGB')
+ target = Image.open(self.masks[index])
+
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.images)
+
+
+[docs]class VOCDetection(VisionDataset):
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
+
+ Args:
+ root (string): Root directory of the VOC Dataset.
+ year (string, optional): The dataset year, supports years 2007 to 2012.
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ (default: alphabetic indexing of VOC's 20 classes).
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, required): A function/transform that takes in the
+ target and transforms it.
+ """
+
+ def __init__(self,
+ root,
+ year='2012',
+ image_set='train',
+ download=False,
+ transform=None,
+ target_transform=None,
+ transforms=None):
+ super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
+ self.year = year
+ self.url = DATASET_YEAR_DICT[year]['url']
+ self.filename = DATASET_YEAR_DICT[year]['filename']
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
+ self.image_set = image_set
+
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
+ voc_root = os.path.join(self.root, base_dir)
+ image_dir = os.path.join(voc_root, 'JPEGImages')
+ annotation_dir = os.path.join(voc_root, 'Annotations')
+
+ if download:
+ download_extract(self.url, self.root, self.filename, self.md5)
+
+ if not os.path.isdir(voc_root):
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+
+ splits_dir = os.path.join(voc_root, 'ImageSets/Main')
+
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
+
+ if not os.path.exists(split_f):
+ raise ValueError(
+ 'Wrong image_set entered! Please use image_set="train" '
+ 'or image_set="trainval" or image_set="val" or a valid'
+ 'image_set from the VOC ImageSets/Main folder.')
+
+ with open(os.path.join(split_f), "r") as f:
+ file_names = [x.strip() for x in f.readlines()]
+
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+ self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
+ assert (len(self.images) == len(self.annotations))
+
+[docs] def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is a dictionary of the XML tree.
+ """
+ img = Image.open(self.images[index]).convert('RGB')
+ target = self.parse_voc_xml(
+ ET.parse(self.annotations[index]).getroot())
+
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.images)
+
+ def parse_voc_xml(self, node):
+ voc_dict = {}
+ children = list(node)
+ if children:
+ def_dic = collections.defaultdict(list)
+ for dc in map(self.parse_voc_xml, children):
+ for ind, v in dc.items():
+ def_dic[ind].append(v)
+ voc_dict = {
+ node.tag:
+ {ind: v[0] if len(v) == 1 else v
+ for ind, v in def_dic.items()}
+ }
+ if node.text:
+ text = node.text.strip()
+ if not children:
+ voc_dict[node.tag] = text
+ return voc_dict
+
+
+def download_extract(url, root, filename, md5):
+ download_url(url, root, filename, md5)
+ with tarfile.open(os.path.join(root, filename), "r") as tar:
+ tar.extractall(path=root)
+
+import torch.nn as nn
+from .utils import load_state_dict_from_url
+
+
+__all__ = ['AlexNet', 'alexnet']
+
+
+model_urls = {
+ 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
+}
+
+
+class AlexNet(nn.Module):
+
+ def __init__(self, num_classes=1000):
+ super(AlexNet, self).__init__()
+ self.features = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ )
+ self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
+ self.classifier = nn.Sequential(
+ nn.Dropout(),
+ nn.Linear(256 * 6 * 6, 4096),
+ nn.ReLU(inplace=True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(inplace=True),
+ nn.Linear(4096, num_classes),
+ )
+
+ def forward(self, x):
+ x = self.features(x)
+ x = self.avgpool(x)
+ x = x.view(x.size(0), 256 * 6 * 6)
+ x = self.classifier(x)
+ return x
+
+
+[docs]def alexnet(pretrained=False, progress=True, **kwargs):
+ r"""AlexNet model architecture from the
+ `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ model = AlexNet(**kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls['alexnet'],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+import re
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .utils import load_state_dict_from_url
+from collections import OrderedDict
+
+__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
+
+model_urls = {
+ 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
+ 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
+ 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
+ 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
+}
+
+
+class _DenseLayer(nn.Sequential):
+ def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
+ super(_DenseLayer, self).__init__()
+ self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
+ self.add_module('relu1', nn.ReLU(inplace=True)),
+ self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
+ growth_rate, kernel_size=1, stride=1,
+ bias=False)),
+ self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
+ self.add_module('relu2', nn.ReLU(inplace=True)),
+ self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
+ kernel_size=3, stride=1, padding=1,
+ bias=False)),
+ self.drop_rate = drop_rate
+
+ def forward(self, x):
+ new_features = super(_DenseLayer, self).forward(x)
+ if self.drop_rate > 0:
+ new_features = F.dropout(new_features, p=self.drop_rate,
+ training=self.training)
+ return torch.cat([x, new_features], 1)
+
+
+class _DenseBlock(nn.Sequential):
+ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
+ super(_DenseBlock, self).__init__()
+ for i in range(num_layers):
+ layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate,
+ bn_size, drop_rate)
+ self.add_module('denselayer%d' % (i + 1), layer)
+
+
+class _Transition(nn.Sequential):
+ def __init__(self, num_input_features, num_output_features):
+ super(_Transition, self).__init__()
+ self.add_module('norm', nn.BatchNorm2d(num_input_features))
+ self.add_module('relu', nn.ReLU(inplace=True))
+ self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
+ kernel_size=1, stride=1, bias=False))
+ self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
+
+
+class DenseNet(nn.Module):
+ r"""Densenet-BC model class, based on
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
+
+ Args:
+ growth_rate (int) - how many filters to add each layer (`k` in paper)
+ block_config (list of 4 ints) - how many layers in each pooling block
+ num_init_features (int) - the number of filters to learn in the first convolution layer
+ bn_size (int) - multiplicative factor for number of bottle neck layers
+ (i.e. bn_size * k features in the bottleneck layer)
+ drop_rate (float) - dropout rate after each dense layer
+ num_classes (int) - number of classification classes
+ """
+
+ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
+ num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
+
+ super(DenseNet, self).__init__()
+
+ # First convolution
+ self.features = nn.Sequential(OrderedDict([
+ ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
+ padding=3, bias=False)),
+ ('norm0', nn.BatchNorm2d(num_init_features)),
+ ('relu0', nn.ReLU(inplace=True)),
+ ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
+ ]))
+
+ # Each denseblock
+ num_features = num_init_features
+ for i, num_layers in enumerate(block_config):
+ block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
+ bn_size=bn_size, growth_rate=growth_rate,
+ drop_rate=drop_rate)
+ self.features.add_module('denseblock%d' % (i + 1), block)
+ num_features = num_features + num_layers * growth_rate
+ if i != len(block_config) - 1:
+ trans = _Transition(num_input_features=num_features,
+ num_output_features=num_features // 2)
+ self.features.add_module('transition%d' % (i + 1), trans)
+ num_features = num_features // 2
+
+ # Final batch norm
+ self.features.add_module('norm5', nn.BatchNorm2d(num_features))
+
+ # Linear layer
+ self.classifier = nn.Linear(num_features, num_classes)
+
+ # Official init from torch repo.
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ features = self.features(x)
+ out = F.relu(features, inplace=True)
+ out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
+ out = self.classifier(out)
+ return out
+
+
+def _load_state_dict(model, model_url, progress):
+ # '.'s are no longer allowed in module names, but previous _DenseLayer
+ # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
+ # They are also in the checkpoints in model_urls. This pattern is used
+ # to find such keys.
+ pattern = re.compile(
+ r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
+
+ state_dict = load_state_dict_from_url(model_url, progress=progress)
+ for key in list(state_dict.keys()):
+ res = pattern.match(key)
+ if res:
+ new_key = res.group(1) + res.group(2)
+ state_dict[new_key] = state_dict[key]
+ del state_dict[key]
+ model.load_state_dict(state_dict)
+
+
+def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress,
+ **kwargs):
+ model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
+ if pretrained:
+ _load_state_dict(model, model_urls[arch], progress)
+ return model
+
+
+[docs]def densenet121(pretrained=False, progress=True, **kwargs):
+ r"""Densenet-121 model from
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
+ **kwargs)
+
+
+[docs]def densenet161(pretrained=False, progress=True, **kwargs):
+ r"""Densenet-161 model from
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
+ **kwargs)
+
+
+[docs]def densenet169(pretrained=False, progress=True, **kwargs):
+ r"""Densenet-169 model from
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
+ **kwargs)
+
+
+[docs]def densenet201(pretrained=False, progress=True, **kwargs):
+ r"""Densenet-201 model from
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
+ **kwargs)
+
+from collections import OrderedDict
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from torchvision.ops import misc as misc_nn_ops
+from torchvision.ops import MultiScaleRoIAlign
+
+from ..utils import load_state_dict_from_url
+
+from .generalized_rcnn import GeneralizedRCNN
+from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
+from .roi_heads import RoIHeads
+from .transform import GeneralizedRCNNTransform
+from .backbone_utils import resnet_fpn_backbone
+
+
+__all__ = [
+ "FasterRCNN", "fasterrcnn_resnet50_fpn",
+]
+
+
+class FasterRCNN(GeneralizedRCNN):
+ """
+ Implements Faster R-CNN.
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes.
+
+ The behavior of the model changes depending if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors, as well as a targets (list of dictionary),
+ containing:
+ - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values
+ between 0 and H and 0 and W
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+ The model returns a Dict[Tensor] during training, containing the classification and regression
+ losses for both the RPN and the R-CNN.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows:
+ - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between
+ 0 and H and 0 and W
+ - labels (Int64Tensor[N]): the predicted labels for each image
+ - scores (Tensor[N]): the scores or each prediction
+
+ Arguments:
+ backbone (nn.Module): the network used to compute the features for the model.
+ It should contain a out_channels attribute, which indicates the number of output
+ channels that each feature map has (and it should be the same for all feature maps).
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
+ num_classes (int): number of output classes of the model (including the background).
+ If box_predictor is specified, num_classes should be None.
+ min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+ max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
+ They are generally the mean values of the dataset on which the backbone has been trained
+ on
+ image_std (Tuple[float, float, float]): std values used for input normalization.
+ They are generally the std values of the dataset on which the backbone has been trained on
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps.
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training of the RPN.
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+ considered as negative during training of the RPN.
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+ for computing the loss
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+ of the RPN
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes
+ box_head (nn.Module): module that takes the cropped feature maps as input
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
+ classification logits and box regression deltas.
+ box_score_thresh (float): during inference, only return proposals with a classification score
+ greater than box_score_thresh
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+ considered as positive during training of the classification head
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+ considered as negative during training of the classification head
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
+ classification head
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+ of the classification head
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+ bounding boxes
+
+ Example::
+
+ >>> import torch
+ >>> import torchvision
+ >>> from torchvision.models.detection import FasterRCNN
+ >>> from torchvision.models.detection.rpn import AnchorGenerator
+ >>> # load a pre-trained model for classification and return
+ >>> # only the features
+ >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
+ >>> # FasterRCNN needs to know the number of
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280
+ >>> # so we need to add it here
+ >>> backbone.out_channels = 1280
+ >>>
+ >>> # let's make the RPN generate 5 x 3 anchors per spatial
+ >>> # location, with 5 different sizes and 3 different aspect
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+ >>> # map could potentially have different sizes and
+ >>> # aspect ratios
+ >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+ >>> aspect_ratios=((0.5, 1.0, 2.0),))
+ >>>
+ >>> # let's define what are the feature maps that we will
+ >>> # use to perform the region of interest cropping, as well as
+ >>> # the size of the crop after rescaling.
+ >>> # if your backbone returns a Tensor, featmap_names is expected to
+ >>> # be [0]. More generally, the backbone should return an
+ >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+ >>> # feature maps to use.
+ >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
+ >>> output_size=7,
+ >>> sampling_ratio=2)
+ >>>
+ >>> # put the pieces together inside a FasterRCNN model
+ >>> model = FasterRCNN(backbone,
+ >>> num_classes=2,
+ >>> rpn_anchor_generator=anchor_generator,
+ >>> box_roi_pool=roi_pooler)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+ """
+
+ def __init__(self, backbone, num_classes=None,
+ # transform parameters
+ min_size=800, max_size=1333,
+ image_mean=None, image_std=None,
+ # RPN parameters
+ rpn_anchor_generator=None, rpn_head=None,
+ rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
+ rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
+ rpn_nms_thresh=0.7,
+ rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
+ rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
+ # Box parameters
+ box_roi_pool=None, box_head=None, box_predictor=None,
+ box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
+ box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
+ box_batch_size_per_image=512, box_positive_fraction=0.25,
+ bbox_reg_weights=None):
+
+ if not hasattr(backbone, "out_channels"):
+ raise ValueError(
+ "backbone should contain an attribute out_channels "
+ "specifying the number of output channels (assumed to be the "
+ "same for all the levels)")
+
+ assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None)))
+ assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None)))
+
+ if num_classes is not None:
+ if box_predictor is not None:
+ raise ValueError("num_classes should be None when box_predictor is specified")
+ else:
+ if box_predictor is None:
+ raise ValueError("num_classes should not be None when box_predictor "
+ "is not specified")
+
+ out_channels = backbone.out_channels
+
+ if rpn_anchor_generator is None:
+ anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+ rpn_anchor_generator = AnchorGenerator(
+ anchor_sizes, aspect_ratios
+ )
+ if rpn_head is None:
+ rpn_head = RPNHead(
+ out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
+ )
+
+ rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
+ rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
+
+ rpn = RegionProposalNetwork(
+ rpn_anchor_generator, rpn_head,
+ rpn_fg_iou_thresh, rpn_bg_iou_thresh,
+ rpn_batch_size_per_image, rpn_positive_fraction,
+ rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
+
+ if box_roi_pool is None:
+ box_roi_pool = MultiScaleRoIAlign(
+ featmap_names=[0, 1, 2, 3],
+ output_size=7,
+ sampling_ratio=2)
+
+ if box_head is None:
+ resolution = box_roi_pool.output_size[0]
+ representation_size = 1024
+ box_head = TwoMLPHead(
+ out_channels * resolution ** 2,
+ representation_size)
+
+ if box_predictor is None:
+ representation_size = 1024
+ box_predictor = FastRCNNPredictor(
+ representation_size,
+ num_classes)
+
+ roi_heads = RoIHeads(
+ # Box
+ box_roi_pool, box_head, box_predictor,
+ box_fg_iou_thresh, box_bg_iou_thresh,
+ box_batch_size_per_image, box_positive_fraction,
+ bbox_reg_weights,
+ box_score_thresh, box_nms_thresh, box_detections_per_img)
+
+ if image_mean is None:
+ image_mean = [0.485, 0.456, 0.406]
+ if image_std is None:
+ image_std = [0.229, 0.224, 0.225]
+ transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
+
+ super(FasterRCNN, self).__init__(backbone, rpn, roi_heads, transform)
+
+
+class TwoMLPHead(nn.Module):
+ """
+ Standard heads for FPN-based models
+
+ Arguments:
+ in_channels (int): number of input channels
+ representation_size (int): size of the intermediate representation
+ """
+
+ def __init__(self, in_channels, representation_size):
+ super(TwoMLPHead, self).__init__()
+
+ self.fc6 = nn.Linear(in_channels, representation_size)
+ self.fc7 = nn.Linear(representation_size, representation_size)
+
+ def forward(self, x):
+ x = x.flatten(start_dim=1)
+
+ x = F.relu(self.fc6(x))
+ x = F.relu(self.fc7(x))
+
+ return x
+
+
+class FastRCNNPredictor(nn.Module):
+ """
+ Standard classification + bounding box regression layers
+ for Fast R-CNN.
+
+ Arguments:
+ in_channels (int): number of input channels
+ num_classes (int): number of output classes (including background)
+ """
+
+ def __init__(self, in_channels, num_classes):
+ super(FastRCNNPredictor, self).__init__()
+ self.cls_score = nn.Linear(in_channels, num_classes)
+ self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
+
+ def forward(self, x):
+ if x.ndimension() == 4:
+ assert list(x.shape[2:]) == [1, 1]
+ x = x.flatten(start_dim=1)
+ scores = self.cls_score(x)
+ bbox_deltas = self.bbox_pred(x)
+
+ return scores, bbox_deltas
+
+
+model_urls = {
+ 'fasterrcnn_resnet50_fpn_coco':
+ 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth',
+}
+
+
+[docs]def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
+ num_classes=91, pretrained_backbone=True, **kwargs):
+ """
+ Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.
+
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+ image, and should be in ``0-1`` range. Different images can have different sizes.
+
+ The behavior of the model changes depending if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors, as well as a targets (list of dictionary),
+ containing:
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values
+ between ``0`` and ``H`` and ``0`` and ``W``
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+ losses for both the RPN and the R-CNN.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+ follows:
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between
+ ``0`` and ``H`` and ``0`` and ``W``
+ - labels (``Int64Tensor[N]``): the predicted labels for each image
+ - scores (``Tensor[N]``): the scores or each prediction
+
+ Example::
+
+ >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+
+ Arguments:
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ if pretrained:
+ # no need to download the backbone if pretrained is set
+ pretrained_backbone = False
+ backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
+ model = FasterRCNN(backbone, num_classes, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+import torch
+from torch import nn
+
+from torchvision.ops import misc as misc_nn_ops
+from torchvision.ops import MultiScaleRoIAlign
+
+from ..utils import load_state_dict_from_url
+
+from .faster_rcnn import FasterRCNN
+from .backbone_utils import resnet_fpn_backbone
+
+
+__all__ = [
+ "KeypointRCNN", "keypointrcnn_resnet50_fpn"
+]
+
+
+class KeypointRCNN(FasterRCNN):
+ """
+ Implements Keypoint R-CNN.
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes.
+
+ The behavior of the model changes depending if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors, as well as a targets (list of dictionary),
+ containing:
+ - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values
+ between 0 and H and 0 and W
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+ - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
+ format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
+
+ The model returns a Dict[Tensor] during training, containing the classification and regression
+ losses for both the RPN and the R-CNN, and the keypoint loss.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows:
+ - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between
+ 0 and H and 0 and W
+ - labels (Int64Tensor[N]): the predicted labels for each image
+ - scores (Tensor[N]): the scores or each prediction
+ - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
+
+ Arguments:
+ backbone (nn.Module): the network used to compute the features for the model.
+ It should contain a out_channels attribute, which indicates the number of output
+ channels that each feature map has (and it should be the same for all feature maps).
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
+ num_classes (int): number of output classes of the model (including the background).
+ If box_predictor is specified, num_classes should be None.
+ min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+ max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
+ They are generally the mean values of the dataset on which the backbone has been trained
+ on
+ image_std (Tuple[float, float, float]): std values used for input normalization.
+ They are generally the std values of the dataset on which the backbone has been trained on
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps.
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training of the RPN.
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+ considered as negative during training of the RPN.
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+ for computing the loss
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+ of the RPN
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes
+ box_head (nn.Module): module that takes the cropped feature maps as input
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
+ classification logits and box regression deltas.
+ box_score_thresh (float): during inference, only return proposals with a classification score
+ greater than box_score_thresh
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+ considered as positive during training of the classification head
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+ considered as negative during training of the classification head
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
+ classification head
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+ of the classification head
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+ bounding boxes
+ keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes, which will be used for the keypoint head.
+ keypoint_head (nn.Module): module that takes the cropped feature maps as input
+ keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
+ heatmap logits
+
+ Example::
+
+ >>> import torchvision
+ >>> from torchvision.models.detection import KeypointRCNN
+ >>> from torchvision.models.detection.rpn import AnchorGenerator
+ >>>
+ >>> # load a pre-trained model for classification and return
+ >>> # only the features
+ >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
+ >>> # KeypointRCNN needs to know the number of
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280
+ >>> # so we need to add it here
+ >>> backbone.out_channels = 1280
+ >>>
+ >>> # let's make the RPN generate 5 x 3 anchors per spatial
+ >>> # location, with 5 different sizes and 3 different aspect
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+ >>> # map could potentially have different sizes and
+ >>> # aspect ratios
+ >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+ >>> aspect_ratios=((0.5, 1.0, 2.0),))
+ >>>
+ >>> # let's define what are the feature maps that we will
+ >>> # use to perform the region of interest cropping, as well as
+ >>> # the size of the crop after rescaling.
+ >>> # if your backbone returns a Tensor, featmap_names is expected to
+ >>> # be [0]. More generally, the backbone should return an
+ >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+ >>> # feature maps to use.
+ >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
+ >>> output_size=7,
+ >>> sampling_ratio=2)
+ >>>
+ >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
+ >>> output_size=14,
+ >>> sampling_ratio=2)
+ >>> # put the pieces together inside a FasterRCNN model
+ >>> model = KeypointRCNN(backbone,
+ >>> num_classes=2,
+ >>> rpn_anchor_generator=anchor_generator,
+ >>> box_roi_pool=roi_pooler,
+ >>> keypoint_roi_pool=keypoint_roi_pooler)
+ >>> model.eval()
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+ """
+ def __init__(self, backbone, num_classes=None,
+ # transform parameters
+ min_size=None, max_size=1333,
+ image_mean=None, image_std=None,
+ # RPN parameters
+ rpn_anchor_generator=None, rpn_head=None,
+ rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
+ rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
+ rpn_nms_thresh=0.7,
+ rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
+ rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
+ # Box parameters
+ box_roi_pool=None, box_head=None, box_predictor=None,
+ box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
+ box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
+ box_batch_size_per_image=512, box_positive_fraction=0.25,
+ bbox_reg_weights=None,
+ # keypoint parameters
+ keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None,
+ num_keypoints=17):
+
+ assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
+ if min_size is None:
+ min_size = (640, 672, 704, 736, 768, 800)
+
+ if num_classes is not None:
+ if keypoint_predictor is not None:
+ raise ValueError("num_classes should be None when keypoint_predictor is specified")
+
+ out_channels = backbone.out_channels
+
+ if keypoint_roi_pool is None:
+ keypoint_roi_pool = MultiScaleRoIAlign(
+ featmap_names=[0, 1, 2, 3],
+ output_size=14,
+ sampling_ratio=2)
+
+ if keypoint_head is None:
+ keypoint_layers = tuple(512 for _ in range(8))
+ keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers)
+
+ if keypoint_predictor is None:
+ keypoint_dim_reduced = 512 # == keypoint_layers[-1]
+ keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)
+
+ super(KeypointRCNN, self).__init__(
+ backbone, num_classes,
+ # transform parameters
+ min_size, max_size,
+ image_mean, image_std,
+ # RPN-specific parameters
+ rpn_anchor_generator, rpn_head,
+ rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test,
+ rpn_post_nms_top_n_train, rpn_post_nms_top_n_test,
+ rpn_nms_thresh,
+ rpn_fg_iou_thresh, rpn_bg_iou_thresh,
+ rpn_batch_size_per_image, rpn_positive_fraction,
+ # Box parameters
+ box_roi_pool, box_head, box_predictor,
+ box_score_thresh, box_nms_thresh, box_detections_per_img,
+ box_fg_iou_thresh, box_bg_iou_thresh,
+ box_batch_size_per_image, box_positive_fraction,
+ bbox_reg_weights)
+
+ self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
+ self.roi_heads.keypoint_head = keypoint_head
+ self.roi_heads.keypoint_predictor = keypoint_predictor
+
+
+class KeypointRCNNHeads(nn.Sequential):
+ def __init__(self, in_channels, layers):
+ d = []
+ next_feature = in_channels
+ for l in layers:
+ d.append(misc_nn_ops.Conv2d(next_feature, l, 3, stride=1, padding=1))
+ d.append(nn.ReLU(inplace=True))
+ next_feature = l
+ super(KeypointRCNNHeads, self).__init__(*d)
+ for m in self.children():
+ if isinstance(m, misc_nn_ops.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ nn.init.constant_(m.bias, 0)
+
+
+class KeypointRCNNPredictor(nn.Module):
+ def __init__(self, in_channels, num_keypoints):
+ super(KeypointRCNNPredictor, self).__init__()
+ input_features = in_channels
+ deconv_kernel = 4
+ self.kps_score_lowres = misc_nn_ops.ConvTranspose2d(
+ input_features,
+ num_keypoints,
+ deconv_kernel,
+ stride=2,
+ padding=deconv_kernel // 2 - 1,
+ )
+ nn.init.kaiming_normal_(
+ self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu"
+ )
+ nn.init.constant_(self.kps_score_lowres.bias, 0)
+ self.up_scale = 2
+ self.out_channels = num_keypoints
+
+ def forward(self, x):
+ x = self.kps_score_lowres(x)
+ x = misc_nn_ops.interpolate(
+ x, scale_factor=self.up_scale, mode="bilinear", align_corners=False
+ )
+ return x
+
+
+model_urls = {
+ 'keypointrcnn_resnet50_fpn_coco':
+ 'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth',
+}
+
+
+[docs]def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
+ num_classes=2, num_keypoints=17,
+ pretrained_backbone=True, **kwargs):
+ """
+ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+ image, and should be in ``0-1`` range. Different images can have different sizes.
+
+ The behavior of the model changes depending if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors, as well as a targets (list of dictionary),
+ containing:
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values
+ between ``0`` and ``H`` and ``0`` and ``W``
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+ - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+ format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+ losses for both the RPN and the R-CNN, and the keypoint loss.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+ follows:
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between
+ ``0`` and ``H`` and ``0`` and ``W``
+ - labels (``Int64Tensor[N]``): the predicted labels for each image
+ - scores (``Tensor[N]``): the scores or each prediction
+ - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+ Example::
+
+ >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+
+ Arguments:
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ if pretrained:
+ # no need to download the backbone if pretrained is set
+ pretrained_backbone = False
+ backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
+ model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls['keypointrcnn_resnet50_fpn_coco'],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+from collections import OrderedDict
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from torchvision.ops import misc as misc_nn_ops
+from torchvision.ops import MultiScaleRoIAlign
+
+from ..utils import load_state_dict_from_url
+
+from .faster_rcnn import FasterRCNN
+from .backbone_utils import resnet_fpn_backbone
+
+__all__ = [
+ "MaskRCNN", "maskrcnn_resnet50_fpn",
+]
+
+
+class MaskRCNN(FasterRCNN):
+ """
+ Implements Mask R-CNN.
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes.
+
+ The behavior of the model changes depending if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors, as well as a targets (list of dictionary),
+ containing:
+ - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values
+ between 0 and H and 0 and W
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+ - masks (UInt8Tensor[N, 1, H, W]): the segmentation binary masks for each instance
+
+ The model returns a Dict[Tensor] during training, containing the classification and regression
+ losses for both the RPN and the R-CNN, and the mask loss.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows:
+ - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between
+ 0 and H and 0 and W
+ - labels (Int64Tensor[N]): the predicted labels for each image
+ - scores (Tensor[N]): the scores or each prediction
+ - masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range. In order to
+ obtain the final segmentation masks, the soft masks can be thresholded, generally
+ with a value of 0.5 (mask >= 0.5)
+
+ Arguments:
+ backbone (nn.Module): the network used to compute the features for the model.
+ It should contain a out_channels attribute, which indicates the number of output
+ channels that each feature map has (and it should be the same for all feature maps).
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
+ num_classes (int): number of output classes of the model (including the background).
+ If box_predictor is specified, num_classes should be None.
+ min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+ max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
+ They are generally the mean values of the dataset on which the backbone has been trained
+ on
+ image_std (Tuple[float, float, float]): std values used for input normalization.
+ They are generally the std values of the dataset on which the backbone has been trained on
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps.
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training of the RPN.
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+ considered as negative during training of the RPN.
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+ for computing the loss
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+ of the RPN
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes
+ box_head (nn.Module): module that takes the cropped feature maps as input
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
+ classification logits and box regression deltas.
+ box_score_thresh (float): during inference, only return proposals with a classification score
+ greater than box_score_thresh
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+ considered as positive during training of the classification head
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+ considered as negative during training of the classification head
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
+ classification head
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+ of the classification head
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+ bounding boxes
+ mask_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes, which will be used for the mask head.
+ mask_head (nn.Module): module that takes the cropped feature maps as input
+ mask_predictor (nn.Module): module that takes the output of the mask_head and returns the
+ segmentation mask logits
+
+ Example::
+
+ >>> import torchvision
+ >>> from torchvision.models.detection import MaskRCNN
+ >>> from torchvision.models.detection.rpn import AnchorGenerator
+ >>>
+ >>> # load a pre-trained model for classification and return
+ >>> # only the features
+ >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
+ >>> # MaskRCNN needs to know the number of
+ >>> # output channels in a backbone. For mobilenet_v2, it's 1280
+ >>> # so we need to add it here
+ >>> backbone.out_channels = 1280
+ >>>
+ >>> # let's make the RPN generate 5 x 3 anchors per spatial
+ >>> # location, with 5 different sizes and 3 different aspect
+ >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+ >>> # map could potentially have different sizes and
+ >>> # aspect ratios
+ >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+ >>> aspect_ratios=((0.5, 1.0, 2.0),))
+ >>>
+ >>> # let's define what are the feature maps that we will
+ >>> # use to perform the region of interest cropping, as well as
+ >>> # the size of the crop after rescaling.
+ >>> # if your backbone returns a Tensor, featmap_names is expected to
+ >>> # be [0]. More generally, the backbone should return an
+ >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+ >>> # feature maps to use.
+ >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
+ >>> output_size=7,
+ >>> sampling_ratio=2)
+ >>>
+ >>> mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
+ >>> output_size=14,
+ >>> sampling_ratio=2)
+ >>> # put the pieces together inside a FasterRCNN model
+ >>> model = MaskRCNN(backbone,
+ >>> num_classes=2,
+ >>> rpn_anchor_generator=anchor_generator,
+ >>> box_roi_pool=roi_pooler,
+ >>> mask_roi_pool=mask_roi_pooler)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+ """
+ def __init__(self, backbone, num_classes=None,
+ # transform parameters
+ min_size=800, max_size=1333,
+ image_mean=None, image_std=None,
+ # RPN parameters
+ rpn_anchor_generator=None, rpn_head=None,
+ rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
+ rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
+ rpn_nms_thresh=0.7,
+ rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
+ rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
+ # Box parameters
+ box_roi_pool=None, box_head=None, box_predictor=None,
+ box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
+ box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
+ box_batch_size_per_image=512, box_positive_fraction=0.25,
+ bbox_reg_weights=None,
+ # Mask parameters
+ mask_roi_pool=None, mask_head=None, mask_predictor=None):
+
+ assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None)))
+
+ if num_classes is not None:
+ if mask_predictor is not None:
+ raise ValueError("num_classes should be None when mask_predictor is specified")
+
+ out_channels = backbone.out_channels
+
+ if mask_roi_pool is None:
+ mask_roi_pool = MultiScaleRoIAlign(
+ featmap_names=[0, 1, 2, 3],
+ output_size=14,
+ sampling_ratio=2)
+
+ if mask_head is None:
+ mask_layers = (256, 256, 256, 256)
+ mask_dilation = 1
+ mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation)
+
+ if mask_predictor is None:
+ mask_predictor_in_channels = 256 # == mask_layers[-1]
+ mask_dim_reduced = 256
+ mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels,
+ mask_dim_reduced, num_classes)
+
+ super(MaskRCNN, self).__init__(
+ backbone, num_classes,
+ # transform parameters
+ min_size, max_size,
+ image_mean, image_std,
+ # RPN-specific parameters
+ rpn_anchor_generator, rpn_head,
+ rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test,
+ rpn_post_nms_top_n_train, rpn_post_nms_top_n_test,
+ rpn_nms_thresh,
+ rpn_fg_iou_thresh, rpn_bg_iou_thresh,
+ rpn_batch_size_per_image, rpn_positive_fraction,
+ # Box parameters
+ box_roi_pool, box_head, box_predictor,
+ box_score_thresh, box_nms_thresh, box_detections_per_img,
+ box_fg_iou_thresh, box_bg_iou_thresh,
+ box_batch_size_per_image, box_positive_fraction,
+ bbox_reg_weights)
+
+ self.roi_heads.mask_roi_pool = mask_roi_pool
+ self.roi_heads.mask_head = mask_head
+ self.roi_heads.mask_predictor = mask_predictor
+
+
+class MaskRCNNHeads(nn.Sequential):
+ def __init__(self, in_channels, layers, dilation):
+ """
+ Arguments:
+ num_classes (int): number of output classes
+ input_size (int): number of channels of the input once it's flattened
+ representation_size (int): size of the intermediate representation
+ """
+ d = OrderedDict()
+ next_feature = in_channels
+ for layer_idx, layer_features in enumerate(layers, 1):
+ d["mask_fcn{}".format(layer_idx)] = misc_nn_ops.Conv2d(
+ next_feature, layer_features, kernel_size=3,
+ stride=1, padding=dilation, dilation=dilation)
+ d["relu{}".format(layer_idx)] = nn.ReLU(inplace=True)
+ next_feature = layer_features
+
+ super(MaskRCNNHeads, self).__init__(d)
+ for name, param in self.named_parameters():
+ if "weight" in name:
+ nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
+ # elif "bias" in name:
+ # nn.init.constant_(param, 0)
+
+
+class MaskRCNNPredictor(nn.Sequential):
+ def __init__(self, in_channels, dim_reduced, num_classes):
+ super(MaskRCNNPredictor, self).__init__(OrderedDict([
+ ("conv5_mask", misc_nn_ops.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
+ ("relu", nn.ReLU(inplace=True)),
+ ("mask_fcn_logits", misc_nn_ops.Conv2d(dim_reduced, num_classes, 1, 1, 0)),
+ ]))
+
+ for name, param in self.named_parameters():
+ if "weight" in name:
+ nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
+ # elif "bias" in name:
+ # nn.init.constant_(param, 0)
+
+
+model_urls = {
+ 'maskrcnn_resnet50_fpn_coco':
+ 'https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth',
+}
+
+
+[docs]def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
+ num_classes=91, pretrained_backbone=True, **kwargs):
+ """
+ Constructs a Mask R-CNN model with a ResNet-50-FPN backbone.
+
+ The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+ image, and should be in ``0-1`` range. Different images can have different sizes.
+
+ The behavior of the model changes depending if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors, as well as a targets (list of dictionary),
+ containing:
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values
+ between ``0`` and ``H`` and ``0`` and ``W``
+ - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+ - masks (``UInt8Tensor[N, 1, H, W]``): the segmentation binary masks for each instance
+
+ The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+ losses for both the RPN and the R-CNN, and the mask loss.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+ follows:
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between
+ ``0`` and ``H`` and ``0`` and ``W``
+ - labels (``Int64Tensor[N]``): the predicted labels for each image
+ - scores (``Tensor[N]``): the scores or each prediction
+ - masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
+ obtain the final segmentation masks, the soft masks can be thresholded, generally
+ with a value of 0.5 (``mask >= 0.5``)
+
+ Example::
+
+ >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
+ >>> model.eval()
+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+ >>> predictions = model(x)
+
+ Arguments:
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ if pretrained:
+ # no need to download the backbone if pretrained is set
+ pretrained_backbone = False
+ backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
+ model = MaskRCNN(backbone, num_classes, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+import warnings
+from collections import namedtuple
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .utils import load_state_dict_from_url
+
+__all__ = ['GoogLeNet', 'googlenet']
+
+model_urls = {
+ # GoogLeNet ported from TensorFlow
+ 'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
+}
+
+_GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1'])
+
+
+[docs]def googlenet(pretrained=False, progress=True, **kwargs):
+ r"""GoogLeNet (Inception v1) model architecture from
+ `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ aux_logits (bool): If True, adds two auxiliary branches that can improve training.
+ Default: *False* when pretrained is True otherwise *True*
+ transform_input (bool): If True, preprocesses the input according to the method with which it
+ was trained on ImageNet. Default: *False*
+ """
+ if pretrained:
+ if 'transform_input' not in kwargs:
+ kwargs['transform_input'] = True
+ if 'aux_logits' not in kwargs:
+ kwargs['aux_logits'] = False
+ if kwargs['aux_logits']:
+ warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, '
+ 'so make sure to train them')
+ original_aux_logits = kwargs['aux_logits']
+ kwargs['aux_logits'] = True
+ kwargs['init_weights'] = False
+ model = GoogLeNet(**kwargs)
+ state_dict = load_state_dict_from_url(model_urls['googlenet'],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ if not original_aux_logits:
+ model.aux_logits = False
+ del model.aux1, model.aux2
+ return model
+
+ return GoogLeNet(**kwargs)
+
+
+class GoogLeNet(nn.Module):
+
+ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
+ super(GoogLeNet, self).__init__()
+ self.aux_logits = aux_logits
+ self.transform_input = transform_input
+
+ self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
+ self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
+ self.conv2 = BasicConv2d(64, 64, kernel_size=1)
+ self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
+ self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
+
+ self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
+ self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
+ self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
+
+ self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
+ self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
+ self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
+ self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
+ self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
+ self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+ self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
+ self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
+
+ if aux_logits:
+ self.aux1 = InceptionAux(512, num_classes)
+ self.aux2 = InceptionAux(528, num_classes)
+
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.dropout = nn.Dropout(0.2)
+ self.fc = nn.Linear(1024, num_classes)
+
+ if init_weights:
+ self._initialize_weights()
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
+ import scipy.stats as stats
+ X = stats.truncnorm(-2, 2, scale=0.01)
+ values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
+ values = values.view(m.weight.size())
+ with torch.no_grad():
+ m.weight.copy_(values)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ if self.transform_input:
+ x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
+ x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
+ x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
+ x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
+
+ # N x 3 x 224 x 224
+ x = self.conv1(x)
+ # N x 64 x 112 x 112
+ x = self.maxpool1(x)
+ # N x 64 x 56 x 56
+ x = self.conv2(x)
+ # N x 64 x 56 x 56
+ x = self.conv3(x)
+ # N x 192 x 56 x 56
+ x = self.maxpool2(x)
+
+ # N x 192 x 28 x 28
+ x = self.inception3a(x)
+ # N x 256 x 28 x 28
+ x = self.inception3b(x)
+ # N x 480 x 28 x 28
+ x = self.maxpool3(x)
+ # N x 480 x 14 x 14
+ x = self.inception4a(x)
+ # N x 512 x 14 x 14
+ if self.training and self.aux_logits:
+ aux1 = self.aux1(x)
+
+ x = self.inception4b(x)
+ # N x 512 x 14 x 14
+ x = self.inception4c(x)
+ # N x 512 x 14 x 14
+ x = self.inception4d(x)
+ # N x 528 x 14 x 14
+ if self.training and self.aux_logits:
+ aux2 = self.aux2(x)
+
+ x = self.inception4e(x)
+ # N x 832 x 14 x 14
+ x = self.maxpool4(x)
+ # N x 832 x 7 x 7
+ x = self.inception5a(x)
+ # N x 832 x 7 x 7
+ x = self.inception5b(x)
+ # N x 1024 x 7 x 7
+
+ x = self.avgpool(x)
+ # N x 1024 x 1 x 1
+ x = x.view(x.size(0), -1)
+ # N x 1024
+ x = self.dropout(x)
+ x = self.fc(x)
+ # N x 1000 (num_classes)
+ if self.training and self.aux_logits:
+ return _GoogLeNetOuputs(x, aux2, aux1)
+ return x
+
+
+class Inception(nn.Module):
+
+ def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
+ super(Inception, self).__init__()
+
+ self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
+
+ self.branch2 = nn.Sequential(
+ BasicConv2d(in_channels, ch3x3red, kernel_size=1),
+ BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
+ )
+
+ self.branch3 = nn.Sequential(
+ BasicConv2d(in_channels, ch5x5red, kernel_size=1),
+ BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)
+ )
+
+ self.branch4 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
+ BasicConv2d(in_channels, pool_proj, kernel_size=1)
+ )
+
+ def forward(self, x):
+ branch1 = self.branch1(x)
+ branch2 = self.branch2(x)
+ branch3 = self.branch3(x)
+ branch4 = self.branch4(x)
+
+ outputs = [branch1, branch2, branch3, branch4]
+ return torch.cat(outputs, 1)
+
+
+class InceptionAux(nn.Module):
+
+ def __init__(self, in_channels, num_classes):
+ super(InceptionAux, self).__init__()
+ self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
+
+ self.fc1 = nn.Linear(2048, 1024)
+ self.fc2 = nn.Linear(1024, num_classes)
+
+ def forward(self, x):
+ # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
+ x = F.adaptive_avg_pool2d(x, (4, 4))
+ # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
+ x = self.conv(x)
+ # N x 128 x 4 x 4
+ x = x.view(x.size(0), -1)
+ # N x 2048
+ x = F.relu(self.fc1(x), inplace=True)
+ # N x 2048
+ x = F.dropout(x, 0.7, training=self.training)
+ # N x 2048
+ x = self.fc2(x)
+ # N x 1024
+
+ return x
+
+
+class BasicConv2d(nn.Module):
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(BasicConv2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return F.relu(x, inplace=True)
+
+from collections import namedtuple
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .utils import load_state_dict_from_url
+
+
+__all__ = ['Inception3', 'inception_v3']
+
+
+model_urls = {
+ # Inception v3 ported from TensorFlow
+ 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
+}
+
+_InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits'])
+
+
+[docs]def inception_v3(pretrained=False, progress=True, **kwargs):
+ r"""Inception v3 model architecture from
+ `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
+
+ .. note::
+ **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
+ N x 3 x 299 x 299, so ensure your images are sized accordingly.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ aux_logits (bool): If True, add an auxiliary branch that can improve training.
+ Default: *True*
+ transform_input (bool): If True, preprocesses the input according to the method with which it
+ was trained on ImageNet. Default: *False*
+ """
+ if pretrained:
+ if 'transform_input' not in kwargs:
+ kwargs['transform_input'] = True
+ if 'aux_logits' in kwargs:
+ original_aux_logits = kwargs['aux_logits']
+ kwargs['aux_logits'] = True
+ else:
+ original_aux_logits = True
+ model = Inception3(**kwargs)
+ state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ if not original_aux_logits:
+ model.aux_logits = False
+ del model.AuxLogits
+ return model
+
+ return Inception3(**kwargs)
+
+
+class Inception3(nn.Module):
+
+ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
+ super(Inception3, self).__init__()
+ self.aux_logits = aux_logits
+ self.transform_input = transform_input
+ self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
+ self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
+ self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
+ self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
+ self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
+ self.Mixed_5b = InceptionA(192, pool_features=32)
+ self.Mixed_5c = InceptionA(256, pool_features=64)
+ self.Mixed_5d = InceptionA(288, pool_features=64)
+ self.Mixed_6a = InceptionB(288)
+ self.Mixed_6b = InceptionC(768, channels_7x7=128)
+ self.Mixed_6c = InceptionC(768, channels_7x7=160)
+ self.Mixed_6d = InceptionC(768, channels_7x7=160)
+ self.Mixed_6e = InceptionC(768, channels_7x7=192)
+ if aux_logits:
+ self.AuxLogits = InceptionAux(768, num_classes)
+ self.Mixed_7a = InceptionD(768)
+ self.Mixed_7b = InceptionE(1280)
+ self.Mixed_7c = InceptionE(2048)
+ self.fc = nn.Linear(2048, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
+ import scipy.stats as stats
+ stddev = m.stddev if hasattr(m, 'stddev') else 0.1
+ X = stats.truncnorm(-2, 2, scale=stddev)
+ values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
+ values = values.view(m.weight.size())
+ with torch.no_grad():
+ m.weight.copy_(values)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ if self.transform_input:
+ x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
+ x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
+ x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
+ x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
+ # N x 3 x 299 x 299
+ x = self.Conv2d_1a_3x3(x)
+ # N x 32 x 149 x 149
+ x = self.Conv2d_2a_3x3(x)
+ # N x 32 x 147 x 147
+ x = self.Conv2d_2b_3x3(x)
+ # N x 64 x 147 x 147
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
+ # N x 64 x 73 x 73
+ x = self.Conv2d_3b_1x1(x)
+ # N x 80 x 73 x 73
+ x = self.Conv2d_4a_3x3(x)
+ # N x 192 x 71 x 71
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
+ # N x 192 x 35 x 35
+ x = self.Mixed_5b(x)
+ # N x 256 x 35 x 35
+ x = self.Mixed_5c(x)
+ # N x 288 x 35 x 35
+ x = self.Mixed_5d(x)
+ # N x 288 x 35 x 35
+ x = self.Mixed_6a(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6b(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6c(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6d(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6e(x)
+ # N x 768 x 17 x 17
+ if self.training and self.aux_logits:
+ aux = self.AuxLogits(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_7a(x)
+ # N x 1280 x 8 x 8
+ x = self.Mixed_7b(x)
+ # N x 2048 x 8 x 8
+ x = self.Mixed_7c(x)
+ # N x 2048 x 8 x 8
+ # Adaptive average pooling
+ x = F.adaptive_avg_pool2d(x, (1, 1))
+ # N x 2048 x 1 x 1
+ x = F.dropout(x, training=self.training)
+ # N x 2048 x 1 x 1
+ x = x.view(x.size(0), -1)
+ # N x 2048
+ x = self.fc(x)
+ # N x 1000 (num_classes)
+ if self.training and self.aux_logits:
+ return _InceptionOuputs(x, aux)
+ return x
+
+
+class InceptionA(nn.Module):
+
+ def __init__(self, in_channels, pool_features):
+ super(InceptionA, self).__init__()
+ self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)
+
+ self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
+ self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)
+
+ self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
+ self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
+ self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)
+
+ self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch5x5 = self.branch5x5_1(x)
+ branch5x5 = self.branch5x5_2(branch5x5)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class InceptionB(nn.Module):
+
+ def __init__(self, in_channels):
+ super(InceptionB, self).__init__()
+ self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)
+
+ self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
+ self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
+ self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2)
+
+ def forward(self, x):
+ branch3x3 = self.branch3x3(x)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
+
+ outputs = [branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class InceptionC(nn.Module):
+
+ def __init__(self, in_channels, channels_7x7):
+ super(InceptionC, self).__init__()
+ self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)
+
+ c7 = channels_7x7
+ self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)
+ self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
+ self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))
+
+ self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)
+ self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
+ self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
+ self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
+ self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
+
+ self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch7x7 = self.branch7x7_1(x)
+ branch7x7 = self.branch7x7_2(branch7x7)
+ branch7x7 = self.branch7x7_3(branch7x7)
+
+ branch7x7dbl = self.branch7x7dbl_1(x)
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class InceptionD(nn.Module):
+
+ def __init__(self, in_channels):
+ super(InceptionD, self).__init__()
+ self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
+ self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2)
+
+ self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
+ self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))
+ self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))
+ self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2)
+
+ def forward(self, x):
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = self.branch3x3_2(branch3x3)
+
+ branch7x7x3 = self.branch7x7x3_1(x)
+ branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
+ branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
+ branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
+
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
+ outputs = [branch3x3, branch7x7x3, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class InceptionE(nn.Module):
+
+ def __init__(self, in_channels):
+ super(InceptionE, self).__init__()
+ self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)
+
+ self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1)
+ self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
+ self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
+
+ self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1)
+ self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
+ self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
+ self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
+
+ self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class InceptionAux(nn.Module):
+
+ def __init__(self, in_channels, num_classes):
+ super(InceptionAux, self).__init__()
+ self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
+ self.conv1 = BasicConv2d(128, 768, kernel_size=5)
+ self.conv1.stddev = 0.01
+ self.fc = nn.Linear(768, num_classes)
+ self.fc.stddev = 0.001
+
+ def forward(self, x):
+ # N x 768 x 17 x 17
+ x = F.avg_pool2d(x, kernel_size=5, stride=3)
+ # N x 768 x 5 x 5
+ x = self.conv0(x)
+ # N x 128 x 5 x 5
+ x = self.conv1(x)
+ # N x 768 x 1 x 1
+ # Adaptive average pooling
+ x = F.adaptive_avg_pool2d(x, (1, 1))
+ # N x 768 x 1 x 1
+ x = x.view(x.size(0), -1)
+ # N x 768
+ x = self.fc(x)
+ # N x 1000
+ return x
+
+
+class BasicConv2d(nn.Module):
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(BasicConv2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return F.relu(x, inplace=True)
+
+import math
+
+import torch
+import torch.nn as nn
+from .utils import load_state_dict_from_url
+
+__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
+
+_MODEL_URLS = {
+ "mnasnet0_5":
+ "https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
+ "mnasnet0_75": None,
+ "mnasnet1_0":
+ "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
+ "mnasnet1_3": None
+}
+
+# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
+# 1.0 - tensorflow.
+_BN_MOMENTUM = 1 - 0.9997
+
+
+class _InvertedResidual(nn.Module):
+
+ def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor,
+ bn_momentum=0.1):
+ super(_InvertedResidual, self).__init__()
+ assert stride in [1, 2]
+ assert kernel_size in [3, 5]
+ mid_ch = in_ch * expansion_factor
+ self.apply_residual = (in_ch == out_ch and stride == 1)
+ self.layers = nn.Sequential(
+ # Pointwise
+ nn.Conv2d(in_ch, mid_ch, 1, bias=False),
+ nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
+ nn.ReLU(inplace=True),
+ # Depthwise
+ nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2,
+ stride=stride, groups=mid_ch, bias=False),
+ nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
+ nn.ReLU(inplace=True),
+ # Linear pointwise. Note that there's no activation.
+ nn.Conv2d(mid_ch, out_ch, 1, bias=False),
+ nn.BatchNorm2d(out_ch, momentum=bn_momentum))
+
+ def forward(self, input):
+ if self.apply_residual:
+ return self.layers(input) + input
+ else:
+ return self.layers(input)
+
+
+def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats,
+ bn_momentum):
+ """ Creates a stack of inverted residuals. """
+ assert repeats >= 1
+ # First one has no skip, because feature map size changes.
+ first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor,
+ bn_momentum=bn_momentum)
+ remaining = []
+ for _ in range(1, repeats):
+ remaining.append(
+ _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor,
+ bn_momentum=bn_momentum))
+ return nn.Sequential(first, *remaining)
+
+
+def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
+ """ Asymmetric rounding to make `val` divisible by `divisor`. With default
+ bias, will round up, unless the number is no more than 10% greater than the
+ smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
+ assert 0.0 < round_up_bias < 1.0
+ new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
+ return new_val if new_val >= round_up_bias * val else new_val + divisor
+
+
+def _scale_depths(depths, alpha):
+ """ Scales tensor depths as in reference MobileNet code, prefers rouding up
+ rather than down. """
+ return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
+
+
+class MNASNet(torch.nn.Module):
+ """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf.
+ >>> model = MNASNet(1000, 1.0)
+ >>> x = torch.rand(1, 3, 224, 224)
+ >>> y = model(x)
+ >>> y.dim()
+ 1
+ >>> y.nelement()
+ 1000
+ """
+
+ def __init__(self, alpha, num_classes=1000, dropout=0.2):
+ super(MNASNet, self).__init__()
+ depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha)
+ layers = [
+ # First layer: regular conv.
+ nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
+ nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
+ nn.ReLU(inplace=True),
+ # Depthwise separable, no skip.
+ nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
+ nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
+ nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
+ # MNASNet blocks: stacks of inverted residuals.
+ _stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM),
+ _stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM),
+ _stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM),
+ _stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM),
+ _stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM),
+ _stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM),
+ # Final mapping to classifier input.
+ nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False),
+ nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
+ nn.ReLU(inplace=True),
+ ]
+ self.layers = nn.Sequential(*layers)
+ self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True),
+ nn.Linear(1280, num_classes))
+ self._initialize_weights()
+
+ def forward(self, x):
+ x = self.layers(x)
+ # Equivalent to global avgpool and removing H and W dimensions.
+ x = x.mean([2, 3])
+ return self.classifier(x)
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out",
+ nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0.01)
+ nn.init.zeros_(m.bias)
+
+
+def _load_pretrained(model_name, model, progress):
+ if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
+ raise ValueError(
+ "No checkpoint is available for model type {}".format(model_name))
+ checkpoint_url = _MODEL_URLS[model_name]
+ model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress))
+
+
+[docs]def mnasnet0_5(pretrained=False, progress=True, **kwargs):
+ """MNASNet with depth multiplier of 0.5 from
+ `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
+ <https://arxiv.org/pdf/1807.11626.pdf>`_.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ model = MNASNet(0.5, **kwargs)
+ if pretrained:
+ _load_pretrained("mnasnet0_5", model, progress)
+ return model
+
+
+[docs]def mnasnet0_75(pretrained=False, progress=True, **kwargs):
+ """MNASNet with depth multiplier of 0.75 from
+ `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
+ <https://arxiv.org/pdf/1807.11626.pdf>`_.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ model = MNASNet(0.75, **kwargs)
+ if pretrained:
+ _load_pretrained("mnasnet0_75", model, progress)
+ return model
+
+
+[docs]def mnasnet1_0(pretrained=False, progress=True, **kwargs):
+ """MNASNet with depth multiplier of 1.0 from
+ `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
+ <https://arxiv.org/pdf/1807.11626.pdf>`_.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ model = MNASNet(1.0, **kwargs)
+ if pretrained:
+ _load_pretrained("mnasnet1_0", model, progress)
+ return model
+
+
+[docs]def mnasnet1_3(pretrained=False, progress=True, **kwargs):
+ """MNASNet with depth multiplier of 1.3 from
+ `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
+ <https://arxiv.org/pdf/1807.11626.pdf>`_.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ model = MNASNet(1.3, **kwargs)
+ if pretrained:
+ _load_pretrained("mnasnet1_3", model, progress)
+ return model
+
+from torch import nn
+from .utils import load_state_dict_from_url
+
+
+__all__ = ['MobileNetV2', 'mobilenet_v2']
+
+
+model_urls = {
+ 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
+}
+
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+ padding = (kernel_size - 1) // 2
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
+ """
+ MobileNet V2 main class
+
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ """
+ super(MobileNetV2, self).__init__()
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+
+ if inverted_residual_setting is None:
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ [6, 160, 3, 2],
+ [6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(3, input_channel, stride=2)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+ input_channel = output_channel
+ # building last several layers
+ features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
+ # make it nn.Sequential
+ self.features = nn.Sequential(*features)
+
+ # building classifier
+ self.classifier = nn.Sequential(
+ nn.Dropout(0.2),
+ nn.Linear(self.last_channel, num_classes),
+ )
+
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ x = self.features(x)
+ x = x.mean([2, 3])
+ x = self.classifier(x)
+ return x
+
+
+[docs]def mobilenet_v2(pretrained=False, progress=True, **kwargs):
+ """
+ Constructs a MobileNetV2 architecture from
+ `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ model = MobileNetV2(**kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+import torch.nn as nn
+from .utils import load_state_dict_from_url
+
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
+
+
+def _resnet(arch, inplanes, planes, pretrained, progress, **kwargs):
+ model = ResNet(inplanes, planes, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+[docs]def resnet18(pretrained=False, progress=True, **kwargs):
+ """Constructs a ResNet-18 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+[docs]def resnet34(pretrained=False, progress=True, **kwargs):
+ """Constructs a ResNet-34 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+[docs]def resnet50(pretrained=False, progress=True, **kwargs):
+ """Constructs a ResNet-50 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+[docs]def resnet101(pretrained=False, progress=True, **kwargs):
+ """Constructs a ResNet-101 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+
+
+[docs]def resnet152(pretrained=False, progress=True, **kwargs):
+ """Constructs a ResNet-152 model.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnext50_32x4d(**kwargs):
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained=False, progress=True, **kwargs)
+
+
+def resnext101_32x8d(**kwargs):
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained=False, progress=True, **kwargs)
+
+from .._utils import IntermediateLayerGetter
+from ..utils import load_state_dict_from_url
+from .. import resnet
+from .deeplabv3 import DeepLabHead, DeepLabV3
+from .fcn import FCN, FCNHead
+
+
+__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101']
+
+
+model_urls = {
+ 'fcn_resnet50_coco': None,
+ 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth',
+ 'deeplabv3_resnet50_coco': None,
+ 'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
+}
+
+
+def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
+ backbone = resnet.__dict__[backbone_name](
+ pretrained=pretrained_backbone,
+ replace_stride_with_dilation=[False, True, True])
+
+ return_layers = {'layer4': 'out'}
+ if aux:
+ return_layers['layer3'] = 'aux'
+ backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
+
+ aux_classifier = None
+ if aux:
+ inplanes = 1024
+ aux_classifier = FCNHead(inplanes, num_classes)
+
+ model_map = {
+ 'deeplabv3': (DeepLabHead, DeepLabV3),
+ 'fcn': (FCNHead, FCN),
+ }
+ inplanes = 2048
+ classifier = model_map[name][0](inplanes, num_classes)
+ base_model = model_map[name][1]
+
+ model = base_model(backbone, classifier, aux_classifier)
+ return model
+
+
+def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
+ if pretrained:
+ aux_loss = True
+ model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
+ if pretrained:
+ arch = arch_type + '_' + backbone + '_coco'
+ model_url = model_urls[arch]
+ if model_url is None:
+ raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
+ else:
+ state_dict = load_state_dict_from_url(model_url, progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+[docs]def fcn_resnet50(pretrained=False, progress=True,
+ num_classes=21, aux_loss=None, **kwargs):
+ """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
+ contains the same classes as Pascal VOC
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
+
+
+[docs]def fcn_resnet101(pretrained=False, progress=True,
+ num_classes=21, aux_loss=None, **kwargs):
+ """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
+ contains the same classes as Pascal VOC
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
+
+
+[docs]def deeplabv3_resnet50(pretrained=False, progress=True,
+ num_classes=21, aux_loss=None, **kwargs):
+ """Constructs a DeepLabV3 model with a ResNet-50 backbone.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
+ contains the same classes as Pascal VOC
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
+
+
+[docs]def deeplabv3_resnet101(pretrained=False, progress=True,
+ num_classes=21, aux_loss=None, **kwargs):
+ """Constructs a DeepLabV3 model with a ResNet-101 backbone.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
+ contains the same classes as Pascal VOC
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
+
+import torch
+import torch.nn as nn
+from .utils import load_state_dict_from_url
+
+
+__all__ = [
+ 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
+ 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
+]
+
+model_urls = {
+ 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
+ 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
+ 'shufflenetv2_x1.5': None,
+ 'shufflenetv2_x2.0': None,
+}
+
+
+def channel_shuffle(x, groups):
+ batchsize, num_channels, height, width = x.data.size()
+ channels_per_group = num_channels // groups
+
+ # reshape
+ x = x.view(batchsize, groups,
+ channels_per_group, height, width)
+
+ x = torch.transpose(x, 1, 2).contiguous()
+
+ # flatten
+ x = x.view(batchsize, -1, height, width)
+
+ return x
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride):
+ super(InvertedResidual, self).__init__()
+
+ if not (1 <= stride <= 3):
+ raise ValueError('illegal stride value')
+ self.stride = stride
+
+ branch_features = oup // 2
+ assert (self.stride != 1) or (inp == branch_features << 1)
+
+ if self.stride > 1:
+ self.branch1 = nn.Sequential(
+ self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
+ nn.BatchNorm2d(inp),
+ nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(branch_features),
+ nn.ReLU(inplace=True),
+ )
+
+ self.branch2 = nn.Sequential(
+ nn.Conv2d(inp if (self.stride > 1) else branch_features,
+ branch_features, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(branch_features),
+ nn.ReLU(inplace=True),
+ self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
+ nn.BatchNorm2d(branch_features),
+ nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(branch_features),
+ nn.ReLU(inplace=True),
+ )
+
+ @staticmethod
+ def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
+ return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
+
+ def forward(self, x):
+ if self.stride == 1:
+ x1, x2 = x.chunk(2, dim=1)
+ out = torch.cat((x1, self.branch2(x2)), dim=1)
+ else:
+ out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
+
+ out = channel_shuffle(out, 2)
+
+ return out
+
+
+class ShuffleNetV2(nn.Module):
+ def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
+ super(ShuffleNetV2, self).__init__()
+
+ if len(stages_repeats) != 3:
+ raise ValueError('expected stages_repeats as list of 3 positive ints')
+ if len(stages_out_channels) != 5:
+ raise ValueError('expected stages_out_channels as list of 5 positive ints')
+ self._stage_out_channels = stages_out_channels
+
+ input_channels = 3
+ output_channels = self._stage_out_channels[0]
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
+ nn.BatchNorm2d(output_channels),
+ nn.ReLU(inplace=True),
+ )
+ input_channels = output_channels
+
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
+ for name, repeats, output_channels in zip(
+ stage_names, stages_repeats, self._stage_out_channels[1:]):
+ seq = [InvertedResidual(input_channels, output_channels, 2)]
+ for i in range(repeats - 1):
+ seq.append(InvertedResidual(output_channels, output_channels, 1))
+ setattr(self, name, nn.Sequential(*seq))
+ input_channels = output_channels
+
+ output_channels = self._stage_out_channels[-1]
+ self.conv5 = nn.Sequential(
+ nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(output_channels),
+ nn.ReLU(inplace=True),
+ )
+
+ self.fc = nn.Linear(output_channels, num_classes)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.maxpool(x)
+ x = self.stage2(x)
+ x = self.stage3(x)
+ x = self.stage4(x)
+ x = self.conv5(x)
+ x = x.mean([2, 3]) # globalpool
+ x = self.fc(x)
+ return x
+
+
+def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
+ model = ShuffleNetV2(*args, **kwargs)
+
+ if pretrained:
+ model_url = model_urls[arch]
+ if model_url is None:
+ raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
+ else:
+ state_dict = load_state_dict_from_url(model_url, progress=progress)
+ model.load_state_dict(state_dict)
+
+ return model
+
+
+[docs]def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
+ """
+ Constructs a ShuffleNetV2 with 0.5x output channels, as described in
+ `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
+ <https://arxiv.org/abs/1807.11164>`_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
+ [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
+
+
+[docs]def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
+ """
+ Constructs a ShuffleNetV2 with 1.0x output channels, as described in
+ `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
+ <https://arxiv.org/abs/1807.11164>`_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
+ [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
+
+
+[docs]def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
+ """
+ Constructs a ShuffleNetV2 with 1.5x output channels, as described in
+ `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
+ <https://arxiv.org/abs/1807.11164>`_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
+ [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
+
+
+[docs]def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
+ """
+ Constructs a ShuffleNetV2 with 2.0x output channels, as described in
+ `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
+ <https://arxiv.org/abs/1807.11164>`_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
+ [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
+
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+from .utils import load_state_dict_from_url
+
+__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
+
+model_urls = {
+ 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
+ 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
+}
+
+
+class Fire(nn.Module):
+
+ def __init__(self, inplanes, squeeze_planes,
+ expand1x1_planes, expand3x3_planes):
+ super(Fire, self).__init__()
+ self.inplanes = inplanes
+ self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
+ self.squeeze_activation = nn.ReLU(inplace=True)
+ self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
+ kernel_size=1)
+ self.expand1x1_activation = nn.ReLU(inplace=True)
+ self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
+ kernel_size=3, padding=1)
+ self.expand3x3_activation = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.squeeze_activation(self.squeeze(x))
+ return torch.cat([
+ self.expand1x1_activation(self.expand1x1(x)),
+ self.expand3x3_activation(self.expand3x3(x))
+ ], 1)
+
+
+class SqueezeNet(nn.Module):
+
+ def __init__(self, version='1_0', num_classes=1000):
+ super(SqueezeNet, self).__init__()
+ self.num_classes = num_classes
+ if version == '1_0':
+ self.features = nn.Sequential(
+ nn.Conv2d(3, 96, kernel_size=7, stride=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
+ Fire(96, 16, 64, 64),
+ Fire(128, 16, 64, 64),
+ Fire(128, 32, 128, 128),
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
+ Fire(256, 32, 128, 128),
+ Fire(256, 48, 192, 192),
+ Fire(384, 48, 192, 192),
+ Fire(384, 64, 256, 256),
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
+ Fire(512, 64, 256, 256),
+ )
+ elif version == '1_1':
+ self.features = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=3, stride=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
+ Fire(64, 16, 64, 64),
+ Fire(128, 16, 64, 64),
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
+ Fire(128, 32, 128, 128),
+ Fire(256, 32, 128, 128),
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
+ Fire(256, 48, 192, 192),
+ Fire(384, 48, 192, 192),
+ Fire(384, 64, 256, 256),
+ Fire(512, 64, 256, 256),
+ )
+ else:
+ # FIXME: Is this needed? SqueezeNet should only be called from the
+ # FIXME: squeezenet1_x() functions
+ # FIXME: This checking is not done for the other models
+ raise ValueError("Unsupported SqueezeNet version {version}:"
+ "1_0 or 1_1 expected".format(version=version))
+
+ # Final convolution is initialized differently from the rest
+ final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
+ self.classifier = nn.Sequential(
+ nn.Dropout(p=0.5),
+ final_conv,
+ nn.ReLU(inplace=True),
+ nn.AdaptiveAvgPool2d((1, 1))
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if m is final_conv:
+ init.normal_(m.weight, mean=0.0, std=0.01)
+ else:
+ init.kaiming_uniform_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.features(x)
+ x = self.classifier(x)
+ return x.view(x.size(0), self.num_classes)
+
+
+def _squeezenet(version, pretrained, progress, **kwargs):
+ model = SqueezeNet(version, **kwargs)
+ if pretrained:
+ arch = 'squeezenet' + version
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+[docs]def squeezenet1_0(pretrained=False, progress=True, **kwargs):
+ r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
+ accuracy with 50x fewer parameters and <0.5MB model size"
+ <https://arxiv.org/abs/1602.07360>`_ paper.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _squeezenet('1_0', pretrained, progress, **kwargs)
+
+
+[docs]def squeezenet1_1(pretrained=False, progress=True, **kwargs):
+ r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
+ <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
+ SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
+ than SqueezeNet 1.0, without sacrificing accuracy.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _squeezenet('1_1', pretrained, progress, **kwargs)
+
+import torch.nn as nn
+from .utils import load_state_dict_from_url
+
+
+__all__ = [
+ 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
+ 'vgg19_bn', 'vgg19',
+]
+
+
+model_urls = {
+ 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
+ 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
+ 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
+ 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
+ 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
+ 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
+ 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
+ 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
+}
+
+
+class VGG(nn.Module):
+
+ def __init__(self, features, num_classes=1000, init_weights=True):
+ super(VGG, self).__init__()
+ self.features = features
+ self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
+ self.classifier = nn.Sequential(
+ nn.Linear(512 * 7 * 7, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, num_classes),
+ )
+ if init_weights:
+ self._initialize_weights()
+
+ def forward(self, x):
+ x = self.features(x)
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.classifier(x)
+ return x
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.constant_(m.bias, 0)
+
+
+def make_layers(cfg, batch_norm=False):
+ layers = []
+ in_channels = 3
+ for v in cfg:
+ if v == 'M':
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+ else:
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
+ if batch_norm:
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
+ else:
+ layers += [conv2d, nn.ReLU(inplace=True)]
+ in_channels = v
+ return nn.Sequential(*layers)
+
+
+cfgs = {
+ 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+ 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+ 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
+ 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+}
+
+
+def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
+ if pretrained:
+ kwargs['init_weights'] = False
+ model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+[docs]def vgg11(pretrained=False, progress=True, **kwargs):
+ """VGG 11-layer model (configuration "A")
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
+
+
+[docs]def vgg11_bn(pretrained=False, progress=True, **kwargs):
+ """VGG 11-layer model (configuration "A") with batch normalization
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
+
+
+[docs]def vgg13(pretrained=False, progress=True, **kwargs):
+ """VGG 13-layer model (configuration "B")
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
+
+
+[docs]def vgg13_bn(pretrained=False, progress=True, **kwargs):
+ """VGG 13-layer model (configuration "B") with batch normalization
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
+
+
+[docs]def vgg16(pretrained=False, progress=True, **kwargs):
+ """VGG 16-layer model (configuration "D")
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
+
+
+[docs]def vgg16_bn(pretrained=False, progress=True, **kwargs):
+ """VGG 16-layer model (configuration "D") with batch normalization
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
+
+
+[docs]def vgg19(pretrained=False, progress=True, **kwargs):
+ """VGG 19-layer model (configuration "E")
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
+
+
+[docs]def vgg19_bn(pretrained=False, progress=True, **kwargs):
+ """VGG 19-layer model (configuration 'E') with batch normalization
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
+
+from __future__ import division
+import torch
+import sys
+import math
+from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION
+try:
+ import accimage
+except ImportError:
+ accimage = None
+import numpy as np
+import numbers
+import collections
+import warnings
+
+if sys.version_info < (3, 3):
+ Sequence = collections.Sequence
+ Iterable = collections.Iterable
+else:
+ Sequence = collections.abc.Sequence
+ Iterable = collections.abc.Iterable
+
+
+def _is_pil_image(img):
+ if accimage is not None:
+ return isinstance(img, (Image.Image, accimage.Image))
+ else:
+ return isinstance(img, Image.Image)
+
+
+def _is_tensor_image(img):
+ return torch.is_tensor(img) and img.ndimension() == 3
+
+
+def _is_numpy_image(img):
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
+
+
+[docs]def to_tensor(pic):
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+
+ See ``ToTensor`` for more details.
+
+ Args:
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+
+ Returns:
+ Tensor: Converted image.
+ """
+ if not(_is_pil_image(pic) or _is_numpy_image(pic)):
+ raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
+
+ if isinstance(pic, np.ndarray):
+ # handle numpy array
+ if pic.ndim == 2:
+ pic = pic[:, :, None]
+
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
+ # backward compatibility
+ if isinstance(img, torch.ByteTensor):
+ return img.float().div(255)
+ else:
+ return img
+
+ if accimage is not None and isinstance(pic, accimage.Image):
+ nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
+ pic.copyto(nppic)
+ return torch.from_numpy(nppic)
+
+ # handle PIL Image
+ if pic.mode == 'I':
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
+ elif pic.mode == 'I;16':
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
+ elif pic.mode == 'F':
+ img = torch.from_numpy(np.array(pic, np.float32, copy=False))
+ elif pic.mode == '1':
+ img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
+ else:
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
+ # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
+ if pic.mode == 'YCbCr':
+ nchannel = 3
+ elif pic.mode == 'I;16':
+ nchannel = 1
+ else:
+ nchannel = len(pic.mode)
+ img = img.view(pic.size[1], pic.size[0], nchannel)
+ # put it from HWC to CHW format
+ # yikes, this transpose takes 80% of the loading time/CPU
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
+ if isinstance(img, torch.ByteTensor):
+ return img.float().div(255)
+ else:
+ return img
+
+
+[docs]def to_pil_image(pic, mode=None):
+ """Convert a tensor or an ndarray to PIL Image.
+
+ See :class:`~torchvision.transforms.ToPILImage` for more details.
+
+ Args:
+ pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
+ mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
+
+ .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
+
+ Returns:
+ PIL Image: Image converted to PIL Image.
+ """
+ if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
+ raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
+
+ elif isinstance(pic, torch.Tensor):
+ if pic.ndimension() not in {2, 3}:
+ raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))
+
+ elif pic.ndimension() == 2:
+ # if 2D image, add channel dimension (CHW)
+ pic = pic.unsqueeze(0)
+
+ elif isinstance(pic, np.ndarray):
+ if pic.ndim not in {2, 3}:
+ raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
+
+ elif pic.ndim == 2:
+ # if 2D image, add channel dimension (HWC)
+ pic = np.expand_dims(pic, 2)
+
+ npimg = pic
+ if isinstance(pic, torch.FloatTensor):
+ pic = pic.mul(255).byte()
+ if isinstance(pic, torch.Tensor):
+ npimg = np.transpose(pic.numpy(), (1, 2, 0))
+
+ if not isinstance(npimg, np.ndarray):
+ raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
+ 'not {}'.format(type(npimg)))
+
+ if npimg.shape[2] == 1:
+ expected_mode = None
+ npimg = npimg[:, :, 0]
+ if npimg.dtype == np.uint8:
+ expected_mode = 'L'
+ elif npimg.dtype == np.int16:
+ expected_mode = 'I;16'
+ elif npimg.dtype == np.int32:
+ expected_mode = 'I'
+ elif npimg.dtype == np.float32:
+ expected_mode = 'F'
+ if mode is not None and mode != expected_mode:
+ raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
+ .format(mode, np.dtype, expected_mode))
+ mode = expected_mode
+
+ elif npimg.shape[2] == 2:
+ permitted_2_channel_modes = ['LA']
+ if mode is not None and mode not in permitted_2_channel_modes:
+ raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes))
+
+ if mode is None and npimg.dtype == np.uint8:
+ mode = 'LA'
+
+ elif npimg.shape[2] == 4:
+ permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
+ if mode is not None and mode not in permitted_4_channel_modes:
+ raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
+
+ if mode is None and npimg.dtype == np.uint8:
+ mode = 'RGBA'
+ else:
+ permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
+ if mode is not None and mode not in permitted_3_channel_modes:
+ raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
+ if mode is None and npimg.dtype == np.uint8:
+ mode = 'RGB'
+
+ if mode is None:
+ raise TypeError('Input type {} is not supported'.format(npimg.dtype))
+
+ return Image.fromarray(npimg, mode=mode)
+
+
+[docs]def normalize(tensor, mean, std, inplace=False):
+ """Normalize a tensor image with mean and standard deviation.
+
+ .. note::
+ This transform acts out of place by default, i.e., it does not mutates the input tensor.
+
+ See :class:`~torchvision.transforms.Normalize` for more details.
+
+ Args:
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+ mean (sequence): Sequence of means for each channel.
+ std (sequence): Sequence of standard deviations for each channel.
+
+ Returns:
+ Tensor: Normalized Tensor image.
+ """
+ if not _is_tensor_image(tensor):
+ raise TypeError('tensor is not a torch image.')
+
+ if not inplace:
+ tensor = tensor.clone()
+
+ mean = torch.as_tensor(mean, dtype=torch.float32, device=tensor.device)
+ std = torch.as_tensor(std, dtype=torch.float32, device=tensor.device)
+ tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
+ return tensor
+
+
+[docs]def resize(img, size, interpolation=Image.BILINEAR):
+ r"""Resize the input PIL Image to the given size.
+
+ Args:
+ img (PIL Image): Image to be resized.
+ size (sequence or int): Desired output size. If size is a sequence like
+ (h, w), the output size will be matched to this. If size is an int,
+ the smaller edge of the image will be matched to this number maintaing
+ the aspect ratio. i.e, if height > width, then image will be rescaled to
+ :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
+ interpolation (int, optional): Desired interpolation. Default is
+ ``PIL.Image.BILINEAR``
+
+ Returns:
+ PIL Image: Resized image.
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+ if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
+ raise TypeError('Got inappropriate size arg: {}'.format(size))
+
+ if isinstance(size, int):
+ w, h = img.size
+ if (w <= h and w == size) or (h <= w and h == size):
+ return img
+ if w < h:
+ ow = size
+ oh = int(size * h / w)
+ return img.resize((ow, oh), interpolation)
+ else:
+ oh = size
+ ow = int(size * w / h)
+ return img.resize((ow, oh), interpolation)
+ else:
+ return img.resize(size[::-1], interpolation)
+
+
+def scale(*args, **kwargs):
+ warnings.warn("The use of the transforms.Scale transform is deprecated, " +
+ "please use transforms.Resize instead.")
+ return resize(*args, **kwargs)
+
+
+[docs]def pad(img, padding, fill=0, padding_mode='constant'):
+ r"""Pad the given PIL Image on all sides with specified padding mode and fill value.
+
+ Args:
+ img (PIL Image): Image to be padded.
+ padding (int or tuple): Padding on each border. If a single int is provided this
+ is used to pad all borders. If tuple of length 2 is provided this is the padding
+ on left/right and top/bottom respectively. If a tuple of length 4 is provided
+ this is the padding for the left, top, right and bottom borders
+ respectively.
+ fill: Pixel fill value for constant fill. Default is 0. If a tuple of
+ length 3, it is used to fill R, G, B channels respectively.
+ This value is only used when the padding_mode is constant
+ padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
+
+ - constant: pads with a constant value, this value is specified with fill
+
+ - edge: pads with the last value on the edge of the image
+
+ - reflect: pads with reflection of image (without repeating the last value on the edge)
+
+ padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
+
+ - symmetric: pads with reflection of image (repeating the last value on the edge)
+
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
+
+ Returns:
+ PIL Image: Padded image.
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ if not isinstance(padding, (numbers.Number, tuple)):
+ raise TypeError('Got inappropriate padding arg')
+ if not isinstance(fill, (numbers.Number, str, tuple)):
+ raise TypeError('Got inappropriate fill arg')
+ if not isinstance(padding_mode, str):
+ raise TypeError('Got inappropriate padding_mode arg')
+
+ if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
+ raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
+ "{} element tuple".format(len(padding)))
+
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
+ 'Padding mode should be either constant, edge, reflect or symmetric'
+
+ if padding_mode == 'constant':
+ if img.mode == 'P':
+ palette = img.getpalette()
+ image = ImageOps.expand(img, border=padding, fill=fill)
+ image.putpalette(palette)
+ return image
+
+ return ImageOps.expand(img, border=padding, fill=fill)
+ else:
+ if isinstance(padding, int):
+ pad_left = pad_right = pad_top = pad_bottom = padding
+ if isinstance(padding, Sequence) and len(padding) == 2:
+ pad_left = pad_right = padding[0]
+ pad_top = pad_bottom = padding[1]
+ if isinstance(padding, Sequence) and len(padding) == 4:
+ pad_left = padding[0]
+ pad_top = padding[1]
+ pad_right = padding[2]
+ pad_bottom = padding[3]
+
+ if img.mode == 'P':
+ palette = img.getpalette()
+ img = np.asarray(img)
+ img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
+ img = Image.fromarray(img)
+ img.putpalette(palette)
+ return img
+
+ img = np.asarray(img)
+ # RGB image
+ if len(img.shape) == 3:
+ img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
+ # Grayscale image
+ if len(img.shape) == 2:
+ img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
+
+ return Image.fromarray(img)
+
+
+[docs]def crop(img, i, j, h, w):
+ """Crop the given PIL Image.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
+ h (int): Height of the cropped image.
+ w (int): Width of the cropped image.
+
+ Returns:
+ PIL Image: Cropped image.
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ return img.crop((j, i, j + w, i + h))
+
+
+def center_crop(img, output_size):
+ if isinstance(output_size, numbers.Number):
+ output_size = (int(output_size), int(output_size))
+ w, h = img.size
+ th, tw = output_size
+ i = int(round((h - th) / 2.))
+ j = int(round((w - tw) / 2.))
+ return crop(img, i, j, th, tw)
+
+
+[docs]def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
+ """Crop the given PIL Image and resize it to desired size.
+
+ Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ i (int): i in (i,j) i.e coordinates of the upper left corner
+ j (int): j in (i,j) i.e coordinates of the upper left corner
+ h (int): Height of the cropped image.
+ w (int): Width of the cropped image.
+ size (sequence or int): Desired output size. Same semantics as ``resize``.
+ interpolation (int, optional): Desired interpolation. Default is
+ ``PIL.Image.BILINEAR``.
+ Returns:
+ PIL Image: Cropped image.
+ """
+ assert _is_pil_image(img), 'img should be PIL Image'
+ img = crop(img, i, j, h, w)
+ img = resize(img, size, interpolation)
+ return img
+
+
+[docs]def hflip(img):
+ """Horizontally flip the given PIL Image.
+
+ Args:
+ img (PIL Image): Image to be flipped.
+
+ Returns:
+ PIL Image: Horizontall flipped image.
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
+
+
+def _get_perspective_coeffs(startpoints, endpoints):
+ """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
+
+ In Perspective Transform each pixel (x, y) in the orignal image gets transformed as,
+ (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
+
+ Args:
+ List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image,
+ List containing [top-left, top-right, bottom-right, bottom-left] of the transformed
+ image
+ Returns:
+ octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
+ """
+ matrix = []
+
+ for p1, p2 in zip(endpoints, startpoints):
+ matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
+ matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
+
+ A = torch.tensor(matrix, dtype=torch.float)
+ B = torch.tensor(startpoints, dtype=torch.float).view(8)
+ res = torch.gels(B, A)[0]
+ return res.squeeze_(1).tolist()
+
+
+[docs]def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC):
+ """Perform perspective transform of the given PIL Image.
+
+ Args:
+ img (PIL Image): Image to be transformed.
+ coeffs (tuple) : 8-tuple (a, b, c, d, e, f, g, h) which contains the coefficients.
+ for a perspective transform.
+ interpolation: Default- Image.BICUBIC
+ Returns:
+ PIL Image: Perspectively transformed Image.
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ coeffs = _get_perspective_coeffs(startpoints, endpoints)
+ return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation)
+
+
+[docs]def vflip(img):
+ """Vertically flip the given PIL Image.
+
+ Args:
+ img (PIL Image): Image to be flipped.
+
+ Returns:
+ PIL Image: Vertically flipped image.
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ return img.transpose(Image.FLIP_TOP_BOTTOM)
+
+
+[docs]def five_crop(img, size):
+ """Crop the given PIL Image into four corners and the central crop.
+
+ .. Note::
+ This transform returns a tuple of images and there may be a
+ mismatch in the number of inputs and targets your ``Dataset`` returns.
+
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+
+ Returns:
+ tuple: tuple (tl, tr, bl, br, center)
+ Corresponding top left, top right, bottom left, bottom right and center crop.
+ """
+ if isinstance(size, numbers.Number):
+ size = (int(size), int(size))
+ else:
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
+
+ w, h = img.size
+ crop_h, crop_w = size
+ if crop_w > w or crop_h > h:
+ raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
+ (h, w)))
+ tl = img.crop((0, 0, crop_w, crop_h))
+ tr = img.crop((w - crop_w, 0, w, crop_h))
+ bl = img.crop((0, h - crop_h, crop_w, h))
+ br = img.crop((w - crop_w, h - crop_h, w, h))
+ center = center_crop(img, (crop_h, crop_w))
+ return (tl, tr, bl, br, center)
+
+
+[docs]def ten_crop(img, size, vertical_flip=False):
+ r"""Crop the given PIL Image into four corners and the central crop plus the
+ flipped version of these (horizontal flipping is used by default).
+
+ .. Note::
+ This transform returns a tuple of images and there may be a
+ mismatch in the number of inputs and targets your ``Dataset`` returns.
+
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ vertical_flip (bool): Use vertical flipping instead of horizontal
+
+ Returns:
+ tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
+ Corresponding top left, top right, bottom left, bottom right and center crop
+ and same for the flipped image.
+ """
+ if isinstance(size, numbers.Number):
+ size = (int(size), int(size))
+ else:
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
+
+ first_five = five_crop(img, size)
+
+ if vertical_flip:
+ img = vflip(img)
+ else:
+ img = hflip(img)
+
+ second_five = five_crop(img, size)
+ return first_five + second_five
+
+
+[docs]def adjust_brightness(img, brightness_factor):
+ """Adjust brightness of an Image.
+
+ Args:
+ img (PIL Image): PIL Image to be adjusted.
+ brightness_factor (float): How much to adjust the brightness. Can be
+ any non negative number. 0 gives a black image, 1 gives the
+ original image while 2 increases the brightness by a factor of 2.
+
+ Returns:
+ PIL Image: Brightness adjusted image.
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ enhancer = ImageEnhance.Brightness(img)
+ img = enhancer.enhance(brightness_factor)
+ return img
+
+
+[docs]def adjust_contrast(img, contrast_factor):
+ """Adjust contrast of an Image.
+
+ Args:
+ img (PIL Image): PIL Image to be adjusted.
+ contrast_factor (float): How much to adjust the contrast. Can be any
+ non negative number. 0 gives a solid gray image, 1 gives the
+ original image while 2 increases the contrast by a factor of 2.
+
+ Returns:
+ PIL Image: Contrast adjusted image.
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ enhancer = ImageEnhance.Contrast(img)
+ img = enhancer.enhance(contrast_factor)
+ return img
+
+
+[docs]def adjust_saturation(img, saturation_factor):
+ """Adjust color saturation of an image.
+
+ Args:
+ img (PIL Image): PIL Image to be adjusted.
+ saturation_factor (float): How much to adjust the saturation. 0 will
+ give a black and white image, 1 will give the original image while
+ 2 will enhance the saturation by a factor of 2.
+
+ Returns:
+ PIL Image: Saturation adjusted image.
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ enhancer = ImageEnhance.Color(img)
+ img = enhancer.enhance(saturation_factor)
+ return img
+
+
+[docs]def adjust_hue(img, hue_factor):
+ """Adjust hue of an image.
+
+ The image hue is adjusted by converting the image to HSV and
+ cyclically shifting the intensities in the hue channel (H).
+ The image is then converted back to original image mode.
+
+ `hue_factor` is the amount of shift in H channel and must be in the
+ interval `[-0.5, 0.5]`.
+
+ See `Hue`_ for more details.
+
+ .. _Hue: https://en.wikipedia.org/wiki/Hue
+
+ Args:
+ img (PIL Image): PIL Image to be adjusted.
+ hue_factor (float): How much to shift the hue channel. Should be in
+ [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
+ HSV space in positive and negative direction respectively.
+ 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
+ with complementary colors while 0 gives the original image.
+
+ Returns:
+ PIL Image: Hue adjusted image.
+ """
+ if not(-0.5 <= hue_factor <= 0.5):
+ raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
+
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ input_mode = img.mode
+ if input_mode in {'L', '1', 'I', 'F'}:
+ return img
+
+ h, s, v = img.convert('HSV').split()
+
+ np_h = np.array(h, dtype=np.uint8)
+ # uint8 addition take cares of rotation across boundaries
+ with np.errstate(over='ignore'):
+ np_h += np.uint8(hue_factor * 255)
+ h = Image.fromarray(np_h, 'L')
+
+ img = Image.merge('HSV', (h, s, v)).convert(input_mode)
+ return img
+
+
+[docs]def adjust_gamma(img, gamma, gain=1):
+ r"""Perform gamma correction on an image.
+
+ Also known as Power Law Transform. Intensities in RGB mode are adjusted
+ based on the following equation:
+
+ .. math::
+ I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
+
+ See `Gamma Correction`_ for more details.
+
+ .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
+
+ Args:
+ img (PIL Image): PIL Image to be adjusted.
+ gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
+ gamma larger than 1 make the shadows darker,
+ while gamma smaller than 1 make dark regions lighter.
+ gain (float): The constant multiplier.
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ if gamma < 0:
+ raise ValueError('Gamma should be a non-negative real number')
+
+ input_mode = img.mode
+ img = img.convert('RGB')
+
+ gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
+ img = img.point(gamma_map) # use PIL's point-function to accelerate this part
+
+ img = img.convert(input_mode)
+ return img
+
+
+[docs]def rotate(img, angle, resample=False, expand=False, center=None):
+ """Rotate the image by angle.
+
+
+ Args:
+ img (PIL Image): PIL Image to be rotated.
+ angle (float or int): In degrees degrees counter clockwise order.
+ resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
+ An optional resampling filter. See `filters`_ for more information.
+ If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
+ expand (bool, optional): Optional expansion flag.
+ If true, expands the output image to make it large enough to hold the entire rotated image.
+ If false or omitted, make the output image the same size as the input image.
+ Note that the expand flag assumes rotation around the center and no translation.
+ center (2-tuple, optional): Optional center of rotation.
+ Origin is the upper left corner.
+ Default is the center of the image.
+
+ .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
+
+ """
+
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ return img.rotate(angle, resample, expand, center)
+
+
+def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
+ # Helper method to compute inverse matrix for affine transformation
+
+ # As it is explained in PIL.Image.rotate
+ # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
+ # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
+ # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
+ # RSS is rotation with scale and shear matrix
+ # RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0]
+ # [ sin(a)*scale cos(a + shear)*scale 0]
+ # [ 0 0 1]
+ # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
+
+ angle = math.radians(angle)
+ shear = math.radians(shear)
+ scale = 1.0 / scale
+
+ # Inverted rotation matrix with scale and shear
+ d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
+ matrix = [
+ math.cos(angle + shear), math.sin(angle + shear), 0,
+ -math.sin(angle), math.cos(angle), 0
+ ]
+ matrix = [scale / d * m for m in matrix]
+
+ # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
+ matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
+ matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
+
+ # Apply center translation: C * RSS^-1 * C^-1 * T^-1
+ matrix[2] += center[0]
+ matrix[5] += center[1]
+ return matrix
+
+
+[docs]def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
+ """Apply affine transformation on the image keeping image center invariant
+
+ Args:
+ img (PIL Image): PIL Image to be rotated.
+ angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction.
+ translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
+ scale (float): overall scale
+ shear (float): shear angle value in degrees between -180 to 180, clockwise direction.
+ resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
+ An optional resampling filter.
+ See `filters`_ for more information.
+ If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
+ fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
+ "Argument translate should be a list or tuple of length 2"
+
+ assert scale > 0.0, "Argument scale should be positive"
+
+ output_size = img.size
+ center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
+ matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
+ kwargs = {"fillcolor": fillcolor} if PILLOW_VERSION[0] == '5' else {}
+ return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs)
+
+
+[docs]def to_grayscale(img, num_output_channels=1):
+ """Convert image to grayscale version of image.
+
+ Args:
+ img (PIL Image): Image to be converted to grayscale.
+
+ Returns:
+ PIL Image: Grayscale version of the image.
+ if num_output_channels = 1 : returned image is single channel
+
+ if num_output_channels = 3 : returned image is 3 channel with r = g = b
+ """
+ if not _is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ if num_output_channels == 1:
+ img = img.convert('L')
+ elif num_output_channels == 3:
+ img = img.convert('L')
+ np_img = np.array(img, dtype=np.uint8)
+ np_img = np.dstack([np_img, np_img, np_img])
+ img = Image.fromarray(np_img, 'RGB')
+ else:
+ raise ValueError('num_output_channels should be either 1 or 3')
+
+ return img
+
+from __future__ import division
+import torch
+import math
+import sys
+import random
+from PIL import Image
+try:
+ import accimage
+except ImportError:
+ accimage = None
+import numpy as np
+import numbers
+import types
+import collections
+import warnings
+
+from . import functional as F
+
+if sys.version_info < (3, 3):
+ Sequence = collections.Sequence
+ Iterable = collections.Iterable
+else:
+ Sequence = collections.abc.Sequence
+ Iterable = collections.abc.Iterable
+
+
+__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
+ "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
+ "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
+ "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
+ "RandomPerspective"]
+
+_pil_interpolation_to_str = {
+ Image.NEAREST: 'PIL.Image.NEAREST',
+ Image.BILINEAR: 'PIL.Image.BILINEAR',
+ Image.BICUBIC: 'PIL.Image.BICUBIC',
+ Image.LANCZOS: 'PIL.Image.LANCZOS',
+ Image.HAMMING: 'PIL.Image.HAMMING',
+ Image.BOX: 'PIL.Image.BOX',
+}
+
+
+[docs]class Compose(object):
+ """Composes several transforms together.
+
+ Args:
+ transforms (list of ``Transform`` objects): list of transforms to compose.
+
+ Example:
+ >>> transforms.Compose([
+ >>> transforms.CenterCrop(10),
+ >>> transforms.ToTensor(),
+ >>> ])
+ """
+
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, img):
+ for t in self.transforms:
+ img = t(img)
+ return img
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += ' {0}'.format(t)
+ format_string += '\n)'
+ return format_string
+
+
+[docs]class ToTensor(object):
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
+ if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
+ or if the numpy.ndarray has dtype = np.uint8
+
+ In the other cases, tensors are returned without scaling.
+ """
+
+[docs] def __call__(self, pic):
+ """
+ Args:
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+
+ Returns:
+ Tensor: Converted image.
+ """
+ return F.to_tensor(pic)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
+
+
+[docs]class ToPILImage(object):
+ """Convert a tensor or an ndarray to PIL Image.
+
+ Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
+ H x W x C to a PIL Image while preserving the value range.
+
+ Args:
+ mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
+ If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
+ - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
+ - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
+ - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
+ - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
+ ``short``).
+
+ .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
+ """
+ def __init__(self, mode=None):
+ self.mode = mode
+
+[docs] def __call__(self, pic):
+ """
+ Args:
+ pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
+
+ Returns:
+ PIL Image: Image converted to PIL Image.
+
+ """
+ return F.to_pil_image(pic, self.mode)
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ if self.mode is not None:
+ format_string += 'mode={0}'.format(self.mode)
+ format_string += ')'
+ return format_string
+
+
+[docs]class Normalize(object):
+ """Normalize a tensor image with mean and standard deviation.
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
+
+ .. note::
+ This transform acts out of place, i.e., it does not mutates the input tensor.
+
+ Args:
+ mean (sequence): Sequence of means for each channel.
+ std (sequence): Sequence of standard deviations for each channel.
+ """
+
+ def __init__(self, mean, std, inplace=False):
+ self.mean = mean
+ self.std = std
+ self.inplace = inplace
+
+[docs] def __call__(self, tensor):
+ """
+ Args:
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+
+ Returns:
+ Tensor: Normalized Tensor image.
+ """
+ return F.normalize(tensor, self.mean, self.std, self.inplace)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
+
+
+[docs]class Resize(object):
+ """Resize the input PIL Image to the given size.
+
+ Args:
+ size (sequence or int): Desired output size. If size is a sequence like
+ (h, w), output size will be matched to this. If size is an int,
+ smaller edge of the image will be matched to this number.
+ i.e, if height > width, then image will be rescaled to
+ (size * height / width, size)
+ interpolation (int, optional): Desired interpolation. Default is
+ ``PIL.Image.BILINEAR``
+ """
+
+ def __init__(self, size, interpolation=Image.BILINEAR):
+ assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
+ self.size = size
+ self.interpolation = interpolation
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be scaled.
+
+ Returns:
+ PIL Image: Rescaled image.
+ """
+ return F.resize(img, self.size, self.interpolation)
+
+ def __repr__(self):
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
+
+
+[docs]class Scale(Resize):
+ """
+ Note: This transform is deprecated in favor of Resize.
+ """
+ def __init__(self, *args, **kwargs):
+ warnings.warn("The use of the transforms.Scale transform is deprecated, " +
+ "please use transforms.Resize instead.")
+ super(Scale, self).__init__(*args, **kwargs)
+
+
+[docs]class CenterCrop(object):
+ """Crops the given PIL Image at the center.
+
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ """
+
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be cropped.
+
+ Returns:
+ PIL Image: Cropped image.
+ """
+ return F.center_crop(img, self.size)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
+
+
+[docs]class Pad(object):
+ """Pad the given PIL Image on all sides with the given "pad" value.
+
+ Args:
+ padding (int or tuple): Padding on each border. If a single int is provided this
+ is used to pad all borders. If tuple of length 2 is provided this is the padding
+ on left/right and top/bottom respectively. If a tuple of length 4 is provided
+ this is the padding for the left, top, right and bottom borders
+ respectively.
+ fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
+ length 3, it is used to fill R, G, B channels respectively.
+ This value is only used when the padding_mode is constant
+ padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
+ Default is constant.
+
+ - constant: pads with a constant value, this value is specified with fill
+
+ - edge: pads with the last value at the edge of the image
+
+ - reflect: pads with reflection of image without repeating the last value on the edge
+
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
+
+ - symmetric: pads with reflection of image repeating the last value on the edge
+
+ For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
+ """
+
+ def __init__(self, padding, fill=0, padding_mode='constant'):
+ assert isinstance(padding, (numbers.Number, tuple))
+ assert isinstance(fill, (numbers.Number, str, tuple))
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+ if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
+ raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
+ "{} element tuple".format(len(padding)))
+
+ self.padding = padding
+ self.fill = fill
+ self.padding_mode = padding_mode
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be padded.
+
+ Returns:
+ PIL Image: Padded image.
+ """
+ return F.pad(img, self.padding, self.fill, self.padding_mode)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
+ format(self.padding, self.fill, self.padding_mode)
+
+
+[docs]class Lambda(object):
+ """Apply a user-defined lambda as a transform.
+
+ Args:
+ lambd (function): Lambda/function to be used for transform.
+ """
+
+ def __init__(self, lambd):
+ assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
+ self.lambd = lambd
+
+ def __call__(self, img):
+ return self.lambd(img)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
+
+
+class RandomTransforms(object):
+ """Base class for a list of transformations with randomness
+
+ Args:
+ transforms (list or tuple): list of transformations
+ """
+
+ def __init__(self, transforms):
+ assert isinstance(transforms, (list, tuple))
+ self.transforms = transforms
+
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError()
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += ' {0}'.format(t)
+ format_string += '\n)'
+ return format_string
+
+
+[docs]class RandomApply(RandomTransforms):
+ """Apply randomly a list of transformations with a given probability
+
+ Args:
+ transforms (list or tuple): list of transformations
+ p (float): probability
+ """
+
+ def __init__(self, transforms, p=0.5):
+ super(RandomApply, self).__init__(transforms)
+ self.p = p
+
+ def __call__(self, img):
+ if self.p < random.random():
+ return img
+ for t in self.transforms:
+ img = t(img)
+ return img
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ format_string += '\n p={}'.format(self.p)
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += ' {0}'.format(t)
+ format_string += '\n)'
+ return format_string
+
+
+[docs]class RandomOrder(RandomTransforms):
+ """Apply a list of transformations in a random order
+ """
+ def __call__(self, img):
+ order = list(range(len(self.transforms)))
+ random.shuffle(order)
+ for i in order:
+ img = self.transforms[i](img)
+ return img
+
+
+[docs]class RandomChoice(RandomTransforms):
+ """Apply single transformation randomly picked from a list
+ """
+ def __call__(self, img):
+ t = random.choice(self.transforms)
+ return t(img)
+
+
+[docs]class RandomCrop(object):
+ """Crop the given PIL Image at a random location.
+
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ padding (int or sequence, optional): Optional padding on each border
+ of the image. Default is None, i.e no padding. If a sequence of length
+ 4 is provided, it is used to pad left, top, right, bottom borders
+ respectively. If a sequence of length 2 is provided, it is used to
+ pad left/right, top/bottom borders, respectively.
+ pad_if_needed (boolean): It will pad the image if smaller than the
+ desired size to avoid raising an exception. Since cropping is done
+ after padding, the padding seems to be done at a random offset.
+ fill: Pixel fill value for constant fill. Default is 0. If a tuple of
+ length 3, it is used to fill R, G, B channels respectively.
+ This value is only used when the padding_mode is constant
+ padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
+
+ - constant: pads with a constant value, this value is specified with fill
+
+ - edge: pads with the last value on the edge of the image
+
+ - reflect: pads with reflection of image (without repeating the last value on the edge)
+
+ padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
+
+ - symmetric: pads with reflection of image (repeating the last value on the edge)
+
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
+
+ """
+
+ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+ self.padding = padding
+ self.pad_if_needed = pad_if_needed
+ self.fill = fill
+ self.padding_mode = padding_mode
+
+ @staticmethod
+ def get_params(img, output_size):
+ """Get parameters for ``crop`` for a random crop.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ output_size (tuple): Expected output size of the crop.
+
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
+ """
+ w, h = img.size
+ th, tw = output_size
+ if w == tw and h == th:
+ return 0, 0, h, w
+
+ i = random.randint(0, h - th)
+ j = random.randint(0, w - tw)
+ return i, j, th, tw
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be cropped.
+
+ Returns:
+ PIL Image: Cropped image.
+ """
+ if self.padding is not None:
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
+
+ # pad the width if needed
+ if self.pad_if_needed and img.size[0] < self.size[1]:
+ img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
+ # pad the height if needed
+ if self.pad_if_needed and img.size[1] < self.size[0]:
+ img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
+
+ i, j, h, w = self.get_params(img, self.size)
+
+ return F.crop(img, i, j, h, w)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
+
+
+[docs]class RandomHorizontalFlip(object):
+ """Horizontally flip the given PIL Image randomly with a given probability.
+
+ Args:
+ p (float): probability of the image being flipped. Default value is 0.5
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be flipped.
+
+ Returns:
+ PIL Image: Randomly flipped image.
+ """
+ if random.random() < self.p:
+ return F.hflip(img)
+ return img
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+[docs]class RandomVerticalFlip(object):
+ """Vertically flip the given PIL Image randomly with a given probability.
+
+ Args:
+ p (float): probability of the image being flipped. Default value is 0.5
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be flipped.
+
+ Returns:
+ PIL Image: Randomly flipped image.
+ """
+ if random.random() < self.p:
+ return F.vflip(img)
+ return img
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+[docs]class RandomPerspective(object):
+ """Performs Perspective transformation of the given PIL Image randomly with a given probability.
+
+ Args:
+ interpolation : Default- Image.BICUBIC
+
+ p (float): probability of the image being perspectively transformed. Default value is 0.5
+
+ distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.
+
+ """
+
+ def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC):
+ self.p = p
+ self.interpolation = interpolation
+ self.distortion_scale = distortion_scale
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be Perspectively transformed.
+
+ Returns:
+ PIL Image: Random perspectivley transformed image.
+ """
+ if not F._is_pil_image(img):
+ raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
+
+ if random.random() < self.p:
+ width, height = img.size
+ startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
+ return F.perspective(img, startpoints, endpoints, self.interpolation)
+ return img
+
+ @staticmethod
+ def get_params(width, height, distortion_scale):
+ """Get parameters for ``perspective`` for a random perspective transform.
+
+ Args:
+ width : width of the image.
+ height : height of the image.
+
+ Returns:
+ List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image,
+ List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
+ """
+ half_height = int(height / 2)
+ half_width = int(width / 2)
+ topleft = (random.randint(0, int(distortion_scale * half_width)),
+ random.randint(0, int(distortion_scale * half_height)))
+ topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
+ random.randint(0, int(distortion_scale * half_height)))
+ botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
+ random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
+ botleft = (random.randint(0, int(distortion_scale * half_width)),
+ random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
+ startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
+ endpoints = [topleft, topright, botright, botleft]
+ return startpoints, endpoints
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={})'.format(self.p)
+
+
+[docs]class RandomResizedCrop(object):
+ """Crop the given PIL Image to random size and aspect ratio.
+
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+ is finally resized to given size.
+ This is popularly used to train the Inception networks.
+
+ Args:
+ size: expected output size of each edge
+ scale: range of size of the origin size cropped
+ ratio: range of aspect ratio of the origin aspect ratio cropped
+ interpolation: Default: PIL.Image.BILINEAR
+ """
+
+ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
+ if isinstance(size, tuple):
+ self.size = size
+ else:
+ self.size = (size, size)
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+ warnings.warn("range should be of kind (min, max)")
+
+ self.interpolation = interpolation
+ self.scale = scale
+ self.ratio = ratio
+
+ @staticmethod
+ def get_params(img, scale, ratio):
+ """Get parameters for ``crop`` for a random sized crop.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ scale (tuple): range of size of the origin size cropped
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+ sized crop.
+ """
+ area = img.size[0] * img.size[1]
+
+ for attempt in range(10):
+ target_area = random.uniform(*scale) * area
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if w <= img.size[0] and h <= img.size[1]:
+ i = random.randint(0, img.size[1] - h)
+ j = random.randint(0, img.size[0] - w)
+ return i, j, h, w
+
+ # Fallback to central crop
+ in_ratio = img.size[0] / img.size[1]
+ if (in_ratio < min(ratio)):
+ w = img.size[0]
+ h = w / min(ratio)
+ elif (in_ratio > max(ratio)):
+ h = img.size[1]
+ w = h * max(ratio)
+ else: # whole image
+ w = img.size[0]
+ h = img.size[1]
+ i = (img.size[1] - h) // 2
+ j = (img.size[0] - w) // 2
+ return i, j, h, w
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be cropped and resized.
+
+ Returns:
+ PIL Image: Randomly cropped and resized image.
+ """
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
+ return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
+
+ def __repr__(self):
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
+ format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
+ format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
+ format_string += ', interpolation={0})'.format(interpolate_str)
+ return format_string
+
+
+[docs]class RandomSizedCrop(RandomResizedCrop):
+ """
+ Note: This transform is deprecated in favor of RandomResizedCrop.
+ """
+ def __init__(self, *args, **kwargs):
+ warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
+ "please use transforms.RandomResizedCrop instead.")
+ super(RandomSizedCrop, self).__init__(*args, **kwargs)
+
+
+[docs]class FiveCrop(object):
+ """Crop the given PIL Image into four corners and the central crop
+
+ .. Note::
+ This transform returns a tuple of images and there may be a mismatch in the number of
+ inputs and targets your Dataset returns. See below for an example of how to deal with
+ this.
+
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an ``int``
+ instead of sequence like (h, w), a square crop of size (size, size) is made.
+
+ Example:
+ >>> transform = Compose([
+ >>> FiveCrop(size), # this is a list of PIL Images
+ >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
+ >>> ])
+ >>> #In your test loop you can do the following:
+ >>> input, target = batch # input is a 5d tensor, target is 2d
+ >>> bs, ncrops, c, h, w = input.size()
+ >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
+ >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
+ """
+
+ def __init__(self, size):
+ self.size = size
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
+ self.size = size
+
+ def __call__(self, img):
+ return F.five_crop(img, self.size)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
+
+
+[docs]class TenCrop(object):
+ """Crop the given PIL Image into four corners and the central crop plus the flipped version of
+ these (horizontal flipping is used by default)
+
+ .. Note::
+ This transform returns a tuple of images and there may be a mismatch in the number of
+ inputs and targets your Dataset returns. See below for an example of how to deal with
+ this.
+
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ vertical_flip(bool): Use vertical flipping instead of horizontal
+
+ Example:
+ >>> transform = Compose([
+ >>> TenCrop(size), # this is a list of PIL Images
+ >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
+ >>> ])
+ >>> #In your test loop you can do the following:
+ >>> input, target = batch # input is a 5d tensor, target is 2d
+ >>> bs, ncrops, c, h, w = input.size()
+ >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
+ >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
+ """
+
+ def __init__(self, size, vertical_flip=False):
+ self.size = size
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
+ self.size = size
+ self.vertical_flip = vertical_flip
+
+ def __call__(self, img):
+ return F.ten_crop(img, self.size, self.vertical_flip)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
+
+
+[docs]class LinearTransformation(object):
+ """Transform a tensor image with a square transformation matrix and a mean_vector computed
+ offline.
+ Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
+ subtract mean_vector from it which is then followed by computing the dot
+ product with the transformation matrix and then reshaping the tensor to its
+ original shape.
+ Applications:
+ - whitening transformation: Suppose X is a column vector zero-centered data.
+ Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
+ perform SVD on this matrix and pass it as transformation_matrix.
+ Args:
+ transformation_matrix (Tensor): tensor [D x D], D = C x H x W
+ mean_vector (Tensor): tensor [D], D = C x H x W
+ """
+
+ def __init__(self, transformation_matrix, mean_vector):
+ if transformation_matrix.size(0) != transformation_matrix.size(1):
+ raise ValueError("transformation_matrix should be square. Got " +
+ "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
+
+ if mean_vector.size(0) != transformation_matrix.size(0):
+ raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) +
+ " as any one of the dimensions of the transformation_matrix [{} x {}]"
+ .format(transformation_matrix.size()))
+
+ self.transformation_matrix = transformation_matrix
+ self.mean_vector = mean_vector
+
+ def __call__(self, tensor):
+ """
+ Args:
+ tensor (Tensor): Tensor image of size (C, H, W) to be whitened.
+
+ Returns:
+ Tensor: Transformed image.
+ """
+ if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
+ raise ValueError("tensor and transformation matrix have incompatible shape." +
+ "[{} x {} x {}] != ".format(*tensor.size()) +
+ "{}".format(self.transformation_matrix.size(0)))
+ flat_tensor = tensor.view(1, -1) - self.mean_vector
+ transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
+ tensor = transformed_tensor.view(tensor.size())
+ return tensor
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '(transformation_matrix='
+ format_string += (str(self.transformation_matrix.tolist()) + ')')
+ format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')')
+ return format_string
+
+
+[docs]class ColorJitter(object):
+ """Randomly change the brightness, contrast and saturation of an image.
+
+ Args:
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
+ or the given [min, max]. Should be non negative numbers.
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
+ or the given [min, max]. Should be non negative numbers.
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
+ or the given [min, max]. Should be non negative numbers.
+ hue (float or tuple of float (min, max)): How much to jitter hue.
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
+ Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
+ """
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
+ self.brightness = self._check_input(brightness, 'brightness')
+ self.contrast = self._check_input(contrast, 'contrast')
+ self.saturation = self._check_input(saturation, 'saturation')
+ self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
+ clip_first_on_zero=False)
+
+ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
+ if isinstance(value, numbers.Number):
+ if value < 0:
+ raise ValueError("If {} is a single number, it must be non negative.".format(name))
+ value = [center - value, center + value]
+ if clip_first_on_zero:
+ value[0] = max(value[0], 0)
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
+ raise ValueError("{} values should be between {}".format(name, bound))
+ else:
+ raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
+
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
+ # or (0., 0.) for hue, do nothing
+ if value[0] == value[1] == center:
+ value = None
+ return value
+
+ @staticmethod
+ def get_params(brightness, contrast, saturation, hue):
+ """Get a randomized transform to be applied on image.
+
+ Arguments are same as that of __init__.
+
+ Returns:
+ Transform which randomly adjusts brightness, contrast and
+ saturation in a random order.
+ """
+ transforms = []
+
+ if brightness is not None:
+ brightness_factor = random.uniform(brightness[0], brightness[1])
+ transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
+
+ if contrast is not None:
+ contrast_factor = random.uniform(contrast[0], contrast[1])
+ transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
+
+ if saturation is not None:
+ saturation_factor = random.uniform(saturation[0], saturation[1])
+ transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
+
+ if hue is not None:
+ hue_factor = random.uniform(hue[0], hue[1])
+ transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
+
+ random.shuffle(transforms)
+ transform = Compose(transforms)
+
+ return transform
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Input image.
+
+ Returns:
+ PIL Image: Color jittered image.
+ """
+ transform = self.get_params(self.brightness, self.contrast,
+ self.saturation, self.hue)
+ return transform(img)
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ format_string += 'brightness={0}'.format(self.brightness)
+ format_string += ', contrast={0}'.format(self.contrast)
+ format_string += ', saturation={0}'.format(self.saturation)
+ format_string += ', hue={0})'.format(self.hue)
+ return format_string
+
+
+[docs]class RandomRotation(object):
+ """Rotate the image by angle.
+
+ Args:
+ degrees (sequence or float or int): Range of degrees to select from.
+ If degrees is a number instead of sequence like (min, max), the range of degrees
+ will be (-degrees, +degrees).
+ resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
+ An optional resampling filter. See `filters`_ for more information.
+ If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
+ expand (bool, optional): Optional expansion flag.
+ If true, expands the output to make it large enough to hold the entire rotated image.
+ If false or omitted, make the output image the same size as the input image.
+ Note that the expand flag assumes rotation around the center and no translation.
+ center (2-tuple, optional): Optional center of rotation.
+ Origin is the upper left corner.
+ Default is the center of the image.
+
+ .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
+
+ """
+
+ def __init__(self, degrees, resample=False, expand=False, center=None):
+ if isinstance(degrees, numbers.Number):
+ if degrees < 0:
+ raise ValueError("If degrees is a single number, it must be positive.")
+ self.degrees = (-degrees, degrees)
+ else:
+ if len(degrees) != 2:
+ raise ValueError("If degrees is a sequence, it must be of len 2.")
+ self.degrees = degrees
+
+ self.resample = resample
+ self.expand = expand
+ self.center = center
+
+ @staticmethod
+ def get_params(degrees):
+ """Get parameters for ``rotate`` for a random rotation.
+
+ Returns:
+ sequence: params to be passed to ``rotate`` for random rotation.
+ """
+ angle = random.uniform(degrees[0], degrees[1])
+
+ return angle
+
+ def __call__(self, img):
+ """
+ img (PIL Image): Image to be rotated.
+
+ Returns:
+ PIL Image: Rotated image.
+ """
+
+ angle = self.get_params(self.degrees)
+
+ return F.rotate(img, angle, self.resample, self.expand, self.center)
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
+ format_string += ', resample={0}'.format(self.resample)
+ format_string += ', expand={0}'.format(self.expand)
+ if self.center is not None:
+ format_string += ', center={0}'.format(self.center)
+ format_string += ')'
+ return format_string
+
+
+[docs]class RandomAffine(object):
+ """Random affine transformation of the image keeping center invariant
+
+ Args:
+ degrees (sequence or float or int): Range of degrees to select from.
+ If degrees is a number instead of sequence like (min, max), the range of degrees
+ will be (-degrees, +degrees). Set to 0 to deactivate rotations.
+ translate (tuple, optional): tuple of maximum absolute fraction for horizontal
+ and vertical translations. For example translate=(a, b), then horizontal shift
+ is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
+ randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
+ scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
+ randomly sampled from the range a <= scale <= b. Will keep original scale by default.
+ shear (sequence or float or int, optional): Range of degrees to select from.
+ If degrees is a number instead of sequence like (min, max), the range of degrees
+ will be (-degrees, +degrees). Will not apply shear by default
+ resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
+ An optional resampling filter. See `filters`_ for more information.
+ If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
+ fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
+
+ .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
+
+ """
+
+ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
+ if isinstance(degrees, numbers.Number):
+ if degrees < 0:
+ raise ValueError("If degrees is a single number, it must be positive.")
+ self.degrees = (-degrees, degrees)
+ else:
+ assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
+ "degrees should be a list or tuple and it must be of length 2."
+ self.degrees = degrees
+
+ if translate is not None:
+ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
+ "translate should be a list or tuple and it must be of length 2."
+ for t in translate:
+ if not (0.0 <= t <= 1.0):
+ raise ValueError("translation values should be between 0 and 1")
+ self.translate = translate
+
+ if scale is not None:
+ assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
+ "scale should be a list or tuple and it must be of length 2."
+ for s in scale:
+ if s <= 0:
+ raise ValueError("scale values should be positive")
+ self.scale = scale
+
+ if shear is not None:
+ if isinstance(shear, numbers.Number):
+ if shear < 0:
+ raise ValueError("If shear is a single number, it must be positive.")
+ self.shear = (-shear, shear)
+ else:
+ assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
+ "shear should be a list or tuple and it must be of length 2."
+ self.shear = shear
+ else:
+ self.shear = shear
+
+ self.resample = resample
+ self.fillcolor = fillcolor
+
+ @staticmethod
+ def get_params(degrees, translate, scale_ranges, shears, img_size):
+ """Get parameters for affine transformation
+
+ Returns:
+ sequence: params to be passed to the affine transformation
+ """
+ angle = random.uniform(degrees[0], degrees[1])
+ if translate is not None:
+ max_dx = translate[0] * img_size[0]
+ max_dy = translate[1] * img_size[1]
+ translations = (np.round(random.uniform(-max_dx, max_dx)),
+ np.round(random.uniform(-max_dy, max_dy)))
+ else:
+ translations = (0, 0)
+
+ if scale_ranges is not None:
+ scale = random.uniform(scale_ranges[0], scale_ranges[1])
+ else:
+ scale = 1.0
+
+ if shears is not None:
+ shear = random.uniform(shears[0], shears[1])
+ else:
+ shear = 0.0
+
+ return angle, translations, scale, shear
+
+ def __call__(self, img):
+ """
+ img (PIL Image): Image to be transformed.
+
+ Returns:
+ PIL Image: Affine transformed image.
+ """
+ ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
+ return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)
+
+ def __repr__(self):
+ s = '{name}(degrees={degrees}'
+ if self.translate is not None:
+ s += ', translate={translate}'
+ if self.scale is not None:
+ s += ', scale={scale}'
+ if self.shear is not None:
+ s += ', shear={shear}'
+ if self.resample > 0:
+ s += ', resample={resample}'
+ if self.fillcolor != 0:
+ s += ', fillcolor={fillcolor}'
+ s += ')'
+ d = dict(self.__dict__)
+ d['resample'] = _pil_interpolation_to_str[d['resample']]
+ return s.format(name=self.__class__.__name__, **d)
+
+
+[docs]class Grayscale(object):
+ """Convert image to grayscale.
+
+ Args:
+ num_output_channels (int): (1 or 3) number of channels desired for output image
+
+ Returns:
+ PIL Image: Grayscale version of the input.
+ - If num_output_channels == 1 : returned image is single channel
+ - If num_output_channels == 3 : returned image is 3 channel with r == g == b
+
+ """
+
+ def __init__(self, num_output_channels=1):
+ self.num_output_channels = num_output_channels
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be converted to grayscale.
+
+ Returns:
+ PIL Image: Randomly grayscaled image.
+ """
+ return F.to_grayscale(img, num_output_channels=self.num_output_channels)
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
+
+
+[docs]class RandomGrayscale(object):
+ """Randomly convert image to grayscale with a probability of p (default 0.1).
+
+ Args:
+ p (float): probability that image should be converted to grayscale.
+
+ Returns:
+ PIL Image: Grayscale version of the input image with probability p and unchanged
+ with probability (1-p).
+ - If input image is 1 channel: grayscale version is 1 channel
+ - If input image is 3 channel: grayscale version is 3 channel with r == g == b
+
+ """
+
+ def __init__(self, p=0.1):
+ self.p = p
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be converted to grayscale.
+
+ Returns:
+ PIL Image: Randomly grayscaled image.
+ """
+ num_output_channels = 1 if img.mode == 'L' else 3
+ if random.random() < self.p:
+ return F.to_grayscale(img, num_output_channels=num_output_channels)
+ return img
+
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={0})'.format(self.p)
+
+import torch
+import math
+irange = range
+
+
+[docs]def make_grid(tensor, nrow=8, padding=2,
+ normalize=False, range=None, scale_each=False, pad_value=0):
+ """Make a grid of images.
+
+ Args:
+ tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
+ or a list of images all of the same size.
+ nrow (int, optional): Number of images displayed in each row of the grid.
+ The Final grid size is (B / nrow, nrow). Default is 8.
+ padding (int, optional): amount of padding. Default is 2.
+ normalize (bool, optional): If True, shift the image to the range (0, 1),
+ by subtracting the minimum and dividing by the maximum pixel value.
+ range (tuple, optional): tuple (min, max) where min and max are numbers,
+ then these numbers are used to normalize the image. By default, min and max
+ are computed from the tensor.
+ scale_each (bool, optional): If True, scale each image in the batch of
+ images separately rather than the (min, max) over all images.
+ pad_value (float, optional): Value for the padded pixels.
+
+ Example:
+ See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
+
+ """
+ if not (torch.is_tensor(tensor) or
+ (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))
+
+ # if list of tensors, convert to a 4D mini-batch Tensor
+ if isinstance(tensor, list):
+ tensor = torch.stack(tensor, dim=0)
+
+ if tensor.dim() == 2: # single image H x W
+ tensor = tensor.unsqueeze(0)
+ if tensor.dim() == 3: # single image
+ if tensor.size(0) == 1: # if single-channel, convert to 3-channel
+ tensor = torch.cat((tensor, tensor, tensor), 0)
+ tensor = tensor.unsqueeze(0)
+
+ if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
+ tensor = torch.cat((tensor, tensor, tensor), 1)
+
+ if normalize is True:
+ tensor = tensor.clone() # avoid modifying tensor in-place
+ if range is not None:
+ assert isinstance(range, tuple), \
+ "range has to be a tuple (min, max) if specified. min and max are numbers"
+
+ def norm_ip(img, min, max):
+ img.clamp_(min=min, max=max)
+ img.add_(-min).div_(max - min + 1e-5)
+
+ def norm_range(t, range):
+ if range is not None:
+ norm_ip(t, range[0], range[1])
+ else:
+ norm_ip(t, float(t.min()), float(t.max()))
+
+ if scale_each is True:
+ for t in tensor: # loop over mini-batch dimension
+ norm_range(t, range)
+ else:
+ norm_range(tensor, range)
+
+ if tensor.size(0) == 1:
+ return tensor.squeeze()
+
+ # make the mini-batch of images into a grid
+ nmaps = tensor.size(0)
+ xmaps = min(nrow, nmaps)
+ ymaps = int(math.ceil(float(nmaps) / xmaps))
+ height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
+ grid = tensor.new_full((3, height * ymaps + padding, width * xmaps + padding), pad_value)
+ k = 0
+ for y in irange(ymaps):
+ for x in irange(xmaps):
+ if k >= nmaps:
+ break
+ grid.narrow(1, y * height + padding, height - padding)\
+ .narrow(2, x * width + padding, width - padding)\
+ .copy_(tensor[k])
+ k = k + 1
+ return grid
+
+
+[docs]def save_image(tensor, filename, nrow=8, padding=2,
+ normalize=False, range=None, scale_each=False, pad_value=0):
+ """Save a given Tensor into an image file.
+
+ Args:
+ tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
+ saves the tensor as a grid of images by calling ``make_grid``.
+ **kwargs: Other arguments are documented in ``make_grid``.
+ """
+ from PIL import Image
+ grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
+ normalize=normalize, range=range, scale_each=scale_each)
+ # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
+ ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
+ im = Image.fromarray(ndarr)
+ im.save(filename)
+' + _('Hide Search Matches') + '
') + .appendTo($('#searchbox')); + } + }, + + /** + * init the domain index toggle buttons + */ + initIndexTable : function() { + var togglers = $('img.toggler').click(function() { + var src = $(this).attr('src'); + var idnum = $(this).attr('id').substr(7); + $('tr.cg-' + idnum).toggle(); + if (src.substr(-9) === 'minus.png') + $(this).attr('src', src.substr(0, src.length-9) + 'plus.png'); + else + $(this).attr('src', src.substr(0, src.length-8) + 'minus.png'); + }).css('display', ''); + if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) { + togglers.click(); + } + }, + + /** + * helper function to hide the search marks again + */ + hideSearchWords : function() { + $('#searchbox .highlight-link').fadeOut(300); + $('span.highlighted').removeClass('highlighted'); + }, + + /** + * make the url absolute + */ + makeURL : function(relativeURL) { + return DOCUMENTATION_OPTIONS.URL_ROOT + '/' + relativeURL; + }, + + /** + * get the current relative url + */ + getCurrentURL : function() { + var path = document.location.pathname; + var parts = path.split(/\//); + $.each(DOCUMENTATION_OPTIONS.URL_ROOT.split(/\//), function() { + if (this === '..') + parts.pop(); + }); + var url = parts.join('/'); + return path.substring(url.lastIndexOf('/') + 1, path.length - 1); + }, + + initOnKeyListeners: function() { + $(document).keyup(function(event) { + var activeElementType = document.activeElement.tagName; + // don't navigate when in search box or textarea + if (activeElementType !== 'TEXTAREA' && activeElementType !== 'INPUT' && activeElementType !== 'SELECT') { + switch (event.keyCode) { + case 37: // left + var prevHref = $('link[rel="prev"]').prop('href'); + if (prevHref) { + window.location.href = prevHref; + return false; + } + case 39: // right + var nextHref = $('link[rel="next"]').prop('href'); + if (nextHref) { + window.location.href = nextHref; + return false; + } + } + } + }); + } +}; + +// quick alias for translations +_ = Documentation.gettext; + +$(document).ready(function() { + Documentation.init(); +}); diff --git a/docs/1.1.0/_static/documentation_options.js b/docs/1.1.0/_static/documentation_options.js new file mode 100644 index 000000000000..f8c020d77901 --- /dev/null +++ b/docs/1.1.0/_static/documentation_options.js @@ -0,0 +1,10 @@ +var DOCUMENTATION_OPTIONS = { + URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), + VERSION: 'master', + LANGUAGE: 'None', + COLLAPSE_INDEX: false, + FILE_SUFFIX: '.html', + HAS_SOURCE: true, + SOURCELINK_SUFFIX: '.txt', + NAVIGATION_WITH_KEYS: false +}; \ No newline at end of file diff --git a/docs/1.1.0/_static/file.png b/docs/1.1.0/_static/file.png new file mode 100644 index 000000000000..a858a410e4fa Binary files /dev/null and b/docs/1.1.0/_static/file.png differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold-italic.woff b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold-italic.woff new file mode 100644 index 000000000000..e317248423c7 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold-italic.woff differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold-italic.woff2 b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold-italic.woff2 new file mode 100644 index 000000000000..cec2dc94fbb5 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold-italic.woff2 differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold.woff b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold.woff new file mode 100644 index 000000000000..de46625edfc8 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold.woff differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold.woff2 b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold.woff2 new file mode 100644 index 000000000000..dc05cd82bc4d Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-bold.woff2 differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book-italic.woff b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book-italic.woff new file mode 100644 index 000000000000..a50e5038a405 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book-italic.woff differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book-italic.woff2 b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book-italic.woff2 new file mode 100644 index 000000000000..fe284db6614a Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book-italic.woff2 differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book.woff b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book.woff new file mode 100644 index 000000000000..6ab8775f00b1 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book.woff differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book.woff2 b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book.woff2 new file mode 100644 index 000000000000..2688739f1f0b Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-book.woff2 differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light-italic.woff b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light-italic.woff new file mode 100644 index 000000000000..beda58d4e218 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light-italic.woff differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light-italic.woff2 b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light-italic.woff2 new file mode 100644 index 000000000000..e2fa0134b1a5 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light-italic.woff2 differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light.woff b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light.woff new file mode 100644 index 000000000000..226a0bf83583 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light.woff differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light.woff2 b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light.woff2 new file mode 100644 index 000000000000..6d8ff2c045b0 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-light.woff2 differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium-italic.woff b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium-italic.woff new file mode 100644 index 000000000000..a42115d63b39 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium-italic.woff differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium-italic.woff2 b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium-italic.woff2 new file mode 100644 index 000000000000..16a7713a451a Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium-italic.woff2 differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium.woff b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium.woff new file mode 100644 index 000000000000..5ea34539c6f5 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium.woff differ diff --git a/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium.woff2 b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium.woff2 new file mode 100644 index 000000000000..c58b6a528bb6 Binary files /dev/null and b/docs/1.1.0/_static/fonts/FreightSans/freight-sans-medium.woff2 differ diff --git a/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff new file mode 100644 index 000000000000..cf37a5c50bdb Binary files /dev/null and b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff differ diff --git a/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff2 b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff2 new file mode 100644 index 000000000000..955a6eab5bb8 Binary files /dev/null and b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Light.woff2 differ diff --git a/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff new file mode 100644 index 000000000000..fc65a679c226 Binary files /dev/null and b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff differ diff --git a/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2 b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2 new file mode 100644 index 000000000000..c352e40e34a3 Binary files /dev/null and b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2 differ diff --git a/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff new file mode 100644 index 000000000000..7d63d89f24bc Binary files /dev/null and b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff differ diff --git a/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff2 b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff2 new file mode 100644 index 000000000000..d0d7ded90791 Binary files /dev/null and b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-Regular.woff2 differ diff --git a/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff new file mode 100644 index 000000000000..1da7753cf283 Binary files /dev/null and b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff differ diff --git a/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2 b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2 new file mode 100644 index 000000000000..79dffdb85f74 Binary files /dev/null and b/docs/1.1.0/_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2 differ diff --git a/docs/1.1.0/_static/images/arrow-down-orange.svg b/docs/1.1.0/_static/images/arrow-down-orange.svg new file mode 100644 index 000000000000..e9d8e9ecf248 --- /dev/null +++ b/docs/1.1.0/_static/images/arrow-down-orange.svg @@ -0,0 +1,19 @@ + + \ No newline at end of file diff --git a/docs/1.1.0/_static/images/arrow-right-with-tail.svg b/docs/1.1.0/_static/images/arrow-right-with-tail.svg new file mode 100644 index 000000000000..5843588fca6f --- /dev/null +++ b/docs/1.1.0/_static/images/arrow-right-with-tail.svg @@ -0,0 +1,19 @@ + + \ No newline at end of file diff --git a/docs/1.1.0/_static/images/chevron-down-grey.svg b/docs/1.1.0/_static/images/chevron-down-grey.svg new file mode 100644 index 000000000000..82d6514f2506 --- /dev/null +++ b/docs/1.1.0/_static/images/chevron-down-grey.svg @@ -0,0 +1,18 @@ + + + + diff --git a/docs/1.1.0/_static/images/chevron-right-orange.svg b/docs/1.1.0/_static/images/chevron-right-orange.svg new file mode 100644 index 000000000000..7033fc93bf4f --- /dev/null +++ b/docs/1.1.0/_static/images/chevron-right-orange.svg @@ -0,0 +1,17 @@ + + + + diff --git a/docs/1.1.0/_static/images/chevron-right-white.svg b/docs/1.1.0/_static/images/chevron-right-white.svg new file mode 100644 index 000000000000..dd9e77f26165 --- /dev/null +++ b/docs/1.1.0/_static/images/chevron-right-white.svg @@ -0,0 +1,17 @@ + + + + \ No newline at end of file diff --git a/docs/1.1.0/_static/images/home-footer-background.jpg b/docs/1.1.0/_static/images/home-footer-background.jpg new file mode 100644 index 000000000000..b307bb57f485 Binary files /dev/null and b/docs/1.1.0/_static/images/home-footer-background.jpg differ diff --git a/docs/1.1.0/_static/images/icon-close.svg b/docs/1.1.0/_static/images/icon-close.svg new file mode 100644 index 000000000000..348964e79f7f --- /dev/null +++ b/docs/1.1.0/_static/images/icon-close.svg @@ -0,0 +1,21 @@ + + \ No newline at end of file diff --git a/docs/1.1.0/_static/images/icon-menu-dots-dark.svg b/docs/1.1.0/_static/images/icon-menu-dots-dark.svg new file mode 100644 index 000000000000..fa2ad044b3f6 --- /dev/null +++ b/docs/1.1.0/_static/images/icon-menu-dots-dark.svg @@ -0,0 +1,42 @@ + + \ No newline at end of file diff --git a/docs/1.1.0/_static/images/logo-dark.svg b/docs/1.1.0/_static/images/logo-dark.svg new file mode 100644 index 000000000000..9b4c1a56ac65 --- /dev/null +++ b/docs/1.1.0/_static/images/logo-dark.svg @@ -0,0 +1,30 @@ + + + + diff --git a/docs/1.1.0/_static/images/logo-facebook-dark.svg b/docs/1.1.0/_static/images/logo-facebook-dark.svg new file mode 100644 index 000000000000..cff17915c4f5 --- /dev/null +++ b/docs/1.1.0/_static/images/logo-facebook-dark.svg @@ -0,0 +1,8 @@ + + + + diff --git a/docs/1.1.0/_static/images/logo-icon.svg b/docs/1.1.0/_static/images/logo-icon.svg new file mode 100644 index 000000000000..575f6823e476 --- /dev/null +++ b/docs/1.1.0/_static/images/logo-icon.svg @@ -0,0 +1,12 @@ + + + + diff --git a/docs/1.1.0/_static/images/logo-twitter-dark.svg b/docs/1.1.0/_static/images/logo-twitter-dark.svg new file mode 100644 index 000000000000..1572570f88cc --- /dev/null +++ b/docs/1.1.0/_static/images/logo-twitter-dark.svg @@ -0,0 +1,16 @@ + + + + diff --git a/docs/1.1.0/_static/images/logo.svg b/docs/1.1.0/_static/images/logo.svg new file mode 100644 index 000000000000..f8d44b98425f --- /dev/null +++ b/docs/1.1.0/_static/images/logo.svg @@ -0,0 +1,31 @@ + + + + diff --git a/docs/1.1.0/_static/images/pytorch-colab.svg b/docs/1.1.0/_static/images/pytorch-colab.svg new file mode 100644 index 000000000000..2ab15e2f3071 --- /dev/null +++ b/docs/1.1.0/_static/images/pytorch-colab.svg @@ -0,0 +1,24 @@ + + + diff --git a/docs/1.1.0/_static/images/pytorch-download.svg b/docs/1.1.0/_static/images/pytorch-download.svg new file mode 100644 index 000000000000..cc37d638e926 --- /dev/null +++ b/docs/1.1.0/_static/images/pytorch-download.svg @@ -0,0 +1,10 @@ + + + diff --git a/docs/1.1.0/_static/images/pytorch-github.svg b/docs/1.1.0/_static/images/pytorch-github.svg new file mode 100644 index 000000000000..2c2570da1de9 --- /dev/null +++ b/docs/1.1.0/_static/images/pytorch-github.svg @@ -0,0 +1,15 @@ + + + diff --git a/docs/1.1.0/_static/images/pytorch-x.svg b/docs/1.1.0/_static/images/pytorch-x.svg new file mode 100644 index 000000000000..74856ea9fdae --- /dev/null +++ b/docs/1.1.0/_static/images/pytorch-x.svg @@ -0,0 +1,10 @@ + + + diff --git a/docs/1.1.0/_static/images/search-icon.svg b/docs/1.1.0/_static/images/search-icon.svg new file mode 100644 index 000000000000..ebb0df867733 --- /dev/null +++ b/docs/1.1.0/_static/images/search-icon.svg @@ -0,0 +1,19 @@ + + diff --git a/docs/1.1.0/_static/images/view-page-source-icon.svg b/docs/1.1.0/_static/images/view-page-source-icon.svg new file mode 100644 index 000000000000..6f5bbe0748fc --- /dev/null +++ b/docs/1.1.0/_static/images/view-page-source-icon.svg @@ -0,0 +1,13 @@ + + + diff --git a/docs/1.1.0/_static/img/aliastracker_graph.png b/docs/1.1.0/_static/img/aliastracker_graph.png new file mode 100644 index 000000000000..11c66e64d81b Binary files /dev/null and b/docs/1.1.0/_static/img/aliastracker_graph.png differ diff --git a/docs/1.1.0/_static/img/dynamic_graph.gif b/docs/1.1.0/_static/img/dynamic_graph.gif new file mode 100644 index 000000000000..f6fde3158010 Binary files /dev/null and b/docs/1.1.0/_static/img/dynamic_graph.gif differ diff --git a/docs/1.1.0/_static/img/pytorch-logo-dark-unstable.png b/docs/1.1.0/_static/img/pytorch-logo-dark-unstable.png new file mode 100644 index 000000000000..5934fc3499da Binary files /dev/null and b/docs/1.1.0/_static/img/pytorch-logo-dark-unstable.png differ diff --git a/docs/1.1.0/_static/img/pytorch-logo-dark.png b/docs/1.1.0/_static/img/pytorch-logo-dark.png new file mode 100644 index 000000000000..b7a1ceb964af Binary files /dev/null and b/docs/1.1.0/_static/img/pytorch-logo-dark.png differ diff --git a/docs/1.1.0/_static/img/pytorch-logo-dark.svg b/docs/1.1.0/_static/img/pytorch-logo-dark.svg new file mode 100644 index 000000000000..5e5300038589 --- /dev/null +++ b/docs/1.1.0/_static/img/pytorch-logo-dark.svg @@ -0,0 +1,33 @@ + + + diff --git a/docs/1.1.0/_static/img/pytorch-logo-flame.png b/docs/1.1.0/_static/img/pytorch-logo-flame.png new file mode 100644 index 000000000000..370633f2ec2b Binary files /dev/null and b/docs/1.1.0/_static/img/pytorch-logo-flame.png differ diff --git a/docs/1.1.0/_static/img/pytorch-logo-flame.svg b/docs/1.1.0/_static/img/pytorch-logo-flame.svg new file mode 100644 index 000000000000..5f2fb76be773 --- /dev/null +++ b/docs/1.1.0/_static/img/pytorch-logo-flame.svg @@ -0,0 +1,33 @@ + + diff --git a/docs/1.1.0/_static/img/tensor_illustration.png b/docs/1.1.0/_static/img/tensor_illustration.png new file mode 100644 index 000000000000..b0039c7f3f3e Binary files /dev/null and b/docs/1.1.0/_static/img/tensor_illustration.png differ diff --git a/docs/1.1.0/_static/jquery-3.2.1.js b/docs/1.1.0/_static/jquery-3.2.1.js new file mode 100644 index 000000000000..d2d8ca4790e5 --- /dev/null +++ b/docs/1.1.0/_static/jquery-3.2.1.js @@ -0,0 +1,10253 @@ +/*! + * jQuery JavaScript Library v3.2.1 + * https://jquery.com/ + * + * Includes Sizzle.js + * https://sizzlejs.com/ + * + * Copyright JS Foundation and other contributors + * Released under the MIT license + * https://jquery.org/license + * + * Date: 2017-03-20T18:59Z + */ +( function( global, factory ) { + + "use strict"; + + if ( typeof module === "object" && typeof module.exports === "object" ) { + + // For CommonJS and CommonJS-like environments where a proper `window` + // is present, execute the factory and get jQuery. + // For environments that do not have a `window` with a `document` + // (such as Node.js), expose a factory as module.exports. + // This accentuates the need for the creation of a real `window`. + // e.g. var jQuery = require("jquery")(window); + // See ticket #14549 for more info. + module.exports = global.document ? + factory( global, true ) : + function( w ) { + if ( !w.document ) { + throw new Error( "jQuery requires a window with a document" ); + } + return factory( w ); + }; + } else { + factory( global ); + } + +// Pass this if window is not defined yet +} )( typeof window !== "undefined" ? window : this, function( window, noGlobal ) { + +// Edge <= 12 - 13+, Firefox <=18 - 45+, IE 10 - 11, Safari 5.1 - 9+, iOS 6 - 9.1 +// throw exceptions when non-strict code (e.g., ASP.NET 4.5) accesses strict mode +// arguments.callee.caller (trac-13335). But as of jQuery 3.0 (2016), strict mode should be common +// enough that all such attempts are guarded in a try block. +"use strict"; + +var arr = []; + +var document = window.document; + +var getProto = Object.getPrototypeOf; + +var slice = arr.slice; + +var concat = arr.concat; + +var push = arr.push; + +var indexOf = arr.indexOf; + +var class2type = {}; + +var toString = class2type.toString; + +var hasOwn = class2type.hasOwnProperty; + +var fnToString = hasOwn.toString; + +var ObjectFunctionString = fnToString.call( Object ); + +var support = {}; + + + + function DOMEval( code, doc ) { + doc = doc || document; + + var script = doc.createElement( "script" ); + + script.text = code; + doc.head.appendChild( script ).parentNode.removeChild( script ); + } +/* global Symbol */ +// Defining this global in .eslintrc.json would create a danger of using the global +// unguarded in another place, it seems safer to define global only for this module + + + +var + version = "3.2.1", + + // Define a local copy of jQuery + jQuery = function( selector, context ) { + + // The jQuery object is actually just the init constructor 'enhanced' + // Need init if jQuery is called (just allow error to be thrown if not included) + return new jQuery.fn.init( selector, context ); + }, + + // Support: Android <=4.0 only + // Make sure we trim BOM and NBSP + rtrim = /^[\s\uFEFF\xA0]+|[\s\uFEFF\xA0]+$/g, + + // Matches dashed string for camelizing + rmsPrefix = /^-ms-/, + rdashAlpha = /-([a-z])/g, + + // Used by jQuery.camelCase as callback to replace() + fcamelCase = function( all, letter ) { + return letter.toUpperCase(); + }; + +jQuery.fn = jQuery.prototype = { + + // The current version of jQuery being used + jquery: version, + + constructor: jQuery, + + // The default length of a jQuery object is 0 + length: 0, + + toArray: function() { + return slice.call( this ); + }, + + // Get the Nth element in the matched element set OR + // Get the whole matched element set as a clean array + get: function( num ) { + + // Return all the elements in a clean array + if ( num == null ) { + return slice.call( this ); + } + + // Return just the one element from the set + return num < 0 ? this[ num + this.length ] : this[ num ]; + }, + + // Take an array of elements and push it onto the stack + // (returning the new matched element set) + pushStack: function( elems ) { + + // Build a new jQuery matched element set + var ret = jQuery.merge( this.constructor(), elems ); + + // Add the old object onto the stack (as a reference) + ret.prevObject = this; + + // Return the newly-formed element set + return ret; + }, + + // Execute a callback for every element in the matched set. + each: function( callback ) { + return jQuery.each( this, callback ); + }, + + map: function( callback ) { + return this.pushStack( jQuery.map( this, function( elem, i ) { + return callback.call( elem, i, elem ); + } ) ); + }, + + slice: function() { + return this.pushStack( slice.apply( this, arguments ) ); + }, + + first: function() { + return this.eq( 0 ); + }, + + last: function() { + return this.eq( -1 ); + }, + + eq: function( i ) { + var len = this.length, + j = +i + ( i < 0 ? len : 0 ); + return this.pushStack( j >= 0 && j < len ? [ this[ j ] ] : [] ); + }, + + end: function() { + return this.prevObject || this.constructor(); + }, + + // For internal use only. + // Behaves like an Array's method, not like a jQuery method. + push: push, + sort: arr.sort, + splice: arr.splice +}; + +jQuery.extend = jQuery.fn.extend = function() { + var options, name, src, copy, copyIsArray, clone, + target = arguments[ 0 ] || {}, + i = 1, + length = arguments.length, + deep = false; + + // Handle a deep copy situation + if ( typeof target === "boolean" ) { + deep = target; + + // Skip the boolean and the target + target = arguments[ i ] || {}; + i++; + } + + // Handle case when target is a string or something (possible in deep copy) + if ( typeof target !== "object" && !jQuery.isFunction( target ) ) { + target = {}; + } + + // Extend jQuery itself if only one argument is passed + if ( i === length ) { + target = this; + i--; + } + + for ( ; i < length; i++ ) { + + // Only deal with non-null/undefined values + if ( ( options = arguments[ i ] ) != null ) { + + // Extend the base object + for ( name in options ) { + src = target[ name ]; + copy = options[ name ]; + + // Prevent never-ending loop + if ( target === copy ) { + continue; + } + + // Recurse if we're merging plain objects or arrays + if ( deep && copy && ( jQuery.isPlainObject( copy ) || + ( copyIsArray = Array.isArray( copy ) ) ) ) { + + if ( copyIsArray ) { + copyIsArray = false; + clone = src && Array.isArray( src ) ? src : []; + + } else { + clone = src && jQuery.isPlainObject( src ) ? src : {}; + } + + // Never move original objects, clone them + target[ name ] = jQuery.extend( deep, clone, copy ); + + // Don't bring in undefined values + } else if ( copy !== undefined ) { + target[ name ] = copy; + } + } + } + } + + // Return the modified object + return target; +}; + +jQuery.extend( { + + // Unique for each copy of jQuery on the page + expando: "jQuery" + ( version + Math.random() ).replace( /\D/g, "" ), + + // Assume jQuery is ready without the ready module + isReady: true, + + error: function( msg ) { + throw new Error( msg ); + }, + + noop: function() {}, + + isFunction: function( obj ) { + return jQuery.type( obj ) === "function"; + }, + + isWindow: function( obj ) { + return obj != null && obj === obj.window; + }, + + isNumeric: function( obj ) { + + // As of jQuery 3.0, isNumeric is limited to + // strings and numbers (primitives or objects) + // that can be coerced to finite numbers (gh-2662) + var type = jQuery.type( obj ); + return ( type === "number" || type === "string" ) && + + // parseFloat NaNs numeric-cast false positives ("") + // ...but misinterprets leading-number strings, particularly hex literals ("0x...") + // subtraction forces infinities to NaN + !isNaN( obj - parseFloat( obj ) ); + }, + + isPlainObject: function( obj ) { + var proto, Ctor; + + // Detect obvious negatives + // Use toString instead of jQuery.type to catch host objects + if ( !obj || toString.call( obj ) !== "[object Object]" ) { + return false; + } + + proto = getProto( obj ); + + // Objects with no prototype (e.g., `Object.create( null )`) are plain + if ( !proto ) { + return true; + } + + // Objects with prototype are plain iff they were constructed by a global Object function + Ctor = hasOwn.call( proto, "constructor" ) && proto.constructor; + return typeof Ctor === "function" && fnToString.call( Ctor ) === ObjectFunctionString; + }, + + isEmptyObject: function( obj ) { + + /* eslint-disable no-unused-vars */ + // See https://github.com/eslint/eslint/issues/6125 + var name; + + for ( name in obj ) { + return false; + } + return true; + }, + + type: function( obj ) { + if ( obj == null ) { + return obj + ""; + } + + // Support: Android <=2.3 only (functionish RegExp) + return typeof obj === "object" || typeof obj === "function" ? + class2type[ toString.call( obj ) ] || "object" : + typeof obj; + }, + + // Evaluates a script in a global context + globalEval: function( code ) { + DOMEval( code ); + }, + + // Convert dashed to camelCase; used by the css and data modules + // Support: IE <=9 - 11, Edge 12 - 13 + // Microsoft forgot to hump their vendor prefix (#9572) + camelCase: function( string ) { + return string.replace( rmsPrefix, "ms-" ).replace( rdashAlpha, fcamelCase ); + }, + + each: function( obj, callback ) { + var length, i = 0; + + if ( isArrayLike( obj ) ) { + length = obj.length; + for ( ; i < length; i++ ) { + if ( callback.call( obj[ i ], i, obj[ i ] ) === false ) { + break; + } + } + } else { + for ( i in obj ) { + if ( callback.call( obj[ i ], i, obj[ i ] ) === false ) { + break; + } + } + } + + return obj; + }, + + // Support: Android <=4.0 only + trim: function( text ) { + return text == null ? + "" : + ( text + "" ).replace( rtrim, "" ); + }, + + // results is for internal usage only + makeArray: function( arr, results ) { + var ret = results || []; + + if ( arr != null ) { + if ( isArrayLike( Object( arr ) ) ) { + jQuery.merge( ret, + typeof arr === "string" ? + [ arr ] : arr + ); + } else { + push.call( ret, arr ); + } + } + + return ret; + }, + + inArray: function( elem, arr, i ) { + return arr == null ? -1 : indexOf.call( arr, elem, i ); + }, + + // Support: Android <=4.0 only, PhantomJS 1 only + // push.apply(_, arraylike) throws on ancient WebKit + merge: function( first, second ) { + var len = +second.length, + j = 0, + i = first.length; + + for ( ; j < len; j++ ) { + first[ i++ ] = second[ j ]; + } + + first.length = i; + + return first; + }, + + grep: function( elems, callback, invert ) { + var callbackInverse, + matches = [], + i = 0, + length = elems.length, + callbackExpect = !invert; + + // Go through the array, only saving the items + // that pass the validator function + for ( ; i < length; i++ ) { + callbackInverse = !callback( elems[ i ], i ); + if ( callbackInverse !== callbackExpect ) { + matches.push( elems[ i ] ); + } + } + + return matches; + }, + + // arg is for internal usage only + map: function( elems, callback, arg ) { + var length, value, + i = 0, + ret = []; + + // Go through the array, translating each of the items to their new values + if ( isArrayLike( elems ) ) { + length = elems.length; + for ( ; i < length; i++ ) { + value = callback( elems[ i ], i, arg ); + + if ( value != null ) { + ret.push( value ); + } + } + + // Go through every key on the object, + } else { + for ( i in elems ) { + value = callback( elems[ i ], i, arg ); + + if ( value != null ) { + ret.push( value ); + } + } + } + + // Flatten any nested arrays + return concat.apply( [], ret ); + }, + + // A global GUID counter for objects + guid: 1, + + // Bind a function to a context, optionally partially applying any + // arguments. + proxy: function( fn, context ) { + var tmp, args, proxy; + + if ( typeof context === "string" ) { + tmp = fn[ context ]; + context = fn; + fn = tmp; + } + + // Quick check to determine if target is callable, in the spec + // this throws a TypeError, but we will just return undefined. + if ( !jQuery.isFunction( fn ) ) { + return undefined; + } + + // Simulated bind + args = slice.call( arguments, 2 ); + proxy = function() { + return fn.apply( context || this, args.concat( slice.call( arguments ) ) ); + }; + + // Set the guid of unique handler to the same of original handler, so it can be removed + proxy.guid = fn.guid = fn.guid || jQuery.guid++; + + return proxy; + }, + + now: Date.now, + + // jQuery.support is not used in Core but other projects attach their + // properties to it so it needs to exist. + support: support +} ); + +if ( typeof Symbol === "function" ) { + jQuery.fn[ Symbol.iterator ] = arr[ Symbol.iterator ]; +} + +// Populate the class2type map +jQuery.each( "Boolean Number String Function Array Date RegExp Object Error Symbol".split( " " ), +function( i, name ) { + class2type[ "[object " + name + "]" ] = name.toLowerCase(); +} ); + +function isArrayLike( obj ) { + + // Support: real iOS 8.2 only (not reproducible in simulator) + // `in` check used to prevent JIT error (gh-2145) + // hasOwn isn't used here due to false negatives + // regarding Nodelist length in IE + var length = !!obj && "length" in obj && obj.length, + type = jQuery.type( obj ); + + if ( type === "function" || jQuery.isWindow( obj ) ) { + return false; + } + + return type === "array" || length === 0 || + typeof length === "number" && length > 0 && ( length - 1 ) in obj; +} +var Sizzle = +/*! + * Sizzle CSS Selector Engine v2.3.3 + * https://sizzlejs.com/ + * + * Copyright jQuery Foundation and other contributors + * Released under the MIT license + * http://jquery.org/license + * + * Date: 2016-08-08 + */ +(function( window ) { + +var i, + support, + Expr, + getText, + isXML, + tokenize, + compile, + select, + outermostContext, + sortInput, + hasDuplicate, + + // Local document vars + setDocument, + document, + docElem, + documentIsHTML, + rbuggyQSA, + rbuggyMatches, + matches, + contains, + + // Instance-specific data + expando = "sizzle" + 1 * new Date(), + preferredDoc = window.document, + dirruns = 0, + done = 0, + classCache = createCache(), + tokenCache = createCache(), + compilerCache = createCache(), + sortOrder = function( a, b ) { + if ( a === b ) { + hasDuplicate = true; + } + return 0; + }, + + // Instance methods + hasOwn = ({}).hasOwnProperty, + arr = [], + pop = arr.pop, + push_native = arr.push, + push = arr.push, + slice = arr.slice, + // Use a stripped-down indexOf as it's faster than native + // https://jsperf.com/thor-indexof-vs-for/5 + indexOf = function( list, elem ) { + var i = 0, + len = list.length; + for ( ; i < len; i++ ) { + if ( list[i] === elem ) { + return i; + } + } + return -1; + }, + + booleans = "checked|selected|async|autofocus|autoplay|controls|defer|disabled|hidden|ismap|loop|multiple|open|readonly|required|scoped", + + // Regular expressions + + // http://www.w3.org/TR/css3-selectors/#whitespace + whitespace = "[\\x20\\t\\r\\n\\f]", + + // http://www.w3.org/TR/CSS21/syndata.html#value-def-identifier + identifier = "(?:\\\\.|[\\w-]|[^\0-\\xa0])+", + + // Attribute selectors: http://www.w3.org/TR/selectors/#attribute-selectors + attributes = "\\[" + whitespace + "*(" + identifier + ")(?:" + whitespace + + // Operator (capture 2) + "*([*^$|!~]?=)" + whitespace + + // "Attribute values must be CSS identifiers [capture 5] or strings [capture 3 or capture 4]" + "*(?:'((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\"|(" + identifier + "))|)" + whitespace + + "*\\]", + + pseudos = ":(" + identifier + ")(?:\\((" + + // To reduce the number of selectors needing tokenize in the preFilter, prefer arguments: + // 1. quoted (capture 3; capture 4 or capture 5) + "('((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\")|" + + // 2. simple (capture 6) + "((?:\\\\.|[^\\\\()[\\]]|" + attributes + ")*)|" + + // 3. anything else (capture 2) + ".*" + + ")\\)|)", + + // Leading and non-escaped trailing whitespace, capturing some non-whitespace characters preceding the latter + rwhitespace = new RegExp( whitespace + "+", "g" ), + rtrim = new RegExp( "^" + whitespace + "+|((?:^|[^\\\\])(?:\\\\.)*)" + whitespace + "+$", "g" ), + + rcomma = new RegExp( "^" + whitespace + "*," + whitespace + "*" ), + rcombinators = new RegExp( "^" + whitespace + "*([>+~]|" + whitespace + ")" + whitespace + "*" ), + + rattributeQuotes = new RegExp( "=" + whitespace + "*([^\\]'\"]*?)" + whitespace + "*\\]", "g" ), + + rpseudo = new RegExp( pseudos ), + ridentifier = new RegExp( "^" + identifier + "$" ), + + matchExpr = { + "ID": new RegExp( "^#(" + identifier + ")" ), + "CLASS": new RegExp( "^\\.(" + identifier + ")" ), + "TAG": new RegExp( "^(" + identifier + "|[*])" ), + "ATTR": new RegExp( "^" + attributes ), + "PSEUDO": new RegExp( "^" + pseudos ), + "CHILD": new RegExp( "^:(only|first|last|nth|nth-last)-(child|of-type)(?:\\(" + whitespace + + "*(even|odd|(([+-]|)(\\d*)n|)" + whitespace + "*(?:([+-]|)" + whitespace + + "*(\\d+)|))" + whitespace + "*\\)|)", "i" ), + "bool": new RegExp( "^(?:" + booleans + ")$", "i" ), + // For use in libraries implementing .is() + // We use this for POS matching in `select` + "needsContext": new RegExp( "^" + whitespace + "*[>+~]|:(even|odd|eq|gt|lt|nth|first|last)(?:\\(" + + whitespace + "*((?:-\\d)?\\d*)" + whitespace + "*\\)|)(?=[^-]|$)", "i" ) + }, + + rinputs = /^(?:input|select|textarea|button)$/i, + rheader = /^h\d$/i, + + rnative = /^[^{]+\{\s*\[native \w/, + + // Easily-parseable/retrievable ID or TAG or CLASS selectors + rquickExpr = /^(?:#([\w-]+)|(\w+)|\.([\w-]+))$/, + + rsibling = /[+~]/, + + // CSS escapes + // http://www.w3.org/TR/CSS21/syndata.html#escaped-characters + runescape = new RegExp( "\\\\([\\da-f]{1,6}" + whitespace + "?|(" + whitespace + ")|.)", "ig" ), + funescape = function( _, escaped, escapedWhitespace ) { + var high = "0x" + escaped - 0x10000; + // NaN means non-codepoint + // Support: Firefox<24 + // Workaround erroneous numeric interpretation of +"0x" + return high !== high || escapedWhitespace ? + escaped : + high < 0 ? + // BMP codepoint + String.fromCharCode( high + 0x10000 ) : + // Supplemental Plane codepoint (surrogate pair) + String.fromCharCode( high >> 10 | 0xD800, high & 0x3FF | 0xDC00 ); + }, + + // CSS string/identifier serialization + // https://drafts.csswg.org/cssom/#common-serializing-idioms + rcssescape = /([\0-\x1f\x7f]|^-?\d)|^-$|[^\0-\x1f\x7f-\uFFFF\w-]/g, + fcssescape = function( ch, asCodePoint ) { + if ( asCodePoint ) { + + // U+0000 NULL becomes U+FFFD REPLACEMENT CHARACTER + if ( ch === "\0" ) { + return "\uFFFD"; + } + + // Control characters and (dependent upon position) numbers get escaped as code points + return ch.slice( 0, -1 ) + "\\" + ch.charCodeAt( ch.length - 1 ).toString( 16 ) + " "; + } + + // Other potentially-special ASCII characters get backslash-escaped + return "\\" + ch; + }, + + // Used for iframes + // See setDocument() + // Removing the function wrapper causes a "Permission Denied" + // error in IE + unloadHandler = function() { + setDocument(); + }, + + disabledAncestor = addCombinator( + function( elem ) { + return elem.disabled === true && ("form" in elem || "label" in elem); + }, + { dir: "parentNode", next: "legend" } + ); + +// Optimize for push.apply( _, NodeList ) +try { + push.apply( + (arr = slice.call( preferredDoc.childNodes )), + preferredDoc.childNodes + ); + // Support: Android<4.0 + // Detect silently failing push.apply + arr[ preferredDoc.childNodes.length ].nodeType; +} catch ( e ) { + push = { apply: arr.length ? + + // Leverage slice if possible + function( target, els ) { + push_native.apply( target, slice.call(els) ); + } : + + // Support: IE<9 + // Otherwise append directly + function( target, els ) { + var j = target.length, + i = 0; + // Can't trust NodeList.length + while ( (target[j++] = els[i++]) ) {} + target.length = j - 1; + } + }; +} + +function Sizzle( selector, context, results, seed ) { + var m, i, elem, nid, match, groups, newSelector, + newContext = context && context.ownerDocument, + + // nodeType defaults to 9, since context defaults to document + nodeType = context ? context.nodeType : 9; + + results = results || []; + + // Return early from calls with invalid selector or context + if ( typeof selector !== "string" || !selector || + nodeType !== 1 && nodeType !== 9 && nodeType !== 11 ) { + + return results; + } + + // Try to shortcut find operations (as opposed to filters) in HTML documents + if ( !seed ) { + + if ( ( context ? context.ownerDocument || context : preferredDoc ) !== document ) { + setDocument( context ); + } + context = context || document; + + if ( documentIsHTML ) { + + // If the selector is sufficiently simple, try using a "get*By*" DOM method + // (excepting DocumentFragment context, where the methods don't exist) + if ( nodeType !== 11 && (match = rquickExpr.exec( selector )) ) { + + // ID selector + if ( (m = match[1]) ) { + + // Document context + if ( nodeType === 9 ) { + if ( (elem = context.getElementById( m )) ) { + + // Support: IE, Opera, Webkit + // TODO: identify versions + // getElementById can match elements by name instead of ID + if ( elem.id === m ) { + results.push( elem ); + return results; + } + } else { + return results; + } + + // Element context + } else { + + // Support: IE, Opera, Webkit + // TODO: identify versions + // getElementById can match elements by name instead of ID + if ( newContext && (elem = newContext.getElementById( m )) && + contains( context, elem ) && + elem.id === m ) { + + results.push( elem ); + return results; + } + } + + // Type selector + } else if ( match[2] ) { + push.apply( results, context.getElementsByTagName( selector ) ); + return results; + + // Class selector + } else if ( (m = match[3]) && support.getElementsByClassName && + context.getElementsByClassName ) { + + push.apply( results, context.getElementsByClassName( m ) ); + return results; + } + } + + // Take advantage of querySelectorAll + if ( support.qsa && + !compilerCache[ selector + " " ] && + (!rbuggyQSA || !rbuggyQSA.test( selector )) ) { + + if ( nodeType !== 1 ) { + newContext = context; + newSelector = selector; + + // qSA looks outside Element context, which is not what we want + // Thanks to Andrew Dupont for this workaround technique + // Support: IE <=8 + // Exclude object elements + } else if ( context.nodeName.toLowerCase() !== "object" ) { + + // Capture the context ID, setting it first if necessary + if ( (nid = context.getAttribute( "id" )) ) { + nid = nid.replace( rcssescape, fcssescape ); + } else { + context.setAttribute( "id", (nid = expando) ); + } + + // Prefix every selector in the list + groups = tokenize( selector ); + i = groups.length; + while ( i-- ) { + groups[i] = "#" + nid + " " + toSelector( groups[i] ); + } + newSelector = groups.join( "," ); + + // Expand context for sibling selectors + newContext = rsibling.test( selector ) && testContext( context.parentNode ) || + context; + } + + if ( newSelector ) { + try { + push.apply( results, + newContext.querySelectorAll( newSelector ) + ); + return results; + } catch ( qsaError ) { + } finally { + if ( nid === expando ) { + context.removeAttribute( "id" ); + } + } + } + } + } + } + + // All others + return select( selector.replace( rtrim, "$1" ), context, results, seed ); +} + +/** + * Create key-value caches of limited size + * @returns {function(string, object)} Returns the Object data after storing it on itself with + * property name the (space-suffixed) string and (if the cache is larger than Expr.cacheLength) + * deleting the oldest entry + */ +function createCache() { + var keys = []; + + function cache( key, value ) { + // Use (key + " ") to avoid collision with native prototype properties (see Issue #157) + if ( keys.push( key + " " ) > Expr.cacheLength ) { + // Only keep the most recent entries + delete cache[ keys.shift() ]; + } + return (cache[ key + " " ] = value); + } + return cache; +} + +/** + * Mark a function for special use by Sizzle + * @param {Function} fn The function to mark + */ +function markFunction( fn ) { + fn[ expando ] = true; + return fn; +} + +/** + * Support testing using an element + * @param {Function} fn Passed the created element and returns a boolean result + */ +function assert( fn ) { + var el = document.createElement("fieldset"); + + try { + return !!fn( el ); + } catch (e) { + return false; + } finally { + // Remove from its parent by default + if ( el.parentNode ) { + el.parentNode.removeChild( el ); + } + // release memory in IE + el = null; + } +} + +/** + * Adds the same handler for all of the specified attrs + * @param {String} attrs Pipe-separated list of attributes + * @param {Function} handler The method that will be applied + */ +function addHandle( attrs, handler ) { + var arr = attrs.split("|"), + i = arr.length; + + while ( i-- ) { + Expr.attrHandle[ arr[i] ] = handler; + } +} + +/** + * Checks document order of two siblings + * @param {Element} a + * @param {Element} b + * @returns {Number} Returns less than 0 if a precedes b, greater than 0 if a follows b + */ +function siblingCheck( a, b ) { + var cur = b && a, + diff = cur && a.nodeType === 1 && b.nodeType === 1 && + a.sourceIndex - b.sourceIndex; + + // Use IE sourceIndex if available on both nodes + if ( diff ) { + return diff; + } + + // Check if b follows a + if ( cur ) { + while ( (cur = cur.nextSibling) ) { + if ( cur === b ) { + return -1; + } + } + } + + return a ? 1 : -1; +} + +/** + * Returns a function to use in pseudos for input types + * @param {String} type + */ +function createInputPseudo( type ) { + return function( elem ) { + var name = elem.nodeName.toLowerCase(); + return name === "input" && elem.type === type; + }; +} + +/** + * Returns a function to use in pseudos for buttons + * @param {String} type + */ +function createButtonPseudo( type ) { + return function( elem ) { + var name = elem.nodeName.toLowerCase(); + return (name === "input" || name === "button") && elem.type === type; + }; +} + +/** + * Returns a function to use in pseudos for :enabled/:disabled + * @param {Boolean} disabled true for :disabled; false for :enabled + */ +function createDisabledPseudo( disabled ) { + + // Known :disabled false positives: fieldset[disabled] > legend:nth-of-type(n+2) :can-disable + return function( elem ) { + + // Only certain elements can match :enabled or :disabled + // https://html.spec.whatwg.org/multipage/scripting.html#selector-enabled + // https://html.spec.whatwg.org/multipage/scripting.html#selector-disabled + if ( "form" in elem ) { + + // Check for inherited disabledness on relevant non-disabled elements: + // * listed form-associated elements in a disabled fieldset + // https://html.spec.whatwg.org/multipage/forms.html#category-listed + // https://html.spec.whatwg.org/multipage/forms.html#concept-fe-disabled + // * option elements in a disabled optgroup + // https://html.spec.whatwg.org/multipage/forms.html#concept-option-disabled + // All such elements have a "form" property. + if ( elem.parentNode && elem.disabled === false ) { + + // Option elements defer to a parent optgroup if present + if ( "label" in elem ) { + if ( "label" in elem.parentNode ) { + return elem.parentNode.disabled === disabled; + } else { + return elem.disabled === disabled; + } + } + + // Support: IE 6 - 11 + // Use the isDisabled shortcut property to check for disabled fieldset ancestors + return elem.isDisabled === disabled || + + // Where there is no isDisabled, check manually + /* jshint -W018 */ + elem.isDisabled !== !disabled && + disabledAncestor( elem ) === disabled; + } + + return elem.disabled === disabled; + + // Try to winnow out elements that can't be disabled before trusting the disabled property. + // Some victims get caught in our net (label, legend, menu, track), but it shouldn't + // even exist on them, let alone have a boolean value. + } else if ( "label" in elem ) { + return elem.disabled === disabled; + } + + // Remaining elements are neither :enabled nor :disabled + return false; + }; +} + +/** + * Returns a function to use in pseudos for positionals + * @param {Function} fn + */ +function createPositionalPseudo( fn ) { + return markFunction(function( argument ) { + argument = +argument; + return markFunction(function( seed, matches ) { + var j, + matchIndexes = fn( [], seed.length, argument ), + i = matchIndexes.length; + + // Match elements found at the specified indexes + while ( i-- ) { + if ( seed[ (j = matchIndexes[i]) ] ) { + seed[j] = !(matches[j] = seed[j]); + } + } + }); + }); +} + +/** + * Checks a node for validity as a Sizzle context + * @param {Element|Object=} context + * @returns {Element|Object|Boolean} The input node if acceptable, otherwise a falsy value + */ +function testContext( context ) { + return context && typeof context.getElementsByTagName !== "undefined" && context; +} + +// Expose support vars for convenience +support = Sizzle.support = {}; + +/** + * Detects XML nodes + * @param {Element|Object} elem An element or a document + * @returns {Boolean} True iff elem is a non-HTML XML node + */ +isXML = Sizzle.isXML = function( elem ) { + // documentElement is verified for cases where it doesn't yet exist + // (such as loading iframes in IE - #4833) + var documentElement = elem && (elem.ownerDocument || elem).documentElement; + return documentElement ? documentElement.nodeName !== "HTML" : false; +}; + +/** + * Sets document-related variables once based on the current document + * @param {Element|Object} [doc] An element or document object to use to set the document + * @returns {Object} Returns the current document + */ +setDocument = Sizzle.setDocument = function( node ) { + var hasCompare, subWindow, + doc = node ? node.ownerDocument || node : preferredDoc; + + // Return early if doc is invalid or already selected + if ( doc === document || doc.nodeType !== 9 || !doc.documentElement ) { + return document; + } + + // Update global variables + document = doc; + docElem = document.documentElement; + documentIsHTML = !isXML( document ); + + // Support: IE 9-11, Edge + // Accessing iframe documents after unload throws "permission denied" errors (jQuery #13936) + if ( preferredDoc !== document && + (subWindow = document.defaultView) && subWindow.top !== subWindow ) { + + // Support: IE 11, Edge + if ( subWindow.addEventListener ) { + subWindow.addEventListener( "unload", unloadHandler, false ); + + // Support: IE 9 - 10 only + } else if ( subWindow.attachEvent ) { + subWindow.attachEvent( "onunload", unloadHandler ); + } + } + + /* Attributes + ---------------------------------------------------------------------- */ + + // Support: IE<8 + // Verify that getAttribute really returns attributes and not properties + // (excepting IE8 booleans) + support.attributes = assert(function( el ) { + el.className = "i"; + return !el.getAttribute("className"); + }); + + /* getElement(s)By* + ---------------------------------------------------------------------- */ + + // Check if getElementsByTagName("*") returns only elements + support.getElementsByTagName = assert(function( el ) { + el.appendChild( document.createComment("") ); + return !el.getElementsByTagName("*").length; + }); + + // Support: IE<9 + support.getElementsByClassName = rnative.test( document.getElementsByClassName ); + + // Support: IE<10 + // Check if getElementById returns elements by name + // The broken getElementById methods don't pick up programmatically-set names, + // so use a roundabout getElementsByName test + support.getById = assert(function( el ) { + docElem.appendChild( el ).id = expando; + return !document.getElementsByName || !document.getElementsByName( expando ).length; + }); + + // ID filter and find + if ( support.getById ) { + Expr.filter["ID"] = function( id ) { + var attrId = id.replace( runescape, funescape ); + return function( elem ) { + return elem.getAttribute("id") === attrId; + }; + }; + Expr.find["ID"] = function( id, context ) { + if ( typeof context.getElementById !== "undefined" && documentIsHTML ) { + var elem = context.getElementById( id ); + return elem ? [ elem ] : []; + } + }; + } else { + Expr.filter["ID"] = function( id ) { + var attrId = id.replace( runescape, funescape ); + return function( elem ) { + var node = typeof elem.getAttributeNode !== "undefined" && + elem.getAttributeNode("id"); + return node && node.value === attrId; + }; + }; + + // Support: IE 6 - 7 only + // getElementById is not reliable as a find shortcut + Expr.find["ID"] = function( id, context ) { + if ( typeof context.getElementById !== "undefined" && documentIsHTML ) { + var node, i, elems, + elem = context.getElementById( id ); + + if ( elem ) { + + // Verify the id attribute + node = elem.getAttributeNode("id"); + if ( node && node.value === id ) { + return [ elem ]; + } + + // Fall back on getElementsByName + elems = context.getElementsByName( id ); + i = 0; + while ( (elem = elems[i++]) ) { + node = elem.getAttributeNode("id"); + if ( node && node.value === id ) { + return [ elem ]; + } + } + } + + return []; + } + }; + } + + // Tag + Expr.find["TAG"] = support.getElementsByTagName ? + function( tag, context ) { + if ( typeof context.getElementsByTagName !== "undefined" ) { + return context.getElementsByTagName( tag ); + + // DocumentFragment nodes don't have gEBTN + } else if ( support.qsa ) { + return context.querySelectorAll( tag ); + } + } : + + function( tag, context ) { + var elem, + tmp = [], + i = 0, + // By happy coincidence, a (broken) gEBTN appears on DocumentFragment nodes too + results = context.getElementsByTagName( tag ); + + // Filter out possible comments + if ( tag === "*" ) { + while ( (elem = results[i++]) ) { + if ( elem.nodeType === 1 ) { + tmp.push( elem ); + } + } + + return tmp; + } + return results; + }; + + // Class + Expr.find["CLASS"] = support.getElementsByClassName && function( className, context ) { + if ( typeof context.getElementsByClassName !== "undefined" && documentIsHTML ) { + return context.getElementsByClassName( className ); + } + }; + + /* QSA/matchesSelector + ---------------------------------------------------------------------- */ + + // QSA and matchesSelector support + + // matchesSelector(:active) reports false when true (IE9/Opera 11.5) + rbuggyMatches = []; + + // qSa(:focus) reports false when true (Chrome 21) + // We allow this because of a bug in IE8/9 that throws an error + // whenever `document.activeElement` is accessed on an iframe + // So, we allow :focus to pass through QSA all the time to avoid the IE error + // See https://bugs.jquery.com/ticket/13378 + rbuggyQSA = []; + + if ( (support.qsa = rnative.test( document.querySelectorAll )) ) { + // Build QSA regex + // Regex strategy adopted from Diego Perini + assert(function( el ) { + // Select is set to empty string on purpose + // This is to test IE's treatment of not explicitly + // setting a boolean content attribute, + // since its presence should be enough + // https://bugs.jquery.com/ticket/12359 + docElem.appendChild( el ).innerHTML = "" + + ""; + + // Support: IE8, Opera 11-12.16 + // Nothing should be selected when empty strings follow ^= or $= or *= + // The test attribute must be unknown in Opera but "safe" for WinRT + // https://msdn.microsoft.com/en-us/library/ie/hh465388.aspx#attribute_section + if ( el.querySelectorAll("[msallowcapture^='']").length ) { + rbuggyQSA.push( "[*^$]=" + whitespace + "*(?:''|\"\")" ); + } + + // Support: IE8 + // Boolean attributes and "value" are not treated correctly + if ( !el.querySelectorAll("[selected]").length ) { + rbuggyQSA.push( "\\[" + whitespace + "*(?:value|" + booleans + ")" ); + } + + // Support: Chrome<29, Android<4.4, Safari<7.0+, iOS<7.0+, PhantomJS<1.9.8+ + if ( !el.querySelectorAll( "[id~=" + expando + "-]" ).length ) { + rbuggyQSA.push("~="); + } + + // Webkit/Opera - :checked should return selected option elements + // http://www.w3.org/TR/2011/REC-css3-selectors-20110929/#checked + // IE8 throws error here and will not see later tests + if ( !el.querySelectorAll(":checked").length ) { + rbuggyQSA.push(":checked"); + } + + // Support: Safari 8+, iOS 8+ + // https://bugs.webkit.org/show_bug.cgi?id=136851 + // In-page `selector#id sibling-combinator selector` fails + if ( !el.querySelectorAll( "a#" + expando + "+*" ).length ) { + rbuggyQSA.push(".#.+[+~]"); + } + }); + + assert(function( el ) { + el.innerHTML = "" + + ""; + + // Support: Windows 8 Native Apps + // The type and name attributes are restricted during .innerHTML assignment + var input = document.createElement("input"); + input.setAttribute( "type", "hidden" ); + el.appendChild( input ).setAttribute( "name", "D" ); + + // Support: IE8 + // Enforce case-sensitivity of name attribute + if ( el.querySelectorAll("[name=d]").length ) { + rbuggyQSA.push( "name" + whitespace + "*[*^$|!~]?=" ); + } + + // FF 3.5 - :enabled/:disabled and hidden elements (hidden elements are still enabled) + // IE8 throws error here and will not see later tests + if ( el.querySelectorAll(":enabled").length !== 2 ) { + rbuggyQSA.push( ":enabled", ":disabled" ); + } + + // Support: IE9-11+ + // IE's :disabled selector does not pick up the children of disabled fieldsets + docElem.appendChild( el ).disabled = true; + if ( el.querySelectorAll(":disabled").length !== 2 ) { + rbuggyQSA.push( ":enabled", ":disabled" ); + } + + // Opera 10-11 does not throw on post-comma invalid pseudos + el.querySelectorAll("*,:x"); + rbuggyQSA.push(",.*:"); + }); + } + + if ( (support.matchesSelector = rnative.test( (matches = docElem.matches || + docElem.webkitMatchesSelector || + docElem.mozMatchesSelector || + docElem.oMatchesSelector || + docElem.msMatchesSelector) )) ) { + + assert(function( el ) { + // Check to see if it's possible to do matchesSelector + // on a disconnected node (IE 9) + support.disconnectedMatch = matches.call( el, "*" ); + + // This should fail with an exception + // Gecko does not error, returns false instead + matches.call( el, "[s!='']:x" ); + rbuggyMatches.push( "!=", pseudos ); + }); + } + + rbuggyQSA = rbuggyQSA.length && new RegExp( rbuggyQSA.join("|") ); + rbuggyMatches = rbuggyMatches.length && new RegExp( rbuggyMatches.join("|") ); + + /* Contains + ---------------------------------------------------------------------- */ + hasCompare = rnative.test( docElem.compareDocumentPosition ); + + // Element contains another + // Purposefully self-exclusive + // As in, an element does not contain itself + contains = hasCompare || rnative.test( docElem.contains ) ? + function( a, b ) { + var adown = a.nodeType === 9 ? a.documentElement : a, + bup = b && b.parentNode; + return a === bup || !!( bup && bup.nodeType === 1 && ( + adown.contains ? + adown.contains( bup ) : + a.compareDocumentPosition && a.compareDocumentPosition( bup ) & 16 + )); + } : + function( a, b ) { + if ( b ) { + while ( (b = b.parentNode) ) { + if ( b === a ) { + return true; + } + } + } + return false; + }; + + /* Sorting + ---------------------------------------------------------------------- */ + + // Document order sorting + sortOrder = hasCompare ? + function( a, b ) { + + // Flag for duplicate removal + if ( a === b ) { + hasDuplicate = true; + return 0; + } + + // Sort on method existence if only one input has compareDocumentPosition + var compare = !a.compareDocumentPosition - !b.compareDocumentPosition; + if ( compare ) { + return compare; + } + + // Calculate position if both inputs belong to the same document + compare = ( a.ownerDocument || a ) === ( b.ownerDocument || b ) ? + a.compareDocumentPosition( b ) : + + // Otherwise we know they are disconnected + 1; + + // Disconnected nodes + if ( compare & 1 || + (!support.sortDetached && b.compareDocumentPosition( a ) === compare) ) { + + // Choose the first element that is related to our preferred document + if ( a === document || a.ownerDocument === preferredDoc && contains(preferredDoc, a) ) { + return -1; + } + if ( b === document || b.ownerDocument === preferredDoc && contains(preferredDoc, b) ) { + return 1; + } + + // Maintain original order + return sortInput ? + ( indexOf( sortInput, a ) - indexOf( sortInput, b ) ) : + 0; + } + + return compare & 4 ? -1 : 1; + } : + function( a, b ) { + // Exit early if the nodes are identical + if ( a === b ) { + hasDuplicate = true; + return 0; + } + + var cur, + i = 0, + aup = a.parentNode, + bup = b.parentNode, + ap = [ a ], + bp = [ b ]; + + // Parentless nodes are either documents or disconnected + if ( !aup || !bup ) { + return a === document ? -1 : + b === document ? 1 : + aup ? -1 : + bup ? 1 : + sortInput ? + ( indexOf( sortInput, a ) - indexOf( sortInput, b ) ) : + 0; + + // If the nodes are siblings, we can do a quick check + } else if ( aup === bup ) { + return siblingCheck( a, b ); + } + + // Otherwise we need full lists of their ancestors for comparison + cur = a; + while ( (cur = cur.parentNode) ) { + ap.unshift( cur ); + } + cur = b; + while ( (cur = cur.parentNode) ) { + bp.unshift( cur ); + } + + // Walk down the tree looking for a discrepancy + while ( ap[i] === bp[i] ) { + i++; + } + + return i ? + // Do a sibling check if the nodes have a common ancestor + siblingCheck( ap[i], bp[i] ) : + + // Otherwise nodes in our document sort first + ap[i] === preferredDoc ? -1 : + bp[i] === preferredDoc ? 1 : + 0; + }; + + return document; +}; + +Sizzle.matches = function( expr, elements ) { + return Sizzle( expr, null, null, elements ); +}; + +Sizzle.matchesSelector = function( elem, expr ) { + // Set document vars if needed + if ( ( elem.ownerDocument || elem ) !== document ) { + setDocument( elem ); + } + + // Make sure that attribute selectors are quoted + expr = expr.replace( rattributeQuotes, "='$1']" ); + + if ( support.matchesSelector && documentIsHTML && + !compilerCache[ expr + " " ] && + ( !rbuggyMatches || !rbuggyMatches.test( expr ) ) && + ( !rbuggyQSA || !rbuggyQSA.test( expr ) ) ) { + + try { + var ret = matches.call( elem, expr ); + + // IE 9's matchesSelector returns false on disconnected nodes + if ( ret || support.disconnectedMatch || + // As well, disconnected nodes are said to be in a document + // fragment in IE 9 + elem.document && elem.document.nodeType !== 11 ) { + return ret; + } + } catch (e) {} + } + + return Sizzle( expr, document, null, [ elem ] ).length > 0; +}; + +Sizzle.contains = function( context, elem ) { + // Set document vars if needed + if ( ( context.ownerDocument || context ) !== document ) { + setDocument( context ); + } + return contains( context, elem ); +}; + +Sizzle.attr = function( elem, name ) { + // Set document vars if needed + if ( ( elem.ownerDocument || elem ) !== document ) { + setDocument( elem ); + } + + var fn = Expr.attrHandle[ name.toLowerCase() ], + // Don't get fooled by Object.prototype properties (jQuery #13807) + val = fn && hasOwn.call( Expr.attrHandle, name.toLowerCase() ) ? + fn( elem, name, !documentIsHTML ) : + undefined; + + return val !== undefined ? + val : + support.attributes || !documentIsHTML ? + elem.getAttribute( name ) : + (val = elem.getAttributeNode(name)) && val.specified ? + val.value : + null; +}; + +Sizzle.escape = function( sel ) { + return (sel + "").replace( rcssescape, fcssescape ); +}; + +Sizzle.error = function( msg ) { + throw new Error( "Syntax error, unrecognized expression: " + msg ); +}; + +/** + * Document sorting and removing duplicates + * @param {ArrayLike} results + */ +Sizzle.uniqueSort = function( results ) { + var elem, + duplicates = [], + j = 0, + i = 0; + + // Unless we *know* we can detect duplicates, assume their presence + hasDuplicate = !support.detectDuplicates; + sortInput = !support.sortStable && results.slice( 0 ); + results.sort( sortOrder ); + + if ( hasDuplicate ) { + while ( (elem = results[i++]) ) { + if ( elem === results[ i ] ) { + j = duplicates.push( i ); + } + } + while ( j-- ) { + results.splice( duplicates[ j ], 1 ); + } + } + + // Clear input after sorting to release objects + // See https://github.com/jquery/sizzle/pull/225 + sortInput = null; + + return results; +}; + +/** + * Utility function for retrieving the text value of an array of DOM nodes + * @param {Array|Element} elem + */ +getText = Sizzle.getText = function( elem ) { + var node, + ret = "", + i = 0, + nodeType = elem.nodeType; + + if ( !nodeType ) { + // If no nodeType, this is expected to be an array + while ( (node = elem[i++]) ) { + // Do not traverse comment nodes + ret += getText( node ); + } + } else if ( nodeType === 1 || nodeType === 9 || nodeType === 11 ) { + // Use textContent for elements + // innerText usage removed for consistency of new lines (jQuery #11153) + if ( typeof elem.textContent === "string" ) { + return elem.textContent; + } else { + // Traverse its children + for ( elem = elem.firstChild; elem; elem = elem.nextSibling ) { + ret += getText( elem ); + } + } + } else if ( nodeType === 3 || nodeType === 4 ) { + return elem.nodeValue; + } + // Do not include comment or processing instruction nodes + + return ret; +}; + +Expr = Sizzle.selectors = { + + // Can be adjusted by the user + cacheLength: 50, + + createPseudo: markFunction, + + match: matchExpr, + + attrHandle: {}, + + find: {}, + + relative: { + ">": { dir: "parentNode", first: true }, + " ": { dir: "parentNode" }, + "+": { dir: "previousSibling", first: true }, + "~": { dir: "previousSibling" } + }, + + preFilter: { + "ATTR": function( match ) { + match[1] = match[1].replace( runescape, funescape ); + + // Move the given value to match[3] whether quoted or unquoted + match[3] = ( match[3] || match[4] || match[5] || "" ).replace( runescape, funescape ); + + if ( match[2] === "~=" ) { + match[3] = " " + match[3] + " "; + } + + return match.slice( 0, 4 ); + }, + + "CHILD": function( match ) { + /* matches from matchExpr["CHILD"] + 1 type (only|nth|...) + 2 what (child|of-type) + 3 argument (even|odd|\d*|\d*n([+-]\d+)?|...) + 4 xn-component of xn+y argument ([+-]?\d*n|) + 5 sign of xn-component + 6 x of xn-component + 7 sign of y-component + 8 y of y-component + */ + match[1] = match[1].toLowerCase(); + + if ( match[1].slice( 0, 3 ) === "nth" ) { + // nth-* requires argument + if ( !match[3] ) { + Sizzle.error( match[0] ); + } + + // numeric x and y parameters for Expr.filter.CHILD + // remember that false/true cast respectively to 0/1 + match[4] = +( match[4] ? match[5] + (match[6] || 1) : 2 * ( match[3] === "even" || match[3] === "odd" ) ); + match[5] = +( ( match[7] + match[8] ) || match[3] === "odd" ); + + // other types prohibit arguments + } else if ( match[3] ) { + Sizzle.error( match[0] ); + } + + return match; + }, + + "PSEUDO": function( match ) { + var excess, + unquoted = !match[6] && match[2]; + + if ( matchExpr["CHILD"].test( match[0] ) ) { + return null; + } + + // Accept quoted arguments as-is + if ( match[3] ) { + match[2] = match[4] || match[5] || ""; + + // Strip excess characters from unquoted arguments + } else if ( unquoted && rpseudo.test( unquoted ) && + // Get excess from tokenize (recursively) + (excess = tokenize( unquoted, true )) && + // advance to the next closing parenthesis + (excess = unquoted.indexOf( ")", unquoted.length - excess ) - unquoted.length) ) { + + // excess is a negative index + match[0] = match[0].slice( 0, excess ); + match[2] = unquoted.slice( 0, excess ); + } + + // Return only captures needed by the pseudo filter method (type and argument) + return match.slice( 0, 3 ); + } + }, + + filter: { + + "TAG": function( nodeNameSelector ) { + var nodeName = nodeNameSelector.replace( runescape, funescape ).toLowerCase(); + return nodeNameSelector === "*" ? + function() { return true; } : + function( elem ) { + return elem.nodeName && elem.nodeName.toLowerCase() === nodeName; + }; + }, + + "CLASS": function( className ) { + var pattern = classCache[ className + " " ]; + + return pattern || + (pattern = new RegExp( "(^|" + whitespace + ")" + className + "(" + whitespace + "|$)" )) && + classCache( className, function( elem ) { + return pattern.test( typeof elem.className === "string" && elem.className || typeof elem.getAttribute !== "undefined" && elem.getAttribute("class") || "" ); + }); + }, + + "ATTR": function( name, operator, check ) { + return function( elem ) { + var result = Sizzle.attr( elem, name ); + + if ( result == null ) { + return operator === "!="; + } + if ( !operator ) { + return true; + } + + result += ""; + + return operator === "=" ? result === check : + operator === "!=" ? result !== check : + operator === "^=" ? check && result.indexOf( check ) === 0 : + operator === "*=" ? check && result.indexOf( check ) > -1 : + operator === "$=" ? check && result.slice( -check.length ) === check : + operator === "~=" ? ( " " + result.replace( rwhitespace, " " ) + " " ).indexOf( check ) > -1 : + operator === "|=" ? result === check || result.slice( 0, check.length + 1 ) === check + "-" : + false; + }; + }, + + "CHILD": function( type, what, argument, first, last ) { + var simple = type.slice( 0, 3 ) !== "nth", + forward = type.slice( -4 ) !== "last", + ofType = what === "of-type"; + + return first === 1 && last === 0 ? + + // Shortcut for :nth-*(n) + function( elem ) { + return !!elem.parentNode; + } : + + function( elem, context, xml ) { + var cache, uniqueCache, outerCache, node, nodeIndex, start, + dir = simple !== forward ? "nextSibling" : "previousSibling", + parent = elem.parentNode, + name = ofType && elem.nodeName.toLowerCase(), + useCache = !xml && !ofType, + diff = false; + + if ( parent ) { + + // :(first|last|only)-(child|of-type) + if ( simple ) { + while ( dir ) { + node = elem; + while ( (node = node[ dir ]) ) { + if ( ofType ? + node.nodeName.toLowerCase() === name : + node.nodeType === 1 ) { + + return false; + } + } + // Reverse direction for :only-* (if we haven't yet done so) + start = dir = type === "only" && !start && "nextSibling"; + } + return true; + } + + start = [ forward ? parent.firstChild : parent.lastChild ]; + + // non-xml :nth-child(...) stores cache data on `parent` + if ( forward && useCache ) { + + // Seek `elem` from a previously-cached index + + // ...in a gzip-friendly way + node = parent; + outerCache = node[ expando ] || (node[ expando ] = {}); + + // Support: IE <9 only + // Defend against cloned attroperties (jQuery gh-1709) + uniqueCache = outerCache[ node.uniqueID ] || + (outerCache[ node.uniqueID ] = {}); + + cache = uniqueCache[ type ] || []; + nodeIndex = cache[ 0 ] === dirruns && cache[ 1 ]; + diff = nodeIndex && cache[ 2 ]; + node = nodeIndex && parent.childNodes[ nodeIndex ]; + + while ( (node = ++nodeIndex && node && node[ dir ] || + + // Fallback to seeking `elem` from the start + (diff = nodeIndex = 0) || start.pop()) ) { + + // When found, cache indexes on `parent` and break + if ( node.nodeType === 1 && ++diff && node === elem ) { + uniqueCache[ type ] = [ dirruns, nodeIndex, diff ]; + break; + } + } + + } else { + // Use previously-cached element index if available + if ( useCache ) { + // ...in a gzip-friendly way + node = elem; + outerCache = node[ expando ] || (node[ expando ] = {}); + + // Support: IE <9 only + // Defend against cloned attroperties (jQuery gh-1709) + uniqueCache = outerCache[ node.uniqueID ] || + (outerCache[ node.uniqueID ] = {}); + + cache = uniqueCache[ type ] || []; + nodeIndex = cache[ 0 ] === dirruns && cache[ 1 ]; + diff = nodeIndex; + } + + // xml :nth-child(...) + // or :nth-last-child(...) or :nth(-last)?-of-type(...) + if ( diff === false ) { + // Use the same loop as above to seek `elem` from the start + while ( (node = ++nodeIndex && node && node[ dir ] || + (diff = nodeIndex = 0) || start.pop()) ) { + + if ( ( ofType ? + node.nodeName.toLowerCase() === name : + node.nodeType === 1 ) && + ++diff ) { + + // Cache the index of each encountered element + if ( useCache ) { + outerCache = node[ expando ] || (node[ expando ] = {}); + + // Support: IE <9 only + // Defend against cloned attroperties (jQuery gh-1709) + uniqueCache = outerCache[ node.uniqueID ] || + (outerCache[ node.uniqueID ] = {}); + + uniqueCache[ type ] = [ dirruns, diff ]; + } + + if ( node === elem ) { + break; + } + } + } + } + } + + // Incorporate the offset, then check against cycle size + diff -= last; + return diff === first || ( diff % first === 0 && diff / first >= 0 ); + } + }; + }, + + "PSEUDO": function( pseudo, argument ) { + // pseudo-class names are case-insensitive + // http://www.w3.org/TR/selectors/#pseudo-classes + // Prioritize by case sensitivity in case custom pseudos are added with uppercase letters + // Remember that setFilters inherits from pseudos + var args, + fn = Expr.pseudos[ pseudo ] || Expr.setFilters[ pseudo.toLowerCase() ] || + Sizzle.error( "unsupported pseudo: " + pseudo ); + + // The user may use createPseudo to indicate that + // arguments are needed to create the filter function + // just as Sizzle does + if ( fn[ expando ] ) { + return fn( argument ); + } + + // But maintain support for old signatures + if ( fn.length > 1 ) { + args = [ pseudo, pseudo, "", argument ]; + return Expr.setFilters.hasOwnProperty( pseudo.toLowerCase() ) ? + markFunction(function( seed, matches ) { + var idx, + matched = fn( seed, argument ), + i = matched.length; + while ( i-- ) { + idx = indexOf( seed, matched[i] ); + seed[ idx ] = !( matches[ idx ] = matched[i] ); + } + }) : + function( elem ) { + return fn( elem, 0, args ); + }; + } + + return fn; + } + }, + + pseudos: { + // Potentially complex pseudos + "not": markFunction(function( selector ) { + // Trim the selector passed to compile + // to avoid treating leading and trailing + // spaces as combinators + var input = [], + results = [], + matcher = compile( selector.replace( rtrim, "$1" ) ); + + return matcher[ expando ] ? + markFunction(function( seed, matches, context, xml ) { + var elem, + unmatched = matcher( seed, null, xml, [] ), + i = seed.length; + + // Match elements unmatched by `matcher` + while ( i-- ) { + if ( (elem = unmatched[i]) ) { + seed[i] = !(matches[i] = elem); + } + } + }) : + function( elem, context, xml ) { + input[0] = elem; + matcher( input, null, xml, results ); + // Don't keep the element (issue #299) + input[0] = null; + return !results.pop(); + }; + }), + + "has": markFunction(function( selector ) { + return function( elem ) { + return Sizzle( selector, elem ).length > 0; + }; + }), + + "contains": markFunction(function( text ) { + text = text.replace( runescape, funescape ); + return function( elem ) { + return ( elem.textContent || elem.innerText || getText( elem ) ).indexOf( text ) > -1; + }; + }), + + // "Whether an element is represented by a :lang() selector + // is based solely on the element's language value + // being equal to the identifier C, + // or beginning with the identifier C immediately followed by "-". + // The matching of C against the element's language value is performed case-insensitively. + // The identifier C does not have to be a valid language name." + // http://www.w3.org/TR/selectors/#lang-pseudo + "lang": markFunction( function( lang ) { + // lang value must be a valid identifier + if ( !ridentifier.test(lang || "") ) { + Sizzle.error( "unsupported lang: " + lang ); + } + lang = lang.replace( runescape, funescape ).toLowerCase(); + return function( elem ) { + var elemLang; + do { + if ( (elemLang = documentIsHTML ? + elem.lang : + elem.getAttribute("xml:lang") || elem.getAttribute("lang")) ) { + + elemLang = elemLang.toLowerCase(); + return elemLang === lang || elemLang.indexOf( lang + "-" ) === 0; + } + } while ( (elem = elem.parentNode) && elem.nodeType === 1 ); + return false; + }; + }), + + // Miscellaneous + "target": function( elem ) { + var hash = window.location && window.location.hash; + return hash && hash.slice( 1 ) === elem.id; + }, + + "root": function( elem ) { + return elem === docElem; + }, + + "focus": function( elem ) { + return elem === document.activeElement && (!document.hasFocus || document.hasFocus()) && !!(elem.type || elem.href || ~elem.tabIndex); + }, + + // Boolean properties + "enabled": createDisabledPseudo( false ), + "disabled": createDisabledPseudo( true ), + + "checked": function( elem ) { + // In CSS3, :checked should return both checked and selected elements + // http://www.w3.org/TR/2011/REC-css3-selectors-20110929/#checked + var nodeName = elem.nodeName.toLowerCase(); + return (nodeName === "input" && !!elem.checked) || (nodeName === "option" && !!elem.selected); + }, + + "selected": function( elem ) { + // Accessing this property makes selected-by-default + // options in Safari work properly + if ( elem.parentNode ) { + elem.parentNode.selectedIndex; + } + + return elem.selected === true; + }, + + // Contents + "empty": function( elem ) { + // http://www.w3.org/TR/selectors/#empty-pseudo + // :empty is negated by element (1) or content nodes (text: 3; cdata: 4; entity ref: 5), + // but not by others (comment: 8; processing instruction: 7; etc.) + // nodeType < 6 works because attributes (2) do not appear as children + for ( elem = elem.firstChild; elem; elem = elem.nextSibling ) { + if ( elem.nodeType < 6 ) { + return false; + } + } + return true; + }, + + "parent": function( elem ) { + return !Expr.pseudos["empty"]( elem ); + }, + + // Element/input types + "header": function( elem ) { + return rheader.test( elem.nodeName ); + }, + + "input": function( elem ) { + return rinputs.test( elem.nodeName ); + }, + + "button": function( elem ) { + var name = elem.nodeName.toLowerCase(); + return name === "input" && elem.type === "button" || name === "button"; + }, + + "text": function( elem ) { + var attr; + return elem.nodeName.toLowerCase() === "input" && + elem.type === "text" && + + // Support: IE<8 + // New HTML5 attribute values (e.g., "search") appear with elem.type === "text" + ( (attr = elem.getAttribute("type")) == null || attr.toLowerCase() === "text" ); + }, + + // Position-in-collection + "first": createPositionalPseudo(function() { + return [ 0 ]; + }), + + "last": createPositionalPseudo(function( matchIndexes, length ) { + return [ length - 1 ]; + }), + + "eq": createPositionalPseudo(function( matchIndexes, length, argument ) { + return [ argument < 0 ? argument + length : argument ]; + }), + + "even": createPositionalPseudo(function( matchIndexes, length ) { + var i = 0; + for ( ; i < length; i += 2 ) { + matchIndexes.push( i ); + } + return matchIndexes; + }), + + "odd": createPositionalPseudo(function( matchIndexes, length ) { + var i = 1; + for ( ; i < length; i += 2 ) { + matchIndexes.push( i ); + } + return matchIndexes; + }), + + "lt": createPositionalPseudo(function( matchIndexes, length, argument ) { + var i = argument < 0 ? argument + length : argument; + for ( ; --i >= 0; ) { + matchIndexes.push( i ); + } + return matchIndexes; + }), + + "gt": createPositionalPseudo(function( matchIndexes, length, argument ) { + var i = argument < 0 ? argument + length : argument; + for ( ; ++i < length; ) { + matchIndexes.push( i ); + } + return matchIndexes; + }) + } +}; + +Expr.pseudos["nth"] = Expr.pseudos["eq"]; + +// Add button/input type pseudos +for ( i in { radio: true, checkbox: true, file: true, password: true, image: true } ) { + Expr.pseudos[ i ] = createInputPseudo( i ); +} +for ( i in { submit: true, reset: true } ) { + Expr.pseudos[ i ] = createButtonPseudo( i ); +} + +// Easy API for creating new setFilters +function setFilters() {} +setFilters.prototype = Expr.filters = Expr.pseudos; +Expr.setFilters = new setFilters(); + +tokenize = Sizzle.tokenize = function( selector, parseOnly ) { + var matched, match, tokens, type, + soFar, groups, preFilters, + cached = tokenCache[ selector + " " ]; + + if ( cached ) { + return parseOnly ? 0 : cached.slice( 0 ); + } + + soFar = selector; + groups = []; + preFilters = Expr.preFilter; + + while ( soFar ) { + + // Comma and first run + if ( !matched || (match = rcomma.exec( soFar )) ) { + if ( match ) { + // Don't consume trailing commas as valid + soFar = soFar.slice( match[0].length ) || soFar; + } + groups.push( (tokens = []) ); + } + + matched = false; + + // Combinators + if ( (match = rcombinators.exec( soFar )) ) { + matched = match.shift(); + tokens.push({ + value: matched, + // Cast descendant combinators to space + type: match[0].replace( rtrim, " " ) + }); + soFar = soFar.slice( matched.length ); + } + + // Filters + for ( type in Expr.filter ) { + if ( (match = matchExpr[ type ].exec( soFar )) && (!preFilters[ type ] || + (match = preFilters[ type ]( match ))) ) { + matched = match.shift(); + tokens.push({ + value: matched, + type: type, + matches: match + }); + soFar = soFar.slice( matched.length ); + } + } + + if ( !matched ) { + break; + } + } + + // Return the length of the invalid excess + // if we're just parsing + // Otherwise, throw an error or return tokens + return parseOnly ? + soFar.length : + soFar ? + Sizzle.error( selector ) : + // Cache the tokens + tokenCache( selector, groups ).slice( 0 ); +}; + +function toSelector( tokens ) { + var i = 0, + len = tokens.length, + selector = ""; + for ( ; i < len; i++ ) { + selector += tokens[i].value; + } + return selector; +} + +function addCombinator( matcher, combinator, base ) { + var dir = combinator.dir, + skip = combinator.next, + key = skip || dir, + checkNonElements = base && key === "parentNode", + doneName = done++; + + return combinator.first ? + // Check against closest ancestor/preceding element + function( elem, context, xml ) { + while ( (elem = elem[ dir ]) ) { + if ( elem.nodeType === 1 || checkNonElements ) { + return matcher( elem, context, xml ); + } + } + return false; + } : + + // Check against all ancestor/preceding elements + function( elem, context, xml ) { + var oldCache, uniqueCache, outerCache, + newCache = [ dirruns, doneName ]; + + // We can't set arbitrary data on XML nodes, so they don't benefit from combinator caching + if ( xml ) { + while ( (elem = elem[ dir ]) ) { + if ( elem.nodeType === 1 || checkNonElements ) { + if ( matcher( elem, context, xml ) ) { + return true; + } + } + } + } else { + while ( (elem = elem[ dir ]) ) { + if ( elem.nodeType === 1 || checkNonElements ) { + outerCache = elem[ expando ] || (elem[ expando ] = {}); + + // Support: IE <9 only + // Defend against cloned attroperties (jQuery gh-1709) + uniqueCache = outerCache[ elem.uniqueID ] || (outerCache[ elem.uniqueID ] = {}); + + if ( skip && skip === elem.nodeName.toLowerCase() ) { + elem = elem[ dir ] || elem; + } else if ( (oldCache = uniqueCache[ key ]) && + oldCache[ 0 ] === dirruns && oldCache[ 1 ] === doneName ) { + + // Assign to newCache so results back-propagate to previous elements + return (newCache[ 2 ] = oldCache[ 2 ]); + } else { + // Reuse newcache so results back-propagate to previous elements + uniqueCache[ key ] = newCache; + + // A match means we're done; a fail means we have to keep checking + if ( (newCache[ 2 ] = matcher( elem, context, xml )) ) { + return true; + } + } + } + } + } + return false; + }; +} + +function elementMatcher( matchers ) { + return matchers.length > 1 ? + function( elem, context, xml ) { + var i = matchers.length; + while ( i-- ) { + if ( !matchers[i]( elem, context, xml ) ) { + return false; + } + } + return true; + } : + matchers[0]; +} + +function multipleContexts( selector, contexts, results ) { + var i = 0, + len = contexts.length; + for ( ; i < len; i++ ) { + Sizzle( selector, contexts[i], results ); + } + return results; +} + +function condense( unmatched, map, filter, context, xml ) { + var elem, + newUnmatched = [], + i = 0, + len = unmatched.length, + mapped = map != null; + + for ( ; i < len; i++ ) { + if ( (elem = unmatched[i]) ) { + if ( !filter || filter( elem, context, xml ) ) { + newUnmatched.push( elem ); + if ( mapped ) { + map.push( i ); + } + } + } + } + + return newUnmatched; +} + +function setMatcher( preFilter, selector, matcher, postFilter, postFinder, postSelector ) { + if ( postFilter && !postFilter[ expando ] ) { + postFilter = setMatcher( postFilter ); + } + if ( postFinder && !postFinder[ expando ] ) { + postFinder = setMatcher( postFinder, postSelector ); + } + return markFunction(function( seed, results, context, xml ) { + var temp, i, elem, + preMap = [], + postMap = [], + preexisting = results.length, + + // Get initial elements from seed or context + elems = seed || multipleContexts( selector || "*", context.nodeType ? [ context ] : context, [] ), + + // Prefilter to get matcher input, preserving a map for seed-results synchronization + matcherIn = preFilter && ( seed || !selector ) ? + condense( elems, preMap, preFilter, context, xml ) : + elems, + + matcherOut = matcher ? + // If we have a postFinder, or filtered seed, or non-seed postFilter or preexisting results, + postFinder || ( seed ? preFilter : preexisting || postFilter ) ? + + // ...intermediate processing is necessary + [] : + + // ...otherwise use results directly + results : + matcherIn; + + // Find primary matches + if ( matcher ) { + matcher( matcherIn, matcherOut, context, xml ); + } + + // Apply postFilter + if ( postFilter ) { + temp = condense( matcherOut, postMap ); + postFilter( temp, [], context, xml ); + + // Un-match failing elements by moving them back to matcherIn + i = temp.length; + while ( i-- ) { + if ( (elem = temp[i]) ) { + matcherOut[ postMap[i] ] = !(matcherIn[ postMap[i] ] = elem); + } + } + } + + if ( seed ) { + if ( postFinder || preFilter ) { + if ( postFinder ) { + // Get the final matcherOut by condensing this intermediate into postFinder contexts + temp = []; + i = matcherOut.length; + while ( i-- ) { + if ( (elem = matcherOut[i]) ) { + // Restore matcherIn since elem is not yet a final match + temp.push( (matcherIn[i] = elem) ); + } + } + postFinder( null, (matcherOut = []), temp, xml ); + } + + // Move matched elements from seed to results to keep them synchronized + i = matcherOut.length; + while ( i-- ) { + if ( (elem = matcherOut[i]) && + (temp = postFinder ? indexOf( seed, elem ) : preMap[i]) > -1 ) { + + seed[temp] = !(results[temp] = elem); + } + } + } + + // Add elements to results, through postFinder if defined + } else { + matcherOut = condense( + matcherOut === results ? + matcherOut.splice( preexisting, matcherOut.length ) : + matcherOut + ); + if ( postFinder ) { + postFinder( null, results, matcherOut, xml ); + } else { + push.apply( results, matcherOut ); + } + } + }); +} + +function matcherFromTokens( tokens ) { + var checkContext, matcher, j, + len = tokens.length, + leadingRelative = Expr.relative[ tokens[0].type ], + implicitRelative = leadingRelative || Expr.relative[" "], + i = leadingRelative ? 1 : 0, + + // The foundational matcher ensures that elements are reachable from top-level context(s) + matchContext = addCombinator( function( elem ) { + return elem === checkContext; + }, implicitRelative, true ), + matchAnyContext = addCombinator( function( elem ) { + return indexOf( checkContext, elem ) > -1; + }, implicitRelative, true ), + matchers = [ function( elem, context, xml ) { + var ret = ( !leadingRelative && ( xml || context !== outermostContext ) ) || ( + (checkContext = context).nodeType ? + matchContext( elem, context, xml ) : + matchAnyContext( elem, context, xml ) ); + // Avoid hanging onto element (issue #299) + checkContext = null; + return ret; + } ]; + + for ( ; i < len; i++ ) { + if ( (matcher = Expr.relative[ tokens[i].type ]) ) { + matchers = [ addCombinator(elementMatcher( matchers ), matcher) ]; + } else { + matcher = Expr.filter[ tokens[i].type ].apply( null, tokens[i].matches ); + + // Return special upon seeing a positional matcher + if ( matcher[ expando ] ) { + // Find the next relative operator (if any) for proper handling + j = ++i; + for ( ; j < len; j++ ) { + if ( Expr.relative[ tokens[j].type ] ) { + break; + } + } + return setMatcher( + i > 1 && elementMatcher( matchers ), + i > 1 && toSelector( + // If the preceding token was a descendant combinator, insert an implicit any-element `*` + tokens.slice( 0, i - 1 ).concat({ value: tokens[ i - 2 ].type === " " ? "*" : "" }) + ).replace( rtrim, "$1" ), + matcher, + i < j && matcherFromTokens( tokens.slice( i, j ) ), + j < len && matcherFromTokens( (tokens = tokens.slice( j )) ), + j < len && toSelector( tokens ) + ); + } + matchers.push( matcher ); + } + } + + return elementMatcher( matchers ); +} + +function matcherFromGroupMatchers( elementMatchers, setMatchers ) { + var bySet = setMatchers.length > 0, + byElement = elementMatchers.length > 0, + superMatcher = function( seed, context, xml, results, outermost ) { + var elem, j, matcher, + matchedCount = 0, + i = "0", + unmatched = seed && [], + setMatched = [], + contextBackup = outermostContext, + // We must always have either seed elements or outermost context + elems = seed || byElement && Expr.find["TAG"]( "*", outermost ), + // Use integer dirruns iff this is the outermost matcher + dirrunsUnique = (dirruns += contextBackup == null ? 1 : Math.random() || 0.1), + len = elems.length; + + if ( outermost ) { + outermostContext = context === document || context || outermost; + } + + // Add elements passing elementMatchers directly to results + // Support: IE<9, Safari + // Tolerate NodeList properties (IE: "length"; Safari: