-
Notifications
You must be signed in to change notification settings - Fork 258
/
Copy pathensembling.py
167 lines (130 loc) Β· 7.55 KB
/
ensembling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# -*- coding: utf-8 -*-
"""
λͺ¨λΈ μμλΈ
================
**λ²μ**: `μ‘°νμ <https://github.com/ChoHyoungSeo/>`_
λ³Έ νν 리μΌμμλ ``torch.vmap`` μ νμ©νμ¬ λͺ¨λΈ μμλΈμ 벑ν°ννλ λ°©λ²μ μ€λͺ
ν©λλ€.
λͺ¨λΈ μμλΈμ΄λ?
-------------------------
λͺ¨λΈ μμλΈμ μ¬λ¬ λͺ¨λΈμ μμΈ‘κ°μ ν¨κ» κ²°ν©νλ κ²μ μλ―Έν©λλ€.
μΌλ°μ μΌλ‘ μ΄ μμ
μ μΌλΆ μ
λ ₯κ°μ λν΄ κ° λͺ¨λΈμ κ°λ³μ μΌλ‘ μ€νν λ€μ μμΈ‘μ κ²°ν©νλ λ°©μμΌλ‘ μ€νλ©λλ€.
νμ§λ§ λμΌν μν€ν
μ²λ‘ λͺ¨λΈμ μ€ννλ κ²½μ°, ``torch.vmap`` μ νμ©νμ¬ ν¨κ» κ²°ν©ν μ μμ΅λλ€.
``vmap`` μ μ
λ ₯ tensorμ μ¬λ¬ μ°¨μμ κ±Έμ³ ν¨μλ₯Ό λ§€ννλ ν¨μ λ³νμ
λλ€. μ΄ ν¨μμ
μ¬μ© μ¬λ‘ μ€ νλλ for λ¬Έμ μ κ±°νκ³ λ²‘ν°νλ₯Ό ν΅ν΄ μλλ₯Ό λμ΄λ κ²μ
λλ€.
κ°λ¨ν MLP μμλΈμ νμ©νμ¬ μ΄λ₯Ό μννλ λ°©λ²μ μ΄ν΄λ³΄κ² μ΅λλ€.
.. note::
μ΄ νν 리μΌμ μ€νμ μν΄μλ PyTorch 2.0 λλ μ΄μμ λ²μ μ΄ νμν©λλ€.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
# λ€μμ κ°λ¨ν MLP μ
λλ€.
class SimpleMLP(nn.Module):
def __init__(self):
super(SimpleMLP, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = x.flatten(1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
######################################################################
# λλ―Έ λ°μ΄ν°λ₯Ό μμ±νκ³ MNIST λ°μ΄ν° μ
μΌλ‘ μμ
νλ€κ³ κ°μ ν΄ λ³΄κ² μ΅λλ€.
# λ°λΌμ μ΄λ―Έμ§λ 28x28 μ¬μ΄μ¦μ΄λ©° λ―Έλ λ°°μΉ ν¬κΈ°λ 64μ
λλ€.
# λ λμκ° 10κ°μ μλ‘ λ€λ₯Έ λͺ¨λΈμμ λμ¨ μμΈ‘κ°μ κ²°ν©νκ³ μΆλ€κ³ κ°μ ν΄ λ³΄κ² μ΅λλ€.
device = 'cuda'
num_models = 10
data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)
models = [SimpleMLP().to(device) for _ in range(num_models)]
######################################################################
# μμΈ‘κ°μ μμ±νλ λ°λ λͺ κ°μ§ μ΅μ
μ΄ μμ΅λλ€.
# κ°κ°μ λͺ¨λΈμ λ€λ₯Έ 무μμ λ―Έλ λ°°μΉ λ°μ΄ν°λ₯Ό μ€ μ μκ³
# κ°κ°μ λͺ¨λΈμ λμΌν λ―Έλ λ°°μΉμ λ°μ΄ν°λ₯Ό μ€ μ μμ΅λλ€.
# (μλ₯Ό λ€μ΄, λ€λ₯Έ λͺ¨λΈ μ΄κΈ°κ°μ μν₯μ ν
μ€νΈν κ²½μ°)
######################################################################
# μ΅μ
1: κ°κ°μ λͺ¨λΈμ λ€λ₯Έ λ―Έλ λ°°μΉλ₯Ό μ£Όλ κ²½μ°
minibatches = data[:num_models]
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]
######################################################################
# μ΅μ
2: κ°μ λ―Έλ λ°°μΉλ₯Ό μ£Όλ κ²½μ°
minibatch = data[0]
predictions2 = [model(minibatch) for model in models]
######################################################################
# ``vmap`` μ νμ©νμ¬ μμλΈ λ²‘ν°ννκΈ°
# -------------------------------------------
#
# ``vmap`` μ μ¬μ©νμ¬ for λ¬Έμ μλλ₯Ό λμ¬λ³΄κ² μ΅λλ€. λ¨Όμ ``vmap`` κ³Ό ν¨κ» μ¬μ©ν λͺ¨λΈμ μ€λΉν΄μΌ ν©λλ€.
#
#
# λ¨Όμ , κ° λ§€κ°λ³μλ₯Ό μμ λͺ¨λΈμ μνλ₯Ό κ²°ν©ν΄ λ³΄κ² μ΅λλ€.
# μλ₯Ό λ€μ΄, ``model[i].fc1.weight`` μ shapeμ ``[784, 128]`` μ
λλ€.
# μ΄ 10κ°μ λͺ¨λΈ κ°κ°μ λν΄ ``.fc1.weight`` λ₯Ό μμ ``[10, 784, 128]`` shapeμ ν° κ°μ€μΉλ₯Ό μμ±ν μ μμ΅λλ€.
#
# νμ΄ν μΉμμλ μ΄λ₯Ό μν΄ ``torch.func.stack_module_state`` λΌλ ν¨μλ₯Ό μ 곡νκ³ μμ΅λλ€.
#
from torch.func import stack_module_state
params, buffers = stack_module_state(models)
######################################################################
# λ€μμΌλ‘, ``vmap`` μ λν ν¨μλ₯Ό μ μν΄μΌ ν©λλ€. μ΄ ν¨μλ νλΌλ―Έν°, λ²νΌ, μ
λ ₯κ°μ΄ μ£Όμ΄μ§λ©΄ λͺ¨λΈμ μ€νν©λλ€.
# μ¬κΈ°μλ ``torch.func.functional_call`` μ νμ©νκ² μ΅λλ€.
from torch.func import functional_call
import copy
# λͺ¨λΈ μ€ νλμ "stateless" λ²μ μ ꡬμΆν©λλ€.
# "stateless"λ λ§€κ°λ³μκ° λ©ν tensorμ΄λ©° μ μ₯μκ° μλ€λ κ²μ μλ―Έν©λλ€.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')
def fmodel(params, buffers, x):
return functional_call(base_model, (params, buffers), (x,))
######################################################################
# μ΅μ
1: κ° λͺ¨λΈμ λν΄ μλ‘ λ€λ₯Έ λ―Έλ λ°°μΉλ₯Ό νμ©νμ¬ μμΈ‘ν©λλ€.
#
# κΈ°λ³Έμ μΌλ‘, ``vmap`` μ λͺ¨λ μ
λ ₯μ 첫 λ²μ§Έ μ°¨μμ κ±Έμ³ ν¨μμ λ§€νν©λλ€.
# ``stack_module_state`` λ₯Ό μ¬μ©νλ©΄ κ° ``params`` μ λ²νΌλ μμͺ½μ 'num_models'
# ν¬κΈ°μ μΆκ° μ°¨μμ κ°μ§λ©°, λ―Έλ λ°°μΉλ 'num_models' ν¬κΈ°κ° λ©λλ€.
print([p.size(0) for p in params.values()]) # μ ν 'num_models' μ°¨μ νμ
assert minibatches.shape == (num_models, 64, 1, 28, 28) # λ―Έλ λ°°μΉμ μ ν μ°¨μμ΄ 'num_models' ν¬κΈ°μΈμ§ νμΈν©λλ€.
from torch import vmap
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
# ``vmap`` μμΈ‘μ΄ λ§λμ§ νμΈν©λλ€.
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
######################################################################
# μ΅μ
2: λμΌν λ―Έλ λ°°μΉ λ°μ΄ν°λ₯Ό νμ©νμ¬ μμΈ‘ν©λλ€.
#
# ``vmap`` μλ λ§€νν μ°¨μμ μ§μ νλ ``in_dims`` λΌλ μΈμκ° μμ΅λλ€.
# ``None`` μ μ¬μ©νλ©΄ 10κ° λͺ¨λΈμ λͺ¨λ λμΌν λ―Έλ λ°°μΉλ₯Ό μ μ©νλλ‘
# ``vmap`` μ μλ €μ€ μ μμ΅λλ€.
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)
######################################################################
# μ°Έκ³ μ¬ν: ``vmap`` μΌλ‘ λ³νν μ μλ ν¨μ μ νμλ μ νμ΄ μμ΅λλ€.
# λ³ννκΈ°μ κ°μ₯ μ’μ ν¨μλ μ
λ ₯κ°μ μν΄μλ§ μΆλ ₯μ΄ κ²°μ λκ³
# λ€λ₯Έ λΆμμ© (μ. λ³μ΄) μ΄ μλ μμ ν¨μ(pure function) μ
λλ€.
# ``vmap`` μ μμμ λ³μ΄λ νμ΄μ¬ μλ£κ΅¬μ‘°λ μ²λ¦¬ν μ μμ§λ§,
# λ€μν λ΄μ₯λ νμ΄ν μΉ μ°μ°μ μ²λ¦¬ν μ μμ΅λλ€.
######################################################################
# μ±λ₯
# -----------
# μ±λ₯ μμΉκ° κΆκΈνμ κ°μ? μμΉλ λ€μκ³Ό κ°μ΅λλ€.
from torch.utils.benchmark import Timer
without_vmap = Timer(
stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
globals=globals())
with_vmap = Timer(
stmt="vmap(fmodel)(params, buffers, minibatches)",
globals=globals())
print(f'Predictions without vmap {without_vmap.timeit(100)}')
print(f'Predictions with vmap {with_vmap.timeit(100)}')
######################################################################
# ``vmap`` μ μ¬μ©νλ©΄ μλκ° ν¬κ² ν₯μλ©λλ€!
#
# μΌλ°μ μΌλ‘, ``vmap`` μ μ¬μ©ν 벑ν°νλ for λ¬Έμμ ν¨μλ₯Ό μ€ννλ κ²λ³΄λ€
# λΉ λ₯΄λ©° μλ μΌκ΄ μ²λ¦¬μ λΉμ·ν μλλ₯Ό λ
λλ€. νμ§λ§ νΉμ μ°μ°μ λν΄ ``vmap`` κ·μΉμ
# ꡬννμ§ μμκ±°λ κΈ°λ³Έ 컀λμ΄ κ΅¬ν νλμ¨μ΄(GPUs)μ μ΅μ νλμ§ μμ κ²½μ°μ κ°μ΄
# λͺ κ°μ§ μμΈκ° μμ΅λλ€. μ΄λ¬ν κ²½μ°κ° λ°κ²¬λλ©΄, GitHubμ μ΄μλ₯Ό μμ±ν΄μ μλ €μ£ΌμκΈ° λ°λλλ€.