Skip to content

Commit f6a00a8

Browse files
peterbell10pytorchmergebot
authored andcommitted
[inductor] Add abs to index_propagation (pytorch#124616)
Pull Request resolved: pytorch#124616 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#124119
1 parent c30ea33 commit f6a00a8

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,22 @@ def repeat(x, n):
12061206
self.assertEqual(expect, actual)
12071207
self.assertEqual(actual, repeat(x, 3))
12081208

1209+
def test_index_propagation_abs(self):
1210+
def reflection_pad_left(x, n):
1211+
# e.g. x=[1, 2, 3], n=2 => returns [3, 2, 1, 2, 3]
1212+
i = torch.arange(x.shape[0] + n, device=x.device)
1213+
return x[(i - n).abs()]
1214+
1215+
x = torch.randn(8, device=self.device)
1216+
opt_fn = torch._dynamo.optimize("inductor")(reflection_pad_left)
1217+
1218+
# this should be collapsed to direct indexing
1219+
actual = _run_and_assert_no_indirect_indexing(
1220+
self, opt_fn, x, 3, has_wrapping=False
1221+
)
1222+
expect = reflection_pad_left(x, 3)
1223+
self.assertEqual(expect, actual)
1224+
12091225
@skipIfRocm
12101226
@config.patch(debug_index_asserts=False)
12111227
def test_neg_index(self):

torch/_inductor/index_propagation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def to_dtype(
8282
) -> TypedExpr:
8383
return TypedExpr(value.expr, dtype)
8484

85+
@staticmethod
86+
def abs(x: TypedExpr) -> TypedExpr:
87+
return TypedExpr(abs(x.expr), x.dtype) # type: ignore[arg-type]
88+
8589
@staticmethod
8690
def square(x: TypedExpr) -> TypedExpr:
8791
return TypedExpr(x.expr * x.expr, x.dtype)

0 commit comments

Comments
 (0)