forked from micah35s/Autoencoder-Image-Compression
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_eval.py
117 lines (95 loc) · 3.92 KB
/
train_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from matplotlib import pyplot as plt
import numpy as np
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.datasets as dset
from torchvision.utils import save_image
import torchvision.utils as vutils
from torchsummary import summary
import argparse
import sys
from math import log10
from Models import autoencoder
from dataloader import DataloaderCompression
from Lossfuncs import mse_loss, parsingLoss
nb_channls = 3
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch_size', type=int, default=8, help='batch size')
parser.add_argument(
'--train', required=True, type=str, help='folder of training images')
parser.add_argument(
'--test', required=True, type=str, help='folder of testing images')
parser.add_argument(
'--max_epochs', type=int, default=50, help='max epochs')
parser.add_argument('--lr', type=float, default=0.005, help='learning rate')
# parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda')
parser.add_argument(
'--iterations', type=int, default=100, help='unroll iterations')
parser.add_argument(
'--image_size', type=int, default=150, help='Load image size')
parser.add_argument('--checkpoint', type=int, default=20, help='save checkpoint after ')
parser.add_argument('--workers', type=int, default=4, help='unroll iterations')
parser.add_argument('--weight_decay', type=float, default=0.0005, help='unroll iterations')
args = parser.parse_args()
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
def to_img(x):
x = 0.5 * (x + 1)
x = x.clamp(0, 1)
x = x.view(x.size(0), nb_channls, args.image_size, args.image_size)
return x
Dataloader = DataloaderCompression(args.train,args.image_size,args.batch_size,args.workers)
model = autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
summary(model, (nb_channls, args.image_size, args.image_size))
# Training Loop. Results will appear every 10th iteration.
itr = 0
training_loss = []
PSNR_list = []
for epoch in range(args.max_epochs):
for data in Dataloader:
img, _ = data
img = Variable(img).to(device)
# Forward
coding, output = model(img)
cyclicloss,r_loss,g_loss,b_loss = mse_loss(output, img)
pLoss = parsingLoss(coding, args.image_size)
loss = 5*cyclicloss + 10*pLoss
PSNR = 10*log10(255**2/cyclicloss)
# Backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
'''
if itr % 10 == 0 and itr < args.iterations:
# Log
print('iter [{}], whole_loss:{:.4f} cyclic_loss{:.4f} pLoss{:.4f} comp_ratio{:.4f}'
.format(itr, loss.data.item(), 5*cyclicloss.data.item(), 10*pLoss.data.item(), PSNR))
'''
'''
if itr % 30 == 0 and itr < args.iterations:
pic = to_img(output.to("cpu").data)
fig = plt.figure(figsize=(16, 16))
ax = plt.imshow(np.transpose(vutils.make_grid(pic.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
plt.show(fig)
compress_ratio.append(comp_ratio)
'''
training_loss.append(loss)
PSNR_list.append(PSNR)
itr += 1
print('epoch [{}/{}], loss:{:.4f}, cyclic_loss{:.4f} pLoss{:.4f} PSNR{:.4f}'
.format(epoch + 1, args.max_epochs, loss.data.item(), 5*cyclicloss.data.item(), 10*pLoss.data.item(), PSNR))
if epoch % 10 == 0:
torch.save(model, 'Compressing_{%d}.pth'%epoch)
plt.plot(training_loss, label='Training loss')
plt.plot(PSNR, label='PSNR')
plt.legend(frameon=False)
plt.savefig("Train.png")
plt.show()