forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparallel_train.py
150 lines (113 loc) · 4.67 KB
/
parallel_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import argparse
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call, grad_and_value, stack_module_state, vmap
# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a
# tutorial on Model Ensembling with JAX by Will Whitney.
#
# The original code comes with the following citation:
# @misc{Whitney2021Parallelizing,
# author = {William F. Whitney},
# title = { {Parallelizing neural networks on one GPU with JAX} },
# year = {2021},
# url = {http://willwhitney.com/parallel-training-jax.html},
# }
# GOAL: Demonstrate that it is possible to use eager-mode vmap
# to parallelize training over models.
parser = argparse.ArgumentParser(description="Functorch Ensembled Models")
parser.add_argument(
"--device",
type=str,
default="cpu",
help="CPU or GPU ID for this process (default: 'cpu')",
)
args = parser.parse_args()
DEVICE = args.device
# Step 1: Make some spirals
def make_spirals(n_samples, noise_std=0.0, rotations=1.0):
ts = torch.linspace(0, 1, n_samples, device=DEVICE)
rs = ts**0.5
thetas = rs * rotations * 2 * math.pi
signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1
labels = (signs > 0).to(torch.long).to(DEVICE)
xs = (
rs * signs * torch.cos(thetas)
+ torch.randn(n_samples, device=DEVICE) * noise_std
)
ys = (
rs * signs * torch.sin(thetas)
+ torch.randn(n_samples, device=DEVICE) * noise_std
)
points = torch.stack([xs, ys], dim=1)
return points, labels
points, labels = make_spirals(100, noise_std=0.05)
# Step 2: Define two-layer MLP and loss function
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes
self.fc1 = nn.Linear(2, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x
loss_fn = nn.NLLLoss()
model = MLPClassifier().to(DEVICE)
def train_step_fn(weights, batch, targets, lr=0.2):
def compute_loss(weights, batch, targets):
output = functional_call(model, weights, batch)
loss = loss_fn(output, targets)
return loss
grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
# NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
# so we are going to re-implement SGD here.
new_weights = {}
with torch.no_grad():
for key in grad_weights:
new_weights[key] = weights[key] - grad_weights[key] * lr
return loss, new_weights
# Step 4: Let's verify this actually trains.
# We should see the loss decrease.
def step4():
global weights
for i in range(2000):
loss, weights = train_step_fn(dict(model.named_parameters()), points, labels)
if i % 100 == 0:
print(loss)
step4()
# Step 5: We're ready for multiple models. Let's define an init_fn
# that, given a number of models, returns to us all of the weights.
def init_fn(num_models):
models = [MLPClassifier().to(DEVICE) for _ in range(num_models)]
params, _ = stack_module_state(models)
return params
# Step 6: Now, can we try multiple models at the same time?
# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps
# on decreasing
def step6():
parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
batched_weights = init_fn(num_models=2)
for i in range(2000):
loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels)
if i % 200 == 0:
print(loss)
step6()
# Step 7: Now, the flaw with step 6 is that we were training on the same exact
# data. This can lead to all of the models in the ensemble overfitting in the
# same way. The solution that http://willwhitney.com/parallel-training-jax.html
# applies is to randomly subset the data in a way that the models do not recieve
# exactly the same data in each training step!
# Because the goal of this doc is to show that we can use eager-mode vmap to
# achieve similar things as JAX, the rest of this is left as an exercise to the reader.
# In conclusion, to achieve what http://willwhitney.com/parallel-training-jax.html
# does, we used the following additional items that PyTorch does not have:
# 1. NN module functional API that turns a module into a (state, state_less_fn) pair
# 2. Functional optimizers
# 3. A "functional" grad API (that effectively wraps autograd.grad)
# 4. Composability between the functional grad API and torch.vmap.