Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit d91e91e

Browse files
authored
Make local backend default for GPU. (#1109)
1 parent 2042ed3 commit d91e91e

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

Sources/x10/xla_client/local_device.cc

+9-5
Original file line numberDiff line numberDiff line change
@@ -539,16 +539,20 @@ std::unique_ptr<ComputationClient::Device> MakeLocalDeviceFromClient(
539539
std::vector<std::unique_ptr<ComputationClient::Device>>
540540
GetAllLocalDevicesForPlatform(const char* platform_name,
541541
const char* device_prefix) {
542+
auto platform = xla::PlatformUtil::GetPlatform(platform_name);
543+
if (!platform.ok()) return {};
542544
xla::LocalClientOptions options;
543-
options.set_platform(
544-
xla::PlatformUtil::GetPlatform(platform_name).ValueOrDie());
545-
xla::LocalClient* client =
546-
xla::ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie();
545+
options.set_platform(platform.ValueOrDie());
546+
auto local_client_statusor =
547+
xla::ClientLibrary::GetOrCreateLocalClient(options);
548+
if (!local_client_statusor.ok()) return {};
549+
xla::LocalClient* client = local_client_statusor.ValueOrDie();
547550
std::vector<std::unique_ptr<ComputationClient::Device>> devices;
548551
devices.reserve(client->device_count());
549552
for (int i = 0; i < client->device_count(); ++i) {
550553
devices.push_back(MakeLocalDeviceFromClient(
551-
absl::StrCat(device_prefix, ":", i), client, i, i, true));
554+
absl::StrCat(device_prefix, ":", i), client, i, i,
555+
std::string(platform_name) == "cpu"));
552556
}
553557
return devices;
554558
}

Sources/x10/xla_client/xrt_computation_client.cc

+7-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "tensorflow/compiler/xla/xla_client/util.h"
3636
#include "tensorflow/compiler/xla/xla_client/xla_util.h"
3737
#include "tensorflow/compiler/xla/xla_client/xrt_local_service.h"
38+
#include "tensorflow/compiler/xla/xla_client/local_device.h"
3839
#include "tensorflow/compiler/xrt/xrt_util.h"
3940
#include "tensorflow/cc/ops/const_op.h"
4041
#include "tensorflow/compiler/xla/shape_util.h"
@@ -534,7 +535,7 @@ bool ParseEnvDeviceCounts(XrtComputationClient::Options* options) {
534535
}
535536

536537
bool ParseEnvDevices(XrtComputationClient::Options* options) {
537-
std::string device = GpuIsAvailable() ? "GPU" : "CPU";
538+
std::string device = "CPU";
538539
std::string default_device_spec = absl::StrFormat(
539540
"%s:0;/job:localservice/replica:0/task:0/device:XLA_%s:0", device,
540541
device);
@@ -629,6 +630,11 @@ XrtComputationClient::XrtComputationClient(
629630
for (const auto& dev_target : options_.global_device_map) {
630631
AddDevice(std::make_unique<XrtDevice>(dev_target.first, this));
631632
}
633+
634+
for (auto& device : GetAllLocalDevicesForPlatform("gpu", "GPU")) {
635+
options_.default_device = "GPU:0";
636+
AddDevice(std::move(device));
637+
}
632638
}
633639

634640
std::vector<size_t> XrtComputationClient::PartitionTransferToServer(

Sources/x10/xla_client/xrt_local_service.cc

+2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ XrtLocalService::XrtLocalService(const std::string& cluster_spec,
6767
const std::string& job_name, int task_index) {
6868
tensorflow::ServerDef server_def;
6969
FillServerDef(cluster_spec, job_name, task_index, &server_def);
70+
(*server_def.mutable_default_session_config()
71+
->mutable_device_count())["GPU"] = 0;
7072
TF_CHECK_OK(tensorflow::NewServer(server_def, &server_));
7173
}
7274

0 commit comments

Comments
 (0)