forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_python_bindings.py
115 lines (93 loc) · 3.59 KB
/
test_python_bindings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Owner(s): ["oncall: jit"]
import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TestPythonBindings\n\n"
"instead."
)
class TestPythonBindings(JitTestCase):
def test_cu_get_functions(self):
@torch.jit.script
def test_get_python_cu_fn(x: torch.Tensor):
return 2 * x
cu = torch.jit._state._python_cu
self.assertTrue(
"test_get_python_cu_fn" in (str(fn.name) for fn in cu.get_functions())
)
def test_cu_create_function(self):
@torch.jit.script
def fn(x: torch.Tensor):
return 2 * x
cu = torch._C.CompilationUnit()
cu.create_function("test_fn", fn.graph)
inp = torch.randn(5)
self.assertEqual(inp * 2, cu.find_function("test_fn")(inp))
self.assertEqual(cu.find_function("doesnt_exist"), None)
self.assertEqual(inp * 2, cu.test_fn(inp))
with self.assertRaises(AttributeError):
cu.doesnt_exist(inp)
def test_invalidation(self):
@torch.jit.script
def test_invalidation_fn(x: torch.Tensor):
return 2 * x
gr = test_invalidation_fn.graph.copy()
n = gr.insertNode(gr.create("prim::profile"))
v = n.output()
# check that they work
str((n, v))
torch._C._jit_pass_dce(gr)
with self.assertRaisesRegex(RuntimeError, "invalidated"):
str(n)
with self.assertRaisesRegex(RuntimeError, "invalidated"):
str(v)
def test_graph_iterator_keepalive(self):
@torch.jit.script
def test_iterator_keepalive_fn(x: torch.Tensor):
return 2 * x
# the list would segfault before because inlined_graph
# is temporary and had been deleted (see issue #50454)
n = test_iterator_keepalive_fn.inlined_graph.nodes()
list(n)
i = test_iterator_keepalive_fn.inlined_graph.inputs()
list(i)
o = test_iterator_keepalive_fn.inlined_graph.outputs()
list(o)
def test_aliasdb(self):
@torch.jit.script
def test_aliasdb_fn(x: torch.Tensor):
return 2 * x
gr = test_aliasdb_fn.graph.copy()
alias_db = gr.alias_db()
self.assertTrue("WILDCARD" in str(alias_db))
self.assertTrue("digraph alias_db" in alias_db.to_graphviz_str())
def test_graph_create(self):
gr = torch._C.Graph()
with self.assertRaises(ValueError):
gr.create("prim::Constant", [None])
def test_add_input(self):
gr = torch._C.Graph()
foo_value = gr.addInput("foo")
assert foo_value in gr.inputs()
def test_canonicalize(self):
ir = """
graph(%p207 : Tensor,
%1 : Tensor,
%p407 : int):
%11 : Tensor = aten::view_expand_placeholder(%1)
%12 : Tensor = aten::pointwise_placeholder(%11, %p207, %p407)
%13 : Tensor = aten::view_expand_placeholder(%12)
%14 : Tensor = aten::pointwise_placeholder(%13)
return (%14)
"""
graph1 = torch._C.parse_ir(ir)
graph1 = torch._C._jit_pass_canonicalize(graph1, True)
graph2 = torch._C.parse_ir(ir)
graph2 = torch._C._jit_pass_canonicalize(graph2)
self.assertEqual(str(graph1), str(graph2))
FileCheck().check("%p207").check_not("%14").run(graph1)
graph3 = torch._C.parse_ir(ir)
graph3 = torch._C._jit_pass_canonicalize(graph3, False)
FileCheck().check_not("%p207").run(graph3)