Skip to content

Commit 2957ccf

Browse files
authored
Add files via upload
0 parents  commit 2957ccf

7 files changed

+294
-0
lines changed

Lossfuncs.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
def mse_loss(input, target):
4+
r = input[:,0:1,:,:] - target[:,0:1,:,:]
5+
g = input[:,1:2,:,:] - target[:,1:2,:,:]
6+
b = input[:,2:3,:,:] - target[:,2:3,:,:]
7+
8+
r = torch.mean(r**2)
9+
g = torch.mean(g**2)
10+
b = torch.mean(b**2)
11+
12+
mean = (r + g + b)/3
13+
14+
return mean, r,g,b
15+
16+
def parsingLoss(coding, image_size):
17+
return torch.sum(torch.abs(coding))/(image_size**2)

Models.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from torch import nn
2+
3+
class autoencoder(nn.Module):
4+
def __init__(self):
5+
super(autoencoder, self).__init__()
6+
self.conv1 = nn.Conv2d(3, 6, kernel_size=(5,5))
7+
self.maxpool1 = nn.MaxPool2d(kernel_size=(2,2), return_indices=True)
8+
self.maxpool2 = nn.MaxPool2d(kernel_size=(2,2), return_indices=True)
9+
self.unconv1 = nn.ConvTranspose2d(6,3,kernel_size=(5,5))
10+
self.maxunpool1 = nn.MaxUnpool2d(kernel_size=(2,2))
11+
self.unmaxunpool2 = nn.MaxUnpool2d(kernel_size=(2,2))
12+
13+
self.encoder1 = nn.Sequential(
14+
nn.Tanh(),
15+
nn.Conv2d(6, 12,kernel_size=(5,5)),
16+
)
17+
18+
self.encoder2 = nn.Sequential(
19+
nn.Tanh(),
20+
nn.Conv2d(12, 16, kernel_size=(5,5)),
21+
nn.Tanh()
22+
)
23+
24+
self.decoder2 = nn.Sequential(
25+
nn.ConvTranspose2d(16, 12, kernel_size=(5,5)),
26+
nn.Tanh()
27+
)
28+
29+
self.decoder1 = nn.Sequential(
30+
nn.ConvTranspose2d(12,6,kernel_size=(5,5)),
31+
nn.Tanh(),
32+
)
33+
34+
35+
def forward(self, x):
36+
x = self.conv1(x)
37+
x,indices1 = self.maxpool1(x)
38+
x = self.encoder1(x)
39+
x,indices2 = self.maxpool2(x)
40+
coding = self.encoder2(x)
41+
42+
x = self.decoder2(coding)
43+
x = self.unmaxunpool2(x, indices2)
44+
x = self.decoder1(x)
45+
x = self.maxunpool1(x,indices1)
46+
x = self.unconv1(x)
47+
output = nn.Tanh()(x)
48+
return coding, output

dataloader.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from torch.utils.data import DataLoader
2+
import torchvision.datasets as dset
3+
from torchvision import transforms
4+
5+
def DataloaderCompression(dataroot, image_size, batch_size, workers):
6+
#dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))
7+
dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))
8+
9+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
10+
return dataloader

decoding.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from matplotlib import pyplot as plt
2+
import numpy as np
3+
4+
import torch
5+
import argparse
6+
from torch.autograd import Variable
7+
from math import log10
8+
import torchvision.utils as vutils
9+
from PIL import Image
10+
from torchvision import transforms
11+
#import train_eval
12+
#from train_eval import to_img
13+
14+
from Models import autoencoder
15+
from dataloader import DataloaderCompression
16+
from Lossfuncs import mse_loss, parsingLoss
17+
18+
nb_channls = 3
19+
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument(
22+
'--batch_size', type=int, default=8, help='batch size')
23+
parser.add_argument(
24+
'--train', required=True, type=str, help='folder of training images')
25+
parser.add_argument(
26+
'--test', required=True, type=str, help='folder of testing images')
27+
parser.add_argument(
28+
'--max_epochs', type=int, default=50, help='max epochs')
29+
parser.add_argument('--lr', type=float, default=0.005, help='learning rate')
30+
# parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda')
31+
parser.add_argument(
32+
'--iterations', type=int, default=100, help='unroll iterations')
33+
parser.add_argument(
34+
'--image_size', type=int, default=150, help='Load image size')
35+
parser.add_argument('--checkpoint', type=int, default=20, help='save checkpoint after ')
36+
parser.add_argument('--workers', type=int, default=4, help='unroll iterations')
37+
parser.add_argument('--weight_decay', type=float, default=0.0005, help='unroll iterations')
38+
args = parser.parse_args()
39+
40+
def to_img(x):
41+
x = 0.5 * (x + 1)
42+
x = x.clamp(0, 1)
43+
x = x.view(x.size(0), nb_channls, args.image_size, args.image_size)
44+
return x
45+
46+
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
47+
48+
model=torch.load('compressing.pth')
49+
model.eval()
50+
51+
Dataloader = DataloaderCompression(args.test,args.image_size,args.batch_size,args.workers)
52+
53+
PSNR = []
54+
Compressing_Ratio = []
55+
itr = 0
56+
for data in Dataloader:
57+
img, _ = data
58+
img = Variable(img).to(device)
59+
60+
coding, output = model(img)
61+
cyclicloss,r_loss,g_loss,b_loss = mse_loss(output, img)
62+
63+
PSNR_value = 10*log10(255**2/cyclicloss)
64+
PSNR.append(PSNR_value)
65+
66+
Comp_ratio = coding.size()[1]/img.size()[1]
67+
Compressing_Ratio.append(Comp_ratio)
68+
69+
pic_ = to_img(output.to("cpu").data)
70+
#pic = transforms.ToPILImage(pic_)
71+
72+
#pic_color = np.transpose(vutils.make_grid(pic.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0))
73+
fig = plt.figure(figsize=(128, 128))
74+
75+
'''
76+
ax = plt.imshow(np.transpose(vutils.make_grid(pic.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
77+
ax.axes.get_xaxis().set_visible(False)
78+
ax.axes.get_yaxis().set_visible(False)
79+
'''
80+
81+
#plt.show(fig)
82+
plt.savefig('output/%d.jpg'%itr)
83+
itr += 1
84+
85+
print('mean PSNR is %s'%np.mean(PSNR))
86+
print('mean compression ratio is %s'%np.mean(Compressing_Ratio))

run_test.sh

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
python decoding.py \
2+
--batch_size 1 \
3+
--train 'Data' \
4+
--test 'Data_valid' \
5+
--max_epochs 30 \
6+
--lr 0.0005 \
7+
--iterations 30 \
8+
--image_size 128

run_train.sh

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
python train_eval.py \
2+
--batch_size 16 \
3+
--train 'Data' \
4+
--test 'Data_valid' \
5+
--max_epochs 30 \
6+
--lr 0.0005 \
7+
--iterations 30 \
8+
--image_size 128

train_eval.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from matplotlib import pyplot as plt
2+
import numpy as np
3+
4+
import torch
5+
import torchvision
6+
from torch import nn
7+
from torch.autograd import Variable
8+
from torch.utils.data import DataLoader
9+
from torchvision import transforms
10+
import torchvision.datasets as dset
11+
from torchvision.utils import save_image
12+
import torchvision.utils as vutils
13+
from torchsummary import summary
14+
import argparse
15+
import sys
16+
from math import log10
17+
18+
from Models import autoencoder
19+
from dataloader import DataloaderCompression
20+
from Lossfuncs import mse_loss, parsingLoss
21+
22+
nb_channls = 3
23+
24+
parser = argparse.ArgumentParser()
25+
parser.add_argument(
26+
'--batch_size', type=int, default=8, help='batch size')
27+
parser.add_argument(
28+
'--train', required=True, type=str, help='folder of training images')
29+
parser.add_argument(
30+
'--test', required=True, type=str, help='folder of testing images')
31+
parser.add_argument(
32+
'--max_epochs', type=int, default=50, help='max epochs')
33+
parser.add_argument('--lr', type=float, default=0.005, help='learning rate')
34+
# parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda')
35+
parser.add_argument(
36+
'--iterations', type=int, default=100, help='unroll iterations')
37+
parser.add_argument(
38+
'--image_size', type=int, default=150, help='Load image size')
39+
parser.add_argument('--checkpoint', type=int, default=20, help='save checkpoint after ')
40+
parser.add_argument('--workers', type=int, default=4, help='unroll iterations')
41+
parser.add_argument('--weight_decay', type=float, default=0.0005, help='unroll iterations')
42+
args = parser.parse_args()
43+
44+
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
45+
46+
def to_img(x):
47+
x = 0.5 * (x + 1)
48+
x = x.clamp(0, 1)
49+
x = x.view(x.size(0), nb_channls, args.image_size, args.image_size)
50+
return x
51+
52+
Dataloader = DataloaderCompression(args.train,args.image_size,args.batch_size,args.workers)
53+
54+
model = autoencoder().to(device)
55+
criterion = nn.MSELoss()
56+
57+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
58+
summary(model, (nb_channls, args.image_size, args.image_size))
59+
60+
# Training Loop. Results will appear every 10th iteration.
61+
itr = 0
62+
training_loss = []
63+
PSNR_list = []
64+
for epoch in range(args.max_epochs):
65+
for data in Dataloader:
66+
img, _ = data
67+
img = Variable(img).to(device)
68+
69+
# Forward
70+
coding, output = model(img)
71+
cyclicloss,r_loss,g_loss,b_loss = mse_loss(output, img)
72+
pLoss = parsingLoss(coding, args.image_size)
73+
74+
loss = 5*cyclicloss + 10*pLoss
75+
76+
PSNR = 10*log10(255**2/cyclicloss)
77+
78+
# Backprop
79+
optimizer.zero_grad()
80+
loss.backward()
81+
optimizer.step()
82+
83+
'''
84+
if itr % 10 == 0 and itr < args.iterations:
85+
# Log
86+
print('iter [{}], whole_loss:{:.4f} cyclic_loss{:.4f} pLoss{:.4f} comp_ratio{:.4f}'
87+
.format(itr, loss.data.item(), 5*cyclicloss.data.item(), 10*pLoss.data.item(), PSNR))
88+
'''
89+
'''
90+
if itr % 30 == 0 and itr < args.iterations:
91+
pic = to_img(output.to("cpu").data)
92+
93+
fig = plt.figure(figsize=(16, 16))
94+
95+
ax = plt.imshow(np.transpose(vutils.make_grid(pic.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
96+
ax.axes.get_xaxis().set_visible(False)
97+
ax.axes.get_yaxis().set_visible(False)
98+
plt.show(fig)
99+
100+
101+
compress_ratio.append(comp_ratio)
102+
'''
103+
training_loss.append(loss)
104+
PSNR_list.append(PSNR)
105+
itr += 1
106+
107+
print('epoch [{}/{}], loss:{:.4f}, cyclic_loss{:.4f} pLoss{:.4f} PSNR{:.4f}'
108+
.format(epoch + 1, args.max_epochs, loss.data.item(), 5*cyclicloss.data.item(), 10*pLoss.data.item(), PSNR))
109+
110+
if epoch % 10 == 0:
111+
torch.save(model, 'Compressing_{%d}.pth'%epoch)
112+
113+
plt.plot(training_loss, label='Training loss')
114+
plt.plot(PSNR, label='PSNR')
115+
plt.legend(frameon=False)
116+
plt.savefig("Train.png")
117+
plt.show()

0 commit comments

Comments
 (0)