Skip to content

Commit 9f2b92b

Browse files
Add functions that return device type and ID for eager.
This addition enables more efficient device handling in S4TF without needing to parse the full device string. As support for devices beyond TF eager are added, this info is needed more often and has a bigger impact on performance. Partial fix for tensorflow/swift#524. PiperOrigin-RevId: 337696655 Change-Id: Ifb576d37c765cced2329b77e0cebb591d8d3a46c
1 parent 0f9acc1 commit 9f2b92b

7 files changed

+215
-0
lines changed

tensorflow/c/eager/c_api_experimental.cc

+16
Original file line numberDiff line numberDiff line change
@@ -638,3 +638,19 @@ void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
638638
TF_Status* status) {
639639
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
640640
}
641+
642+
const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
643+
if (h == nullptr) {
644+
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
645+
return nullptr;
646+
}
647+
return tensorflow::unwrap(h)->DeviceType(&status->status);
648+
}
649+
650+
int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
651+
if (h == nullptr) {
652+
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
653+
return -1;
654+
}
655+
return tensorflow::unwrap(h)->DeviceId(&status->status);
656+
}

tensorflow/c/eager/c_api_experimental.h

+8
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,14 @@ TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
553553
unsigned char enable,
554554
TF_Status* status);
555555

556+
// Returns the device type of the operation that produced `h`.
557+
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
558+
TFE_TensorHandle* h, TF_Status* status);
559+
560+
// Returns the device ID of the operation that produced `h`.
561+
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
562+
TF_Status* status);
563+
556564
#ifdef __cplusplus
557565
} /* end extern "C" */
558566
#endif

tensorflow/c/eager/c_api_experimental_test.cc

+104
Original file line numberDiff line numberDiff line change
@@ -411,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
411411
TF_DeleteStatus(status);
412412
}
413413

414+
TEST(CAPI, TensorHandleNullptr) {
415+
TFE_TensorHandle* h = nullptr;
416+
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
417+
TF_NewStatus(), TF_DeleteStatus);
418+
419+
const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
420+
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
421+
ASSERT_EQ(device_type, nullptr);
422+
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
423+
424+
TF_SetStatus(status.get(), TF_OK, "");
425+
426+
int device_id = TFE_TensorHandleDeviceID(h, status.get());
427+
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
428+
ASSERT_EQ(device_id, -1);
429+
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
430+
}
431+
432+
TEST(CAPI, TensorHandleDevices) {
433+
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
434+
TF_NewStatus(), TF_DeleteStatus);
435+
TFE_ContextOptions* opts = TFE_NewContextOptions();
436+
TFE_Context* ctx = TFE_NewContext(opts, status.get());
437+
TFE_DeleteContextOptions(opts);
438+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
439+
440+
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
441+
const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
442+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
443+
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
444+
int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
445+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
446+
ASSERT_EQ(0, device_id) << device_id;
447+
448+
// Disable the test if no GPU is present.
449+
string gpu_device_name;
450+
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
451+
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
452+
hcpu, ctx, gpu_device_name.c_str(), status.get());
453+
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
454+
455+
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
456+
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
457+
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
458+
TFE_TensorHandle* retvals[1];
459+
int num_retvals = 1;
460+
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
461+
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
462+
463+
device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
464+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
465+
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
466+
467+
device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
468+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
469+
ASSERT_EQ(0, device_id) << device_id;
470+
471+
TFE_DeleteOp(shape_op);
472+
TFE_DeleteTensorHandle(retvals[0]);
473+
TFE_DeleteTensorHandle(hgpu);
474+
}
475+
476+
TFE_DeleteTensorHandle(hcpu);
477+
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
478+
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
479+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
480+
TFE_DeleteExecutor(executor);
481+
TFE_DeleteContext(ctx);
482+
}
483+
484+
TEST(CAPI, TensorHandleDefaults) {
485+
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
486+
TF_NewStatus(), TF_DeleteStatus);
487+
TFE_ContextOptions* opts = TFE_NewContextOptions();
488+
TFE_Context* ctx = TFE_NewContext(opts, status.get());
489+
TFE_DeleteContextOptions(opts);
490+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
491+
492+
TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
493+
const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
494+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
495+
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
496+
int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
497+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
498+
ASSERT_EQ(0, device_id) << device_id;
499+
500+
TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
501+
h_default, ctx, "/device:CPU:0", status.get());
502+
const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
503+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
504+
ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
505+
int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
506+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
507+
ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
508+
509+
TFE_DeleteTensorHandle(h_default);
510+
TFE_DeleteTensorHandle(h_cpu);
511+
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
512+
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
513+
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
514+
TFE_DeleteExecutor(executor);
515+
TFE_DeleteContext(ctx);
516+
}
517+
414518
} // namespace
415519
} // namespace tensorflow

tensorflow/c/eager/immediate_execution_tensor_handle.h

+4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
4444
virtual const char* DeviceName(Status* status) const = 0;
4545
// Returns the device where the tensor was placed.
4646
virtual const char* BackingDeviceName(Status* status) const = 0;
47+
// Returns the device type which created the handle.
48+
virtual const char* DeviceType(Status* status) const = 0;
49+
// Returns the device ID which created the handle.
50+
virtual int DeviceId(Status* status) const = 0;
4751
// Returns a tensor for the handle. If tensor is remote, it will be copied.
4852
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
4953

tensorflow/core/common_runtime/eager/tensor_handle.cc

+22
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,28 @@ const char* TensorHandle::BackingDeviceName(Status* status) const {
11161116
}
11171117
}
11181118

1119+
const char* TensorHandle::DeviceType(Status* status) const {
1120+
if (VariantDeviceIsCustom(device())) {
1121+
status->Update(
1122+
tensorflow::errors::Unimplemented("Custom device unsupported"));
1123+
return nullptr;
1124+
}
1125+
status->Update(WaitUnknownDevice());
1126+
tensorflow::Device* d = op_device();
1127+
return (d == nullptr) ? "CPU" : d->parsed_name().type.c_str();
1128+
}
1129+
1130+
int TensorHandle::DeviceId(Status* status) const {
1131+
if (VariantDeviceIsCustom(device())) {
1132+
status->Update(
1133+
tensorflow::errors::Unimplemented("Custom device unsupported"));
1134+
return -1;
1135+
}
1136+
status->Update(WaitUnknownDevice());
1137+
tensorflow::Device* d = op_device();
1138+
return (d == nullptr) ? 0 : d->parsed_name().id;
1139+
}
1140+
11191141
tensorflow::ImmediateExecutionTensorHandle* TensorHandle::Copy() {
11201142
Ref();
11211143
return this;

tensorflow/core/common_runtime/eager/tensor_handle.h

+2
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
131131

132132
const char* DeviceName(Status* status) const override;
133133
const char* BackingDeviceName(Status* status) const override;
134+
const char* DeviceType(Status* status) const override;
135+
int DeviceId(Status* status) const override;
134136
AbstractTensorInterface* Resolve(Status* status) override;
135137

136138
ImmediateExecutionTensorHandle* Copy() override;

tensorflow/core/common_runtime/eager/tensor_handle_test.cc

+59
Original file line numberDiff line numberDiff line change
@@ -408,4 +408,63 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) {
408408
context->Unref();
409409
}
410410

411+
TEST(TensorHandle_DeviceNameTest, OnLocalDevice) {
412+
std::vector<std::unique_ptr<Device>> devices;
413+
devices.emplace_back(
414+
CreateDevice("CPU", "/job:localhost/replica:0/task:0/device:CPU:0"));
415+
devices.emplace_back(
416+
CreateDevice("GPU", "/job:localhost/replica:0/task:0/device:GPU:0"));
417+
StaticDeviceMgr local_device_mgr(std::move(devices));
418+
auto ctx = new EagerContext(
419+
SessionOptions(),
420+
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
421+
false, &local_device_mgr, false, nullptr, nullptr);
422+
423+
Device* dcpu = local_device_mgr.ListDevices()[0];
424+
Device* dgpu = local_device_mgr.ListDevices()[1];
425+
tensorflow::DataType dtype = DT_RESOURCE;
426+
TensorShape shape = {2};
427+
Tensor tcpu(dtype, shape);
428+
Tensor tgpu(dtype, shape);
429+
Status s;
430+
431+
TensorHandle* th_cpu =
432+
TensorHandle::CreateLocalHandle(std::move(tcpu), dcpu, dcpu, dcpu, ctx);
433+
const char* device_name = th_cpu->DeviceName(&s);
434+
TF_EXPECT_OK(s);
435+
ASSERT_TRUE(absl::StrContains(device_name, "CPU")) << device_name;
436+
const char* backing_device_name = th_cpu->BackingDeviceName(&s);
437+
TF_EXPECT_OK(s);
438+
ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU"))
439+
<< backing_device_name;
440+
const char* device_type = th_cpu->DeviceType(&s);
441+
TF_EXPECT_OK(s);
442+
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
443+
int device_id = th_cpu->DeviceId(&s);
444+
TF_EXPECT_OK(s);
445+
ASSERT_EQ(0, device_id) << device_id;
446+
447+
TensorHandle* th_gpu =
448+
TensorHandle::CreateLocalHandle(std::move(tgpu), dgpu, dgpu, dgpu, ctx);
449+
device_name = th_gpu->DeviceName(&s);
450+
TF_EXPECT_OK(s);
451+
ASSERT_TRUE(absl::StrContains(device_name, "GPU")) << device_name;
452+
backing_device_name = th_gpu->BackingDeviceName(&s);
453+
TF_EXPECT_OK(s);
454+
std::cout << "backing_device_name for GPU: " << backing_device_name
455+
<< std::endl;
456+
ASSERT_TRUE(absl::StrContains(backing_device_name, "GPU"))
457+
<< backing_device_name;
458+
device_type = th_gpu->DeviceType(&s);
459+
TF_EXPECT_OK(s);
460+
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
461+
device_id = th_gpu->DeviceId(&s);
462+
TF_EXPECT_OK(s);
463+
ASSERT_EQ(0, device_id) << device_id;
464+
465+
th_cpu->Unref();
466+
th_gpu->Unref();
467+
ctx->Unref();
468+
}
469+
411470
} // namespace tensorflow

0 commit comments

Comments
 (0)