forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrequest_callback_no_python.cpp
629 lines (561 loc) · 24.5 KB
/
request_callback_no_python.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
#include <c10/core/StreamGuard.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/distributed/rpc/utils.h>
namespace torch {
namespace distributed {
namespace rpc {
using namespace torch::distributed::autograd;
using namespace torch::autograd::profiler;
// When request message has autograd info, processMessage() will set up valid
// current context id properly. This struct is used to clean up current context
// id after processMessage() is done.
struct DistAutogradContextGuard {
explicit DistAutogradContextGuard(int64_t ctxId) {
auto& container = DistAutogradContainer::getInstance();
prevCtxId_ = container.currentContextId();
container.forceCurrentContextId(ctxId);
}
~DistAutogradContextGuard() {
auto& container = DistAutogradContainer::getInstance();
container.forceCurrentContextId(prevCtxId_);
}
int64_t prevCtxId_;
};
std::unique_ptr<RpcCommandBase> RequestCallbackNoPython::
deserializePythonRpcCommand(
std::unique_ptr<RpcCommandBase> rpc,
const MessageType& messageType) const {
TORCH_CHECK(
messageType != MessageType::PYTHON_CALL &&
messageType != MessageType::PYTHON_REMOTE_CALL,
"Python calls are not supported!");
return rpc;
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processMessage(
Message& request,
std::vector<c10::Stream> streams) const {
// We need two futures here because it could pause twice when processing a
// RPC message:
// 1) waiting for all RRefs in the arguments to become confirmed;
// 2) waiting for processRpc to finish.
auto& rrefContext = RRefContext::getInstance();
try {
rrefContext.recordThreadLocalPendingRRefs();
// Deserialize PythonUDF here to trigger RRef unpickling
std::unique_ptr<RpcCommandBase> rpc = deserializePythonRpcCommand(
deserializeRequest(request), request.type());
auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs();
auto retFuture = rrefsReadyFuture->thenAsync(
[this,
// std::function must be copyable, hence hae to cast the unique_ptr to
// a shared_ptr here.
rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc),
messageType = request.type(),
streams = std::move(streams)](JitFuture& /* unused */) mutable {
// The cost of pre-request check is minimal thanks to
// std::shared_lock. The cost is in magnitude
// of 10us.
auto serverProcessGlobalProfilerStateStackEntryPtr =
profiler::processglobal::StateStackEntry::current();
// If server global profiler is enabled, we further pay the
// cost of thread local profiler state initialization.
if (serverProcessGlobalProfilerStateStackEntryPtr) {
// Initialize thread-local profiler state from process-global
// profiler state.
enableProfilerLegacy(
serverProcessGlobalProfilerStateStackEntryPtr->statePtr()
->config());
}
auto retFuture =
processRpcWithErrors(*rpc, messageType, std::move(streams));
// Response message has been sent at this moment, this post-response
// work doesn't affect RPC trip time.
if (serverProcessGlobalProfilerStateStackEntryPtr) {
// Restore thread-local profiler state.
thread_event_lists event_lists = disableProfilerLegacy();
// Put thread_local event_lists into the process-global profiler
// state.
profiler::processglobal::pushResultRecursive(
serverProcessGlobalProfilerStateStackEntryPtr, event_lists);
}
return retFuture;
},
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
auto retFutureWithMessageId = retFuture->then(
[id = request.id()](JitFuture& future) {
c10::intrusive_ptr<Message> message =
future.value().toCustomClass<Message>();
message->setId(id);
return withStorages(message);
},
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
return retFutureWithMessageId;
} catch (std::exception& e) {
rrefContext.clearRecordedPendingRRefsOnError();
return asFuture(handleError(e, request.type(), request.id()));
}
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpcWithErrors(
RpcCommandBase& rpc,
const MessageType& messageType,
std::vector<c10::Stream> streams) const {
try {
return processRpc(rpc, messageType, std::move(streams));
} catch (std::exception& e) {
// Pass a dummy message ID since it will be overwritten anyways.
return asFuture(handleError(e, messageType, -1));
}
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptCall(
RpcCommandBase& rpc,
std::vector<c10::Stream> streams) const {
auto& scriptCall = static_cast<ScriptCall&>(rpc);
TORCH_CHECK(
scriptCall.hasOp(), "Only supports the case where ScriptCall has an op");
auto future = runJitOperator(
*scriptCall.op(), scriptCall.stackRef(), std::move(streams));
return future->then(
[](JitFuture& future) {
return withStorages(ScriptResp(future.value()).toMessage());
},
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonCall(
RpcCommandBase& rpc,
std::vector<c10::Stream> /* unused */) const {
C10_THROW_ERROR(Error, "Python call not supported!");
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonRemoteCall(
RpcCommandBase& rpc,
std::vector<c10::Stream> /* unused */) const {
C10_THROW_ERROR(Error, "Python call not supported!");
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::assignOwnerRRef(
const RRefId& rrefId,
const RRefId& forkId,
c10::intrusive_ptr<JitFuture> valueFuture) const {
auto& ctx = RRefContext::getInstance();
c10::intrusive_ptr<OwnerRRef> ownerRRef;
if (rrefId == forkId) {
// Creating an owner RRef on self, should already exist in owners map
ownerRRef =
fromRRefInterface(ctx.getOwnerRRef(rrefId, /* forceCreated */ true)
->constValue()
.toRRef());
} else {
ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, valueFuture->elementType());
// Caller is a user and callee is the owner, add fork
//
// NB: rrefId == forkId is true if and only if calling remote to self.
// In that case both the caller and the callee will access the
// OwnerRRef. Hence, on the callee side (here), it should not call
// addForkOfOwner as it is not a fork. To allow callee to distinguish
// when this request is sent to self, the caller will set forkId using
// rrefId (OwnerRRef does not have a forkId anyway).
ctx.addForkOfOwner(rrefId, forkId);
}
return valueFuture->then(
[ownerRRef, rrefId, forkId](JitFuture& future) {
if (future.hasError()) {
ownerRRef->setError(future.exception_ptr());
} else {
ownerRRef->setValue(future.value());
}
return withStorages(RemoteRet(rrefId, forkId).toMessage());
},
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptRemoteCall(
RpcCommandBase& rpc,
std::vector<c10::Stream> streams) const {
auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
TORCH_CHECK(
scriptRemoteCall.hasOp(), "ScriptRemoteCall needs to have an op!");
auto future = runJitOperator(
*scriptRemoteCall.op(), scriptRemoteCall.stackRef(), std::move(streams));
return assignOwnerRRef(
scriptRemoteCall.retRRefId(),
scriptRemoteCall.retForkId(),
std::move(future));
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::retrieveOwnerRRef(
const RRefId& rrefId) const {
auto& ctx = RRefContext::getInstance();
auto rrefFuture = ctx.getOwnerRRef(rrefId);
at::TypePtr type = rrefFuture->elementType();
TORCH_INTERNAL_ASSERT(type->kind() == at::RRefType::Kind);
return rrefFuture->thenAsync(
[](JitFuture& rrefFuture) {
c10::intrusive_ptr<OwnerRRef> rref =
fromRRefInterface(rrefFuture.value().toRRef());
return rref->getFuture();
},
type->cast<at::RRefType>()->getElementType());
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processScriptRRefFetchCall(RpcCommandBase& rpc) const {
auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
auto future = retrieveOwnerRRef(srf.rrefId());
return future->then(
[](JitFuture& future) {
return withStorages(ScriptRRefFetchRet({future.value()}).toMessage());
},
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processPythonRRefFetchCall(RpcCommandBase& rpc) const {
C10_THROW_ERROR(Error, "Python call not supported!");
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefUserDelete(
RpcCommandBase& rpc) const {
auto& rud = static_cast<RRefUserDelete&>(rpc);
auto& ctx = RRefContext::getInstance();
auto deletedRRef = ctx.delForkOfOwner(rud.rrefId(), rud.forkId());
handleRRefDelete(deletedRRef);
return asFuture(RRefAck().toMessage());
}
void RequestCallbackNoPython::handleRRefDelete(
c10::intrusive_ptr<RRef>& rref) const {
TORCH_CHECK(!rref->isPyObj(), "RRefs with python objects not supported!");
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefChildAccept(
RpcCommandBase& rpc) const {
auto& rca = static_cast<RRefChildAccept&>(rpc);
auto& ctx = RRefContext::getInstance();
ctx.delPendingChild(rca.forkId());
return asFuture(RRefAck().toMessage());
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefForkRequest(
RpcCommandBase& rpc) const {
auto& rfr = static_cast<RRefForkRequest&>(rpc);
auto& ctx = RRefContext::getInstance();
ctx.addForkOfOwnerIfNotPresent(rfr.rrefId(), rfr.forkId());
return asFuture(RRefAck().toMessage());
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processForwardAutogradReq(
RpcCommandBase& rpc,
std::vector<c10::Stream> streams) const {
auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
// Need to reverse the device map for the backward pass of distributed
// autograd.
DeviceMap reverseDeviceMap;
for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
}
// Attach 'recv' autograd function.
auto autogradContext = addRecvRpcBackward(
rpcWithAutograd.autogradMetadata(),
rpcWithAutograd.tensors(),
rpcWithAutograd.fromWorkerId(),
reverseDeviceMap);
// For this recv thread on server side, before processRpc(),
// set current_context_id_ to be context_id passed from client.
// In this way, if there is nested rpc call in python rpc call, original
// context_id from client can be passed in the chain calls.
TORCH_INTERNAL_ASSERT(
autogradContext != nullptr,
"autogradContext is nullptr, FORWARD_AUTOGRAD_REQ should always get "
"or create valid autogradContext in addRecvRpcBackward.");
DistAutogradContextGuard ctxGuard(autogradContext->contextId());
// Process the original RPC.
auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
// Kick off processing for the nested RPC command.
// wrappedRpcResponseFuture will be a Future<T> to the result.
auto wrappedRpcResponseFuture = processRpc(
rpcWithAutograd.wrappedRpc(), wrappedMessageType, std::move(streams));
auto fromWorkerId = rpcWithAutograd.fromWorkerId();
// The original future needs to be marked as completed when the wrapped
// one completes, with the autograd context information wrapped.
auto responseFuture = wrappedRpcResponseFuture->then(
[fromWorkerId, ctxId = autogradContext->contextId()](
JitFuture& wrappedRpcResponseFuture) {
// As this callback can be invoked by a different thread, we have to
// make sure that the thread_local states in the previous thread is
// correctly propagated.
// NB: The execution of TorchScript functions can also run on a
// different thread, which is addressed by
// https://github.com/pytorch/pytorch/pull/36395
// NB: when adding async UDF support, we should also propagate
// thread_local states there.
// TODO: Land on a general solution for RPC ThreadLocalState. See
// https://github.com/pytorch/pytorch/issues/38510
DistAutogradContextGuard cbCtxGuard(ctxId);
if (wrappedRpcResponseFuture.hasError()) {
// Propagate error to responseFuture if we had one.
std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr());
} else {
auto msg = getMessageWithAutograd(
fromWorkerId,
wrappedRpcResponseFuture.value().toCustomClass<Message>(),
MessageType::FORWARD_AUTOGRAD_RESP);
return withStorages(std::move(msg));
}
},
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
return responseFuture;
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processBackwardAutogradReq(
RpcCommandBase& rpc,
std::vector<c10::Stream> streams) const {
c10::MultiStreamGuard guard(streams);
auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
const auto& autogradMetadata = gradientsCall.getAutogradMetadata();
// Retrieve the appropriate autograd context.
auto autogradContext = DistAutogradContainer::getInstance().retrieveContext(
autogradMetadata.autogradContextId);
// Lookup the appropriate 'send' function to enqueue.
std::shared_ptr<SendRpcBackward> sendFunction =
autogradContext->retrieveSendFunction(autogradMetadata.autogradMessageId);
// Attach the gradients to the send function.
sendFunction->setGrads(gradientsCall.getGrads());
// Now execute the autograd graph using the "distributed engine."
auto execFuture = DistEngine::getInstance().executeSendFunctionAsync(
autogradContext, sendFunction, gradientsCall.retainGraph());
// Our response is satisfied when the rpcs come back.
return execFuture->then(
[](JitFuture& execFuture) {
if (execFuture.hasError()) {
std::rethrow_exception(execFuture.exception_ptr());
} else {
return withStorages(PropagateGradientsResp().toMessage());
}
},
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processCleanupAutogradContextReq(RpcCommandBase& rpc) const {
auto& cleanupContextReq = static_cast<CleanupAutogradContextReq&>(rpc);
auto cleanupContextId = cleanupContextReq.getContextId();
// release the context if it still exists on this thread. We need to
// check if it exists since it may have been deleted by an in-flight
// RPC. This can create nested RPCs if there are other nodes that get
// notified to clean up their context.
DistAutogradContainer::getInstance().releaseContextIfPresent(
cleanupContextId);
return asFuture(CleanupAutogradContextResp().toMessage());
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processRunWithProfilingReq(RpcCommandBase& rpc) const {
auto& rpcWithProfilingReq = static_cast<RpcWithProfilingReq&>(rpc);
auto wrappedMsgType = rpcWithProfilingReq.wrappedMessageType();
auto profilingConfig = rpcWithProfilingReq.getProfilingConfig();
if (profilingConfig.state == ProfilerState::KINETO ||
profilingConfig.state == ProfilerState::KINETO_GPU_FALLBACK ||
profilingConfig.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
profilingConfig = ProfilerConfig(
ProfilerState::CPU,
profilingConfig.report_input_shapes,
profilingConfig.profile_memory);
}
// If requested with CUDA from caller but CUDA is not available on this
// machine, fallback to CPU and log a warning instead of crashing.
if (profilingConfig.state == ProfilerState::CUDA && !this->cudaAvailable()) {
profilingConfig = ProfilerConfig(
ProfilerState::CPU,
profilingConfig.report_input_shapes,
profilingConfig.profile_memory);
LOG(WARNING) << "Profiler was requested to be enabled with CUDA on this "
"node, but CUDA is not available. "
<< "Falling back to CPU profiling only.";
}
TORCH_INTERNAL_ASSERT(
profilingConfig.state != ProfilerState::CUDA || this->cudaAvailable(),
"Profiler state set to CUDA but CUDA not available.");
const auto profilingKeyId = rpcWithProfilingReq.getProfilingId();
// Enable the profiler with the config from the sender.
// When enabling on the main thread, ensure profiler states are cleaned
// up, but defer consolidation of all profiled events to the continuation
// below.
ProfilerDisableOptions requestThreadOptions(
true /* cleanup TLS state */, false /* consolidate events */);
{
TLSLegacyProfilerGuard g(
profilingConfig, c10::nullopt, requestThreadOptions);
TORCH_INTERNAL_ASSERT(
profilerEnabled(), "Expected profiler to be enabled!");
// Kick off processing for nested work and get Future<T> result in
// wrappedRpcResponseFuture
auto wrappedRpcResponseFuture = processRpc(
rpcWithProfilingReq.wrappedRpc(),
wrappedMsgType,
{}); // TODO: https://github.com/pytorch/pytorch/issues/55757
auto responseFuture = wrappedRpcResponseFuture->then(
at::wrapPropagateTLSState([profilingKeyId, profilingConfig](
JitFuture& wrappedRpcResponseFuture) {
std::vector<LegacyEvent> profiledEvents;
// Defer consolidation of profiler events until async work has
// completed (such as async UDF)
TORCH_INTERNAL_ASSERT(
profilerEnabled(), "Expected profiler to be enabled!");
// On continuation thread, don't clean up profiler states, since
// they will be cleaned up by main thread, and consolidate all
// events so we obtain asynchronously run events.
ProfilerDisableOptions opts(false, true);
auto event_lists = disableProfilerLegacy(opts);
if (wrappedRpcResponseFuture.hasError()) {
// Propagate error
// No need to propagate remote events in the case of an error.
std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr());
} else {
populateRemoteProfiledEvents(
profiledEvents, profilingConfig, event_lists);
auto rpcWithProfilingResp = std::make_unique<RpcWithProfilingResp>(
MessageType::RUN_WITH_PROFILING_RESP,
wrappedRpcResponseFuture.value().toCustomClass<Message>(),
profiledEvents,
profilingKeyId);
return withStorages(std::move(*rpcWithProfilingResp).toMessage());
}
}),
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
return responseFuture;
// Exiting the scope will disable the profiler on this thread with the
// options specified above.
}
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefBackward(
RpcCommandBase& rpc) const {
C10_THROW_ERROR(Error, "Python call not supported!");
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpc(
RpcCommandBase& rpc,
const MessageType& messageType,
std::vector<c10::Stream> streams) const {
// TODO: RpcCommandBase should have an abstract execute() method that we can
// call here instead of having another switch statement here. Even better we
// could have abstract classes RpcRequest and RpcResp which inherit from
// RpcCommandBase and RpcRequest declares the abstract method execute() that
// we can call here. RpcResponse could have an abstract method to convert it
// to a python object.
switch (messageType) {
case MessageType::SCRIPT_CALL: {
return processScriptCall(rpc, std::move(streams));
}
case MessageType::PYTHON_CALL: {
return processPythonCall(rpc, std::move(streams));
}
case MessageType::SCRIPT_REMOTE_CALL: {
return processScriptRemoteCall(rpc, std::move(streams));
}
case MessageType::PYTHON_REMOTE_CALL: {
return processPythonRemoteCall(rpc, std::move(streams));
}
case MessageType::SCRIPT_RREF_FETCH_CALL: {
return processScriptRRefFetchCall(rpc);
}
case MessageType::PYTHON_RREF_FETCH_CALL: {
return processPythonRRefFetchCall(rpc);
}
case MessageType::RREF_USER_DELETE: {
return processRRefUserDelete(rpc);
}
case MessageType::RREF_CHILD_ACCEPT: {
return processRRefChildAccept(rpc);
}
case MessageType::RREF_FORK_REQUEST: {
return processRRefForkRequest(rpc);
}
case MessageType::FORWARD_AUTOGRAD_REQ: {
return processForwardAutogradReq(rpc, std::move(streams));
}
case MessageType::BACKWARD_AUTOGRAD_REQ: {
return processBackwardAutogradReq(rpc, std::move(streams));
};
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
return processCleanupAutogradContextReq(rpc);
}
case MessageType::RUN_WITH_PROFILING_REQ: {
return processRunWithProfilingReq(rpc);
}
case MessageType::RREF_BACKWARD_REQ: {
return processRRefBackward(rpc);
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Request type ", messageType, " not supported.");
}
}
}
c10::intrusive_ptr<Message> RequestCallbackNoPython::handleError(
const std::exception& e,
const MessageType messageType,
int64_t messageId) const {
LOG(ERROR) << "Received error while processing request type " << messageType
<< ": " << e.what();
// Adding node information to the error here since all processed RPC
// requests should be going through this function.
std::string errorMsg = c10::str(
"Error on Node ",
DistAutogradContainer::getInstance().getWorkerId(),
": ",
e.what());
return createExceptionResponse(errorMsg, messageId);
}
bool RequestCallbackNoPython::cudaAvailable() const {
#ifdef USE_CUDA
return true;
#else
return false;
#endif
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::runJitOperator(
const jit::Operator& op,
std::vector<at::IValue>& stack,
std::vector<c10::Stream> streams) const {
c10::MultiStreamGuard guard(streams);
try {
op.getOperation()(stack);
} catch (const std::exception&) {
return asFuture(std::current_exception());
}
TORCH_INTERNAL_ASSERT(
stack.size() == 1,
"Return value of a builtin operator or a TorchScript function should be "
"a single IValue, got a vector of size ",
stack.size());
TypePtr type = stack.front().type();
return asFuture(std::move(stack.front()), std::move(type));
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
IValue value,
TypePtr type) const {
auto future = c10::make_intrusive<JitFuture>(
std::move(type), RpcAgent::getCurrentRpcAgent()->getDevices());
future->markCompleted(std::move(value));
return future;
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
c10::intrusive_ptr<Message> message) const {
auto future = c10::make_intrusive<JitFuture>(
at::getCustomClassType<c10::intrusive_ptr<Message>>(),
RpcAgent::getCurrentRpcAgent()->getDevices());
std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages =
message->getStorages();
future->markCompleted(std::move(message), std::move(storages));
return future;
}
c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
std::exception_ptr err) const {
auto future = c10::make_intrusive<JitFuture>(
at::NoneType::get(), RpcAgent::getCurrentRpcAgent()->getDevices());
future->setError(err);
return future;
}
} // namespace rpc
} // namespace distributed
} // namespace torch