1
+ from matplotlib import pyplot as plt
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torchvision
6
+ from torch import nn
7
+ from torch .autograd import Variable
8
+ from torch .utils .data import DataLoader
9
+ from torchvision import transforms
10
+ import torchvision .datasets as dset
11
+ from torchvision .utils import save_image
12
+ import torchvision .utils as vutils
13
+ from torchsummary import summary
14
+ import argparse
15
+ import sys
16
+ from math import log10
17
+
18
+ from Models import autoencoder
19
+ from dataloader import DataloaderCompression
20
+ from Lossfuncs import mse_loss , parsingLoss
21
+
22
+ nb_channls = 3
23
+
24
+ parser = argparse .ArgumentParser ()
25
+ parser .add_argument (
26
+ '--batch_size' , type = int , default = 8 , help = 'batch size' )
27
+ parser .add_argument (
28
+ '--train' , required = True , type = str , help = 'folder of training images' )
29
+ parser .add_argument (
30
+ '--test' , required = True , type = str , help = 'folder of testing images' )
31
+ parser .add_argument (
32
+ '--max_epochs' , type = int , default = 50 , help = 'max epochs' )
33
+ parser .add_argument ('--lr' , type = float , default = 0.005 , help = 'learning rate' )
34
+ # parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda')
35
+ parser .add_argument (
36
+ '--iterations' , type = int , default = 100 , help = 'unroll iterations' )
37
+ parser .add_argument (
38
+ '--image_size' , type = int , default = 150 , help = 'Load image size' )
39
+ parser .add_argument ('--checkpoint' , type = int , default = 20 , help = 'save checkpoint after ' )
40
+ parser .add_argument ('--workers' , type = int , default = 4 , help = 'unroll iterations' )
41
+ parser .add_argument ('--weight_decay' , type = float , default = 0.0005 , help = 'unroll iterations' )
42
+ args = parser .parse_args ()
43
+
44
+ device = torch .device ("cuda:0" if (torch .cuda .is_available ()) else "cpu" )
45
+
46
+ def to_img (x ):
47
+ x = 0.5 * (x + 1 )
48
+ x = x .clamp (0 , 1 )
49
+ x = x .view (x .size (0 ), nb_channls , args .image_size , args .image_size )
50
+ return x
51
+
52
+ Dataloader = DataloaderCompression (args .train ,args .image_size ,args .batch_size ,args .workers )
53
+
54
+ model = autoencoder ().to (device )
55
+ criterion = nn .MSELoss ()
56
+
57
+ optimizer = torch .optim .Adam (model .parameters (), lr = args .lr , weight_decay = args .weight_decay )
58
+ summary (model , (nb_channls , args .image_size , args .image_size ))
59
+
60
+ # Training Loop. Results will appear every 10th iteration.
61
+ itr = 0
62
+ training_loss = []
63
+ PSNR_list = []
64
+ for epoch in range (args .max_epochs ):
65
+ for data in Dataloader :
66
+ img , _ = data
67
+ img = Variable (img ).to (device )
68
+
69
+ # Forward
70
+ coding , output = model (img )
71
+ cyclicloss ,r_loss ,g_loss ,b_loss = mse_loss (output , img )
72
+ pLoss = parsingLoss (coding , args .image_size )
73
+
74
+ loss = 5 * cyclicloss + 10 * pLoss
75
+
76
+ PSNR = 10 * log10 (255 ** 2 / cyclicloss )
77
+
78
+ # Backprop
79
+ optimizer .zero_grad ()
80
+ loss .backward ()
81
+ optimizer .step ()
82
+
83
+ '''
84
+ if itr % 10 == 0 and itr < args.iterations:
85
+ # Log
86
+ print('iter [{}], whole_loss:{:.4f} cyclic_loss{:.4f} pLoss{:.4f} comp_ratio{:.4f}'
87
+ .format(itr, loss.data.item(), 5*cyclicloss.data.item(), 10*pLoss.data.item(), PSNR))
88
+ '''
89
+ '''
90
+ if itr % 30 == 0 and itr < args.iterations:
91
+ pic = to_img(output.to("cpu").data)
92
+
93
+ fig = plt.figure(figsize=(16, 16))
94
+
95
+ ax = plt.imshow(np.transpose(vutils.make_grid(pic.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
96
+ ax.axes.get_xaxis().set_visible(False)
97
+ ax.axes.get_yaxis().set_visible(False)
98
+ plt.show(fig)
99
+
100
+
101
+ compress_ratio.append(comp_ratio)
102
+ '''
103
+ training_loss .append (loss )
104
+ PSNR_list .append (PSNR )
105
+ itr += 1
106
+
107
+ print ('epoch [{}/{}], loss:{:.4f}, cyclic_loss{:.4f} pLoss{:.4f} PSNR{:.4f}'
108
+ .format (epoch + 1 , args .max_epochs , loss .data .item (), 5 * cyclicloss .data .item (), 10 * pLoss .data .item (), PSNR ))
109
+
110
+ if epoch % 10 == 0 :
111
+ torch .save (model , 'Compressing_{%d}.pth' % epoch )
112
+
113
+ plt .plot (training_loss , label = 'Training loss' )
114
+ plt .plot (PSNR , label = 'PSNR' )
115
+ plt .legend (frameon = False )
116
+ plt .savefig ("Train.png" )
117
+ plt .show ()
0 commit comments