Skip to content

Commit f06b1b6

Browse files
committed
feat(NIN): 实现NIN类
1 parent cecfe69 commit f06b1b6

File tree

2 files changed

+191
-0
lines changed

2 files changed

+191
-0
lines changed

models/NIN.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# @Time : 19-6-21 上午11:00
4+
# @Author : zj
5+
6+
7+
import nn
8+
from .Net import Net
9+
from .utils import load_params
10+
11+
__all__ = ['NIN', 'nin']
12+
13+
model_urls = {
14+
'nin': ''
15+
}
16+
17+
18+
class NIN(Net):
19+
"""
20+
NIN网络
21+
"""
22+
23+
def __init__(self, in_channels=1, out_channels=10, momentum=0, nesterov=False, p_h=1.0):
24+
super(NIN, self).__init__()
25+
self.conv1 = nn.Conv2d(in_channels, 5, 5, 192, stride=1, padding=2, momentum=momentum, nesterov=nesterov)
26+
self.conv2 = nn.Conv2d(96, 5, 5, 192, stride=1, padding=2, momentum=momentum, nesterov=nesterov)
27+
self.conv3 = nn.Conv2d(192, 3, 3, 192, stride=1, padding=1, momentum=momentum, nesterov=nesterov)
28+
29+
self.mlp1 = nn.Conv2d(192, 1, 1, 160, stride=1, padding=0, momentum=momentum, nesterov=nesterov)
30+
self.mlp2 = nn.Conv2d(160, 1, 1, 96, stride=1, padding=0, momentum=momentum, nesterov=nesterov)
31+
32+
self.mlp2_1 = nn.Conv2d(192, 1, 1, 192, stride=1, padding=0, momentum=momentum, nesterov=nesterov)
33+
self.mlp2_2 = nn.Conv2d(192, 1, 1, 192, stride=1, padding=0, momentum=momentum, nesterov=nesterov)
34+
35+
self.mlp3_1 = nn.Conv2d(192, 1, 1, 192, stride=1, padding=0, momentum=momentum, nesterov=nesterov)
36+
self.mlp3_2 = nn.Conv2d(192, 1, 1, out_channels, stride=1, padding=0, momentum=momentum, nesterov=nesterov)
37+
38+
self.maxPool1 = nn.MaxPool(2, 2, 96, stride=2)
39+
self.maxPool2 = nn.MaxPool(2, 2, 192, stride=2)
40+
41+
self.gap = nn.GAP()
42+
43+
self.relu1 = nn.ReLU()
44+
self.relu2 = nn.ReLU()
45+
self.relu3 = nn.ReLU()
46+
self.relu4 = nn.ReLU()
47+
self.relu5 = nn.ReLU()
48+
self.relu6 = nn.ReLU()
49+
self.relu7 = nn.ReLU()
50+
self.relu8 = nn.ReLU()
51+
self.relu9 = nn.ReLU()
52+
53+
self.dropout = nn.Dropout()
54+
55+
self.p_h = p_h
56+
self.U1 = None
57+
self.U2 = None
58+
59+
def __call__(self, inputs):
60+
return self.forward(inputs)
61+
62+
def forward(self, inputs):
63+
# inputs.shape = [N, C, H, W]
64+
assert len(inputs.shape) == 4
65+
x = self.relu1(self.conv1(inputs))
66+
x = self.relu2(self.mlp1(x))
67+
x = self.relu3(self.mlp2(x))
68+
x = self.maxPool1(x)
69+
self.U1 = self.dropout(x.shape, self.p_h)
70+
x *= self.U1
71+
72+
x = self.relu4(self.conv2(x))
73+
x = self.relu5(self.mlp2_1(x))
74+
x = self.relu6(self.mlp2_2(x))
75+
x = self.maxPool2(x)
76+
self.U2 = self.dropout(x.shape, self.p_h)
77+
x *= self.U2
78+
79+
x = self.relu7(self.conv3(x))
80+
x = self.relu8(self.mlp3_1(x))
81+
x = self.relu9(self.mlp3_2(x))
82+
83+
x = self.gap(x)
84+
return x
85+
86+
def backward(self, grad_out):
87+
# grad_out.shape = [N, C]
88+
assert len(grad_out) == 2
89+
da11 = self.gap.backward(grad_out)
90+
91+
dz11 = self.relu9.backward(da11)
92+
da10 = self.mlp3_2.backward(dz11)
93+
dz10 = self.relu8.backward(da10)
94+
da9 = self.mlp3_1.backward(dz10)
95+
dz9 = self.relu7.backward(da9)
96+
da8 = self.conv3.backward(dz9)
97+
98+
da8 *= self.U2
99+
da7 = self.maxPool2.backward(da8)
100+
dz7 = self.relu6.backward(da7)
101+
da6 = self.mlp2_2.backward(dz7)
102+
dz6 = self.relu5.backward(da6)
103+
da5 = self.mlp2_1.backward(dz6)
104+
dz5 = self.relu4.backward(da5)
105+
da4 = self.conv2.backward(dz5)
106+
107+
da4 *= self.U1
108+
da3 = self.maxPool1.backward(da4)
109+
dz3 = self.relu3.backward(da3)
110+
da2 = self.mlp2.backward(dz3)
111+
dz2 = self.relu2.backward(da2)
112+
da1 = self.mlp1.backward(dz2)
113+
dz1 = self.relu1.backward(da1)
114+
da0 = self.conv1.backward(dz1)
115+
116+
def update(self, lr=1e-3, reg=1e-3):
117+
self.mlp3_2.update(learning_rate=lr, regularization_rate=reg)
118+
self.mlp3_1.update(learning_rate=lr, regularization_rate=reg)
119+
self.conv3.update(learning_rate=lr, regularization_rate=reg)
120+
121+
self.mlp2_2.update(learning_rate=lr, regularization_rate=reg)
122+
self.mlp2_1.update(learning_rate=lr, regularization_rate=reg)
123+
self.conv2.update(learning_rate=lr, regularization_rate=reg)
124+
125+
self.mlp2.update(learning_rate=lr, regularization_rate=reg)
126+
self.mlp1.update(learning_rate=lr, regularization_rate=reg)
127+
self.conv1.update(learning_rate=lr, regularization_rate=reg)
128+
129+
def predict(self, inputs):
130+
# inputs.shape = [N, C, H, W]
131+
assert len(inputs.shape) == 4
132+
x = self.relu1(self.conv1(inputs))
133+
x = self.relu2(self.mlp1(x))
134+
x = self.relu3(self.mlp2(x))
135+
x = self.maxPool1(x)
136+
137+
x = self.relu4(self.conv2(x))
138+
x = self.relu5(self.mlp2_1(x))
139+
x = self.relu6(self.mlp2_2(x))
140+
x = self.maxPool2(x)
141+
142+
x = self.relu7(self.conv3(x))
143+
x = self.relu8(self.mlp3_1(x))
144+
x = self.relu9(self.mlp3_2(x))
145+
146+
x = self.gap(x)
147+
return x
148+
149+
def get_params(self):
150+
out = dict()
151+
out['conv1'] = self.conv1.get_params()
152+
out['conv2'] = self.conv2.get_params()
153+
out['conv3'] = self.conv3.get_params()
154+
155+
out['mlp1'] = self.mlp1.get_params()
156+
out['mlp2'] = self.mlp2.get_params()
157+
out['mlp2_1'] = self.mlp2_1.get_params()
158+
out['mlp2_2'] = self.mlp2_2.get_params()
159+
out['mlp3_1'] = self.mlp3_1.get_params()
160+
out['mlp3_2'] = self.mlp3_2.get_params()
161+
162+
out['p_h'] = self.p_h
163+
164+
return out
165+
166+
def set_params(self, params):
167+
self.conv1.set_params(params['conv1'])
168+
self.conv2.set_params(params['conv2'])
169+
self.conv3.set_params(params['conv3'])
170+
171+
self.mlp1.set_params(params['mlp1'])
172+
self.mlp2.set_params(params['mlp2'])
173+
self.mlp2_1.set_params(params['mlp2_1'])
174+
self.mlp2_2.set_params(params['mlp2_1'])
175+
self.mlp3_1.set_params(params['mlp3_1'])
176+
self.mlp3_2.set_params(params['mlp3_1'])
177+
178+
self.p_h = params.get('p_h', 1.0)
179+
180+
181+
def nin(pretrained=False, **kwargs):
182+
"""
183+
创建模型对象
184+
"""
185+
186+
model = NIN(**kwargs)
187+
if pretrained:
188+
params = load_params(model_urls['nin'])
189+
model.set_params(params)
190+
return model

models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .ThreeLayerNet import *
55
from .LeNet5 import *
66
from .AlexNet import *
7+
from .NIN import *

0 commit comments

Comments
 (0)