forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpy_rref.cpp
368 lines (324 loc) · 12.3 KB
/
py_rref.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/distributed/autograd/autograd.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
#include <torch/csrc/distributed/rpc/python_functions.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/pybind_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
///////////////////// Pickle/Unpickle Helplers ////////////////////////////
namespace {
py::tuple toPyTuple(const RRefForkData& rrefForkData) {
// add GIL as it is contructing a py::object
pybind11::gil_scoped_acquire ag;
return py::make_tuple(
rrefForkData.ownerId_,
rrefForkData.rrefId_.createdOn_,
rrefForkData.rrefId_.localId_,
rrefForkData.forkId_.createdOn_,
rrefForkData.forkId_.localId_,
rrefForkData.parent_,
rrefForkData.typeStr_);
}
RRefForkData fromPyTuple(const py::tuple& pyTuple) {
// add GIL as it is accessing a py::object
pybind11::gil_scoped_acquire ag;
TORCH_INTERNAL_ASSERT(
pyTuple.size() == RFD_TUPLE_SIZE,
"Pickled RRefForkData must contain ",
RFD_TUPLE_SIZE,
" numbers.");
worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>();
// const reference will extend the lifetime of the temporary variable
const RRefId& rrefId = RRefId(
pyTuple[RREFID_ON_IDX].cast<worker_id_t>(),
pyTuple[RREFID_ID_IDX].cast<local_id_t>());
const RRefId& forkId = RRefId(
pyTuple[FORKID_ON_IDX].cast<worker_id_t>(),
pyTuple[FORKID_ID_IDX].cast<local_id_t>());
worker_id_t parent = pyTuple[PARENT_IDX].cast<worker_id_t>();
const std::string& typeStr = pyTuple[TYPE_IDX].cast<std::string>();
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
}
TypePtr tryInferTypeWithTypeHint(
const py::object& value,
const py::object& type_hint) {
// If the py::object to be contained by the RRef is a ScriptModule, we enforce
// users to specify its ModuleInterface type.
if (auto module = jit::as_module(value)) {
TORCH_CHECK(
!type_hint.is_none(),
"The RRef being created contains a ScriptModule, "
"must provide its ModuleInterface type hint. ");
c10::QualifiedName type_qualified_name = c10::QualifiedName(
py::cast<std::string>(py::module::import("torch._jit_internal")
.attr("_qualified_name")(type_hint)));
TypePtr type_hint_ptr =
jit::get_python_cu()->get_interface(type_qualified_name);
std::ostringstream subtype_check_msg;
TORCH_CHECK(
type_hint_ptr != nullptr &&
module.value().type()->isSubtypeOfExt(
*type_hint_ptr, &subtype_check_msg),
module.value().type()->repr_str(),
" is not a subtype of the type hint: ",
type_qualified_name.qualifiedName(),
", did you pass a valid interface type?\n",
subtype_check_msg.str());
return type_hint_ptr;
} else {
TORCH_CHECK(
type_hint.is_none(),
"type_hint should only be specified when the RRef being created contains a ScriptModule.");
}
// Check if value is an instance of a ScriptClass. If not, skip type inference
// because it will try to script the class that value is in instance of, and
// this should be avoided.
py::bool_ can_compile = py::module::import("torch._jit_internal")
.attr("can_compile_class")(value.get_type());
if (py::cast<bool>(can_compile)) {
py::object existing_ty = py::module::import("torch.jit._state")
.attr("_get_script_class")(value.get_type());
if (existing_ty.is_none()) {
return PyObjectType::get();
}
}
// NB: `jit::tryToInferType(..)` infers types including ScriptClass, but
// excluding ScriptModule.
jit::InferredType type_inferred = jit::tryToInferType(value);
if (type_inferred.success()) {
// If we could infer the type from the pyobject, we create
// the RRef with the IValue of that type.
return type_inferred.type();
}
// Otherwise it's a pure pyobject, create the RRef
// that holds an IValue of an pyobject.
return PyObjectType::get();
}
} // namespace
/////////////////////////// PyRRef //////////////////////////////////
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref)
: rref_(std::move(rref)), profilingFuture_(c10::nullopt) {
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
C10_LOG_API_USAGE_ONCE("torch.distributed.rref");
}
PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
: PyRRef([&value, &type_hint]() mutable {
TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint);
auto rref = RRefContext::getInstance().createOwnerRRef(elem_type);
// jit::toIValue takes a py::handle as the first argument, and it calls
// py::handle.cast<py::object>() to incref of provided value. The
// returned ivalue will keep the reference alive.
// NB: the first argument const py::object& value must be kept alive
// until the following jit::toIValue returns (i.e., incref done). That's
// why this ctor can only be called while holding GIL.
IValue ivalue = jit::toIValue(value, elem_type);
rref->setValue(std::move(ivalue));
return rref;
}()) {}
PyRRef::~PyRRef() {
if (type_.has_value()) {
pybind11::gil_scoped_acquire ag;
(*type_).dec_ref();
// explicitly setting PyObject* to nullptr to prevent py::object's dtor to
// decref on the PyObject again.
// See Note [Destructing py::object] in python_ivalue.h
(*type_).ptr() = nullptr;
}
}
c10::intrusive_ptr<JitFuture> PyRRef::getFuture() const {
// Marking hasValue to false, as this Future is only used for signaling
// profiler to update profiling result and the profiler does not retrieve
// any value from it.
return toPyJitFuture(rref_->getOwnerCreationFuture(), false /* hasValue */);
}
c10::intrusive_ptr<JitFuture> PyRRef::getProfilingFuture() const {
TORCH_INTERNAL_ASSERT(profilingFuture_, "Profiling future has not been set!");
return *profilingFuture_;
}
void PyRRef::setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture) {
profilingFuture_ = std::move(profilingFuture);
}
bool PyRRef::isOwner() const {
return rref_->isOwner();
}
bool PyRRef::confirmedByOwner() const {
return rref_->confirmedByOwner();
}
WorkerInfo PyRRef::owner() const {
return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner());
}
std::string PyRRef::ownerName() const {
return rref_->ownerName();
}
py::object PyRRef::toHere(const float timeoutSeconds) const {
C10_LOG_API_USAGE_ONCE("torch.distributed.rref.to_here");
if (rref_->isOwner()) {
return localValue();
} else {
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
// a python object
IValue value = c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere(
timeoutSeconds);
if (rref_->isPyObj()) {
// python_rpc_handler deserialization will acquires GIL.
auto rfr_values = value.toTupleRef().elements().vec();
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
auto ret = pythonRpcHandler.deserialize(
SerializedPyObj::fromIValues(std::move(rfr_values)));
pythonRpcHandler.handleException(ret);
return ret;
} else {
// acquiring GIL as torch::jit::toPyObject creates new py::object
// without grabbing the GIL.
pybind11::gil_scoped_acquire ag;
return torch::jit::toPyObject(std::move(value));
}
}
}
py::object PyRRef::localValue() const {
TORCH_CHECK(
rref_->isOwner(),
"For ",
*rref_,
", can't call localValue() on user ",
RRefContext::getInstance().agent()->getWorkerInfo(),
". Call it on owner ",
owner());
py::object res;
auto value =
c10::static_intrusive_pointer_cast<const OwnerRRef>(rref_)->getValue();
auto& rpcHandler = PythonRpcHandler::getInstance();
{
// acquiring GIL as torch::jit::toPyObject creates new py::object without
// grabbing the GIL.
pybind11::gil_scoped_acquire ag;
res = torch::jit::toPyObject(std::move(value));
rpcHandler.handleExceptionGILHeld(res);
}
return res;
}
std::string PyRRef::str() const {
if (rref_->isOwner()) {
return c10::str("OwnerRRef(", rref_->rrefId(), ")");
} else {
return c10::str(
"UserRRef(RRefId = ",
rref_->rrefId(),
", ForkId = ",
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId(),
")");
}
}
py::object PyRRef::createRRefProxy(
const RRefProxyType& type,
float timeoutSeconds) const {
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
pybind11::gil_scoped_acquire ag;
auto& functions = pythonRpcHandler.getRRefProxyFunctions();
auto& ctor = functions.rrefProxyCtor_;
switch (type) {
case RRefProxyType::RPC_SYNC: {
return ctor(*this, functions.rpcSync_, timeoutSeconds);
}
case RRefProxyType::RPC_ASYNC: {
return ctor(*this, functions.rpcAsync_, timeoutSeconds);
}
case RRefProxyType::REMOTE: {
return ctor(*this, functions.remote_, timeoutSeconds);
}
default: {
TORCH_INTERNAL_ASSERT(false, "Unrecognized RRefProxy type ", type);
}
}
}
py::object PyRRef::getRRefType(float timeout, bool blocking) {
// GIL is not released when calling this function.
if (!type_.has_value()) {
pybind11::gil_scoped_release release;
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
auto& typeFuncs = pythonRpcHandler.getRRefTypeFunctions();
pybind11::gil_scoped_acquire acquire;
type_ = isOwner() ? typeFuncs.onOwner_(*this, blocking)
: typeFuncs.onUser_(*this, timeout, blocking);
}
// Returns py::object that can be Python type or future.
return *type_;
}
py::tuple PyRRef::pickle() const {
auto& ctx = RRefContext::getInstance();
auto rrefForkData = ctx.prepareChildFork(rref_);
return toPyTuple(rrefForkData);
}
PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
auto& ctx = RRefContext::getInstance();
auto rrefForkData = fromPyTuple(pyTuple);
TypePtr rrefType =
PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
ctx.notifyOwnerAndParentOfFork(
rrefForkData.forkId_, rrefForkData.parent_, rref);
return PyRRef(std::move(rref));
}
c10::IValue PyRRef::toIValue() const {
// cast to RRefInterface to hold it into IValue
auto rrefPtr = c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref_);
return IValue(rrefPtr);
}
void PyRRef::backward(int64_t autogradContextId, bool retainGraph) {
backward(autogradContextId, retainGraph, rref_);
}
void PyRRef::backwardOwnerRRef(
int64_t autogradContextId,
bool retainGraph,
IValue value) {
// If we have a PyObj, retrieve the underlying tensor.
if (value.isPyObject()) {
py::gil_scoped_acquire gil;
py::object obj = torch::jit::toPyObject(value);
try {
value = torch::jit::toIValue(obj, c10::TensorType::get());
} catch (py::cast_error& e) {
TORCH_CHECK(false, "RRef should contain a tensor for .backward()");
}
}
TORCH_CHECK(value.isTensor(), "RRef should contain a tensor for .backward()");
auto root = value.toTensor();
if (autogradContextId == -1) {
torch::autograd::backward({root});
} else {
torch::distributed::autograd::backward(
autogradContextId, {root}, retainGraph);
}
}
void PyRRef::backward(
int64_t autogradContextId,
bool retainGraph,
const c10::intrusive_ptr<RRef>& rref) {
if (rref->isOwner()) {
backwardOwnerRRef(
autogradContextId,
retainGraph,
c10::static_intrusive_pointer_cast<const OwnerRRef>(rref)->getValue());
} else {
TORCH_CHECK(
autogradContextId != -1,
"User RRefs require 'dist_autograd_ctx_id' to be specified");
autograd::RRefBackwardReq rrefBackwardReq(
rref->rrefId(), autogradContextId, retainGraph);
// Invoke distributed backward remotely.
auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
rpcAgent
->send(
rpcAgent->getWorkerInfo(rref->owner()),
std::move(rrefBackwardReq).toMessage())
->waitAndThrow();
}
}
} // namespace rpc
} // namespace distributed
} // namespace torch