Skip to content

Commit 5b98abf

Browse files
committed
numpy和pytorch实现NIN模型及测试
1 parent f06b1b6 commit 5b98abf

File tree

9 files changed

+297
-6
lines changed

9 files changed

+297
-6
lines changed

models/NIN.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, in_channels=1, out_channels=10, momentum=0, nesterov=False, p
5050
self.relu8 = nn.ReLU()
5151
self.relu9 = nn.ReLU()
5252

53-
self.dropout = nn.Dropout()
53+
self.dropout = nn.Dropout2d()
5454

5555
self.p_h = p_h
5656
self.U1 = None
@@ -85,7 +85,7 @@ def forward(self, inputs):
8585

8686
def backward(self, grad_out):
8787
# grad_out.shape = [N, C]
88-
assert len(grad_out) == 2
88+
assert len(grad_out.shape) == 2
8989
da11 = self.gap.backward(grad_out)
9090

9191
dz11 = self.relu9.backward(da11)

models/pytorch/NIN.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# @Time : 19-6-21 下午2:56
4+
# @Author : zj
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
__all__ = ['NIN', 'nin']
10+
11+
model_urls = {
12+
'nin': ''
13+
}
14+
15+
16+
class NIN(nn.Module):
17+
18+
def __init__(self, in_channels=1, out_channels=10):
19+
super(NIN, self).__init__()
20+
21+
self.features1 = nn.Sequential(
22+
nn.Conv2d(in_channels, 192, (5, 5), stride=1, padding=2),
23+
nn.ReLU(),
24+
nn.Conv2d(192, 160, (1, 1), stride=1, padding=0),
25+
nn.ReLU(),
26+
nn.Conv2d(160, 96, (1, 1), stride=1, padding=0),
27+
nn.ReLU(),
28+
nn.MaxPool2d(2, stride=2),
29+
nn.Dropout2d()
30+
)
31+
self.features2 = nn.Sequential(
32+
nn.Conv2d(96, 192, (5, 5), stride=1, padding=2),
33+
nn.ReLU(),
34+
nn.Conv2d(192, 192, (1, 1), stride=1, padding=0),
35+
nn.ReLU(),
36+
nn.Conv2d(192, 192, (1, 1), stride=1, padding=0),
37+
nn.ReLU(),
38+
nn.MaxPool2d(2, stride=2),
39+
nn.Dropout2d()
40+
)
41+
self.features3 = nn.Sequential(
42+
nn.Conv2d(192, 192, (3, 3), stride=1, padding=1),
43+
nn.ReLU(),
44+
nn.Conv2d(192, 192, (1, 1), stride=1, padding=0),
45+
nn.ReLU(),
46+
nn.Conv2d(192, out_channels, (1, 1), stride=1, padding=0),
47+
nn.ReLU(),
48+
)
49+
50+
self.gap = nn.AvgPool2d(8)
51+
52+
def forward(self, inputs):
53+
x = self.features1(inputs)
54+
x = self.features2(x)
55+
x = self.features3(x)
56+
x = self.gap(x)
57+
58+
return x.view(x.shape[0], x.shape[1])
59+
60+
61+
def nin(pretrained=False, **kwargs):
62+
"""
63+
创建模型对象
64+
"""
65+
66+
model = NIN(**kwargs)
67+
# if pretrained:
68+
# params = load_params(model_urls['nin'])
69+
# model.set_params(params)
70+
return model

models/pytorch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
3+
from .NIN import *

nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .FC import FC
55
from .ReLU import ReLU
66
from .Dropout import Dropout
7+
from .Dropout2d import Dropout2d
78
from .CrossEntropyLoss import CrossEntropyLoss
89
from .Conv2d import Conv2d
910
from .MaxPool import MaxPool

src/3_nn_cifar10.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
data_path = '/home/lab305/Documents/data/decompress_cifar_10'
3333

3434
if __name__ == '__main__':
35-
x_train, x_test, y_train, y_test = vision.data.load_cifar10(data_path, shuffle=True)
35+
x_train, x_test, y_train, y_test = vision.data.load_cifar10(data_path, shuffle=True, is_flatten=True)
3636

3737
x_train = x_train / 255 - 0.5
3838
x_test = x_test / 255 - 0.5

src/nin_cifar10.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# @Time : 19-6-21 下午2:45
4+
# @Author : zj
5+
6+
import nn
7+
import models
8+
import models.utils as utils
9+
import vision.data
10+
import numpy as np
11+
import time
12+
13+
data_path = '/home/lab305/Documents/data/decompress_cifar_10'
14+
15+
epochs = 100
16+
batch_size = 128
17+
momentum = 0.9
18+
learning_rate = 1e-3
19+
reg = 1e-3
20+
p_h = 0.5
21+
22+
23+
def nin_train():
24+
x_train, x_test, y_train, y_test = vision.data.load_cifar10(data_path, shuffle=True)
25+
26+
# 标准化
27+
x_train = x_train / 255.0 - 0.5
28+
x_test = x_test / 255.0 - 0.5
29+
30+
net = models.nin(in_channels=3, p_h=p_h)
31+
criterion = nn.CrossEntropyLoss()
32+
33+
accuracy = vision.Accuracy()
34+
35+
loss_list = []
36+
train_list = []
37+
test_list = []
38+
best_train_accuracy = 0.995
39+
best_test_accuracy = 0.995
40+
41+
range_list = np.arange(0, x_train.shape[0] - batch_size, step=batch_size)
42+
for i in range(epochs):
43+
total_loss = 0
44+
num = 0
45+
start = time.time()
46+
for j in range_list:
47+
data = x_train[j:j + batch_size]
48+
labels = y_train[j:j + batch_size]
49+
50+
scores = net(data)
51+
loss = criterion(scores, labels)
52+
total_loss += loss
53+
num += 1
54+
55+
grad_out = criterion.backward()
56+
net.backward(grad_out)
57+
net.update(lr=learning_rate, reg=reg)
58+
end = time.time()
59+
print('one epoch need time: %.3f' % (end - start))
60+
print('epoch: %d loss: %f' % (i + 1, total_loss / num))
61+
loss_list.append(total_loss / num)
62+
63+
if (i % 20) == 19:
64+
# # 每隔20次降低学习率
65+
# learning_rate *= 0.5
66+
67+
train_accuracy = accuracy.compute_v2(x_train, y_train, net, batch_size=batch_size)
68+
test_accuracy = accuracy.compute_v2(x_test, y_test, net, batch_size=batch_size)
69+
train_list.append(train_accuracy)
70+
test_list.append(test_accuracy)
71+
72+
print(loss_list)
73+
print(train_list)
74+
print(test_list)
75+
if train_accuracy > best_train_accuracy and test_accuracy > best_test_accuracy:
76+
path = 'nin-epochs-%d.pkl' % (i + 1)
77+
utils.save_params(net.get_params(), path=path)
78+
break
79+
80+
draw = vision.Draw()
81+
draw(loss_list, xlabel='迭代/20次')
82+
draw.multi_plot((train_list, test_list), ('训练集', '测试集'), title='精度图', xlabel='迭代/20次', ylabel='精度值')
83+
84+
85+
if __name__ == '__main__':
86+
start = time.time()
87+
nin_train()
88+
end = time.time()
89+
print('training need time: %.3f' % (end - start))

src/nin_cifar10_pytorch.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# @Time : 19-6-21 下午3:41
4+
# @Author : zj
5+
6+
import torch
7+
import torch.nn as nn
8+
import torch.optim as optim
9+
import time
10+
import vision.data
11+
import models.pytorch
12+
13+
epochs = 100
14+
batch_size = 128
15+
lr = 1e-3
16+
momentum = 0.9
17+
18+
data_path = '/home/lab305/Documents/data/cifar_10'
19+
20+
21+
def train():
22+
train_loader, test_loader = vision.data.load_cifar10_pytorch(data_path, batch_size=batch_size, shuffle=True)
23+
24+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25+
device = torch.device("cpu")
26+
27+
net = models.pytorch.nin(in_channels=3).to(device)
28+
criterion = nn.CrossEntropyLoss().to(device)
29+
optimer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, nesterov=True)
30+
# stepLR = StepLR(optimer, 100, 0.5)
31+
32+
best_train_accuracy = 0.995
33+
best_test_accuracy = 0
34+
35+
accuracy = vision.Accuracy()
36+
37+
loss_list = []
38+
train_list = []
39+
for i in range(epochs):
40+
num = 0
41+
total_loss = 0
42+
start = time.time()
43+
# 训练阶段
44+
net.train()
45+
for j, item in enumerate(train_loader, 0):
46+
data, labels = item
47+
data = data.to(device)
48+
labels = labels.to(device)
49+
50+
scores = net.forward(data)
51+
loss = criterion.forward(scores, labels)
52+
total_loss += loss.item()
53+
54+
optimer.zero_grad()
55+
loss.backward()
56+
optimer.step()
57+
num += 1
58+
end = time.time()
59+
# stepLR.step()
60+
61+
avg_loss = total_loss / num
62+
loss_list.append(float('%.8f' % avg_loss))
63+
print('epoch: %d time: %.2f loss: %.8f' % (i + 1, end - start, avg_loss))
64+
65+
if i % 20 == 19:
66+
# 验证阶段
67+
net.eval()
68+
train_accuracy = accuracy.compute_pytorch(train_loader, net, device)
69+
train_list.append(float('%.4f' % train_accuracy))
70+
if best_train_accuracy < train_accuracy:
71+
best_train_accuracy = train_accuracy
72+
73+
test_accuracy = accuracy.compute_pytorch(test_loader, net, device)
74+
if best_test_accuracy < test_accuracy:
75+
best_test_accuracy = test_accuracy
76+
77+
print('best train accuracy: %.2f %% best test accuracy: %.2f %%' % (
78+
best_train_accuracy * 100, best_test_accuracy * 100))
79+
print(loss_list)
80+
print(train_list)
81+
82+
83+
if __name__ == '__main__':
84+
start = time.time()
85+
train()
86+
end = time.time()
87+
print('training need time: %.3f' % (end - start))

vision/Accuracy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66

77
import numpy as np
8+
import torch
89

910

1011
class Accuracy(object):
@@ -34,3 +35,17 @@ def compute_v2(self, data_array, labels_array, net, batch_size=128):
3435
num += 1
3536

3637
return total_accuracy / num
38+
39+
def compute_pytorch(self, loader, net, device):
40+
total_accuracy = 0
41+
num = 0
42+
for item in loader:
43+
data, labels = item
44+
data = data.to(device)
45+
labels = labels.to(device)
46+
47+
scores = net.forward(data)
48+
predicted = torch.argmax(scores, dim=1)
49+
total_accuracy += torch.mean((predicted == labels).float()).item()
50+
num += 1
51+
return total_accuracy / num

vision/data/cifar10.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
# @Time : 19-6-20 下午7:22
44
# @Author : zj
55

6+
import torchvision.transforms as transforms
7+
import torchvision.datasets as datasets
8+
from torch.utils.data import DataLoader
9+
610
import numpy as np
711
import os
812
from .utils import *
@@ -14,7 +18,7 @@
1418
dst_size = (32, 32)
1519

1620

17-
def load_cifar10(cifar10_path, shuffle=True):
21+
def load_cifar10(cifar10_path, shuffle=True, is_flatten=False):
1822
"""
1923
加载cifar10
2024
"""
@@ -41,8 +45,11 @@ def load_cifar10(cifar10_path, shuffle=True):
4145
file_path = os.path.join(data_dir, filename)
4246
img = read_image(file_path)
4347
if img is not None:
44-
x_test.append(img.reshape(-1))
4548
y_test.append(i)
49+
if is_flatten:
50+
x_test.append(img.reshape(-1))
51+
else:
52+
x_test.append(np.transpose(img, (2, 0, 1)))
4653

4754
train_file_list = np.array(train_file_list)
4855
if shuffle:
@@ -52,7 +59,26 @@ def load_cifar10(cifar10_path, shuffle=True):
5259
for file_path in train_file_list:
5360
img = read_image(file_path)
5461
if img is not None:
55-
x_train.append(img.reshape(-1))
5662
y_train.append(int(os.path.split(file_path)[0].split('/')[-1]))
63+
if is_flatten:
64+
x_train.append(img.reshape(-1))
65+
else:
66+
x_train.append(np.transpose(img, (2, 0, 1)))
5767

5868
return np.array(x_train), np.array(x_test), np.array(y_train), np.array(y_test)
69+
70+
71+
def load_cifar10_pytorch(cifar10_path, batch_size=128, shuffle=False):
72+
transform = transforms.Compose([
73+
transforms.Resize((227, 227)),
74+
transforms.ToTensor(),
75+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
76+
])
77+
78+
train_data_set = datasets.CIFAR10(root=cifar10_path, train=True, download=True, transform=transform)
79+
test_data_set = datasets.CIFAR10(root=cifar10_path, train=False, download=True, transform=transform)
80+
81+
train_loader = DataLoader(train_data_set, batch_size=batch_size, shuffle=shuffle, num_workers=2)
82+
test_loader = DataLoader(test_data_set, batch_size=batch_size, shuffle=shuffle, num_workers=2)
83+
84+
return train_loader, test_loader

0 commit comments

Comments
 (0)