@@ -411,5 +411,109 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
411
411
TF_DeleteStatus (status);
412
412
}
413
413
414
+ TEST (CAPI, TensorHandleNullptr) {
415
+ TFE_TensorHandle* h = nullptr ;
416
+ std::unique_ptr<TF_Status, decltype (&TF_DeleteStatus)> status (
417
+ TF_NewStatus (), TF_DeleteStatus);
418
+
419
+ const char * device_type = TFE_TensorHandleDeviceType (h, status.get ());
420
+ ASSERT_EQ (TF_INVALID_ARGUMENT, TF_GetCode (status.get ()));
421
+ ASSERT_EQ (device_type, nullptr );
422
+ ASSERT_EQ (" Invalid handle" , string (TF_Message (status.get ())));
423
+
424
+ TF_SetStatus (status.get (), TF_OK, " " );
425
+
426
+ int device_id = TFE_TensorHandleDeviceID (h, status.get ());
427
+ ASSERT_EQ (TF_INVALID_ARGUMENT, TF_GetCode (status.get ()));
428
+ ASSERT_EQ (device_id, -1 );
429
+ ASSERT_EQ (" Invalid handle" , string (TF_Message (status.get ())));
430
+ }
431
+
432
+ TEST (CAPI, TensorHandleDevices) {
433
+ std::unique_ptr<TF_Status, decltype (&TF_DeleteStatus)> status (
434
+ TF_NewStatus (), TF_DeleteStatus);
435
+ TFE_ContextOptions* opts = TFE_NewContextOptions ();
436
+ TFE_Context* ctx = TFE_NewContext (opts, status.get ());
437
+ TFE_DeleteContextOptions (opts);
438
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
439
+
440
+ TFE_TensorHandle* hcpu = TestMatrixTensorHandle (ctx);
441
+ const char * device_type = TFE_TensorHandleDeviceType (hcpu, status.get ());
442
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
443
+ ASSERT_TRUE (absl::StrContains (device_type, " CPU" )) << device_type;
444
+ int device_id = TFE_TensorHandleDeviceID (hcpu, status.get ());
445
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
446
+ ASSERT_EQ (0 , device_id) << device_id;
447
+
448
+ // Disable the test if no GPU is present.
449
+ string gpu_device_name;
450
+ if (GetDeviceName (ctx, &gpu_device_name, " GPU" )) {
451
+ TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice (
452
+ hcpu, ctx, gpu_device_name.c_str (), status.get ());
453
+ ASSERT_TRUE (TF_GetCode (status.get ()) == TF_OK) << TF_Message (status.get ());
454
+
455
+ TFE_Op* shape_op = ShapeOp (ctx, hgpu);
456
+ TFE_OpSetDevice (shape_op, gpu_device_name.c_str (), status.get ());
457
+ ASSERT_TRUE (TF_GetCode (status.get ()) == TF_OK) << TF_Message (status.get ());
458
+ TFE_TensorHandle* retvals[1 ];
459
+ int num_retvals = 1 ;
460
+ TFE_Execute (shape_op, &retvals[0 ], &num_retvals, status.get ());
461
+ ASSERT_TRUE (TF_GetCode (status.get ()) == TF_OK) << TF_Message (status.get ());
462
+
463
+ device_type = TFE_TensorHandleDeviceType (retvals[0 ], status.get ());
464
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
465
+ ASSERT_TRUE (absl::StrContains (device_type, " GPU" )) << device_type;
466
+
467
+ device_id = TFE_TensorHandleDeviceID (retvals[0 ], status.get ());
468
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
469
+ ASSERT_EQ (0 , device_id) << device_id;
470
+
471
+ TFE_DeleteOp (shape_op);
472
+ TFE_DeleteTensorHandle (retvals[0 ]);
473
+ TFE_DeleteTensorHandle (hgpu);
474
+ }
475
+
476
+ TFE_DeleteTensorHandle (hcpu);
477
+ TFE_Executor* executor = TFE_ContextGetExecutorForThread (ctx);
478
+ TFE_ExecutorWaitForAllPendingNodes (executor, status.get ());
479
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
480
+ TFE_DeleteExecutor (executor);
481
+ TFE_DeleteContext (ctx);
482
+ }
483
+
484
+ TEST (CAPI, TensorHandleDefaults) {
485
+ std::unique_ptr<TF_Status, decltype (&TF_DeleteStatus)> status (
486
+ TF_NewStatus (), TF_DeleteStatus);
487
+ TFE_ContextOptions* opts = TFE_NewContextOptions ();
488
+ TFE_Context* ctx = TFE_NewContext (opts, status.get ());
489
+ TFE_DeleteContextOptions (opts);
490
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
491
+
492
+ TFE_TensorHandle* h_default = TestMatrixTensorHandle (ctx);
493
+ const char * device_type = TFE_TensorHandleDeviceType (h_default, status.get ());
494
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
495
+ ASSERT_TRUE (absl::StrContains (device_type, " CPU" )) << device_type;
496
+ int device_id = TFE_TensorHandleDeviceID (h_default, status.get ());
497
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
498
+ ASSERT_EQ (0 , device_id) << device_id;
499
+
500
+ TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice (
501
+ h_default, ctx, " /device:CPU:0" , status.get ());
502
+ const char * device_type_cpu = TFE_TensorHandleDeviceType (h_cpu, status.get ());
503
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
504
+ ASSERT_TRUE (absl::StrContains (device_type_cpu, " CPU" )) << device_type_cpu;
505
+ int device_id_cpu = TFE_TensorHandleDeviceID (h_cpu, status.get ());
506
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
507
+ ASSERT_EQ (0 , device_id_cpu) << device_id_cpu;
508
+
509
+ TFE_DeleteTensorHandle (h_default);
510
+ TFE_DeleteTensorHandle (h_cpu);
511
+ TFE_Executor* executor = TFE_ContextGetExecutorForThread (ctx);
512
+ TFE_ExecutorWaitForAllPendingNodes (executor, status.get ());
513
+ ASSERT_EQ (TF_OK, TF_GetCode (status.get ())) << TF_Message (status.get ());
514
+ TFE_DeleteExecutor (executor);
515
+ TFE_DeleteContext (ctx);
516
+ }
517
+
414
518
} // namespace
415
519
} // namespace tensorflow
0 commit comments