Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 0f09ac9

Browse files
authored
Pull latest C++ x10 changes (#966)
1 parent 6e2507d commit 0f09ac9

16 files changed

+588
-61
lines changed

Sources/x10/xla_client/mesh_service.cc

+55-30
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
#include <atomic>
2727
#include <chrono>
2828
#include <iostream>
29+
#include <map>
2930
#include <mutex>
31+
#include <set>
3032
#include <unordered_map>
3133

3234
#include "absl/strings/str_cat.h"
@@ -91,21 +93,14 @@ class MeshServiceImpl : public grpc::MeshService::Service {
9193
private:
9294
class RendezvousData {
9395
public:
94-
explicit RendezvousData(size_t count)
95-
: mwait_(count), release_count_(0), payloads_(count) {}
96+
explicit RendezvousData(size_t count, const std::set<int64>& replicas)
97+
: count_(count),
98+
replicas_(replicas),
99+
mwait_(count),
100+
release_count_(0) {}
96101

97102
bool Release() { return release_count_.fetch_add(1) == 0; }
98103

99-
void SetPayload(size_t ordinal, std::string payload) {
100-
std::lock_guard<std::mutex> lock(lock_);
101-
if (ordinal >= payloads_.size()) {
102-
status_ = ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT,
103-
absl::StrCat("Invalid ordinal: ", ordinal));
104-
} else {
105-
payloads_[ordinal] = std::move(payload);
106-
}
107-
}
108-
109104
::grpc::Status Wait() {
110105
::grpc::Status status =
111106
ToGrpcStatus(xla::util::CheckedCall([&]() { mwait_.Wait(); }));
@@ -116,25 +111,50 @@ class MeshServiceImpl : public grpc::MeshService::Service {
116111
return status;
117112
}
118113

119-
void Done() { mwait_.Done(); }
114+
void Complete(int64 ordinal, std::string payload,
115+
const std::set<int64>& replicas) {
116+
std::lock_guard<std::mutex> lock(lock_);
117+
if ((!replicas_.empty() && replicas_.count(ordinal) == 0) ||
118+
(replicas_.empty() && ordinal >= count_)) {
119+
status_ = ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT,
120+
absl::StrCat("Invalid ordinal: ", ordinal));
121+
} else if (replicas != replicas_) {
122+
status_ = ::grpc::Status(
123+
::grpc::StatusCode::INVALID_ARGUMENT,
124+
absl::StrCat("Mismatching replicas: (",
125+
absl::StrJoin(replicas_, ", "), ") vs. (",
126+
absl::StrJoin(replicas, ", "), ")"));
127+
} else {
128+
auto insert_result = payloads_.emplace(ordinal, std::move(payload));
129+
if (!insert_result.second) {
130+
status_ =
131+
::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT,
132+
absl::StrCat("Duplicate ordinal: ", ordinal));
133+
}
134+
}
135+
mwait_.Done();
136+
}
120137

121-
const std::vector<std::string>& Payloads() const { return payloads_; };
138+
const std::map<int64, std::string>& Payloads() const { return payloads_; };
122139

123140
private:
141+
size_t count_;
142+
std::set<int64> replicas_;
124143
std::mutex lock_;
125144
util::MultiWait mwait_;
126145
std::atomic<size_t> release_count_;
127-
std::vector<std::string> payloads_;
146+
std::map<int64, std::string> payloads_;
128147
::grpc::Status status_;
129148
};
130149

131-
std::shared_ptr<RendezvousData> GetRendezvous(const std::string& tag) {
150+
std::shared_ptr<RendezvousData> GetRendezvous(
151+
const std::string& tag, const std::set<int64>& replicas) {
132152
std::lock_guard<std::mutex> lock(lock_);
133153
auto it = rendezvous_map_.find(tag);
134154
if (it == rendezvous_map_.end()) {
155+
size_t count = replicas.empty() ? config_.mesh_size() : replicas.size();
135156
it = rendezvous_map_
136-
.emplace(tag,
137-
std::make_shared<RendezvousData>(config_.mesh_size()))
157+
.emplace(tag, std::make_shared<RendezvousData>(count, replicas))
138158
.first;
139159
}
140160
return it->second;
@@ -165,18 +185,19 @@ ::grpc::Status MeshServiceImpl::GetConfig(::grpc::ServerContext* context,
165185
::grpc::Status MeshServiceImpl::Rendezvous(
166186
::grpc::ServerContext* context, const grpc::RendezvousRequest* request,
167187
grpc::RendezvousResponse* response) {
168-
auto rendezvous = GetRendezvous(request->tag());
169-
rendezvous->SetPayload(request->ordinal(), request->payload());
170-
rendezvous->Done();
188+
std::set<int64> replicas(request->replicas().begin(),
189+
request->replicas().end());
190+
auto rendezvous = GetRendezvous(request->tag(), replicas);
191+
rendezvous->Complete(request->ordinal(), request->payload(), replicas);
171192
TF_VLOG(3) << "Entering rendezvous: ordinal=" << request->ordinal()
172-
<< " tag=" << request->tag() << " peer=" << context->peer();
193+
<< ", tag=" << request->tag() << ", peer=" << context->peer();
173194
::grpc::Status status = rendezvous->Wait();
174195
TF_VLOG(3) << "Exiting rendezvous: ordinal=" << request->ordinal()
175-
<< " tag=" << request->tag() << " peer=" << context->peer()
176-
<< " status=" << status;
196+
<< ", tag=" << request->tag() << ", peer=" << context->peer()
197+
<< ", status=" << status;
177198
if (status.ok()) {
178-
for (auto& payload : rendezvous->Payloads()) {
179-
response->add_payloads(payload);
199+
for (auto& ordinal_payload : rendezvous->Payloads()) {
200+
response->add_payloads(ordinal_payload.second);
180201
}
181202
}
182203
ReleaseRendezvous(request->tag(), rendezvous);
@@ -267,13 +288,17 @@ grpc::Config MeshClient::GetConfig() const {
267288
}
268289

269290
std::vector<std::string> MeshClient::Rendezvous(
270-
int ordinal, const std::string& tag, const std::string& payload) const {
291+
int ordinal, const std::string& tag, const std::string& payload,
292+
absl::Span<const int64> replicas) const {
271293
::grpc::ClientContext context;
272294
grpc::RendezvousRequest request;
273295
grpc::RendezvousResponse response;
274296
request.set_tag(tag);
275297
request.set_payload(payload);
276298
request.set_ordinal(ordinal);
299+
for (auto& replica : replicas) {
300+
request.add_replicas(replica);
301+
}
277302
TF_VLOG(3) << "Waiting for rendezvous: ordinal=" << ordinal << " tag=" << tag;
278303
::grpc::Status status = impl_->stub->Rendezvous(&context, request, &response);
279304
TF_VLOG(3) << "Rendezvous wait complete: " << tag;
@@ -290,16 +315,16 @@ std::vector<std::string> MeshClient::Rendezvous(
290315
std::string MeshClient::GetNcclUniqueUid(
291316
absl::Span<const int64> replicas) const {
292317
::grpc::ClientContext context;
293-
grpc::GetNcclUniqueUidRequest reqeust;
318+
grpc::GetNcclUniqueUidRequest request;
294319
grpc::GetNcclUniqueUidResponse response;
295320
for (auto& replica : replicas) {
296-
reqeust.add_replicas(replica);
321+
request.add_replicas(replica);
297322
}
298323

299324
TF_VLOG(3) << "Waiting for NCCL UID: replicas=("
300325
<< absl::StrJoin(replicas, ", ") << ")";
301326
::grpc::Status status =
302-
impl_->stub->GetNcclUniqueUid(&context, reqeust, &response);
327+
impl_->stub->GetNcclUniqueUid(&context, request, &response);
303328
TF_VLOG(3) << "NCCL UID wait complete: " << absl::StrJoin(replicas, ", ")
304329
<< ")";
305330
if (!status.ok()) {

Sources/x10/xla_client/mesh_service.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ class MeshClient {
5151
grpc::Config GetConfig() const;
5252

5353
std::vector<std::string> Rendezvous(int ordinal, const std::string& tag,
54-
const std::string& payload) const;
54+
const std::string& payload,
55+
absl::Span<const int64> replicas) const;
5556

5657
std::string GetNcclUniqueUid(absl::Span<const int64> replicas) const;
5758

Sources/x10/xla_client/mesh_service.proto

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,15 @@ message RendezvousRequest {
4646
required string tag = 1;
4747
required bytes payload = 2;
4848
required uint32 ordinal = 3;
49+
repeated uint32 replicas = 4;
4950
}
5051

5152
message RendezvousResponse {
5253
repeated bytes payloads = 1;
5354
}
5455

5556
message GetNcclUniqueUidRequest {
56-
repeated int64 replicas = 1;
57+
repeated uint32 replicas = 1;
5758
}
5859

5960
message GetNcclUniqueUidResponse {

Sources/x10/xla_tensor/aten_compat.h

+1
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@
774774
_(xla, generic_slice) \
775775
_(xla, get_dimensions_size) \
776776
_(xla, moving_average) \
777+
_(xla, nms) \
777778
_(xla, not_supported) \
778779
_(xla, replication_pad) \
779780
_(xla, replication_pad_backward) \

Sources/x10/xla_tensor/cross_replica_reduces.cpp

+10-22
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "tensorflow/compiler/tf2xla/xla_tensor/convert_ops.h"
2323
#include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
2424
#include "tensorflow/compiler/tf2xla/xla_tensor/layout_manager.h"
25+
#include "tensorflow/compiler/tf2xla/xla_tensor/token_handler.h"
2526
#include "tensorflow/compiler/xla/shape_util.h"
2627

2728
namespace swift_xla {
@@ -94,14 +95,6 @@ std::vector<xla::ReplicaGroup> CreateReduceGroups(
9495
return reduce_groups;
9596
}
9697

97-
xla::XlaOp SliceOneToken(xla::XlaOp input) {
98-
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
99-
if (input_shape.rank() == 0) {
100-
return input;
101-
}
102-
return xla::SliceInDim(input, 0, 1, 1, 0);
103-
}
104-
10598
} // namespace
10699

107100
std::vector<xla::XlaOp> BuildAllReduce(
@@ -151,32 +144,27 @@ AllToAllResult BuildAllToAll(
151144
const std::vector<std::vector<xla::int64>>& groups) {
152145
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
153146
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
154-
xla::XlaOp affine_token = MaybeConvertTo(token, input_shape.element_type());
155147
// TODO: This is missing layout pinning ATM. If XLA scheduling is not exactly
156148
// the same (graphs on cores differ), XLA could assign different layouts and
157149
// things will break.
158-
xla::XlaOp reduce_result =
159-
xla::AllToAll(input + affine_token, split_dimension, concat_dimension,
160-
split_count, reduce_groups);
161-
xla::XlaOp chained_token =
162-
MaybeConvertTo(affine_token * SliceOneToken(reduce_result),
163-
XlaHelpers::TypeOfXlaOp(token));
164-
return {reduce_result, chained_token};
150+
TokenHandler token_handler(token);
151+
xla::XlaOp reduce_result = xla::AllToAll(
152+
token_handler.GetInput(input, &input_shape), split_dimension,
153+
concat_dimension, split_count, reduce_groups);
154+
return {reduce_result, token_handler.GetNewToken(reduce_result)};
165155
}
166156

167157
CollectivePermuteResult BuildCollectivePermute(
168158
xla::XlaOp input, xla::XlaOp token,
169159
const std::vector<std::pair<xla::int64, xla::int64>>& source_target_pairs) {
170160
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
171-
xla::XlaOp affine_token = MaybeConvertTo(token, input_shape.element_type());
161+
TokenHandler token_handler(token);
172162
// TODO: This is missing layout pinning ATM. If XLA scheduling is not exactly
173163
// the same (graphs on cores differ), XLA could assign different layouts and
174164
// things will break.
175-
xla::XlaOp result =
176-
xla::CollectivePermute(input + affine_token, source_target_pairs);
177-
xla::XlaOp chained_token = MaybeConvertTo(
178-
affine_token * SliceOneToken(result), XlaHelpers::TypeOfXlaOp(token));
179-
return {result, chained_token};
165+
xla::XlaOp result = xla::CollectivePermute(
166+
token_handler.GetInput(input, &input_shape), source_target_pairs);
167+
return {result, token_handler.GetNewToken(result)};
180168
}
181169

182170
} // namespace swift_xla

0 commit comments

Comments
 (0)