This repository was archived by the owner on Jul 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 137
/
Copy pathxrt_local_service.cc
77 lines (67 loc) · 3.03 KB
/
xrt_local_service.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
// Copyright 2020 TensorFlow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tensorflow/compiler/xla/xla_client/xrt_local_service.h"
#include <vector>
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace xla {
namespace {
void FillServerDef(const std::string& cluster_spec, const std::string& job_name,
int task_index, tensorflow::ServerDef* options) {
options->set_protocol("grpc");
options->set_job_name(job_name);
options->set_task_index(task_index);
size_t my_num_tasks = 0;
tensorflow::ClusterDef* cluster = options->mutable_cluster();
for (auto& job_str : absl::StrSplit(cluster_spec, ',')) {
tensorflow::JobDef* job_def = cluster->add_job();
// Split each entry in the flag into 2 pieces, separated by "|".
std::vector<std::string> job_pieces = absl::StrSplit(job_str, '|');
XLA_CHECK_EQ(2, job_pieces.size()) << job_str;
const std::string& cjob_name = job_pieces[0];
const std::string& spec = job_pieces[1];
job_def->set_name(cjob_name);
std::vector<std::string> host_ports = absl::StrSplit(spec, ';');
for (size_t i = 0; i < host_ports.size(); ++i) {
(*job_def->mutable_tasks())[i] = host_ports[i];
}
size_t num_tasks = host_ports.size();
if (job_name == options->job_name()) {
my_num_tasks = num_tasks;
}
LOG(INFO) << "Peer " << cjob_name << " " << num_tasks << " {"
<< absl::StrJoin(host_ports, ", ") << "}";
}
XLA_CHECK_NE(my_num_tasks, 0) << "Job '" << options->job_name()
<< "' does not appear in the cluster spec";
XLA_CHECK_LT(options->task_index(), my_num_tasks)
<< "Task index " << options->task_index() << " is invalid (job '"
<< options->job_name() << "' contains " << my_num_tasks << " tasks";
}
} // namespace
XrtLocalService::XrtLocalService(const std::string& cluster_spec,
const std::string& job_name, int task_index) {
tensorflow::ServerDef server_def;
FillServerDef(cluster_spec, job_name, task_index, &server_def);
(*server_def.mutable_default_session_config()
->mutable_device_count())["GPU"] = 0;
TF_CHECK_OK(tensorflow::NewServer(server_def, &server_));
}
void XrtLocalService::Start() { TF_CHECK_OK(server_->Start()); }
} // namespace xla