2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
4
from torchvision import datasets , transforms
5
+ from torch .utils .tensorboard import SummaryWriter
6
+ # from DynamicRELU import DYReLU2
7
+
8
+ tensorboard_on = False
9
+ if tensorboard_on :
10
+ writer = SummaryWriter ()
11
+
5
12
6
13
class MyConvNet (nn .Module ):
7
- def __init__ (self ):
14
+ def __init__ (self , relu , relustr , ** kwargs ):
8
15
super (MyConvNet , self ).__init__ ()
9
16
self .conv1 = nn .Conv2d (1 , 20 , 5 , 1 )
10
17
self .conv2 = nn .Conv2d (20 , 50 , 5 , 1 )
11
18
self .fc1 = nn .Linear (4 * 4 * 50 , 500 )
12
19
self .fc2 = nn .Linear (500 , 10 )
13
-
20
+ if relustr == 'dyrelu' :
21
+ self .relu1 = relu (20 , 20 )
22
+ self .relu2 = relu (50 , 50 )
23
+ self .relu3 = relu (500 , 500 )
24
+ else :
25
+ self .relu1 = relu ()
26
+ self .relu2 = relu ()
27
+ self .relu3 = relu ()
28
+
14
29
def forward (self , x ):
15
- x = self .conv1 (x ) # 28x28 -> 24x24
16
- x = F . relu (x ) # 24x24
17
- x = F .max_pool2d (x , 2 , 2 ) # 24x24 -> 12x12
18
- x = self .conv2 (x ) # 12x12 -> 8x8
19
- x = F . relu (x ) # 8x8
20
- x = F .max_pool2d (x , 2 , 2 ) # 8x8 -> 4x4
21
- x = torch .flatten (x , 1 )
22
- x = self .fc1 (x )
23
- x = F . relu (x )
24
- x = self .fc2 (x )
30
+ x = self .conv1 (x ) # 28x28x1 -> 24x24x20
31
+ x = self . relu1 (x ) # 24x24x20
32
+ x = F .max_pool2d (x , 2 , 2 ) # 24x24x20 -> 12x12x20
33
+ x = self .conv2 (x ) # 12x12x20 -> 8x8x50
34
+ x = self . relu2 (x ) # 8x8x50
35
+ x = F .max_pool2d (x , 2 , 2 ) # 8x8x50 -> 4x4x50
36
+ x = torch .flatten (x , 1 ) # 4x4x50 -> 4*4*50
37
+ x = self .fc1 (x ) # 4*4*50 -> 500
38
+ # x = self.relu3 (x) # 500 -> 500
39
+ x = self .fc2 (x ) # 500 -> 10
25
40
return F .log_softmax (x , dim = 1 )
26
41
27
42
@@ -38,12 +53,16 @@ def train(model, device, train_loader, optimizer, epoch):
38
53
loss .backward ()
39
54
optimizer .step ()
40
55
41
- if batch_idx % 100 == 0 :
42
- print ("Train epoch: {}, iteration: {}, Loss: {}" .format (
43
- epoch , batch_idx , loss .item ()
44
- ))
56
+ if tensorboard_on :
57
+ writer .add_scalar ('Loss/train' ,
58
+ loss .item (),
59
+ epoch * len (train_loader ) + batch_idx )
60
+
61
+ # if batch_idx % 100 == 0:
62
+ print ("Epoch: {}, train loss: {}, " .format (epoch , loss .item ()), end = '' )
63
+
45
64
46
- def test (model , device , test_loader ):
65
+ def test (model , device , test_loader , epoch ):
47
66
model .eval ()
48
67
total_loss = 0
49
68
correct = 0.
@@ -57,40 +76,67 @@ def test(model, device, test_loader):
57
76
58
77
total_loss /= len (test_loader .dataset )
59
78
acc = correct / len (test_loader .dataset ) * 100.
60
- print ("Test loss: {}, accuracy: {}" .format (total_loss , acc ))
79
+ print ("test loss: {}, accuracy: {}" .format (total_loss , acc ))
80
+
81
+ if tensorboard_on :
82
+ writer .add_scalar ('Loss/test' , total_loss , epoch )
83
+ writer .add_scalar ('Accuracy/test' , acc , epoch )
84
+
61
85
62
86
def main ():
63
- device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
64
- batch_size = 32
87
+ batch_size = 128
88
+ lr = 0.01
89
+ momentum = 0.9
90
+ epochs = 15
91
+ schd_step = 7
92
+ relus = {'relu' : nn .ReLU ,
93
+ 'lrelu' : nn .LeakyReLU ,
94
+ 'rrelu' : nn .RReLU ,
95
+ 'prelu' : nn .PReLU ,
96
+ 'relu6' : nn .ReLU6 ,
97
+ 'elu' : nn .ELU ,
98
+ 'selu' : nn .SELU ,
99
+ # dyrelu': DYReLU2
100
+ }
101
+ relu_kwargs = [{}, {}, {}, {}, {}, {}, {}, {}]
65
102
66
- kwargs = {'num_workers' : 1 , 'pin_memory' : True } if torch .cuda .is_available () else {}
103
+ device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
104
+ kwargs = {'num_workers' : 1 , 'pin_memory' : True } \
105
+ if torch .cuda .is_available () else {}
67
106
train_dataloader = torch .utils .data .DataLoader (
68
- datasets .MNIST ('./data' , train = True , download = True ,
69
- transform = transforms .Compose ([
70
- transforms .ToTensor (),
71
- transforms .Normalize ((0.1307 ,), (0.3081 ,))
72
- ])),
107
+ datasets .MNIST (
108
+ './data' , train = True , download = True ,
109
+ transform = transforms .Compose ([
110
+ transforms .ToTensor (),
111
+ transforms .Normalize ((0.1307 ,), (0.3081 ,))
112
+ ])),
73
113
batch_size = batch_size , shuffle = True , ** kwargs )
74
114
test_dataloader = torch .utils .data .DataLoader (
75
- datasets .MNIST ('./data' , train = False , download = True ,
76
- transform = transforms .Compose ([
77
- transforms .ToTensor (),
78
- transforms .Normalize ((0.1307 ,), (0.3081 ,))
79
- ])),
115
+ datasets .MNIST (
116
+ './data' , train = False , download = True ,
117
+ transform = transforms .Compose ([
118
+ transforms .ToTensor (),
119
+ transforms .Normalize ((0.1307 ,), (0.3081 ,))
120
+ ])),
80
121
batch_size = batch_size , shuffle = True , ** kwargs )
81
122
82
- lr = 1e-2
83
- momentum = 0.5
84
- epochs = 10
123
+ for i , (relustr , relu ) in enumerate (relus .items ()):
124
+ print ('--------------------- {} ---------------------' .format (relustr ))
125
+ model = MyConvNet (relu , relustr , ** relu_kwargs [i ]).to (device )
126
+ optimizer = torch .optim .SGD (
127
+ model .parameters (), lr = lr , momentum = momentum )
128
+ scheduler = torch .optim .lr_scheduler .StepLR (optimizer , schd_step )
85
129
86
- model = MyConvNet ().to (device )
87
- optimizer = torch .optim .SGD (model .parameters (), lr = lr , momentum = momentum )
130
+ for epoch in range (epochs ):
131
+ train (model , device , train_dataloader , optimizer , epoch )
132
+ test (model , device , test_dataloader , epoch )
133
+ scheduler .step ()
88
134
89
- for epoch in range (epochs ):
90
- train (model , device , train_dataloader , optimizer , epoch )
91
- test (model , device , test_dataloader )
135
+ # torch.save(model.state_dict(), 'mnist_cnn.pt')
92
136
93
- torch .save (model .state_dict (), 'mnist_cnn.pt' )
94
137
95
138
if __name__ == '__main__' :
96
139
main ()
140
+
141
+ if tensorboard_on :
142
+ writer .close ()
0 commit comments