@@ -60,10 +60,10 @@ limitations under the License.
60
60
#include " xla/shape_util.h"
61
61
#include " xla/stream_executor/platform.h"
62
62
#include " xla/tests/client_library_test_runner_mixin.h"
63
- #include " xla/tests/hlo_test_base.h"
63
+ #include " xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
64
+ #include " xla/tests/hlo_pjrt_test_base.h"
64
65
#include " xla/tests/literal_test_util.h"
65
66
#include " xla/tests/test_macros.h"
66
- #include " xla/tests/test_utils.h"
67
67
#include " xla/tsl/lib/core/status_test_util.h"
68
68
#include " xla/tsl/platform/statusor.h"
69
69
#include " xla/tsl/platform/test.h"
@@ -189,7 +189,7 @@ std::ostream& operator<<(std::ostream& os, InitMethod op) {
189
189
190
190
using ::testing::HasSubstr;
191
191
192
- class CustomCallTest : public HloTestBase {
192
+ class CustomCallTest : public HloPjRtTestBase {
193
193
protected:
194
194
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
195
195
Shape r2f32_ = ShapeUtil::MakeShape(F32, {2 , 2 });
@@ -208,7 +208,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) {
208
208
module->AddEntryComputation (builder.Build ());
209
209
210
210
TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
211
- LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, error_spec_ );
211
+ LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, kDefaultErrorSpec );
212
212
}
213
213
214
214
XLA_TEST_F (CustomCallTest, CustomCallR0F32Add2Aliased) {
@@ -227,7 +227,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2Aliased) {
227
227
module->AddEntryComputation (builder.Build ());
228
228
229
229
TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
230
- LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, error_spec_ );
230
+ LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, kDefaultErrorSpec );
231
231
}
232
232
233
233
XLA_TEST_F (CustomCallTest, CustomCallR2F32Reduce) {
@@ -249,7 +249,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) {
249
249
module->AddEntryComputation (builder.Build ());
250
250
251
251
TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
252
- LiteralTestUtil::ExpectR0Near<float >(10 .0f , result, error_spec_ );
252
+ LiteralTestUtil::ExpectR0Near<float >(10 .0f , result, kDefaultErrorSpec );
253
253
}
254
254
255
255
XLA_TEST_F (CustomCallTest, ReportsSuccess) {
@@ -265,7 +265,7 @@ XLA_TEST_F(CustomCallTest, ReportsSuccess) {
265
265
module->AddEntryComputation (builder.Build ());
266
266
267
267
TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
268
- LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, error_spec_ );
268
+ LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, kDefaultErrorSpec );
269
269
}
270
270
271
271
XLA_TEST_F (CustomCallTest, ReportsFailure) {
@@ -350,7 +350,8 @@ XLA_TEST_F(CustomCallTest, FillStatusMsgWithBackendConfigStr) {
350
350
}
351
351
352
352
class CustomCallClientAPITest
353
- : public ClientLibraryTestRunnerMixin<HloTestBase> {};
353
+ : public ClientLibraryTestRunnerMixin<
354
+ HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {};
354
355
355
356
// When using the client API, CustomCall targets can't begin with '$' -- these
356
357
// are reserved for internal use.
@@ -865,8 +866,7 @@ XLA_TEST_F(FfiCustomCallTest, Tokens) {
865
866
866
867
module->AddEntryComputation (builder.Build ());
867
868
868
- auto status = Execute (std::move (module), {}).status ();
869
- EXPECT_EQ (status, absl::OkStatus ());
869
+ TF_EXPECT_OK (Execute (std::move (module), {}).status ());
870
870
}
871
871
872
872
XLA_TEST_F (FfiCustomCallTest, FfiUnknownTarget) {
@@ -1018,7 +1018,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleTypedBuffers) {
1018
1018
module->AddEntryComputation (builder.Build ());
1019
1019
1020
1020
TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
1021
- LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, error_spec_ );
1021
+ LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, kDefaultErrorSpec );
1022
1022
}
1023
1023
1024
1024
XLA_TEST_F (FfiCustomCallTest, FfiHandleInputAsParameters) {
@@ -1036,7 +1036,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleInputAsParameters) {
1036
1036
Literal argument = LiteralUtil::CreateR0<float >(42 .0f );
1037
1037
1038
1038
TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {&argument}));
1039
- LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, error_spec_ );
1039
+ LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, kDefaultErrorSpec );
1040
1040
}
1041
1041
1042
1042
XLA_TEST_F (FfiCustomCallTest, FfiHandleBufferBaseFloat) {
@@ -1052,7 +1052,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleBufferBaseFloat) {
1052
1052
module->AddEntryComputation (builder.Build ());
1053
1053
1054
1054
TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
1055
- LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, error_spec_ );
1055
+ LiteralTestUtil::ExpectR0Near<float >(44 .0f , result, kDefaultErrorSpec );
1056
1056
}
1057
1057
1058
1058
XLA_TEST_F (FfiCustomCallTest, FfiHandleBufferBaseDouble) {
@@ -1069,7 +1069,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleBufferBaseDouble) {
1069
1069
module->AddEntryComputation (builder.Build ());
1070
1070
1071
1071
TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
1072
- LiteralTestUtil::ExpectR0Near<double >(44 .0f , result, error_spec_ );
1072
+ LiteralTestUtil::ExpectR0Near<double >(44 .0f , result, kDefaultErrorSpec );
1073
1073
}
1074
1074
1075
1075
XLA_TEST_F (FfiCustomCallTest, FfiHandleAttr) {
@@ -1086,7 +1086,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleAttr) {
1086
1086
module->AddEntryComputation (builder.Build ());
1087
1087
1088
1088
TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
1089
- LiteralTestUtil::ExpectR0Near<float >(45 .0f , result, error_spec_ );
1089
+ LiteralTestUtil::ExpectR0Near<float >(45 .0f , result, kDefaultErrorSpec );
1090
1090
}
1091
1091
1092
1092
XLA_TEST_F (FfiCustomCallTest, FfiHandleAttrPointer) {
@@ -1105,7 +1105,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleAttrPointer) {
1105
1105
module->AddEntryComputation (builder.Build ());
1106
1106
1107
1107
TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
1108
- LiteralTestUtil::ExpectR0Near<float >(46 .0f , result, error_spec_ );
1108
+ LiteralTestUtil::ExpectR0Near<float >(46 .0f , result, kDefaultErrorSpec );
1109
1109
}
1110
1110
1111
1111
XLA_TEST_F (FfiCustomCallTest, FfiHandleR2Vector) {
@@ -1128,7 +1128,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleR2Vector) {
1128
1128
module->AddEntryComputation (builder.Build ());
1129
1129
1130
1130
TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
1131
- LiteralTestUtil::ExpectR0Near<float >(10 .0f , result, error_spec_ );
1131
+ LiteralTestUtil::ExpectR0Near<float >(10 .0f , result, kDefaultErrorSpec );
1132
1132
}
1133
1133
1134
1134
XLA_TEST_F (FfiCustomCallTest, FfiWrongEnumType) {
@@ -1204,7 +1204,7 @@ XLA_TEST_P(FfiCustomCallEnumTest, FfiHandleEnumAttr) {
1204
1204
break ;
1205
1205
}
1206
1206
1207
- LiteralTestUtil::ExpectR0Near<float >(expected, result, error_spec_ );
1207
+ LiteralTestUtil::ExpectR0Near<float >(expected, result, kDefaultErrorSpec );
1208
1208
}
1209
1209
1210
1210
INSTANTIATE_TEST_SUITE_P (
@@ -1330,22 +1330,26 @@ XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleOutput) {
1330
1330
c1 = f32[] constant(42.0)
1331
1331
c2 = f32[] constant(8.0)
1332
1332
c3 = f32[] constant(43.0)
1333
- ROOT custom-call = ((f32[], f32[]), (f32[], f32[])) custom-call(c0, c1, c2, c3), custom_call_target="__xla_test$$FfiTupleRotate", api_version=API_VERSION_TYPED_FFI
1333
+ custom-call = ((f32[], f32[]), (f32[], f32[])) custom-call(c0, c1, c2, c3), custom_call_target="__xla_test$$FfiTupleRotate", api_version=API_VERSION_TYPED_FFI
1334
+ t0x = (f32[], f32[]) get-tuple-element(custom-call), index=0
1335
+ t00 = f32[] get-tuple-element(t0x), index=0
1336
+ t01 = f32[] get-tuple-element(t0x), index=1
1337
+ t1x = (f32[], f32[]) get-tuple-element(custom-call), index=1
1338
+ t10 = f32[] get-tuple-element(t1x), index=0
1339
+ t11 = f32[] get-tuple-element(t1x), index=1
1340
+ ROOT tuple = (f32[], f32[], f32[], f32[]) tuple(t00, t01, t10, t11)
1334
1341
})" ;
1335
1342
1336
1343
TF_ASSERT_OK_AND_ASSIGN (auto module,
1337
1344
ParseAndReturnVerifiedModule (kModuleStr ));
1338
1345
1339
- Literal arg0 = LiteralUtil::CreateR0<float >(7 .f );
1340
- Literal arg1 = LiteralUtil::CreateR0<float >(42 .f );
1341
- Literal arg2 = LiteralUtil::CreateR0<float >(8 .f );
1342
- Literal arg3 = LiteralUtil::CreateR0<float >(43 .f );
1343
-
1344
- Literal tuple0 = LiteralUtil::MakeTuple ({&arg1, &arg2});
1345
- Literal tuple1 = LiteralUtil::MakeTuple ({&arg3, &arg0});
1346
+ const Literal arg0 = LiteralUtil::CreateR0<float >(7 .f );
1347
+ const Literal arg1 = LiteralUtil::CreateR0<float >(42 .f );
1348
+ const Literal arg2 = LiteralUtil::CreateR0<float >(8 .f );
1349
+ const Literal arg3 = LiteralUtil::CreateR0<float >(43 .f );
1346
1350
1347
- Literal expected = LiteralUtil::MakeTuple ({&tuple0 , &tuple1 });
1348
- TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
1351
+ const Literal expected = LiteralUtil::MakeTuple ({&arg1 , &arg2, &arg3, &arg0 });
1352
+ TF_ASSERT_OK_AND_ASSIGN (const Literal result, Execute (std::move (module), {}));
1349
1353
EXPECT_EQ (result, expected);
1350
1354
}
1351
1355
@@ -1417,7 +1421,11 @@ XLA_TEST_F(FfiCustomCallTest, IgnoresEmptyTupleParameter) {
1417
1421
HloModule m
1418
1422
1419
1423
ENTRY test {
1420
- p0 = (u32[], s16[], ((), ())) parameter(0)
1424
+ t0 = u32[] parameter(0)
1425
+ t1 = s16[] parameter(1)
1426
+ t2 = () tuple()
1427
+ t3 = ((), ()) tuple(t2, t2)
1428
+ p0 = (u32[], s16[], ((), ())) tuple(t0, t1, t3)
1421
1429
ROOT custom-call = (s16[], u32[]) custom-call(p0), custom_call_target="__xla_test$$SwapTupleAnyBuffersToS16U32", api_version=API_VERSION_TYPED_FFI
1422
1430
})" ;
1423
1431
@@ -1426,12 +1434,10 @@ XLA_TEST_F(FfiCustomCallTest, IgnoresEmptyTupleParameter) {
1426
1434
1427
1435
Literal arg0 = LiteralUtil::CreateR0<uint32_t >(0xDEADC0DE );
1428
1436
Literal arg1 = LiteralUtil::CreateR0<int16_t >(29 );
1429
- Literal empty_tuple = LiteralUtil::MakeTuple ({});
1430
- Literal nested_tuple = LiteralUtil::MakeTuple ({&empty_tuple, &empty_tuple});
1431
- Literal argument = LiteralUtil::MakeTuple ({&arg0, &arg1, &nested_tuple});
1432
- Literal expected = LiteralUtil::MakeTuple ({&arg1, &arg0});
1437
+ const Literal expected = LiteralUtil::MakeTuple ({&arg1, &arg0});
1433
1438
1434
- TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {&argument}));
1439
+ TF_ASSERT_OK_AND_ASSIGN (const Literal result,
1440
+ Execute (std::move (module), {&arg0, &arg1}));
1435
1441
EXPECT_EQ (result, expected);
1436
1442
}
1437
1443
@@ -1486,7 +1492,13 @@ XLA_TEST_F(FfiCustomCallTest, HandleTupleDifferentRanks) {
1486
1492
HloModule m
1487
1493
1488
1494
ENTRY test {
1489
- p0 = ((u32[], s16[5]), (f32[2, 2], f32[4, 2, 2])) parameter(0)
1495
+ t00 = u32[] parameter(0)
1496
+ t01 = s16[5] parameter(1)
1497
+ t0x = (u32[], s16[5]) tuple(t00, t01)
1498
+ t10 = f32[2, 2] parameter(2)
1499
+ t11 = f32[4, 2, 2] parameter(3)
1500
+ t1x = (f32[2, 2], f32[4, 2, 2]) tuple(t10, t11)
1501
+ p0 = ((u32[], s16[5]), (f32[2, 2], f32[4, 2, 2])) tuple(t0x, t1x)
1490
1502
ROOT custom-call = (s32[5], f32[5, 2, 2]) custom-call(p0), custom_call_target="__xla_test$$HandleTupleDifferentRanks", api_version=API_VERSION_TYPED_FFI
1491
1503
})" ;
1492
1504
@@ -1500,12 +1512,10 @@ XLA_TEST_F(FfiCustomCallTest, HandleTupleDifferentRanks) {
1500
1512
{{5 .f , 6 .f }, {7 .f , 8 .f }},
1501
1513
{{9 .f , 10 .f }, {11 .f , 12 .f }},
1502
1514
{{13 .f , 14 .f }, {15 .f , 16 .f }}});
1503
- Literal tuple_arg_0 = LiteralUtil::MakeTuple ({&arg_0, &arg_1});
1504
- Literal tuple_arg_1 = LiteralUtil::MakeTuple ({&arg_2, &arg_3});
1505
- Literal tuple_arg = LiteralUtil::MakeTuple ({&tuple_arg_0, &tuple_arg_1});
1506
1515
1507
- TF_ASSERT_OK_AND_ASSIGN (auto result,
1508
- Execute (std::move (module), {&tuple_arg}));
1516
+ TF_ASSERT_OK_AND_ASSIGN (
1517
+ const Literal result,
1518
+ Execute (std::move (module), {&arg_0, &arg_1, &arg_2, &arg_3}));
1509
1519
1510
1520
Literal expected_0 =
1511
1521
LiteralUtil::CreateR1<int32_t >({2900 , 3000 , 3100 , 3200 , 3300 });
@@ -1516,7 +1526,8 @@ XLA_TEST_F(FfiCustomCallTest, HandleTupleDifferentRanks) {
1516
1526
{{13 .f , 14 .f }, {15 .f , 16 .f }},
1517
1527
{{17 .f , 18 .f }, {19 .f , 20 .f }}});
1518
1528
1519
- Literal expected_tuple = LiteralUtil::MakeTuple ({&expected_0, &expected_1});
1529
+ const Literal expected_tuple =
1530
+ LiteralUtil::MakeTuple ({&expected_0, &expected_1});
1520
1531
EXPECT_EQ (result, expected_tuple);
1521
1532
}
1522
1533
@@ -1526,7 +1537,13 @@ XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleInputAndOutput) {
1526
1537
1527
1538
ENTRY test {
1528
1539
c0 = ((f32[], f32[]), (f32[], f32[])) constant(((7.0, 42.0), (8.0, 43.0)))
1529
- ROOT custom-call = (f32[], (f32[], f32[]), f32[]) custom-call(c0), custom_call_target="__xla_test$$FfiTupleRotate", api_version=API_VERSION_TYPED_FFI
1540
+ custom-call = (f32[], (f32[], f32[]), f32[]) custom-call(c0), custom_call_target="__xla_test$$FfiTupleRotate", api_version=API_VERSION_TYPED_FFI
1541
+ t00 = f32[] get-tuple-element(custom-call), index=0
1542
+ t1x = (f32[], f32[]) get-tuple-element(custom-call), index=1
1543
+ t10 = f32[] get-tuple-element(t1x), index=0
1544
+ t11 = f32[] get-tuple-element(t1x), index=1
1545
+ t20 = f32[] get-tuple-element(custom-call), index=2
1546
+ ROOT result = (f32[], f32[], f32[], f32[]) tuple(t00, t10, t11, t20)
1530
1547
})" ;
1531
1548
1532
1549
TF_ASSERT_OK_AND_ASSIGN (auto module,
@@ -1537,9 +1554,8 @@ XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleInputAndOutput) {
1537
1554
Literal arg2 = LiteralUtil::CreateR0<float >(8 .f );
1538
1555
Literal arg3 = LiteralUtil::CreateR0<float >(43 .f );
1539
1556
1540
- Literal inner_tuple = LiteralUtil::MakeTuple ({&arg2, &arg3});
1541
- Literal expected = LiteralUtil::MakeTuple ({&arg1, &inner_tuple, &arg0});
1542
- TF_ASSERT_OK_AND_ASSIGN (auto result, Execute (std::move (module), {}));
1557
+ const Literal expected = LiteralUtil::MakeTuple ({&arg1, &arg2, &arg3, &arg0});
1558
+ TF_ASSERT_OK_AND_ASSIGN (const Literal result, Execute (std::move (module), {}));
1543
1559
EXPECT_EQ (result, expected);
1544
1560
}
1545
1561
0 commit comments