26
26
#include < atomic>
27
27
#include < chrono>
28
28
#include < iostream>
29
+ #include < map>
29
30
#include < mutex>
31
+ #include < set>
30
32
#include < unordered_map>
31
33
32
34
#include " absl/strings/str_cat.h"
@@ -91,21 +93,14 @@ class MeshServiceImpl : public grpc::MeshService::Service {
91
93
private:
92
94
class RendezvousData {
93
95
public:
94
- explicit RendezvousData (size_t count)
95
- : mwait_(count), release_count_(0 ), payloads_(count) {}
96
+ explicit RendezvousData (size_t count, const std::set<int64>& replicas)
97
+ : count_(count),
98
+ replicas_(replicas),
99
+ mwait_(count),
100
+ release_count_(0 ) {}
96
101
97
102
bool Release () { return release_count_.fetch_add (1 ) == 0 ; }
98
103
99
- void SetPayload (size_t ordinal, std::string payload) {
100
- std::lock_guard<std::mutex> lock (lock_);
101
- if (ordinal >= payloads_.size ()) {
102
- status_ = ::grpc::Status (::grpc::StatusCode::INVALID_ARGUMENT,
103
- absl::StrCat (" Invalid ordinal: " , ordinal));
104
- } else {
105
- payloads_[ordinal] = std::move (payload);
106
- }
107
- }
108
-
109
104
::grpc::Status Wait () {
110
105
::grpc::Status status =
111
106
ToGrpcStatus (xla::util::CheckedCall ([&]() { mwait_.Wait (); }));
@@ -116,25 +111,50 @@ class MeshServiceImpl : public grpc::MeshService::Service {
116
111
return status;
117
112
}
118
113
119
- void Done () { mwait_.Done (); }
114
+ void Complete (int64 ordinal, std::string payload,
115
+ const std::set<int64>& replicas) {
116
+ std::lock_guard<std::mutex> lock (lock_);
117
+ if ((!replicas_.empty () && replicas_.count (ordinal) == 0 ) ||
118
+ (replicas_.empty () && ordinal >= count_)) {
119
+ status_ = ::grpc::Status (::grpc::StatusCode::INVALID_ARGUMENT,
120
+ absl::StrCat (" Invalid ordinal: " , ordinal));
121
+ } else if (replicas != replicas_) {
122
+ status_ = ::grpc::Status (
123
+ ::grpc::StatusCode::INVALID_ARGUMENT,
124
+ absl::StrCat (" Mismatching replicas: (" ,
125
+ absl::StrJoin (replicas_, " , " ), " ) vs. (" ,
126
+ absl::StrJoin (replicas, " , " ), " )" ));
127
+ } else {
128
+ auto insert_result = payloads_.emplace (ordinal, std::move (payload));
129
+ if (!insert_result.second ) {
130
+ status_ =
131
+ ::grpc::Status (::grpc::StatusCode::INVALID_ARGUMENT,
132
+ absl::StrCat (" Duplicate ordinal: " , ordinal));
133
+ }
134
+ }
135
+ mwait_.Done();
136
+ }
120
137
121
- const std::vector< std::string>& Payloads () const { return payloads_; };
138
+ const std::map<int64, std::string>& Payloads () const { return payloads_; };
122
139
123
140
private:
141
+ size_t count_;
142
+ std::set<int64> replicas_;
124
143
std::mutex lock_;
125
144
util::MultiWait mwait_;
126
145
std::atomic<size_t > release_count_;
127
- std::vector< std::string> payloads_;
146
+ std::map<int64, std::string> payloads_;
128
147
::grpc::Status status_;
129
148
};
130
149
131
- std::shared_ptr<RendezvousData> GetRendezvous (const std::string& tag) {
150
+ std::shared_ptr<RendezvousData> GetRendezvous (
151
+ const std::string& tag, const std::set<int64>& replicas) {
132
152
std::lock_guard<std::mutex> lock (lock_);
133
153
auto it = rendezvous_map_.find (tag);
134
154
if (it == rendezvous_map_.end ()) {
155
+ size_t count = replicas.empty () ? config_.mesh_size () : replicas.size ();
135
156
it = rendezvous_map_
136
- .emplace (tag,
137
- std::make_shared<RendezvousData>(config_.mesh_size ()))
157
+ .emplace (tag, std::make_shared<RendezvousData>(count, replicas))
138
158
.first ;
139
159
}
140
160
return it->second ;
@@ -165,18 +185,19 @@ ::grpc::Status MeshServiceImpl::GetConfig(::grpc::ServerContext* context,
165
185
::grpc::Status MeshServiceImpl::Rendezvous (
166
186
::grpc::ServerContext* context, const grpc::RendezvousRequest* request,
167
187
grpc::RendezvousResponse* response) {
168
- auto rendezvous = GetRendezvous (request->tag ());
169
- rendezvous->SetPayload (request->ordinal (), request->payload ());
170
- rendezvous->Done ();
188
+ std::set<int64> replicas (request->replicas ().begin (),
189
+ request->replicas ().end ());
190
+ auto rendezvous = GetRendezvous (request->tag (), replicas);
191
+ rendezvous->Complete (request->ordinal (), request->payload (), replicas);
171
192
TF_VLOG (3 ) << " Entering rendezvous: ordinal=" << request->ordinal ()
172
- << " tag=" << request->tag () << " peer=" << context->peer ();
193
+ << " , tag=" << request->tag () << " , peer=" << context->peer ();
173
194
::grpc::Status status = rendezvous->Wait ();
174
195
TF_VLOG (3 ) << " Exiting rendezvous: ordinal=" << request->ordinal ()
175
- << " tag=" << request->tag () << " peer=" << context->peer ()
176
- << " status=" << status;
196
+ << " , tag=" << request->tag () << " , peer=" << context->peer ()
197
+ << " , status=" << status;
177
198
if (status.ok ()) {
178
- for (auto & payload : rendezvous->Payloads ()) {
179
- response->add_payloads (payload );
199
+ for (auto & ordinal_payload : rendezvous->Payloads ()) {
200
+ response->add_payloads (ordinal_payload. second );
180
201
}
181
202
}
182
203
ReleaseRendezvous (request->tag (), rendezvous);
@@ -267,13 +288,17 @@ grpc::Config MeshClient::GetConfig() const {
267
288
}
268
289
269
290
std::vector<std::string> MeshClient::Rendezvous (
270
- int ordinal, const std::string& tag, const std::string& payload) const {
291
+ int ordinal, const std::string& tag, const std::string& payload,
292
+ absl::Span<const int64> replicas) const {
271
293
::grpc::ClientContext context;
272
294
grpc::RendezvousRequest request;
273
295
grpc::RendezvousResponse response;
274
296
request.set_tag (tag);
275
297
request.set_payload (payload);
276
298
request.set_ordinal (ordinal);
299
+ for (auto & replica : replicas) {
300
+ request.add_replicas (replica);
301
+ }
277
302
TF_VLOG (3 ) << " Waiting for rendezvous: ordinal=" << ordinal << " tag=" << tag;
278
303
::grpc::Status status = impl_->stub ->Rendezvous (&context, request, &response);
279
304
TF_VLOG (3 ) << " Rendezvous wait complete: " << tag;
@@ -290,16 +315,16 @@ std::vector<std::string> MeshClient::Rendezvous(
290
315
std::string MeshClient::GetNcclUniqueUid (
291
316
absl::Span<const int64> replicas) const {
292
317
::grpc::ClientContext context;
293
- grpc::GetNcclUniqueUidRequest reqeust ;
318
+ grpc::GetNcclUniqueUidRequest request ;
294
319
grpc::GetNcclUniqueUidResponse response;
295
320
for (auto & replica : replicas) {
296
- reqeust .add_replicas (replica);
321
+ request .add_replicas (replica);
297
322
}
298
323
299
324
TF_VLOG (3 ) << " Waiting for NCCL UID: replicas=("
300
325
<< absl::StrJoin (replicas, " , " ) << " )" ;
301
326
::grpc::Status status =
302
- impl_->stub ->GetNcclUniqueUid (&context, reqeust , &response);
327
+ impl_->stub ->GetNcclUniqueUid (&context, request , &response);
303
328
TF_VLOG (3 ) << " NCCL UID wait complete: " << absl::StrJoin (replicas, " , " )
304
329
<< " )" ;
305
330
if (!status.ok ()) {
0 commit comments