forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_custom_lowering.py
151 lines (122 loc) · 5.22 KB
/
test_custom_lowering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch._inductor.ir import Pointwise
from torch._inductor.lowering import register_lowering
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.virtualized import ops
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
# These tests check issues for lowerings that aren't in the main pytorch repo
class TestCustomLowering(InductorTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.test_inductor_ops = torch.library.Library( # noqa: TOR901
"test_inductor_ops", "DEF"
)
cls.impl_cuda = torch.library.Library( # noqa: TOR901
"test_inductor_ops", "IMPL", "CUDA"
)
cls.impl_meta = torch.library.Library( # noqa: TOR901
"test_inductor_ops", "IMPL", "Meta"
)
cls._register_jagged_to_padded_dense()
@classmethod
def tearDown(cls):
super().tearDownClass()
@classmethod
def _register_jagged_to_padded_dense(cls):
# Approximation of fbgemm.jagged_to_padded_dense_forward
cls.test_inductor_ops.define(
"jagged_to_padded_dense(Tensor input, Tensor offsets, SymInt max_seq_len, Scalar pad_value) -> Tensor"
)
def j2pd_meta(inp, offsets, max_seq_len, pad_value):
return torch.empty(
(offsets.shape[0] - 1, max_seq_len, inp.shape[1]),
device=inp.device,
dtype=inp.dtype,
)
def j2pd_cuda(inp, offsets, max_seq_len, pad_value):
res = torch.full(
(offsets.shape[0] - 1, max_seq_len, inp.shape[1]),
pad_value,
device=inp.device,
dtype=inp.dtype,
)
for b in range(offsets.shape[0] - 1):
for r in range(offsets[b + 1] - offsets[b]):
res[b][r] = inp[offsets[b] + r]
return res
def j2pd_lowering(inp, offsets, max_seq_len, pad_value):
offsets_loader = offsets.make_loader()
inp_loader = inp.make_loader()
jagged_len = inp.get_size()[0]
offsets_dtype = offsets.get_dtype()
def inner_fn(index):
batch_idx, seq_idx, emb_idx = index
begin_idx = ops.indirect_indexing(
offsets_loader([batch_idx]),
jagged_len + 1,
)
end_idx = offsets_loader([batch_idx + 1])
jagged_idx = begin_idx + seq_idx
return ops.masked(
ops.lt(
ops.index_expr(jagged_idx, offsets_dtype),
end_idx,
),
lambda: inp_loader([jagged_idx, emb_idx]),
pad_value,
)
return Pointwise.create(
device=inp.get_device(),
dtype=inp.get_dtype(),
inner_fn=inner_fn,
ranges=[offsets.get_size()[0] - 1, max_seq_len, inp.get_size()[1]],
)
register_lowering(
torch.ops.test_inductor_ops.jagged_to_padded_dense, type_promotion_kind=None
)(j2pd_lowering)
cls.impl_meta.impl("jagged_to_padded_dense", j2pd_meta)
cls.impl_cuda.impl("jagged_to_padded_dense", j2pd_cuda)
@unittest.skipIf(not HAS_CUDA, "CUDA needed")
def test_jagged_to_padded_dense_sanity_cuda(self):
def fn(inp, offsets, max_seq_len):
return torch.ops.test_inductor_ops.jagged_to_padded_dense(
inp, offsets, max_seq_len, 60.0
)
inp = torch.rand((9, 96), device="cuda")
offsets = torch.tensor([0, 2, 5, 9], dtype=torch.int32, device="cuda")
max_seq_len = 4
res = fn(inp, offsets, max_seq_len)
self.assertEqual(inp[0], res[0][0])
self.assertEqual(inp[1], res[0][1])
self.assertEqual(inp[2], res[1][0])
self.assertEqual(inp[3], res[1][1])
self.assertEqual(inp[5], res[2][0])
self.assertEqual(inp[8], res[2][3])
fn_opt = torch.compile(fn)
self.assertEqual(
fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len)
)
@unittest.skipIf(not HAS_CUDA, "CUDA needed")
def test_jagged_to_padded_dense_zero_size(self):
# Previously, the masking was being completely stripped for the
# masked load of the input value. That would lead to an IMA
# because cuda was trying to read index 0 of a zero-size tensor.
def fn(inp, offsets, max_seq_len):
inp = torch.bmm(inp, torch.ones((1, 96, 1), device="cuda")).view((0, 1))
return torch.ops.test_inductor_ops.jagged_to_padded_dense(
inp, offsets, max_seq_len, 60.0
)
inp = torch.rand((1, 0, 96), device="cuda")
offsets = torch.zeros(1025, device="cuda", dtype=torch.int32)
max_seq_len = 20
fn_opt = torch.compile(fn)
self.assertEqual(
fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len)
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CPU or HAS_CUDA:
run_tests(needs="filelock")