forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_after_aot.py
83 lines (66 loc) · 2.75 KB
/
test_after_aot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Owner(s): ["module: dynamo"]
import io
import os
import shutil
import sys
import tempfile
import unittest
import torch._dynamo.test_case
from torch._dynamo.repro.after_aot import InputReader, InputWriter, save_graph_repro
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import IS_FBCODE
from torch.utils._traceback import report_compile_source_on_error
def strip_trailing_whitespace(r):
return "\n".join([l.rstrip() for l in r.split("\n")])
class TestAfterAot(torch._dynamo.test_case.TestCase):
@unittest.skipIf(IS_FBCODE, "NotImplementedError")
def test_save_graph_repro(self):
# TODO: This triggers CUDA context initialization, even though
# it is CPU only
buf = io.StringIO()
args = [torch.randn(4)]
def f(x):
return (x * x,)
gm = make_fx(f)(*args)
with tempfile.TemporaryDirectory() as d:
save_graph_repro(buf, gm, args, "inductor_accuracy", save_dir=d)
r = buf.getvalue()
with report_compile_source_on_error():
exec(r, {"__compile_source__": r})
shutil.rmtree(os.path.join(d, "storages"))
# Should still work even without the save dir
with report_compile_source_on_error():
exec(r, {"__compile_source__": r})
@unittest.skipIf(sys.byteorder != "little", "checksum depends on endianness")
def test_dump_tensor(self):
def test(tensor, expected):
with tempfile.TemporaryDirectory() as d:
writer = InputWriter(d, stable_hash=True)
writer.tensor("x", tensor)
self.assertExpectedInline("\n".join(writer._lines), expected, skip=1)
reader = InputReader(d)
env = {"reader": reader, "torch": torch}
# TODO: assert no logs
exec("\n".join(writer._lines), env)
self.assertEqual(reader.args[0], tensor)
test(
torch.zeros(3, 4),
"""\
buf0 = reader.storage('c17fd92682ca5b304ac71074b558dda9e8eb4d66', 48)
reader.tensor(buf0, (3, 4), is_leaf=True) # x""",
)
test(
torch.ones(3, 4, dtype=torch.int32),
"""\
buf0 = reader.storage('7c221e2da0c58c700cc2996644dd13d042bd552e', 48, dtype_hint=torch.int32)
reader.tensor(buf0, (3, 4), dtype=torch.int32, is_leaf=True) # x""",
)
test(
torch.empty((3, 4, 5, 6), memory_format=torch.channels_last).fill_(2),
"""\
buf0 = reader.storage('49ebab3961d6221e64c4c72b0aefd976bdd2afc4', 1440)
reader.tensor(buf0, (3, 4, 5, 6), (120, 1, 24, 4), is_leaf=True) # x""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()