Skip to content

Commit a82444d

Browse files
nvgrwtensorflower-gardener
authored andcommitted
Migrate custom_call_test to always use PjRt for its test backend.
PiperOrigin-RevId: 723698246
1 parent 662b097 commit a82444d

File tree

2 files changed

+64
-47
lines changed

2 files changed

+64
-47
lines changed

third_party/xla/xla/tests/BUILD

+3-2
Original file line numberDiff line numberDiff line change
@@ -2208,12 +2208,13 @@ xla_test(
22082208
name = "custom_call_test",
22092209
srcs = ["custom_call_test.cc"],
22102210
backends = ["cpu"],
2211+
tags = ["test_migrated_to_hlo_runner_pjrt"],
22112212
deps = [
22122213
":client_library_test_runner_mixin",
2213-
":hlo_test_base",
2214+
":hlo_pjrt_interpreter_reference_mixin",
2215+
":hlo_pjrt_test_base",
22142216
":literal_test_util",
22152217
":test_macros_header",
2216-
":test_utils",
22172218
":xla_internal_test_main", # fixdeps: keep
22182219
"//xla:array2d",
22192220
"//xla:array3d",

third_party/xla/xla/tests/custom_call_test.cc

+61-45
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ limitations under the License.
6060
#include "xla/shape_util.h"
6161
#include "xla/stream_executor/platform.h"
6262
#include "xla/tests/client_library_test_runner_mixin.h"
63-
#include "xla/tests/hlo_test_base.h"
63+
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
64+
#include "xla/tests/hlo_pjrt_test_base.h"
6465
#include "xla/tests/literal_test_util.h"
6566
#include "xla/tests/test_macros.h"
66-
#include "xla/tests/test_utils.h"
6767
#include "xla/tsl/lib/core/status_test_util.h"
6868
#include "xla/tsl/platform/statusor.h"
6969
#include "xla/tsl/platform/test.h"
@@ -189,7 +189,7 @@ std::ostream& operator<<(std::ostream& os, InitMethod op) {
189189

190190
using ::testing::HasSubstr;
191191

192-
class CustomCallTest : public HloTestBase {
192+
class CustomCallTest : public HloPjRtTestBase {
193193
protected:
194194
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
195195
Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2});
@@ -208,7 +208,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) {
208208
module->AddEntryComputation(builder.Build());
209209

210210
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
211-
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
211+
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, kDefaultErrorSpec);
212212
}
213213

214214
XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2Aliased) {
@@ -227,7 +227,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2Aliased) {
227227
module->AddEntryComputation(builder.Build());
228228

229229
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
230-
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
230+
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, kDefaultErrorSpec);
231231
}
232232

233233
XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) {
@@ -249,7 +249,7 @@ XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) {
249249
module->AddEntryComputation(builder.Build());
250250

251251
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
252-
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
252+
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, kDefaultErrorSpec);
253253
}
254254

255255
XLA_TEST_F(CustomCallTest, ReportsSuccess) {
@@ -265,7 +265,7 @@ XLA_TEST_F(CustomCallTest, ReportsSuccess) {
265265
module->AddEntryComputation(builder.Build());
266266

267267
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
268-
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
268+
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, kDefaultErrorSpec);
269269
}
270270

271271
XLA_TEST_F(CustomCallTest, ReportsFailure) {
@@ -350,7 +350,8 @@ XLA_TEST_F(CustomCallTest, FillStatusMsgWithBackendConfigStr) {
350350
}
351351

352352
class CustomCallClientAPITest
353-
: public ClientLibraryTestRunnerMixin<HloTestBase> {};
353+
: public ClientLibraryTestRunnerMixin<
354+
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {};
354355

355356
// When using the client API, CustomCall targets can't begin with '$' -- these
356357
// are reserved for internal use.
@@ -865,8 +866,7 @@ XLA_TEST_F(FfiCustomCallTest, Tokens) {
865866

866867
module->AddEntryComputation(builder.Build());
867868

868-
auto status = Execute(std::move(module), {}).status();
869-
EXPECT_EQ(status, absl::OkStatus());
869+
TF_EXPECT_OK(Execute(std::move(module), {}).status());
870870
}
871871

872872
XLA_TEST_F(FfiCustomCallTest, FfiUnknownTarget) {
@@ -1018,7 +1018,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleTypedBuffers) {
10181018
module->AddEntryComputation(builder.Build());
10191019

10201020
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
1021-
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
1021+
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, kDefaultErrorSpec);
10221022
}
10231023

10241024
XLA_TEST_F(FfiCustomCallTest, FfiHandleInputAsParameters) {
@@ -1036,7 +1036,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleInputAsParameters) {
10361036
Literal argument = LiteralUtil::CreateR0<float>(42.0f);
10371037

10381038
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {&argument}));
1039-
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
1039+
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, kDefaultErrorSpec);
10401040
}
10411041

10421042
XLA_TEST_F(FfiCustomCallTest, FfiHandleBufferBaseFloat) {
@@ -1052,7 +1052,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleBufferBaseFloat) {
10521052
module->AddEntryComputation(builder.Build());
10531053

10541054
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
1055-
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
1055+
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, kDefaultErrorSpec);
10561056
}
10571057

10581058
XLA_TEST_F(FfiCustomCallTest, FfiHandleBufferBaseDouble) {
@@ -1069,7 +1069,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleBufferBaseDouble) {
10691069
module->AddEntryComputation(builder.Build());
10701070

10711071
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
1072-
LiteralTestUtil::ExpectR0Near<double>(44.0f, result, error_spec_);
1072+
LiteralTestUtil::ExpectR0Near<double>(44.0f, result, kDefaultErrorSpec);
10731073
}
10741074

10751075
XLA_TEST_F(FfiCustomCallTest, FfiHandleAttr) {
@@ -1086,7 +1086,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleAttr) {
10861086
module->AddEntryComputation(builder.Build());
10871087

10881088
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
1089-
LiteralTestUtil::ExpectR0Near<float>(45.0f, result, error_spec_);
1089+
LiteralTestUtil::ExpectR0Near<float>(45.0f, result, kDefaultErrorSpec);
10901090
}
10911091

10921092
XLA_TEST_F(FfiCustomCallTest, FfiHandleAttrPointer) {
@@ -1105,7 +1105,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleAttrPointer) {
11051105
module->AddEntryComputation(builder.Build());
11061106

11071107
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
1108-
LiteralTestUtil::ExpectR0Near<float>(46.0f, result, error_spec_);
1108+
LiteralTestUtil::ExpectR0Near<float>(46.0f, result, kDefaultErrorSpec);
11091109
}
11101110

11111111
XLA_TEST_F(FfiCustomCallTest, FfiHandleR2Vector) {
@@ -1128,7 +1128,7 @@ XLA_TEST_F(FfiCustomCallTest, FfiHandleR2Vector) {
11281128
module->AddEntryComputation(builder.Build());
11291129

11301130
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
1131-
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
1131+
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, kDefaultErrorSpec);
11321132
}
11331133

11341134
XLA_TEST_F(FfiCustomCallTest, FfiWrongEnumType) {
@@ -1204,7 +1204,7 @@ XLA_TEST_P(FfiCustomCallEnumTest, FfiHandleEnumAttr) {
12041204
break;
12051205
}
12061206

1207-
LiteralTestUtil::ExpectR0Near<float>(expected, result, error_spec_);
1207+
LiteralTestUtil::ExpectR0Near<float>(expected, result, kDefaultErrorSpec);
12081208
}
12091209

12101210
INSTANTIATE_TEST_SUITE_P(
@@ -1330,22 +1330,26 @@ XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleOutput) {
13301330
c1 = f32[] constant(42.0)
13311331
c2 = f32[] constant(8.0)
13321332
c3 = f32[] constant(43.0)
1333-
ROOT custom-call = ((f32[], f32[]), (f32[], f32[])) custom-call(c0, c1, c2, c3), custom_call_target="__xla_test$$FfiTupleRotate", api_version=API_VERSION_TYPED_FFI
1333+
custom-call = ((f32[], f32[]), (f32[], f32[])) custom-call(c0, c1, c2, c3), custom_call_target="__xla_test$$FfiTupleRotate", api_version=API_VERSION_TYPED_FFI
1334+
t0x = (f32[], f32[]) get-tuple-element(custom-call), index=0
1335+
t00 = f32[] get-tuple-element(t0x), index=0
1336+
t01 = f32[] get-tuple-element(t0x), index=1
1337+
t1x = (f32[], f32[]) get-tuple-element(custom-call), index=1
1338+
t10 = f32[] get-tuple-element(t1x), index=0
1339+
t11 = f32[] get-tuple-element(t1x), index=1
1340+
ROOT tuple = (f32[], f32[], f32[], f32[]) tuple(t00, t01, t10, t11)
13341341
})";
13351342

13361343
TF_ASSERT_OK_AND_ASSIGN(auto module,
13371344
ParseAndReturnVerifiedModule(kModuleStr));
13381345

1339-
Literal arg0 = LiteralUtil::CreateR0<float>(7.f);
1340-
Literal arg1 = LiteralUtil::CreateR0<float>(42.f);
1341-
Literal arg2 = LiteralUtil::CreateR0<float>(8.f);
1342-
Literal arg3 = LiteralUtil::CreateR0<float>(43.f);
1343-
1344-
Literal tuple0 = LiteralUtil::MakeTuple({&arg1, &arg2});
1345-
Literal tuple1 = LiteralUtil::MakeTuple({&arg3, &arg0});
1346+
const Literal arg0 = LiteralUtil::CreateR0<float>(7.f);
1347+
const Literal arg1 = LiteralUtil::CreateR0<float>(42.f);
1348+
const Literal arg2 = LiteralUtil::CreateR0<float>(8.f);
1349+
const Literal arg3 = LiteralUtil::CreateR0<float>(43.f);
13461350

1347-
Literal expected = LiteralUtil::MakeTuple({&tuple0, &tuple1});
1348-
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
1351+
const Literal expected = LiteralUtil::MakeTuple({&arg1, &arg2, &arg3, &arg0});
1352+
TF_ASSERT_OK_AND_ASSIGN(const Literal result, Execute(std::move(module), {}));
13491353
EXPECT_EQ(result, expected);
13501354
}
13511355

@@ -1417,7 +1421,11 @@ XLA_TEST_F(FfiCustomCallTest, IgnoresEmptyTupleParameter) {
14171421
HloModule m
14181422
14191423
ENTRY test {
1420-
p0 = (u32[], s16[], ((), ())) parameter(0)
1424+
t0 = u32[] parameter(0)
1425+
t1 = s16[] parameter(1)
1426+
t2 = () tuple()
1427+
t3 = ((), ()) tuple(t2, t2)
1428+
p0 = (u32[], s16[], ((), ())) tuple(t0, t1, t3)
14211429
ROOT custom-call = (s16[], u32[]) custom-call(p0), custom_call_target="__xla_test$$SwapTupleAnyBuffersToS16U32", api_version=API_VERSION_TYPED_FFI
14221430
})";
14231431

@@ -1426,12 +1434,10 @@ XLA_TEST_F(FfiCustomCallTest, IgnoresEmptyTupleParameter) {
14261434

14271435
Literal arg0 = LiteralUtil::CreateR0<uint32_t>(0xDEADC0DE);
14281436
Literal arg1 = LiteralUtil::CreateR0<int16_t>(29);
1429-
Literal empty_tuple = LiteralUtil::MakeTuple({});
1430-
Literal nested_tuple = LiteralUtil::MakeTuple({&empty_tuple, &empty_tuple});
1431-
Literal argument = LiteralUtil::MakeTuple({&arg0, &arg1, &nested_tuple});
1432-
Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0});
1437+
const Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0});
14331438

1434-
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {&argument}));
1439+
TF_ASSERT_OK_AND_ASSIGN(const Literal result,
1440+
Execute(std::move(module), {&arg0, &arg1}));
14351441
EXPECT_EQ(result, expected);
14361442
}
14371443

@@ -1486,7 +1492,13 @@ XLA_TEST_F(FfiCustomCallTest, HandleTupleDifferentRanks) {
14861492
HloModule m
14871493
14881494
ENTRY test {
1489-
p0 = ((u32[], s16[5]), (f32[2, 2], f32[4, 2, 2])) parameter(0)
1495+
t00 = u32[] parameter(0)
1496+
t01 = s16[5] parameter(1)
1497+
t0x = (u32[], s16[5]) tuple(t00, t01)
1498+
t10 = f32[2, 2] parameter(2)
1499+
t11 = f32[4, 2, 2] parameter(3)
1500+
t1x = (f32[2, 2], f32[4, 2, 2]) tuple(t10, t11)
1501+
p0 = ((u32[], s16[5]), (f32[2, 2], f32[4, 2, 2])) tuple(t0x, t1x)
14901502
ROOT custom-call = (s32[5], f32[5, 2, 2]) custom-call(p0), custom_call_target="__xla_test$$HandleTupleDifferentRanks", api_version=API_VERSION_TYPED_FFI
14911503
})";
14921504

@@ -1500,12 +1512,10 @@ XLA_TEST_F(FfiCustomCallTest, HandleTupleDifferentRanks) {
15001512
{{5.f, 6.f}, {7.f, 8.f}},
15011513
{{9.f, 10.f}, {11.f, 12.f}},
15021514
{{13.f, 14.f}, {15.f, 16.f}}});
1503-
Literal tuple_arg_0 = LiteralUtil::MakeTuple({&arg_0, &arg_1});
1504-
Literal tuple_arg_1 = LiteralUtil::MakeTuple({&arg_2, &arg_3});
1505-
Literal tuple_arg = LiteralUtil::MakeTuple({&tuple_arg_0, &tuple_arg_1});
15061515

1507-
TF_ASSERT_OK_AND_ASSIGN(auto result,
1508-
Execute(std::move(module), {&tuple_arg}));
1516+
TF_ASSERT_OK_AND_ASSIGN(
1517+
const Literal result,
1518+
Execute(std::move(module), {&arg_0, &arg_1, &arg_2, &arg_3}));
15091519

15101520
Literal expected_0 =
15111521
LiteralUtil::CreateR1<int32_t>({2900, 3000, 3100, 3200, 3300});
@@ -1516,7 +1526,8 @@ XLA_TEST_F(FfiCustomCallTest, HandleTupleDifferentRanks) {
15161526
{{13.f, 14.f}, {15.f, 16.f}},
15171527
{{17.f, 18.f}, {19.f, 20.f}}});
15181528

1519-
Literal expected_tuple = LiteralUtil::MakeTuple({&expected_0, &expected_1});
1529+
const Literal expected_tuple =
1530+
LiteralUtil::MakeTuple({&expected_0, &expected_1});
15201531
EXPECT_EQ(result, expected_tuple);
15211532
}
15221533

@@ -1526,7 +1537,13 @@ XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleInputAndOutput) {
15261537
15271538
ENTRY test {
15281539
c0 = ((f32[], f32[]), (f32[], f32[])) constant(((7.0, 42.0), (8.0, 43.0)))
1529-
ROOT custom-call = (f32[], (f32[], f32[]), f32[]) custom-call(c0), custom_call_target="__xla_test$$FfiTupleRotate", api_version=API_VERSION_TYPED_FFI
1540+
custom-call = (f32[], (f32[], f32[]), f32[]) custom-call(c0), custom_call_target="__xla_test$$FfiTupleRotate", api_version=API_VERSION_TYPED_FFI
1541+
t00 = f32[] get-tuple-element(custom-call), index=0
1542+
t1x = (f32[], f32[]) get-tuple-element(custom-call), index=1
1543+
t10 = f32[] get-tuple-element(t1x), index=0
1544+
t11 = f32[] get-tuple-element(t1x), index=1
1545+
t20 = f32[] get-tuple-element(custom-call), index=2
1546+
ROOT result = (f32[], f32[], f32[], f32[]) tuple(t00, t10, t11, t20)
15301547
})";
15311548

15321549
TF_ASSERT_OK_AND_ASSIGN(auto module,
@@ -1537,9 +1554,8 @@ XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleInputAndOutput) {
15371554
Literal arg2 = LiteralUtil::CreateR0<float>(8.f);
15381555
Literal arg3 = LiteralUtil::CreateR0<float>(43.f);
15391556

1540-
Literal inner_tuple = LiteralUtil::MakeTuple({&arg2, &arg3});
1541-
Literal expected = LiteralUtil::MakeTuple({&arg1, &inner_tuple, &arg0});
1542-
TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {}));
1557+
const Literal expected = LiteralUtil::MakeTuple({&arg1, &arg2, &arg3, &arg0});
1558+
TF_ASSERT_OK_AND_ASSIGN(const Literal result, Execute(std::move(module), {}));
15431559
EXPECT_EQ(result, expected);
15441560
}
15451561

0 commit comments

Comments
 (0)