forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_remove_mutation.py
318 lines (264 loc) · 10.2 KB
/
test_remove_mutation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
# Owner(s): ["oncall: jit"]
import os
import sys
from typing import List
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestRemoveMutation(JitTestCase):
def test_aten_inplace(self):
def test_not_new_alias(x):
y = x[0]
y.add_(2)
return y
fn = torch.jit.script(test_not_new_alias)
graph = fn.graph
self.run_pass("remove_mutation", graph)
FileCheck().check("aten::add_").run(graph)
self.assertEqual(fn(torch.ones([2, 2])), test_not_new_alias(torch.ones([2, 2])))
def test_no_lowering():
x = torch.tensor([2, 2])
x[0] = 3
return x
# there is no functional equivalent of x[0] = ...
fn = torch.jit.script(test_no_lowering)
graph = fn.graph
self.run_pass("remove_mutation", graph)
FileCheck().check("aten::copy_").run(graph)
self.assertEqual(fn(), test_no_lowering())
def test_move_before_not_valid():
y = torch.tensor([2, 2])
z = y + 2
y.add_(2)
return y, z
fn = torch.jit.script(test_move_before_not_valid)
graph = fn.graph
self.run_pass("remove_mutation", graph)
FileCheck().check("aten::add_").run(graph)
self.assertEqual(fn(), test_move_before_not_valid())
def test_successful():
x = torch.tensor([2, 2])
x.add_(1)
x.add_(3)
y = x + 4
return x, y
fn = torch.jit.script(test_successful)
graph = fn.graph
self.run_pass("remove_mutation", graph)
FileCheck().check_not("aten::add_").run(graph)
self.assertEqual(test_successful(), fn())
def test_intermediary_use():
x = torch.tensor([2, 2])
x.add_(1)
y = x + 4
x.add_(3)
return x, y
fn = torch.jit.script(test_intermediary_use)
graph = fn.graph
FileCheck().check_count("aten::add_", 2).run(graph)
self.run_pass("remove_mutation", graph)
# Unable to remove the second add_ because of the y = x + 4 use
# In the future we could duplicating the value of x as a temporary and replacing
# its intermediary use (so long as aliasing is safe)
FileCheck().check_count("aten::add_", 1).run(graph)
self.assertEqual(test_intermediary_use(), fn())
def test_if_output(self):
def foo(x, cond: bool):
if cond:
y = x + 5
else:
y = x + 2
y.add_(4)
return y
out_eager = foo(torch.tensor(5), True)
foo_script = torch.jit.script(foo)
FileCheck().check("aten::add_").run(foo_script.graph)
self.run_pass("remove_mutation", foo_script.graph)
FileCheck().check_not("aten::add_").run(foo_script.graph)
self.assertEqual(out_eager, foo_script(torch.tensor(5), True))
def test_if_output_fail(self):
@torch.jit.script
def foo(cond: bool):
li = []
if cond:
x = torch.tensor(1)
li.append(x)
else:
x = torch.tensor(2)
y = x.add_(2)
return y, li
self.run_pass("inline", foo.graph)
self.run_pass("remove_mutation", foo.graph)
FileCheck().check("aten::add_").run(foo.graph)
@torch.jit.script
def foo(cond: bool, y):
if cond:
x = y
else:
x = torch.tensor(2)
z = x.add_(2)
return z
self.run_pass("inline", foo.graph)
self.run_pass("remove_mutation", foo.graph)
FileCheck().check("aten::add_").run(foo.graph)
def test_special_mapped_op(self):
def test_successful():
x = torch.tensor([2, 2])
y = torch.tensor([2, 4])
x.zero_()
y.fill_(3)
return x, y
fn = torch.jit.script(test_successful)
graph = fn.graph
self.run_pass("remove_mutation", graph)
FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(graph)
self.assertEqual(test_successful(), fn())
# full_like is not implemented for a tensor fill value
def test_successful():
x = torch.tensor([2, 2])
y = torch.tensor([2, 4])
x.fill_(y)
return x + x
fn = torch.jit.script(test_successful)
graph = fn.graph
self.run_pass("remove_mutation", graph)
FileCheck().check_not("aten::fill_").run(graph)
def normal():
# NOTE: For some unknown reason, the
# `torch._C._jit_pass_remove_mutation` call within `self.run_pass`
# replaces `torch.randn(..., dtype=None).normal_()` with an
# `aten::normal` call with dtype double, even if the default dtype
# is float. So we must explicitly set the dtype here
return torch.rand(2, 1, 3, 4, dtype=torch.float).normal_()
fn = torch.jit.script(normal)
graph = fn.graph
self.run_pass("remove_mutation", graph)
FileCheck().check_not("normal_").run(graph)
with freeze_rng_state():
out_eager = normal()
with freeze_rng_state():
out_script = fn()
self.assertEqual(out_eager, out_script)
def test_lists_append(self):
def successful_remove():
return [i for i in range(5)] # noqa: C416
fn = torch.jit.script(successful_remove)
graph = fn.graph
self.run_pass("loop_unrolling", graph)
self.run_pass("remove_mutation", graph)
self.run_pass("constant_propagation", graph)
FileCheck().check("graph").check_next("Constant").check_next("return").run(
graph
)
self.assertEqual(successful_remove(), successful_remove())
def intermediary_use():
a = [1, 2]
b = len(a)
a.append(3)
return a
fn = torch.jit.script(intermediary_use)
graph = fn.graph
FileCheck().check("append").run(graph)
self.run_pass("remove_mutation", graph)
# it is possible to remove the append here but don't currently have the logic for it
FileCheck().check_not("append").run(graph)
self.assertEqual(intermediary_use(), fn())
def test_lists_insert(self):
def successful_remove():
a: List[int] = []
a.insert(0, 1)
a.insert(0, 2)
a.insert(-10, 3)
a.insert(-9, 4)
a.insert(10, 5)
return a
fn = torch.jit.script(successful_remove)
graph = fn.graph
torch._C._jit_pass_remove_mutation(graph)
torch._C._jit_pass_constant_propagation(graph)
FileCheck().check("graph").check_next("Constant").check_next("return").run(
graph
)
self.assertEqual(successful_remove(), fn())
def test_list_indexing_removal(self):
@torch.jit.script
def out_of_bounds():
x = [1, 2]
x[4] = 3
return x
torch._C._jit_pass_remove_mutation(out_of_bounds.graph)
FileCheck().check("set_item").run(out_of_bounds.graph)
@torch.jit.script
def unknown(y: int):
x = [1, 2]
x[y] = 3
return x
torch._C._jit_pass_remove_mutation(out_of_bounds.graph)
FileCheck().check("set_item").run(out_of_bounds.graph)
def successful():
x = [1, 2, 3]
x[0] = 4
x[-1] = 0
return x
scripted_fn = torch.jit.script(successful)
torch._C._jit_pass_remove_mutation(scripted_fn.graph)
FileCheck().check_not("set_item").run(scripted_fn.graph)
self.checkScript(successful, ())
def successful():
x = [1, 2, 3]
x[0] = 4
x[-1] = 0
return x
scripted_fn = torch.jit.script(successful)
torch._C._jit_pass_remove_mutation(scripted_fn.graph)
FileCheck().check_not("set_item").run(scripted_fn.graph)
self.checkScript(successful, ())
def successful():
x = [1]
x[-1] = 3
return x
scripted_fn = torch.jit.script(successful)
torch._C._jit_pass_remove_mutation(scripted_fn.graph)
FileCheck().check_not("set_item").run(scripted_fn.graph)
self.checkScript(successful, ())
def test_common_pytorch_list_ops(self):
for op in ["cat", "stack", "vstack", "hstack", "dstack"]:
class OpMod(torch.nn.Module):
def __init__(self, op):
super().__init__()
self.op = torch_op
def forward(self):
x = torch.tensor([1, 2, 3, 4])
x.add_(3)
y = [x, x]
return self.op(y) + 3
torch_op = getattr(torch, op)
mod = OpMod(torch_op)
mod_script = torch.jit.script(mod)
self.run_pass("remove_mutation", mod_script.forward.graph)
FileCheck().check_not("aten::add_").run(mod_script.forward.graph)
self.assertEqual(mod(), mod_script())
# test that the output doesnt alias the input
for inputs in [torch.rand(2, 2)], [torch.rand(2, 2) for _ in range(2)]:
result = torch_op(inputs)
sums = [ten.sum() for ten in result]
for inp in inputs:
inp.fill_(10)
self.assertEqual(sums, [ten.sum() for ten in result])
@torch.jit.script
def test_multiple_uses():
x = torch.tensor([1, 2, 3, 4])
x.add_(3)
y = [x, x]
return torch.cat(y), y
self.run_pass("remove_mutation", mod_script.forward.graph)
FileCheck().check("aten::add_").run(test_multiple_uses.graph)