forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_inplacing_pass.py
69 lines (51 loc) · 1.57 KB
/
test_inplacing_pass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Owner(s): ["module: inductor"]
import torch
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
aten = torch.ops.aten
const = torch.tensor(0.0)
device = "cuda"
class TestReinplacingPassCorrectness(TestCase):
def _test(self, f):
nf = torch.compile(f)
inp = (
torch.randn(4, device=device),
torch.ones(2, device=device, dtype=torch.int),
)
inp2 = (inp[0].clone(), inp[1].clone())
self.assertEqual(f(*inp), nf(*inp2))
# breakpoint()
self.assertEqual(inp, inp2)
def test_dont_modify_live(self):
def f(x, y):
x = x.cos()
x2 = x.index_put((y,), const)
return x2, x
self._test(f)
def test_dont_modify_view_of_live(self):
def f(x, y):
x = x.cos()
x2 = aten.alias(x)
x2 = x2.index_put((y,), const)
y = x2 + x.cos()
return y
self._test(f)
def test_dont_modify_input(self):
def f(x, y):
return x.index_put((y,), const)
self._test(f)
def test_should_modify_inner(self):
def f(x, y):
x = x.cos()
x = x.index_put((y,), const)
return x
self._test(f)
def test_should_modify_input(self):
def f(x, y):
x = x.index_put_((y,), const)
return x
self._test(f)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:
run_tests()