-
Notifications
You must be signed in to change notification settings - Fork 258
/
Copy pathpruning_tutorial.py
361 lines (308 loc) ยท 19.9 KB
/
pruning_tutorial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
# -*- coding: utf-8 -*-
"""
๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ(Pruning) ํํ ๋ฆฌ์ผ
=====================================
**Author**: `Michela Paganini <https://github.com/mickypaganini>`_
**๋ฒ์ญ**: `์์์ค <https://github.com/Justin-A>`_
์ต์ฒจ๋จ ๋ฅ๋ฌ๋ ๋ชจ๋ธ๋ค์ ๊ต์ฅํ ๋ง์ ์์ ํ๋ผ๋ฏธํฐ๊ฐ๋ค๋ก ๊ตฌ์ฑ๋๊ธฐ ๋๋ฌธ์, ์ฝ๊ฒ ๋ฐฐํฌํ๊ธฐ๊ฐ ์ด๋ ต์ต๋๋ค.
์ด์ ๋ฐ๋๋ก, ์๋ฌผํ์ ์ ๊ฒฝ๋ง๋ค์ ํจ์จ์ ์ผ๋ก ํฌ์ํ๊ฒ ์ฐ๊ฒฐ๋ ๊ฒ์ผ๋ก ์๋ ค์ ธ ์์ต๋๋ค.
๋ชจ๋ธ์ ์ ํ๋๋ฅผ ํผ์ํ์ง ์์ผ๋ฉด์ ๋ชจ๋ธ์ ํฌํจ๋ ํ๋ผ๋ฏธํฐ ์๋ฅผ ์ค์ฌ ์์ถํ๋ ์ต์ ์ ๊ธฐ๋ฒ์ ํ์
ํ๋ ๊ฒ์
๋ฉ๋ชจ๋ฆฌ, ๋ฐฐํฐ๋ฆฌ, ํ๋์จ์ด ์๋น๋์ ์ค์ผ ์ ์๊ธฐ ๋๋ฌธ์ ์ค์ํฉ๋๋ค. ๊ทธ๋ผ์ผ๋ก์ ๊ธฐ๊ธฐ์ ๊ฒฝ๋ํ๋ ๋ชจ๋ธ์ ๋ฐฐํฌํ์ฌ
๊ฐ๊ฐ์ธ์ด ์ฌ์ฉํ๊ณ ์๋ ๊ธฐ๊ธฐ์์ ์ฐ์ฐ์ ์ํํ์ฌ ํ๋ผ์ด๋ฒ์๋ฅผ ๋ณด์ฅํ ์ ์๊ธฐ ๋๋ฌธ์
๋๋ค.
์ฐ๊ตฌ ์ธก๋ฉด์์๋, ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ๊ต์ฅํ ๋ง์ ์์ ํ๋ผ๋ฏธํฐ๊ฐ๋ค๋ก ๊ตฌ์ฑ๋ ๋ชจ๋ธ๊ณผ
๊ต์ฅํ ์ ์ ์์ ํ๋ผ๋ฏธํฐ๊ฐ๋ค๋ก ๊ตฌ์ฑ๋ ๋ชจ๋ธ ๊ฐ ํ์ต ์ญํ ์ฐจ์ด๋ฅผ ์กฐ์ฌํ๋๋ฐ ์ฃผ๋ก ์ด์ฉ๋๊ธฐ๋ ํ๋ฉฐ,
ํ์ ์ ๊ฒฝ๋ง ๋ชจ๋ธ๊ณผ ํ๋ผ๋ฏธํฐ๊ฐ๋ค์ ์ด๊ธฐํ๊ฐ ์ด์ด ์ข๊ฒ ์ ๋ ์ผ์ด์ค๋ฅผ ๋ฐํ์ผ๋ก
("`lottery tickets <https://arxiv.org/abs/1803.03635>`_") ์ ๊ฒฝ๋ง ๊ตฌ์กฐ๋ฅผ ์ฐพ๋ ๊ธฐ์ ๋ค์ ๋ํด ๋ฐ๋ ์๊ฒฌ์ ์ ์ํ๊ธฐ๋ ํฉ๋๋ค.
์ด๋ฒ ํํ ๋ฆฌ์ผ์์๋, ``torch.nn.utils.prune`` ์ ์ฌ์ฉํ์ฌ ์ฌ๋ฌ๋ถ์ด ์ค๊ณํ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ๋ํด ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํด๋ณด๋ ๊ฒ์ ๋ฐฐ์๋ณด๊ณ ,
์ฌํ์ ์ผ๋ก ์ฌ๋ฌ๋ถ์ ๋ง์ถคํ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ๋ํด ๋ฐฐ์๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
์๊ตฌ์ฌํญ
------------
``"torch>=1.4"``
"""
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
######################################################################
#
# ๋ฅ๋ฌ๋ ๋ชจ๋ธ ์์ฑ
# -----------------------
#
# ์ด๋ฒ ํํ ๋ฆฌ์ผ์์๋, ์ ๋ฅด์ฟค ๊ต์๋์ ์ฐ๊ตฌ์ง๋ค์ด 1998๋
๋์ ๋ฐํํ `LeNet
# <http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf>`_ ์ ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ์ด์ฉํฉ๋๋ค.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1๊ฐ ์ฑ๋ ์์ ์ด๋ฏธ์ง๋ฅผ ์
๋ ฅ๊ฐ์ผ๋ก ์ด์ฉํ์ฌ 6๊ฐ ์ฑ๋ ์์ ์ถ๋ ฅ๊ฐ์ ๊ณ์ฐํ๋ ๋ฐฉ์
# Convolution ์ฐ์ฐ์ ์งํํ๋ ์ปค๋(ํํฐ)์ ํฌ๊ธฐ๋ 5x5 ์ ์ด์ฉ
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # Convolution ์ฐ์ฐ ๊ฒฐ๊ณผ 5x5 ํฌ๊ธฐ์ 16 ์ฑ๋ ์์ ์ด๋ฏธ์ง
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
######################################################################
# ๋ชจ๋ ์ ๊ฒ
# -----------------
#
# ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋์ง ์์ LeNet ๋ชจ๋ธ์ ``conv1`` ์ธต์ ์ ๊ฒํด๋ด
์๋ค.
# ์ฌ๊ธฐ์๋ 2๊ฐ์ ํ๋ผ๋ฏธํฐ๊ฐ๋ค์ธ ``๊ฐ์ค์น`` ๊ฐ๊ณผ ``ํธํฅ`` ๊ฐ์ด ํฌํจ๋ ๊ฒ์ด๋ฉฐ, ๋ฒํผ๋ ์กด์ฌํ์ง ์์ ๊ฒ์
๋๋ค.
#
module = model.conv1
print(list(module.named_parameters()))
######################################################################
print(list(module.named_buffers()))
######################################################################
# ๋ชจ๋ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ ์ ์ฉ ์์
# -----------------------------------
#
# ๋ชจ๋์ ๋ํด ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๊ธฐ ์ํด (์ด๋ฒ ์์ ์์๋, LeNet ๋ชจ๋ธ์ ``conv1`` ์ธต)
# ์ฒซ ๋ฒ์งธ๋ก๋, ``torch.nn.utils.prune`` (๋๋ ``BasePruningMethod`` ์ ์๋ธ ํด๋์ค๋ก ์ง์
# `๊ตฌํ <torch-nn-utils-prune>`_ )
# ๋ด ์กด์ฌํ๋ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ํํฉ๋๋ค.
# ๊ทธ ํ, ํด๋น ๋ชจ๋ ๋ด์์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๊ณ ์ ํ๋ ๋ชจ๋๊ณผ ํ๋ผ๋ฏธํฐ๋ฅผ ์ง์ ํฉ๋๋ค.
# ๋ง์ง๋ง์ผ๋ก, ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ๋นํ ํค์๋ ์ธ์๊ฐ์ ์ด์ฉํ์ฌ ๊ฐ์ง์น๊ธฐ ๋งค๊ฐ๋ณ์๋ฅผ ์ง์ ํฉ๋๋ค.
# ์ด๋ฒ ์์ ์์๋, ``conv1`` ์ธต์ ๊ฐ์ค์น์ 30%๊ฐ๋ค์ ๋๋ค์ผ๋ก ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํด๋ณด๊ฒ ์ต๋๋ค.
# ๋ชจ๋์ ํจ์์ ๋ํ ์ฒซ ๋ฒ์งธ ์ธ์๊ฐ์ผ๋ก ์ ๋ฌ๋๋ฉฐ, ``name`` ์ ๋ฌธ์์ด ์๋ณ์๋ฅผ ์ด์ฉํ์ฌ ํด๋น ๋ชจ๋ ๋ด ๋งค๊ฐ๋ณ์๋ฅผ ๊ตฌ๋ถํฉ๋๋ค.
# ๊ทธ๋ฆฌ๊ณ , ``amount`` ๋ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๊ธฐ ์ํ ๋์ ๊ฐ์ค์น๊ฐ๋ค์ ๋ฐฑ๋ถ์จ (0๊ณผ 1์ฌ์ด์ ์ค์๊ฐ),
# ํน์ ๊ฐ์ค์น๊ฐ์ ์ฐ๊ฒฐ์ ๊ฐ์ (์์๊ฐ ์๋ ์ ์) ๋ฅผ ์ง์ ํฉ๋๋ค.
prune.random_unstructured(module, name="weight", amount=0.3)
######################################################################
# ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ๊ฐ์ค์น๊ฐ๋ค์ ํ๋ผ๋ฏธํฐ๊ฐ๋ค๋ก๋ถํฐ ์ ๊ฑฐํ๊ณ ``weight_orig`` (์ฆ, ์ด๊ธฐ ๊ฐ์ค์น ์ด๋ฆ์ "_orig"์ ๋ถ์ธ) ์ด๋ผ๋
# ์๋ก์ด ํ๋ผ๋ฏธํฐ๊ฐ์ผ๋ก ๋์ฒดํ๋ ๊ฒ์ผ๋ก ์คํ๋ฉ๋๋ค.
# ``weight_orig`` ์ ํ
์๊ฐ์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋์ง ์์ ์ํ๋ฅผ ์ ์ฅํฉ๋๋ค.
# ``bias`` ์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋์ง ์์๊ธฐ ๋๋ฌธ์ ๊ทธ๋๋ก ๋จ์ ์์ต๋๋ค.
print(list(module.named_parameters()))
######################################################################
# ์์์ ์ ํํ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ํด ์์ฑ๋๋ ๊ฐ์ง์น๊ธฐ ๋ง์คํฌ๋ ์ด๊ธฐ ํ๋ผ๋ฏธํฐ ``name`` ์ ``weight_mask``
# (์ฆ, ์ด๊ธฐ ๊ฐ์ค์น ์ด๋ฆ์ "_mask"๋ฅผ ๋ถ์ธ) ์ด๋ฆ์ ๋ชจ๋ ๋ฒํผ๋ก ์ ์ฅ๋ฉ๋๋ค.
print(list(module.named_buffers()))
######################################################################
# ์์ ์ด ๋์ง ์์ ์ํ์์ ์์ ํ๋ฅผ ์งํํ๊ธฐ ์ํด์๋ ``๊ฐ์ค์น`` ๊ฐ ์์ฑ์ด ์กด์ฌํด์ผ ํฉ๋๋ค.
# ``torch.nn.utils.prune`` ๋ด ๊ตฌํ๋ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ๊ฐ์ค์น๊ฐ๋ค์ ์ด์ฉํ์ฌ
# (๊ธฐ์กด์ ๊ฐ์ค์น๊ฐ์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋) ์์ ํ๋ฅผ ์งํํ๊ณ , ``weight`` ์์ฑ๊ฐ์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ๊ฐ์ค์น๊ฐ๋ค์ ์ ์ฅํฉ๋๋ค.
# ์ด์ ๊ฐ์ค์น๊ฐ๋ค์ ``module`` ์ ๋งค๊ฐ๋ณ์๊ฐ ์๋๋ผ ํ๋์ ์์ฑ๊ฐ์ผ๋ก ์ทจ๊ธ๋๋ ์ ์ ์ฃผ์ํ์ธ์.
print(module.weight)
######################################################################
# ์ต์ข
์ ์ผ๋ก, ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ํ์ดํ ์น์ ``forward_pre_hooks`` ๋ฅผ ์ด์ฉํ์ฌ ๊ฐ ์์ ํ๊ฐ ์งํ๋๊ธฐ ์ ์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ฉ๋๋ค.
# ๊ตฌ์ฒด์ ์ผ๋ก, ์ง๊ธ๊น์ง ์งํํ ๊ฒ ์ฒ๋ผ, ๋ชจ๋์ด ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋์์ ๋,
# ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ๊ฐ ํ๋ผ๋ฏธํฐ๊ฐ๋ค์ด ``forward_pre_hook`` ๋ฅผ ์ป๊ฒ๋ฉ๋๋ค.
# ์ด๋ฌํ ๊ฒฝ์ฐ, ``weight`` ์ด๋ฆ์ธ ๊ธฐ์กด ํ๋ผ๋ฏธํฐ๊ฐ์ ๋ํด์๋ง ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ์๊ธฐ ๋๋ฌธ์,
# ํ
์ ์ค์ง 1๊ฐ๋ง ์กด์ฌํ ๊ฒ์
๋๋ค.
print(module._forward_pre_hooks)
######################################################################
# ์๊ฒฐ์ฑ์ ์ํด, ํธํฅ๊ฐ์ ๋ํด์๋ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ ์ ์์ผ๋ฉฐ,
# ๋ชจ๋์ ํ๋ผ๋ฏธํฐ, ๋ฒํผ, ํ
, ์์ฑ๊ฐ๋ค์ด ์ด๋ป๊ฒ ๋ณ๊ฒฝ๋๋์ง ํ์ธํ ์ ์์ต๋๋ค.
# ๋ ๋ค๋ฅธ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํด๋ณด๊ธฐ ์ํด, ``l1_unstructured`` ๊ฐ์ง์น๊ธฐ ํจ์์์ ๊ตฌํ๋ ๋ด์ฉ๊ณผ ๊ฐ์ด,
# L1 Norm ๊ฐ์ด ๊ฐ์ฅ ์์ ํธํฅ๊ฐ 3๊ฐ๋ฅผ ๊ฐ์ง์น๊ธฐ๋ฅผ ์๋ํด๋ด
์๋ค.
prune.l1_unstructured(module, name="bias", amount=3)
######################################################################
# ์ด์ ์์ ์ค์ตํ ๋ด์ฉ์ ํ ๋๋ก, ๋ช
๋ช
๋ ํ๋ผ๋ฏธํฐ๊ฐ๋ค์ด ``weight_orig``, ``bias_orig`` 2๊ฐ๋ฅผ ๋ชจ๋ ํฌํจํ ๊ฒ์ด๋ผ ์์ํ ์ ์์ต๋๋ค.
# ๋ฒํผ๋ค์ ``weight_mask``, ``bias_mask`` 2๊ฐ๋ฅผ ํฌํจํ ๊ฒ์
๋๋ค.
# ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ 2๊ฐ์ ํ
์๊ฐ๋ค์ ๋ชจ๋์ ์์ฑ๊ฐ์ผ๋ก ์กด์ฌํ ๊ฒ์ด๋ฉฐ, ๋ชจ๋์ 2๊ฐ์ ``forward_pre_hooks`` ์ ๊ฐ๊ฒ ๋ ๊ฒ์
๋๋ค.
print(list(module.named_parameters()))
######################################################################
print(list(module.named_buffers()))
######################################################################
print(module.bias)
######################################################################
print(module._forward_pre_hooks)
######################################################################
# ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ ๋ฐ๋ณต ์ ์ฉ
# ------------------------------------
#
# ๋ชจ๋ ๋ด ๊ฐ์ ํ๋ผ๋ฏธํฐ๊ฐ์ ๋ํด ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ฌ๋ฌ๋ฒ ์ ์ฉ๋ ์ ์์ผ๋ฉฐ, ๋ค์ํ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์กฐํฉ์ด ์ ์ฉ๋ ๊ฒ๊ณผ ๋์ผํ๊ฒ ์ ์ฉ๋ ์ ์์ต๋๋ค.
# ์๋ก์ด ๋ง์คํฌ์ ์ด์ ์ ๋ง์คํฌ์ ๊ฒฐํฉ์ ``PruningContainer`` ์ ``compute_mask`` ๋ฉ์๋๋ฅผ ํตํด ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
#
# ์๋ฅผ ๋ค์ด, ๋ง์ฝ ``module.weight`` ๊ฐ์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๊ณ ์ถ์ ๋, ํ
์์ 0๋ฒ์งธ ์ถ์ L2 norm๊ฐ์ ๊ธฐ์ค์ผ๋ก ๊ตฌ์กฐํ๋ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํฉ๋๋ค.
# (์ฌ๊ธฐ์ 0๋ฒ์งธ ์ถ์ด๋, ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ ํตํด ๊ณ์ฐ๋ ์ถ๋ ฅ๊ฐ์ ๋ํด ๊ฐ ์ฑ๋๋ณ๋ก ์ ์ฉ๋๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.)
# ์ด ๋ฐฉ์์ ``ln_structured`` ํจ์์ ``n=2`` ์ ``dim=0`` ์ ์ธ์๊ฐ์ ๋ฐํ์ผ๋ก ๊ตฌํ๋ ์ ์์ต๋๋ค.
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
############################################################################
# ์ฐ๋ฆฌ๊ฐ ํ์ธํ ์ ์๋ฏ์ด, ์ด์ ๋ง์คํฌ์ ์์ฉ์ ์ ์งํ๋ฉด์ ์ฑ๋์ 50% (6๊ฐ ์ค 3๊ฐ) ์ ํด๋น๋๋ ๋ชจ๋ ์ฐ๊ฒฐ์ 0์ผ๋ก ๋ณ๊ฒฝํฉ๋๋ค.
print(module.weight)
############################################################################
# ์ด์ ํด๋นํ๋ ํ
์ ``torch.nn.utils.prune.PruningContainer`` ํํ๋ก ์กด์ฌํ๋ฉฐ, ๊ฐ์ค์น์ ์ ์ฉ๋ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ด๋ ฅ์ ์ ์ฅํฉ๋๋ค.
for hook in module._forward_pre_hooks.values():
if hook._tensor_name == "weight": # ๊ฐ์ค์น์ ํด๋นํ๋ ํ
์ ์ ํ
break
print(list(hook)) # ์ปจํ
์ด๋ ๋ด ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ด๋ ฅ
######################################################################
# ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ๋ชจ๋ธ์ ์ง๋ ฌํ
# ---------------------------------------------
# ๋ง์คํฌ ๋ฒํผ๋ค๊ณผ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ํ
์ ๊ณ์ฐ์ ์ฌ์ฉ๋ ๊ธฐ์กด์ ํ๋ผ๋ฏธํฐ๋ฅผ ํฌํจํ์ฌ ๊ด๋ จ๋ ๋ชจ๋ ํ
์๊ฐ๋ค์
# ํ์ํ ๊ฒฝ์ฐ ๋ชจ๋ธ์ ``state_dict`` ์ ์ ์ฅ๋๊ธฐ ๋๋ฌธ์, ์ฝ๊ฒ ์ง๋ ฌํํ์ฌ ์ ์ฅํ ์ ์์ต๋๋ค.
print(model.state_dict().keys())
######################################################################
# ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ฌ-ํ๋ผ๋ฏธํฐํ ์ ๊ฑฐ
# -----------------------------------------
#
# ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ๊ฒ์ ์๊ตฌ์ ์ผ๋ก ๋ง๋ค๊ธฐ ์ํด์, ์ฌ-ํ๋ผ๋ฏธํฐํ ๊ด์ ์
# ``weight_orig`` ์ ``weight_mask`` ๊ฐ์ ์ ๊ฑฐํ๊ณ , ``forward_pre_hook`` ๊ฐ์ ์ ๊ฑฐํฉ๋๋ค.
# ์ ๊ฑฐํ๊ธฐ ์ํด ``torch.nn.utils.prune`` ๋ด ``remove`` ํจ์๋ฅผ ์ด์ฉํ ์ ์์ต๋๋ค.
# ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋์ง ์์ ๊ฒ์ฒ๋ผ ์คํ๋๋ ๊ฒ์ด ์๋ ์ ์ ์ฃผ์ํ์ธ์.
# ์ด๋ ๋จ์ง ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ์ํ์์ ๊ฐ์ค์น ํ๋ผ๋ฏธํฐ๊ฐ์ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๊ฐ์ผ๋ก ์ฌํ ๋นํ๋ ๊ฒ์ ํตํด ์๊ตฌ์ ์ผ๋ก ๋ง๋๋ ๊ฒ์ผ ๋ฟ์
๋๋ค.
######################################################################
# ์ฌ-ํ๋ผ๋ฏธํฐํ๋ฅผ ์ ๊ฑฐํ๊ธฐ ์ ์ํ
print(list(module.named_parameters()))
######################################################################
print(list(module.named_buffers()))
######################################################################
print(module.weight)
######################################################################
# ์ฌ-ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๊ฑฐํ ํ ์ํ
prune.remove(module, 'weight')
print(list(module.named_parameters()))
######################################################################
print(list(module.named_buffers()))
######################################################################
# ๋ชจ๋ธ ๋ด ์ฌ๋ฌ ํ๋ผ๋ฏธํฐ๊ฐ๋ค์ ๋ํ์ฌ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ ์ ์ฉ
# ----------------------------------------------------------
#
# ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๊ณ ์ถ์ ํ๋ผ๋ฏธํฐ๊ฐ๋ค์ ์ง์ ํจ์ผ๋ก์จ, ์ด๋ฒ ์์ ์์ ๋ณผ ์ ์๋ ๊ฒ ์ฒ๋ผ,
# ์ ๊ฒฝ๋ง ๋ชจ๋ธ ๋ด ์ฌ๋ฌ ํ
์๊ฐ๋ค์ ๋ํด์ ์ฝ๊ฒ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ ์ ์์ต๋๋ค.
new_model = LeNet()
for name, module in new_model.named_modules():
# ๋ชจ๋ 2D-conv ์ธต์ 20% ์ฐ๊ฒฐ์ ๋ํด ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉ
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
# ๋ชจ๋ ์ ํ ์ธต์ 40% ์ฐ๊ฒฐ์ ๋ํด ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉ
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys()) # ์กด์ฌํ๋ ๋ชจ๋ ๋ง์คํฌ๋ค์ ํ์ธ
######################################################################
# ์ ์ญ ๋ฒ์์ ๋ํ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ ์ ์ฉ
# ----------------------------------------------
#
# ์ง๊ธ๊น์ง, "์ง์ญ ๋ณ์" ์ ๋ํด์๋ง ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๋ ๋ฐฉ๋ฒ์ ์ดํด๋ณด์์ต๋๋ค.
# (์ฆ, ๊ฐ์ค์น ๊ท๋ชจ, ํ์ฑํ ์ ๋, ๊ฒฝ์ฌ๊ฐ ๋ฑ์ ๊ฐ ํญ๋ชฉ์ ํต๊ณ๋์ ๋ฐํ์ผ๋ก ๋ชจ๋ธ ๋ด ํ
์๊ฐ ํ๋์ฉ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๋ ๋ฐฉ์)
# ๊ทธ๋ฌ๋, ๋ฒ์ฉ์ ์ด๊ณ ์๋ง ๋ ๊ฐ๋ ฅํ ๋ฐฉ๋ฒ์ ๊ฐ ์ธต์์ ๊ฐ์ฅ ๋ฎ์ 20%์ ์ฐ๊ฒฐ์ ์ ๊ฑฐํ๋ ๊ฒ ๋์ ์, ์ ์ฒด ๋ชจ๋ธ์ ๋ํด์ ๊ฐ์ฅ ๋ฎ์ 20% ์ฐ๊ฒฐ์ ํ๋ฒ์ ์ ๊ฑฐํ๋ ๊ฒ์
๋๋ค.
# ์ด๊ฒ์ ๊ฐ ์ธต์ ๋ํด์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๋ ์ฐ๊ฒฐ์ ๋ฐฑ๋ถ์จ๊ฐ์ ๋ค๋ฅด๊ฒ ๋ง๋ค ๊ฐ๋ฅ์ฑ์ด ์์ต๋๋ค.
# ``torch.nn.utils.prune`` ๋ด ``global_unstructured`` ์ ์ด์ฉํ์ฌ ์ด๋ป๊ฒ ์ ์ญ ๋ฒ์์ ๋ํ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๋์ง ์ดํด๋ด
์๋ค.
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
######################################################################
# ์ด์ ๊ฐ ์ธต์ ์กด์ฌํ๋ ์ฐ๊ฒฐ๋ค์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ์ ๋๊ฐ 20%๊ฐ ์๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
# ๊ทธ๋ฌ๋, ์ ์ฒด ๊ฐ์ง์น๊ธฐ ์ ์ฉ ๋ฒ์๋ ์ฝ 20%๊ฐ ๋ ๊ฒ์
๋๋ค.
print(
"Sparsity in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
"Sparsity in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
"Sparsity in fc1.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
"Sparsity in fc2.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
"Sparsity in fc3.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
"Global sparsity: {:.2f}%".format(
100. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
######################################################################
# ``torch.nn.utils.prune`` ์์ ํ์ฅ๋ ๋ง์ถคํ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ
# ------------------------------------------------------------------
# ๋ง์ถคํ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์, ๋ค๋ฅธ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๋ ๊ฒ๊ณผ ๊ฐ์ ๋ฐฉ์์ผ๋ก,
# ``BasePruningMethod`` ์ ๊ธฐ๋ณธ ํด๋์ค์ธ ``nn.utils.prune`` ๋ชจ๋์ ํ์ฉํ์ฌ ๊ตฌํํ ์ ์์ต๋๋ค.
# ๊ธฐ๋ณธ ํด๋์ค๋ ``__call__``, ``apply_mask``, ``apply``, ``prune``, ``remove`` ๋ฉ์๋๋ค์ ๋ดํฌํ๊ณ ์์ต๋๋ค.
# ํน๋ณํ ์ผ์ด์ค๊ฐ ์๋ ๊ฒฝ์ฐ, ๊ธฐ๋ณธ์ ์ผ๋ก ๊ตฌ์ฑ๋ ๋ฉ์๋๋ค์ ์ฌ๊ตฌ์ฑํ ํ์๊ฐ ์์ต๋๋ค.
# ๊ทธ๋ฌ๋, ``__init__`` (๊ตฌ์ฑ์์), ``compute_mask``
# (๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ๋
ผ๋ฆฌ์ ๋ฐ๋ผ ์ฃผ์ด์ง ํ
์๊ฐ์ ๋ง์คํฌ๋ฅผ ์ ์ฉํ๋ ๋ฐฉ๋ฒ) ์ ๊ณ ๋ คํ์ฌ ๊ตฌ์ฑํด์ผ ํฉ๋๋ค.
# ๊ฒ๋ค๊ฐ, ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ด๋ ํ ๋ฐฉ์์ผ๋ก ์ ์ฉํ๋์ง ๋ช
ํํ๊ฒ ๊ตฌ์ฑํด์ผ ํฉ๋๋ค.
# (์ง์๋๋ ์ต์
์ ``global``, ``structured``, ``unstructured`` ์
๋๋ค.)
# ์ด๋ฌํ ๋ฐฉ์์, ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ๋ฐ๋ณต์ ์ผ๋ก ์ ์ฉํด์ผ ํ๋ ๊ฒฝ์ฐ ๋ง์คํฌ๋ฅผ ๊ฒฐํฉํ๋ ๋ฐฉ๋ฒ์ ๊ฒฐ์ ํ๊ธฐ ์ํด ํ์ํฉ๋๋ค.
# ์ฆ, ์ด๋ฏธ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ๋ชจ๋ธ์ ๋ํด์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ ๋,
# ๊ธฐ์กด์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋์ง ์์ ํ๋ผ๋ฏธํฐ ๊ฐ์ ๋ํด ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ํฅ์ ๋ฏธ์น ๊ฒ์ผ๋ก ์์๋ฉ๋๋ค.
# ``PRUNING_TYPE`` ์ ์ง์ ํ๋ค๋ฉด, ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๊ธฐ ์ํด ํ๋ผ๋ฏธํฐ ๊ฐ์ ์ฌ๋ฐ๋ฅด๊ฒ ์ ๊ฑฐํ๋
# ``PruningContainer`` (๋ง์คํฌ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ๋ฐ๋ณต์ ์ผ๋ก ์ ์ฉํ๋ ๊ฒ์ ์ฒ๋ฆฌํ๋)๋ฅผ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
# ์๋ฅผ ๋ค์ด, ๋ค๋ฅธ ๋ชจ๋ ํญ๋ชฉ์ด ์กด์ฌํ๋ ํ
์๋ฅผ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ๊ตฌํํ๊ณ ์ถ์ ๋,
# (๋๋, ํ
์๊ฐ ์ด์ ์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ํด ์ ๊ฑฐ๋์๊ฑฐ๋ ๋จ์์๋ ํ
์์ ๋ํด)
# ํ ์ธต์ ๊ฐ๋ณ ์ฐ๊ฒฐ์ ์์ฉํ๋ฉฐ ์ ์ฒด ์ ๋/์ฑ๋ (``'structured'``), ๋๋ ๋ค๋ฅธ ํ๋ผ๋ฏธํฐ ๊ฐ
# (``'global'``) ์ฐ๊ฒฐ์๋ ์์ฉํ์ง ์๊ธฐ ๋๋ฌธ์ ``PRUNING_TYPE='unstructured'`` ๋ฐฉ์์ผ๋ก ์งํ๋ฉ๋๋ค.
class FooBarPruningMethod(prune.BasePruningMethod):
"""
ํ
์ ๋ด ๋ค๋ฅธ ํญ๋ชฉ๋ค์ ๋ํด ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉ
"""
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
mask.view(-1)[::2] = 0
return mask
######################################################################
# ``nn.Module`` ์ ๋งค๊ฐ๋ณ์์ ์ ์ฉํ๊ธฐ ์ํด ์ธ์คํด์คํํ๊ณ ์ ์ฉํ๋ ๊ฐ๋จํ ๊ธฐ๋ฅ์ ๊ตฌํํด๋ด
๋๋ค.
def foobar_unstructured(module, name):
"""
ํ
์ ๋ด ๋ค๋ฅธ ๋ชจ๋ ํญ๋ชฉ๋ค์ ์ ๊ฑฐํ์ฌ `module` ์์ `name` ์ด๋ผ๋ ํ๋ผ๋ฏธํฐ์ ๋ํด ๊ฐ์์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉ
๋ค์ ๋ด์ฉ์ ๋ฐ๋ผ ๋ชจ๋์ ์์ (๋๋ ์์ ๋ ๋ชจ๋์ ๋ฐํ):
1) ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ํด ๋งค๊ฐ๋ณ์ `name` ์ ์ ์ฉ๋ ์ด์ง ๋ง์คํฌ์ ํด๋นํ๋ ๋ช
๋ช
๋ ๋ฒํผ `name+'_mask'` ๋ฅผ ์ถ๊ฐํฉ๋๋ค.
`name` ํ๋ผ๋ฏธํฐ๋ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ๊ฒ์ผ๋ก ๋์ฒด๋๋ฉฐ, ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋์ง ์์
๊ธฐ์กด์ ํ๋ผ๋ฏธํฐ๋ `name+'_orig'` ๋ผ๋ ์ด๋ฆ์ ์๋ก์ด ๋งค๊ฐ๋ณ์์ ์ ์ฅ๋ฉ๋๋ค.
์ธ์๊ฐ:
module (nn.Module): ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ ์ ์ฉํด์ผ ํ๋ ํ
์๋ฅผ ํฌํจํ๋ ๋ชจ๋
name (string): ๋ชจ๋ ๋ด ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ํ๋ผ๋ฏธํฐ์ ์ด๋ฆ
๋ฐํ๊ฐ:
module (nn.Module): ์
๋ ฅ ๋ชจ๋์ ๋ํด์ ๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ๋ชจ๋
์์:
>>> m = nn.Linear(3, 4)
>>> foobar_unstructured(m, name='bias')
"""
FooBarPruningMethod.apply(module, name)
return module
######################################################################
# ํ๋ฒ ํด๋ด
์๋ค!
model = LeNet()
foobar_unstructured(model.fc3, name='bias')
print(model.fc3.bias_mask)