Skip to content

Commit bd53030

Browse files
authored
Refactor autograd package to separate Python dependencies. (pytorch#662)
The core autograd Variable, Function, and Engine no longer depend on the Python API. This let's us implement functions in C++. In the future, we can also multithread engine and release the GIL for most of the non-Python backwards.
1 parent 16d2c3d commit bd53030

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2970
-1767
lines changed

setup.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,23 @@ def run(self):
220220
"torch/csrc/Exceptions.cpp",
221221
"torch/csrc/Tensor.cpp",
222222
"torch/csrc/Storage.cpp",
223+
"torch/csrc/DynamicTypes.cpp",
223224
"torch/csrc/byte_order.cpp",
224225
"torch/csrc/utils.cpp",
226+
"torch/csrc/utils/object_ptr.cpp",
225227
"torch/csrc/allocators.cpp",
226228
"torch/csrc/serialization.cpp",
227229
"torch/csrc/autograd/init.cpp",
228-
"torch/csrc/autograd/variable.cpp",
229-
"torch/csrc/autograd/function.cpp",
230230
"torch/csrc/autograd/engine.cpp",
231+
"torch/csrc/autograd/function.cpp",
232+
"torch/csrc/autograd/variable.cpp",
233+
"torch/csrc/autograd/grad_buffer.cpp",
234+
"torch/csrc/autograd/python_function.cpp",
235+
"torch/csrc/autograd/python_cpp_function.cpp",
236+
"torch/csrc/autograd/python_variable.cpp",
237+
"torch/csrc/autograd/python_engine.cpp",
238+
"torch/csrc/autograd/functions/batch_normalization.cpp",
239+
"torch/csrc/autograd/functions/init.cpp",
231240
"torch/csrc/nn/THNN_generic.cpp",
232241
]
233242

test/test_autograd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def bw_hook(inc, grad):
7474
counter[0] += inc
7575

7676
z = x ** 2 + x * 2 + x * y + y
77+
x.register_hook(lambda *args: bw_hook(0, *args))
7778
test = z.register_hook(lambda *args: bw_hook(1, *args))
7879
z.backward(torch.ones(5, 5), retain_variables=True)
7980
self.assertEqual(counter[0], 1)

test/test_multiprocessing.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,8 @@ def autograd_sharing(queue, ready, master_modified):
7979
is_ok = var.data.equal(expected_var)
8080
var.data[:] = torch.ones(5, 5)
8181

82-
if var.grad is not None:
83-
is_ok &= var.grad.data.equal(torch.ones(5, 5) * 4)
84-
var.grad.data[:] = torch.ones(5, 5)
82+
is_ok &= var.grad.data.equal(torch.zeros(5, 5))
83+
var.grad.data[:] = torch.ones(5, 5)
8584

8685
queue.put(is_ok)
8786

@@ -357,20 +356,19 @@ def _test_autograd_sharing(self, var):
357356
queue = mp.Queue()
358357
p = mp.Process(target=autograd_sharing, args=(queue, ready, master_modified))
359358
p.start()
359+
var.grad.data.zero_()
360360
queue.put(var)
361361

362362
ready.wait()
363363
var.data[0, 0] = 1000
364-
if var.grad is not None:
365-
var.grad.data[:] = torch.ones(5, 5) * 4
364+
var.grad.data[:] = torch.ones(5, 5) * 4
366365
master_modified.set()
367366

368367
worker_ok = queue.get()
369368
self.assertTrue(worker_ok)
370369

371370
self.assertEqual(var.data, torch.ones(5, 5))
372-
if var.grad is not None:
373-
self.assertEqual(var.grad.data, torch.ones(5, 5))
371+
self.assertEqual(var.grad.data, torch.ones(5, 5) * 4)
374372
p.join()
375373

376374
def test_variable_sharing(self):

torch/autograd/function.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch._C as _C
33
import torch.utils.hooks as hooks
44
from collections import OrderedDict
5-
from itertools import chain
65

76

87
class Function(_C._FunctionBase):
@@ -98,9 +97,9 @@ def mark_non_differentiable(self, *args):
9897
**This should be called at most once, only from inside the**
9998
:func:`forward` **method, and all arguments should be outputs.**
10099
101-
This will mark outputs as non requiring gradient, increasing the
100+
This will mark outputs as not requiring gradients, increasing the
102101
efficiency of backward computation. You still need to accept a gradient
103-
for this output in :meth:`~Function.backward`, but it's always going to
102+
for each output in :meth:`~Function.backward`, but it's always going to
104103
be ``None``.
105104
106105
This is used e.g. for indices returned from a max :class:`Function`.

torch/autograd/variable.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -56,30 +56,6 @@ class Variable(_C._VariableBase):
5656
'is_cuda',
5757
}
5858

59-
@property
60-
def grad(self):
61-
if self.requires_grad and self._grad is None:
62-
# TODO: this won't have to be zeroed in the future
63-
self._grad = Variable(self.data.new(self.data.size()).zero_())
64-
return self._grad
65-
66-
@property
67-
def requires_grad(self):
68-
return self._requires_grad
69-
70-
@requires_grad.setter
71-
def requires_grad(self, value):
72-
if self.creator is not None:
73-
if value is False:
74-
hint = (" If you want to use a computed variable in a subgraph "
75-
"that doesn't require differentiation use "
76-
"var_no_grad = var.detach().")
77-
else:
78-
hint = ''
79-
raise RuntimeError("you can only change requires_grad flags of "
80-
"leaf variables." + hint)
81-
self._requires_grad = value
82-
8359
def __getattr__(self, name):
8460
if name in self._fallthrough_methods:
8561
return getattr(self.data, name)
@@ -108,19 +84,30 @@ def __deepcopy__(self, memo):
10884
if self.creator is not None:
10985
raise RuntimeError("Only Variables created explicitly by the user "
11086
"(graph leaves) support the deepcopy protocol at the moment")
111-
result = type(self)(self.data.clone(), requires_grad=self.requires_grad,
112-
volatile=self.volatile)
87+
result = type(self)(self.data.clone())
88+
result.requires_grad = self.requires_grad
89+
result.volatile = self.volatile
11390
memo[id(self)] = result
11491
return result
11592

11693
def __reduce_ex__(self, proto):
94+
state = (self.requires_grad, self.volatile, self._backward_hooks)
11795
if proto > 1:
118-
return super(Variable, self).__reduce_ex__(proto)
96+
return type(self), (self.data,), state
11997
if sys.version_info[0] == 2:
12098
from copy_reg import __newobj__
12199
else:
122100
from copyreg import __newobj__
123-
return __newobj__, (type(self),), self.__getstate__()
101+
return __newobj__, (type(self), self.data), state
102+
103+
def __setstate__(self, state):
104+
if len(state) == 5:
105+
# legacy serialization of Variable
106+
self.data = state[0]
107+
state = (state[3], state[4], state[2])
108+
if self.creator is not None:
109+
raise RuntimeError('__setstate__ can be only called on leaf variables')
110+
self.requires_grad, self.volatile, self._backward_hooks = state
124111

125112
def __repr__(self):
126113
return 'Variable containing:' + self.data.__repr__()

torch/csrc/DynamicTypes.cpp

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
#include "DynamicTypes.h"
2+
3+
#include "THP.h"
4+
#include <vector>
5+
#include <unordered_map>
6+
#include <THPP/tensors/THTensor.hpp>
7+
#include <THPP/tensors/THSTensor.hpp>
8+
9+
#ifdef WITH_CUDA
10+
#include <THC/THC.h>
11+
#include <THPP/tensors/THCTensor.hpp>
12+
extern THCState* state;
13+
#endif
14+
15+
16+
using namespace thpp;
17+
18+
namespace torch {
19+
20+
struct TensorType {
21+
Type data_type;
22+
bool is_cuda;
23+
bool is_sparse;
24+
25+
friend bool operator==(const TensorType &t1, const TensorType &t2)
26+
{
27+
return (t1.data_type == t2.data_type &&
28+
t1.is_cuda == t2.is_cuda &&
29+
t1.is_sparse == t2.is_sparse);
30+
}
31+
32+
friend bool operator!=(const TensorType &t1, const TensorType &t2)
33+
{
34+
return !(t1 == t2);
35+
}
36+
};
37+
38+
struct TensorTypeHasher
39+
{
40+
std::size_t operator()(const TensorType& k) const
41+
{
42+
size_t hash = static_cast<size_t>(k.data_type);
43+
hash = (hash << 8) + k.is_cuda;
44+
hash = (hash << 1) + k.is_sparse;
45+
return hash;
46+
}
47+
};
48+
49+
static std::unordered_map<std::string, Type> type_names = {
50+
{"Float", Type::FLOAT},
51+
{"Double", Type::DOUBLE},
52+
{"Half", Type::HALF},
53+
{"Byte", Type::UCHAR},
54+
{"Char", Type::CHAR},
55+
{"Short", Type::SHORT},
56+
{"Int", Type::INT},
57+
{"Long", Type::LONG},
58+
};
59+
static std::unordered_map<PyTypeObject*, TensorType> pytype_to_tensortype;
60+
static std::unordered_map<TensorType, PyTypeObject*, TensorTypeHasher> tensortype_to_pytype;
61+
62+
void registerPyTypeObject(PyTypeObject *pytype, const std::string& name, bool is_cuda, bool is_sparse)
63+
{
64+
TensorType type;
65+
type.data_type = type_names.at(name);
66+
type.is_cuda = is_cuda;
67+
type.is_sparse = is_sparse;
68+
69+
pytype_to_tensortype[pytype] = type;
70+
tensortype_to_pytype[type] = pytype;
71+
}
72+
73+
PyTypeObject* getPyTypeObject(const thpp::Tensor& tensor)
74+
{
75+
TensorType type;
76+
type.data_type = tensor.type();
77+
type.is_cuda = tensor.isCuda();
78+
type.is_sparse = tensor.isSparse();
79+
80+
return tensortype_to_pytype.at(type);
81+
}
82+
83+
static std::unique_ptr<Tensor> createTensor(void *tensor, Type type, bool is_cuda, bool is_sparse)
84+
{
85+
if (is_cuda) {
86+
#ifdef WITH_CUDA
87+
if (type == Type::UCHAR) {
88+
return std::unique_ptr<Tensor>(new THCTensor<unsigned char>(state, (THCudaByteTensor*)tensor));
89+
} else if (type == Type::CHAR) {
90+
return std::unique_ptr<Tensor>(new THCTensor<char>(state, (THCudaCharTensor*)tensor));
91+
} else if (type == Type::SHORT) {
92+
return std::unique_ptr<Tensor>(new THCTensor<short>(state, (THCudaShortTensor*)tensor));
93+
} else if (type == Type::INT) {
94+
return std::unique_ptr<Tensor>(new THCTensor<int>(state, (THCudaIntTensor*)tensor));
95+
} else if (type == Type::LONG) {
96+
return std::unique_ptr<Tensor>(new THCTensor<long>(state, (THCudaLongTensor*)tensor));
97+
} else if (type == Type::FLOAT) {
98+
return std::unique_ptr<Tensor>(new THCTensor<float>(state, (THCudaTensor*)tensor));
99+
} else if (type == Type::DOUBLE) {
100+
return std::unique_ptr<Tensor>(new THCTensor<double>(state, (THCudaDoubleTensor*)tensor));
101+
} else if (type == Type::HALF) {
102+
return std::unique_ptr<Tensor>(new THCTensor<half>(state, (THCudaHalfTensor*)tensor));
103+
}
104+
#else
105+
throw std::runtime_error("Compiled without CUDA support");
106+
#endif
107+
} else if (is_sparse) {
108+
if (type == Type::UCHAR) {
109+
return std::unique_ptr<Tensor>(new THSTensor<unsigned char>((THSByteTensor*)tensor));
110+
} else if (type == Type::CHAR) {
111+
return std::unique_ptr<Tensor>(new THSTensor<char>((THSCharTensor*)tensor));
112+
} else if (type == Type::SHORT) {
113+
return std::unique_ptr<Tensor>(new THSTensor<short>((THSShortTensor*)tensor));
114+
} else if (type == Type::INT) {
115+
return std::unique_ptr<Tensor>(new THSTensor<int>((THSIntTensor*)tensor));
116+
} else if (type == Type::LONG) {
117+
return std::unique_ptr<Tensor>(new THSTensor<long>((THSLongTensor*)tensor));
118+
} else if (type == Type::FLOAT) {
119+
return std::unique_ptr<Tensor>(new THSTensor<float>((THSFloatTensor*)tensor));
120+
} else if (type == Type::DOUBLE) {
121+
return std::unique_ptr<Tensor>(new THSTensor<double>((THSDoubleTensor*)tensor));
122+
}
123+
} else if (type == Type::UCHAR) {
124+
return std::unique_ptr<Tensor>(new THTensor<unsigned char>((THByteTensor*)tensor));
125+
} else if (type == Type::CHAR) {
126+
return std::unique_ptr<Tensor>(new THTensor<char>((THCharTensor*)tensor));
127+
} else if (type == Type::SHORT) {
128+
return std::unique_ptr<Tensor>(new THTensor<short>((THShortTensor*)tensor));
129+
} else if (type == Type::INT) {
130+
return std::unique_ptr<Tensor>(new THTensor<int>((THIntTensor*)tensor));
131+
} else if (type == Type::LONG) {
132+
return std::unique_ptr<Tensor>(new THTensor<long>((THLongTensor*)tensor));
133+
} else if (type == Type::FLOAT) {
134+
return std::unique_ptr<Tensor>(new THTensor<float>((THFloatTensor*)tensor));
135+
} else if (type == Type::DOUBLE) {
136+
return std::unique_ptr<Tensor>(new THTensor<double>((THDoubleTensor*)tensor));
137+
}
138+
throw std::invalid_argument("Unsupported tensor type");
139+
}
140+
141+
std::unique_ptr<Tensor> createTensor(PyObject *data)
142+
{
143+
auto tensor_type = pytype_to_tensortype.at(Py_TYPE(data));
144+
auto type = tensor_type.data_type;
145+
auto tensor = ((THPVoidTensor *)data)->cdata;
146+
auto wrapper = createTensor(tensor, type, tensor_type.is_cuda, tensor_type.is_sparse);
147+
wrapper->retain();
148+
return wrapper;
149+
}
150+
151+
PyObject* createPyObject(const thpp::Tensor& tensor)
152+
{
153+
auto type = getPyTypeObject(tensor);
154+
PyObject *obj = type->tp_alloc(type, 0);
155+
if (obj) {
156+
((THPVoidTensor*)obj)->cdata = (THVoidTensor *)const_cast<thpp::Tensor&>(tensor).retain().cdata();
157+
}
158+
return obj;
159+
}
160+
161+
} // namespace

torch/csrc/DynamicTypes.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
// Provides conversions between Python tensor objects and thpp::Tensors.
4+
5+
#include <memory>
6+
#include <Python.h>
7+
#include <THPP/THPP.h>
8+
9+
namespace torch {
10+
11+
// Register a PyTypeObject* with the given attributes
12+
void registerPyTypeObject(
13+
PyTypeObject *pytype, const std::string& name,
14+
bool is_cuda, bool is_sparse);
15+
16+
// Gets the PyTypeObject* corresponding to the Tensor
17+
PyTypeObject* getPyTypeObject(const thpp::Tensor& tensor);
18+
19+
// Creates a Tensor from a Python tensor object
20+
std::unique_ptr<thpp::Tensor> createTensor(PyObject *data);
21+
22+
// Creates Python tensor object from a Tensor
23+
PyObject* createPyObject(const thpp::Tensor& tensor);
24+
25+
} // namespace torch

torch/csrc/Exceptions.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#include <stdexcept>
66
#include <string>
77

8-
#include "THP.h"
9-
108
#define HANDLE_TH_ERRORS \
119
try {
1210

@@ -21,6 +19,11 @@
2119
extern PyObject *THPException_FatalError;
2220

2321
#ifdef _THP_CORE
22+
23+
// Throwing this exception means that the python error flags have been already
24+
// set and control should be immediately returned to the interpreter.
25+
class python_error : public std::exception {};
26+
2427
struct THException: public std::exception {
2528
THException(const char* msg): msg(msg) {};
2629

0 commit comments

Comments
 (0)