-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathmain.py
253 lines (204 loc) · 8.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import os
import threading
import time
from functools import wraps
import torch
import torch.nn as nn
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.optim as optim
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.rpc import RRef
from torchvision.models.resnet import Bottleneck
#########################################################
# Define Model Parallel ResNet50 #
#########################################################
# In order to split the ResNet50 and place it on two different workers, we
# implement it in two model shards. The ResNetBase class defines common
# attributes and methods shared by two shards. ResNetShard1 and ResNetShard2
# contain two partitions of the model layers respectively.
num_classes = 1000
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class ResNetBase(nn.Module):
def __init__(self, block, inplanes, num_classes=1000,
groups=1, width_per_group=64, norm_layer=None):
super(ResNetBase, self).__init__()
self._lock = threading.Lock()
self._block = block
self._norm_layer = nn.BatchNorm2d
self.inplanes = inplanes
self.dilation = 1
self.groups = groups
self.base_width = width_per_group
def _make_layer(self, planes, blocks, stride=1):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if stride != 1 or self.inplanes != planes * self._block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * self._block.expansion, stride),
norm_layer(planes * self._block.expansion),
)
layers = []
layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * self._block.expansion
for _ in range(1, blocks):
layers.append(self._block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def parameter_rrefs(self):
r"""
Create one RRef for each parameter in the given local module, and return a
list of RRefs.
"""
return [RRef(p) for p in self.parameters()]
class ResNetShard1(ResNetBase):
"""
The first part of ResNet.
"""
def __init__(self, device, *args, **kwargs):
super(ResNetShard1, self).__init__(
Bottleneck, 64, num_classes=num_classes, *args, **kwargs)
self.device = device
self.seq = nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
self._norm_layer(self.inplanes),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
self._make_layer(64, 3),
self._make_layer(128, 4, stride=2)
).to(self.device)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x_rref):
x = x_rref.to_here().to(self.device)
with self._lock:
out = self.seq(x)
return out.cpu()
class ResNetShard2(ResNetBase):
"""
The second part of ResNet.
"""
def __init__(self, device, *args, **kwargs):
super(ResNetShard2, self).__init__(
Bottleneck, 512, num_classes=num_classes, *args, **kwargs)
self.device = device
self.seq = nn.Sequential(
self._make_layer(256, 6, stride=2),
self._make_layer(512, 3, stride=2),
nn.AdaptiveAvgPool2d((1, 1)),
).to(self.device)
self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device)
def forward(self, x_rref):
x = x_rref.to_here().to(self.device)
with self._lock:
out = self.fc(torch.flatten(self.seq(x), 1))
return out.cpu()
class DistResNet50(nn.Module):
"""
Assemble two parts as an nn.Module and define pipelining logic
"""
def __init__(self, split_size, workers, *args, **kwargs):
super(DistResNet50, self).__init__()
self.split_size = split_size
# Put the first part of the ResNet50 on workers[0]
self.p1_rref = rpc.remote(
workers[0],
ResNetShard1,
args = ("cuda:0",) + args,
kwargs = kwargs
)
# Put the second part of the ResNet50 on workers[1]
self.p2_rref = rpc.remote(
workers[1],
ResNetShard2,
args = ("cuda:1",) + args,
kwargs = kwargs
)
def forward(self, xs):
# Split the input batch xs into micro-batches, and collect async RPC
# futures into a list
out_futures = []
for x in iter(xs.split(self.split_size, dim=0)):
x_rref = RRef(x)
y_rref = self.p1_rref.remote().forward(x_rref)
z_fut = self.p2_rref.rpc_async().forward(y_rref)
out_futures.append(z_fut)
# collect and cat all output tensors into one tensor.
return torch.cat(torch.futures.wait_all(out_futures))
def parameter_rrefs(self):
remote_params = []
remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here())
remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here())
return remote_params
#########################################################
# Run RPC Processes #
#########################################################
num_batches = 3
batch_size = 120
image_w = 128
image_h = 128
def run_master(split_size):
# put the two model parts on worker1 and worker2 respectively
model = DistResNet50(split_size, ["worker1", "worker2"])
loss_fn = nn.MSELoss()
opt = DistributedOptimizer(
optim.SGD,
model.parameter_rrefs(),
lr=0.05,
)
one_hot_indices = torch.LongTensor(batch_size) \
.random_(0, num_classes) \
.view(batch_size, 1)
for i in range(num_batches):
print(f"Processing batch {i}")
# generate random inputs and labels
inputs = torch.randn(batch_size, 3, image_w, image_h)
labels = torch.zeros(batch_size, num_classes) \
.scatter_(1, one_hot_indices, 1)
# The distributed autograd context is the dedicated scope for the
# distributed backward pass to store gradients, which can later be
# retrieved using the context_id by the distributed optimizer.
with dist_autograd.context() as context_id:
outputs = model(inputs)
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
opt.step(context_id)
def run_worker(rank, world_size, num_split):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
# Higher timeout is added to accommodate for kernel compilation time in case of ROCm.
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256, rpc_timeout=300)
if rank == 0:
rpc.init_rpc(
"master",
rank=rank,
world_size=world_size,
rpc_backend_options=options
)
run_master(num_split)
else:
rpc.init_rpc(
f"worker{rank}",
rank=rank,
world_size=world_size,
rpc_backend_options=options
)
pass
# block until all rpcs finish
rpc.shutdown()
if __name__=="__main__":
world_size = 3
for num_split in [1, 2, 4, 8]:
tik = time.time()
mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True)
tok = time.time()
print(f"number of splits = {num_split}, execution time = {tok - tik}")