forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_input_attr_tracking.py
403 lines (310 loc) · 13.7 KB
/
test_input_attr_tracking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# Owner(s): ["module: dynamo"]
# flake8: noqa
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import (
CompileCounter,
CompileCounterWithBackend,
EagerAndRecordGraphs,
normalize_gm,
)
class TestInputAttrTracking(torch._dynamo.test_case.TestCase):
def test_tensor_property_on_tensor(self):
def fn(x):
return x * x.y
x_ = torch.randn([2, 2])
y_ = torch.randn([2, 2])
x_.y = y_
eager_result = fn(x_)
graph = None
def grab_graph_backend(gm, inps):
nonlocal graph
graph = gm
return gm
fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
compile_result = fn(x_)
self.assertEqual(eager_result, compile_result)
placeholder_cnt = 0
for node in graph.graph.nodes:
if node.op == "placeholder":
placeholder_cnt += 1
# We want to be very sure that this lifts y to inputs!
self.assertEqual(placeholder_cnt, 2)
def test_tensor_property_assigned_on_tensor(self):
def fn(x, y):
x.y = y
return x * x.y
x_ = torch.randn([2, 2])
y_ = torch.randn([2, 2])
eager_result = fn(x_, y_)
graph = None
def grab_graph_backend(gm, inps):
nonlocal graph
graph = gm
return gm
fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
compile_result = fn(x_, y_)
self.assertEqual(eager_result, compile_result)
placeholder_cnt = 0
for node in graph.graph.nodes:
if node.op == "placeholder":
placeholder_cnt += 1
# y is already an input
self.assertEqual(placeholder_cnt, 2)
def test_const_property_on_tensor(self):
def fn(x):
return x * x.y
x_ = torch.randn([2, 2])
y_ = 4
x_.y = y_
eager_result = fn(x_)
graph = None
def grab_graph_backend(gm, inps):
nonlocal graph
graph = gm
return gm
fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
compile_result = fn(x_)
self.assertEqual(eager_result, compile_result)
placeholder_cnt = 0
for node in graph.graph.nodes:
if node.op == "placeholder":
placeholder_cnt += 1
# We want to be very sure that this does not lifts y to inputs, as its a const
self.assertEqual(placeholder_cnt, 1)
def test_const_property_assigned_on_tensor(self):
def fn(x, y):
x.y = y
return x * x.y
x_ = torch.randn([2, 2])
y_ = 4
eager_result = fn(x_, y_)
fn = torch._dynamo.optimize("eager", nopython=True)(fn)
compile_result = fn(x_, y_)
self.assertEqual(eager_result, compile_result)
def test_guards_correctly_property_assigned_on_tensor_type_change(self):
def fn(x, y):
x.y = y
return x * x.y
x_ = torch.randn([2, 2])
fn = torch._dynamo.optimize("eager", nopython=True)(fn)
compile_result_const = fn(x_, 4)
self.assertEqual(compile_result_const, x_ * 4)
y = torch.randn([2, 2])
compile_result_tensor = fn(x_, y)
self.assertEqual(compile_result_tensor, x_ * y)
def test_guards_correctly_property_assigned_on_tensor_type_change_inductor(self):
def fn(x, y):
x.y = y
return x * x.y
x_ = torch.randn([2, 2])
fn = torch._dynamo.optimize("inductor", nopython=True)(fn)
compile_result_const = fn(x_, 4)
self.assertEqual(compile_result_const, x_ * 4)
y = torch.randn([2, 2])
compile_result_tensor = fn(x_, y)
self.assertEqual(compile_result_tensor, x_ * y)
def test_complex_attr_access_without_graph_breaks(self):
def fn(x, y, z):
for t in x:
t.y = y
t.z = y * z
new_y = 1
new_z = 1
for t in x:
new_y = t.y * new_y
new_z = t.z * new_z
return new_y, new_z
x_0 = torch.randn([2, 2])
x_1 = torch.randn([2, 2])
x_2 = torch.randn([2, 2])
x = [x_0, x_1, x_2]
y = torch.randn([2, 2])
z = 5
eager_result = fn(x, y, z)
counter = CompileCounter()
fn = torch._dynamo.optimize(counter, nopython=True)(fn)
compile_result = fn(x, y, z)
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 9)
# Graph for reference
# ------------- ------ ----------------------- ------------------------------------ --------
# placeholder l_y_ L_y_ () {}
# call_function mul <built-in function mul> (l_y_, 5) {}
# call_function mul_1 <built-in function mul> (l_y_, 5) {}
# call_function mul_2 <built-in function mul> (l_y_, 5) {}
# call_function mul_3 <built-in function mul> (l_y_, 1) {}
# call_function mul_4 <built-in function mul> (mul, 1) {}
# call_function mul_5 <built-in function mul> (l_y_, mul_3) {}
# call_function mul_6 <built-in function mul> (mul_1, mul_4) {}
# call_function mul_7 <built-in function mul> (l_y_, mul_5) {}
# call_function mul_8 <built-in function mul> (mul_2, mul_6) {}
# output output output ((mul_7, mul_8, mul, mul_1, mul_2),) {}
def test_complex_attr_access_with_graph_breaks(self):
def fn(x, y, z):
for t in x:
t.y = y
t.z = y * z
print("Break!")
new_y = 1
new_z = 1
for t in x:
new_y = t.y * new_y
new_z = t.z * new_z
return new_y, new_z
x_0 = torch.randn([2, 2])
x_1 = torch.randn([2, 2])
x_2 = torch.randn([2, 2])
x = [x_0, x_1, x_2]
y = torch.randn([2, 2])
z = 5
eager_result = fn(x, y, z)
counter = CompileCounter()
fn = torch._dynamo.optimize(counter, nopython=False)(fn)
compile_result = fn(x, y, z)
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 2)
self.assertEqual(counter.op_count, 9)
# Graph for reference
# ------------- ------ ----------------------- ---------------------- --------
# placeholder l_y_ L_y_ () {}
# call_function mul <built-in function mul> (l_y_, 5) {}
# call_function mul_1 <built-in function mul> (l_y_, 5) {}
# call_function mul_2 <built-in function mul> (l_y_, 5) {}
# output output output ((mul, mul_1, mul_2),) {}
# [GRAPH BREAK!]
# ------------- ------- ----------------------- ----------------- --------
# placeholder l_x_0_y L_x_0_y () {}
# placeholder l_x_0_z L_x_0_z () {}
# placeholder l_x_1_y L_x_1_y () {}
# placeholder l_x_1_z L_x_1_z () {}
# placeholder l_x_2_y L_x_2_y () {}
# placeholder l_x_2_z L_x_2_z () {}
# call_function mul <built-in function mul> (l_x_0_y, 1) {}
# call_function mul_1 <built-in function mul> (l_x_0_z, 1) {}
# call_function mul_2 <built-in function mul> (l_x_1_y, mul) {}
# call_function mul_3 <built-in function mul> (l_x_1_z, mul_1) {}
# call_function mul_4 <built-in function mul> (l_x_2_y, mul_2) {}
# call_function mul_5 <built-in function mul> (l_x_2_z, mul_3) {}
# output output output ((mul_4, mul_5),) {}
def test_complex_attr_access_with_inline_reconstruct(self):
def inline_test_fn(x, y, z):
print("f")
return x.a + y.a + z.a
def fn(x, y, z):
x.a = 1
y.a = 2
z.a = 3
mult = inline_test_fn(x, y, z)
y = y * mult
x = x * mult
return x, y
x = torch.randn([2, 2])
y = torch.randn([2, 2])
z = torch.randn([2, 2])
eager_result = fn(x, y, z)
counter = CompileCounter()
fn = torch._dynamo.optimize(counter, nopython=False)(fn)
compile_result = fn(x, y, z)
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 2)
# Graph for reference
# __compiled_fn_2 <eval_with_key>.0 opcode name target args kwargs
# ------------- ------ ----------------------- --------------- --------
# placeholder l_x_ L_x_ () {}
# placeholder l_y_ L_y_ () {}
# call_function mul <built-in function mul> (l_y_, 6) {}
# call_function mul_1 <built-in function mul> (l_x_, 6) {}
# output output output ((mul_1, mul),) {}
def test_set_data_on_input_tensor(self):
def fn(x, y):
x.data = y.data
if x.size() == y.size():
return x * y
else:
return y * y
x = torch.randn([5, 5])
y = torch.randn([2, 2])
eager_result = fn(x, y)
eager_and_record = EagerAndRecordGraphs()
counter = CompileCounterWithBackend(eager_and_record)
fn = torch._dynamo.optimize(counter, nopython=True)(fn)
compile_result = fn(x, y)
graph = eager_and_record.graphs[0]
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 6)
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_y_: "f32[2, 2]", L_x_: "f32[2, 2]"):
l_y_ = L_y_
l_x_ = L_x_
detach: "f32[2, 2]" = l_y_.detach()
_set_grad_enabled = torch._C._set_grad_enabled(False)
set_: "f32[2, 2]" = torch_Tensor_set_(l_x_, detach); detach = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
_lower_version_count_by_1 = torch__dynamo_variables_builtin__lower_version_count_by_1(set_); set_ = None
mul: "f32[2, 2]" = l_x_ * l_y_; l_x_ = l_y_ = None
return (mul,)
""",
)
# Note - this does not actually get captured in the graph yet.
# The plan of record is to introduce a set_data op, entirely subsume the operation into a call_function
# in the fx graph, and let aot_autograd handle it.
def test_set_data_on_scoped_tensor(self):
def fn(x):
z = torch.zeros([4, 4])
z.data = x.data
if x.size() == z.size():
return z * x
else:
return x
x = torch.randn([5, 5])
eager_result = fn(x)
counter = CompileCounter()
fn = torch._dynamo.optimize(counter, nopython=False)(fn)
compile_result = fn(x)
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 2)
self.assertEqual(counter.op_count, 3)
def test_set_data_on_user_defined_class_input_tensor(self):
class MyUserDefinedClass:
def __init__(self, x, y):
self.x = x
self.y = y
def do_some_setattr_stuff(self):
self.z = x * y
self.a = x + x
return self.z * self.a
x = torch.randn([5, 5])
y = torch.randn([5, 5])
mudc_1 = MyUserDefinedClass(x, y)
eager_result = mudc_1.do_some_setattr_stuff()
counter = CompileCounter()
mudc_2 = MyUserDefinedClass(x, y)
do_some_setattr_stuff = torch._dynamo.optimize(counter, nopython=True)(
mudc_2.do_some_setattr_stuff
)
compile_result = do_some_setattr_stuff()
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 3)
# Graph for reference
# __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs
# ------------- ------ ----------------------- -------------------- --------
# placeholder l_x_ L_x_ () {}
# placeholder l_y_ L_y_ () {}
# call_function mul <built-in function mul> (l_x_, l_y_) {}
# call_function add <built-in function add> (l_x_, l_x_) {}
# call_function mul_1 <built-in function mul> (mul, add) {}
# output output output ((mul_1, mul, add),) {}
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()