forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_attr.py
68 lines (53 loc) · 2.14 KB
/
test_attr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Owner(s): ["oncall: jit"]
from typing import NamedTuple, Tuple
import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestGetDefaultAttr(JitTestCase):
def test_getattr_with_default(self):
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.init_attr_val = 1.0
def forward(self, x):
y = getattr(self, "init_attr_val") # noqa: B009
w: list[float] = [1.0]
z = getattr(self, "missing", w) # noqa: B009
z.append(y)
return z
result = A().forward(0.0)
self.assertEqual(2, len(result))
graph = torch.jit.script(A()).graph
# The "init_attr_val" attribute exists
FileCheck().check('prim::GetAttr[name="init_attr_val"]').run(graph)
# The "missing" attribute does not exist, so there should be no corresponding GetAttr in AST
FileCheck().check_not("missing").run(graph)
# instead the getattr call will emit the default value, which is a list with one float element
FileCheck().check("float[] = prim::ListConstruct").run(graph)
def test_getattr_named_tuple(self):
global MyTuple
class MyTuple(NamedTuple):
x: str
y: torch.Tensor
def fn(x: MyTuple) -> Tuple[str, torch.Tensor, int]:
return (
getattr(x, "x", "fdsa"),
getattr(x, "y", torch.ones((3, 3))),
getattr(x, "z", 7),
)
inp = MyTuple(x="test", y=torch.ones(3, 3) * 2)
ref = fn(inp)
fn_s = torch.jit.script(fn)
res = fn_s(inp)
self.assertEqual(res, ref)
def test_getattr_tuple(self):
def fn(x: Tuple[str, int]) -> int:
return getattr(x, "x", 2)
with self.assertRaisesRegex(RuntimeError, "but got a normal Tuple"):
torch.jit.script(fn)