forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_data_parallel.py
159 lines (130 loc) · 5.57 KB
/
test_data_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Owner(s): ["oncall: jit"]
import os
import sys
import unittest
import torch
import torch.nn as nn
import torch.nn.parallel as dp
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestDataParallel(JitTestCase):
class Mpy(torch.nn.Module):
def __init__(self):
super(TestDataParallel.Mpy, self).__init__()
self.m = nn.Sequential(
nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
)
@torch.jit.ignore
def forward(self, input):
return self.m(input)
class Mpy1(torch.nn.Module):
def __init__(self, block):
super(TestDataParallel.Mpy1, self).__init__()
self.m = block
@torch.jit.ignore
def forward(self, input):
return self.m.forward(input)
class Mpy2(torch.nn.Module):
def __init__(self, block1, block2):
super(TestDataParallel.Mpy2, self).__init__()
self.m1 = block1
self.m2 = block2
@torch.jit.ignore
def forward(self, input):
x = self.m1.forward(input)
return self.m2(x)
class Msm(torch.jit.ScriptModule):
__constants__ = ["m"]
def __init__(self):
super(TestDataParallel.Msm, self).__init__()
self.m = nn.Sequential(
nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
)
@torch.jit.script_method
def forward(self, input):
return self.m(input)
class Msm1(torch.jit.ScriptModule):
def __init__(self, block):
super(TestDataParallel.Msm1, self).__init__()
self.block = block
@torch.jit.script_method
def forward(self, input):
x = self.block(input)
return x
def check_replicas(self, module, replicas, input_shape=(2, 2)):
input = torch.randn(input_shape).cuda()
expected_output = module(input).data
for i, replica in enumerate(replicas):
for p in replica.parameters():
self.assertEqual(p.get_device(), i)
for b in replica.buffers():
self.assertEqual(b.get_device(), i)
replica_input = input.cuda(i)
self.assertEqual(replica(replica_input).data, expected_output)
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
def test_python_submodule_script(self):
module = self.Mpy1(self.Msm()).cuda()
replicas = dp.replicate(module, {0, 1})
self.check_replicas(module, replicas)
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
def test_shared_module(self):
s = self.Msm()
p1 = self.Mpy1(s)
module = self.Mpy2(p1, s).cuda()
replicas = dp.replicate(module, {0, 1})
self.check_replicas(module, replicas)
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
def test_traced_module(self):
module = torch.jit.trace(self.Mpy1(self.Mpy()), torch.ones(2, 2)).cuda()
replicas = dp.replicate(module, {0, 1})
self.check_replicas(module, replicas)
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
def test_tensor_sharing(self):
module = self.Msm1(self.Msm()).cuda()
replica = dp.replicate(module, {0, 1})
def assert_share_data(t1, t2):
# Only checks that they point to the same memory on the same device.
if t1.device != t2.device:
return False
if t1.storage().data_ptr() != t2.storage().data_ptr():
return False
return True
for p1, p2 in zip(module.parameters(), replica[0].parameters()):
self.assertTrue(assert_share_data(p1, p2))
for p1, p2 in zip(module.buffers(), replica[0].buffers()):
self.assertTrue(assert_share_data(p1, p2))
for p1, p2 in zip(module.parameters(), replica[1].parameters()):
self.assertFalse(assert_share_data(p1, p2))
for p1, p2 in zip(module.buffers(), replica[1].buffers()):
self.assertFalse(assert_share_data(p1, p2))
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
def test_tensor_sharing_with_forward(self):
module = self.Msm1(self.Msm()).cuda()
replica = dp.replicate(module, {0, 1})
x = torch.ones(2, 2, requires_grad=True).cuda()
first_forward = module(x)
first_forward.sum().backward()
with torch.no_grad():
for p in module.parameters():
# Use .data here to avoid version counter bump.
# The graph created by the following forward will be wrong but
# we never backward through them so it's fine
p.data -= 1.0 * p.grad
second_forward = module(x)
# replica which is on the same GPU has a shallow copy of the original
# params and buffers
r0_forward = replica[0](x)
self.assertEqual(second_forward, r0_forward)
# replica which is on a different GPU has a deep copy of the original
# params and buffers
x1 = torch.ones(2, 2, requires_grad=True).cuda(device=1)
r1_forward = replica[1](x1)
self.assertEqual(first_forward, r1_forward)