forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpython_comm.cpp
108 lines (104 loc) · 3.66 KB
/
python_comm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include <ATen/core/functional.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/cuda/Stream.h>
#include <torch/csrc/cuda/THCP.h>
#include <torch/csrc/cuda/comm.h>
#include <torch/csrc/utils/pybind.h>
#include <ATen/ATen.h>
#include <cstddef>
#include <vector>
#include <torch/csrc/profiler/unwind/unwind.h>
namespace torch::cuda::python {
void initCommMethods(PyObject* module) {
auto m = py::cast<py::module>(module);
m.def(
"_broadcast_coalesced",
[](std::vector<at::Tensor>& tensors,
const std::vector<int64_t>& devices,
size_t buffer_size) {
return broadcast_coalesced(tensors, devices, buffer_size);
},
py::arg("tensors"),
py::arg("devices"),
py::arg("buffer_size"),
py::call_guard<py::gil_scoped_release>())
.def(
"_broadcast",
[](at::Tensor& tensor, std::vector<int64_t> devices) {
return broadcast(tensor, devices);
},
py::call_guard<py::gil_scoped_release>(),
py::arg("tensor"),
py::arg("devices"))
.def(
"_broadcast_out",
[](at::Tensor& tensor, std::vector<at::Tensor>& out_tensors) {
return broadcast_out(tensor, out_tensors);
},
py::call_guard<py::gil_scoped_release>(),
py::arg("tensor"),
py::arg("out"))
.def(
"_scatter",
[](at::Tensor& tensor,
std::vector<int64_t>& devices,
std::optional<std::vector<int64_t>> chunk_sizes,
int64_t dim,
std::optional<py::object> py_streams) {
std::optional<std::vector<std::optional<at::cuda::CUDAStream>>>
streams;
if (py_streams) {
py::handle handle = *py_streams;
streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr());
}
// Note: We're holding the GIL up to here.
pybind11::gil_scoped_release no_gil;
return scatter(tensor, devices, chunk_sizes, dim, streams);
},
py::arg("tensor"),
py::arg("devices"),
py::arg("chunk_sizes"),
py::arg("dim"),
py::arg("streams"))
.def(
"_scatter_out",
[](at::Tensor& tensor,
std::vector<at::Tensor>& out_tensors,
int64_t dim,
std::optional<py::object> py_streams) {
std::optional<std::vector<std::optional<at::cuda::CUDAStream>>>
streams;
if (py_streams) {
py::handle handle = *py_streams;
streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr());
}
// Note: We're holding the GIL up to here.
pybind11::gil_scoped_release no_gil;
return scatter_out(tensor, out_tensors, dim, streams);
},
py::arg("tensor"),
py::arg("out"),
py::arg("dim"),
py::arg("streams"))
.def(
"_gather",
[](std::vector<at::Tensor>& tensors,
int64_t dim,
std::optional<int32_t> destination_index) {
return gather(tensors, dim, destination_index);
},
py::arg("tensors"),
py::arg("dim"),
py::arg("destination_index"),
py::call_guard<py::gil_scoped_release>())
.def(
"_gather_out",
[](std::vector<at::Tensor>& tensors,
at::Tensor& out_tensor,
int64_t dim) { return gather_out(tensors, out_tensor, dim); },
py::arg("tensors"),
py::arg("out"),
py::arg("dim"),
py::call_guard<py::gil_scoped_release>());
}
} // namespace torch::cuda::python