forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcomm.hpp
141 lines (116 loc) · 4.31 KB
/
comm.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <utility>
namespace c10d {
// Broadcast many tensors to all processes in the process group.
TORCH_API void broadcast_coalesced(
const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
at::TensorList tensors,
size_t buffer_size,
int rank = 0);
// This class passes bucket contents tensor to DDP communication hook.
class TORCH_API GradBucket {
public:
explicit GradBucket(
size_t index,
size_t bucket_count,
at::Tensor tensor,
std::vector<size_t> offsets,
std::vector<size_t> lengths,
std::vector<c10::IntArrayRef> sizes_vec,
std::vector<at::Tensor> parameters,
std::optional<at::Tensor> sparse_grad_indices)
: index_(index),
bucket_count_(bucket_count),
buffer_(std::move(tensor)),
offsets_(std::move(offsets)),
lengths_(std::move(lengths)),
sizes_vec_(std::move(sizes_vec)),
parameters_(std::move(parameters)),
sparse_grad_indices_(std::move(sparse_grad_indices)) {}
// Returns the index of the bucket, which is unique across all the buckets.
size_t getIndex() const {
return index_;
}
const at::Tensor& getBuffer() const {
return buffer_;
}
// Returns a mutable buffer compared with the above method.
at::Tensor& getBufferRef() {
return buffer_;
}
// Overwrites the buffer at a specific index.
void setBuffer(at::Tensor& buffer) {
buffer_ = buffer;
}
// Each tensor in the list that getGradients corresponds to a
// parameter.
std::vector<at::Tensor> getGradients() const;
// Returns model parameters belonging to this bucket. They are returned in the
// same order as gradient tensors via getGradients(). For example,
// getParameters[i] will have its gradient stored in
// getGradients[i]
const std::vector<at::Tensor> getParameters() const {
return parameters_;
}
// Returns whther this bucket is the last bucket to allreduce in an iteration.
bool isLast() const {
return index_ == bucket_count_ - 1;
}
std::optional<at::Tensor>& getSparseGradIndices() {
return sparse_grad_indices_;
}
private:
size_t index_;
size_t bucket_count_;
at::Tensor buffer_;
// Per-variable info in buffer_.
std::vector<size_t> offsets_;
std::vector<size_t> lengths_;
std::vector<c10::IntArrayRef> sizes_vec_;
// Model parameters for this bucket.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::vector<at::Tensor> parameters_;
// Predefined sparse indices for this bucket (only used for sparse tensors).
// The gradients will be updated to have indices with these tensor values
std::optional<at::Tensor> sparse_grad_indices_;
};
// Base class of both `PythonCommHook` and `CppCommHook`.
// Requires implementing 1) `runHook` method that communicates gradients
// asynchronously, and 2) `parseHookResult` method that converts the hook
// result into a tensor.
class TORCH_API CommHookInterface {
public:
virtual ~CommHookInterface() = default;
// Passes the input grad bucket to the registered communication hook.
// Once the tensor in the bucket are ready, kicks off the hook asynchronously
// and returns a future that holds the communication results.
virtual c10::intrusive_ptr<c10::ivalue::Future> runHook(
GradBucket& bucket) = 0;
// Returns the resulting tensor once the communication hook result is
// ready. The resulting tensor will then be copied to the grads of
// individual parameters.
virtual at::Tensor parseHookResult(const c10::IValue& result) = 0;
};
namespace detail {
// This helper function is called both by CppCommHookInterface below and inside
// reducer.
TORCH_API at::Tensor parseCppCommHookResult(const c10::IValue& result);
} // namespace detail
// This CppCommHook interface only requires implementing runHook method that
// potentially uses a state.
template <typename T>
class CppCommHookInterface : public CommHookInterface {
public:
explicit CppCommHookInterface(T state) : state_(std::move(state)) {}
~CppCommHookInterface() override = default;
at::Tensor parseHookResult(const c10::IValue& result) override {
return detail::parseCppCommHookResult(result);
}
protected:
T state_;
};
} // namespace c10d