forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpython_call.cpp
49 lines (42 loc) · 1.65 KB
/
python_call.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <torch/csrc/distributed/rpc/python_call.h>
namespace torch {
namespace distributed {
namespace rpc {
PythonCall::PythonCall(SerializedPyObj&& serializedPyObj, bool isAsyncExecution)
: serializedPyObj_(std::move(serializedPyObj)),
isAsyncExecution_(isAsyncExecution) {}
c10::intrusive_ptr<Message> PythonCall::toMessageImpl() && {
std::vector<char> payload;
payload.reserve(serializedPyObj_.payload_.length() + 1);
payload.push_back(isAsyncExecution_ ? 1 : 0);
payload.insert(
payload.end(),
serializedPyObj_.payload_.begin(),
serializedPyObj_.payload_.end());
return c10::make_intrusive<Message>(
std::move(payload),
std::move(serializedPyObj_.tensors_),
MessageType::PYTHON_CALL);
}
std::unique_ptr<PythonCall> PythonCall::fromMessage(const Message& message) {
TORCH_INTERNAL_ASSERT(
!message.payload().empty(),
"Failed to convert an RPC message to PythonCall, the payload should at "
"least contain one byte indicating whether this is an async function, "
"but got payload of size ",
message.payload().size());
const char& c = message.payload()[0];
TORCH_INTERNAL_ASSERT(c == 0 || c == 1);
bool isAsyncExecution = (c == 1);
std::string payload(message.payload().begin() + 1, message.payload().end());
std::vector<Tensor> tensors = message.tensors();
SerializedPyObj serializedPyObj(std::move(payload), std::move(tensors));
return std::make_unique<PythonCall>(
std::move(serializedPyObj), isAsyncExecution);
}
const SerializedPyObj& PythonCall::serializedPyObj() const {
return serializedPyObj_;
}
} // namespace rpc
} // namespace distributed
} // namespace torch