4
4
import torch .nn as nn
5
5
from tqdm import tqdm
6
6
import matplotlib .pyplot as plt
7
- plt .switch_backend ('agg' )
7
+ plt .switch_backend ('agg' ) # for servers not supporting display
8
8
9
9
# import neccesary libraries for defining the optimizers
10
10
import torch .optim as optim
13
13
14
14
from unet import UNet
15
15
from datasets import DAE_dataset
16
+ import config as cfg
16
17
17
18
device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
18
19
print ('device: ' , device )
@@ -21,11 +22,13 @@ def q(text = ''):
21
22
print ('> {}' .format (text ))
22
23
sys .exit ()
23
24
24
- models_dir = 'models'
25
+ data_dir = cfg .data_dir
26
+
27
+ models_dir = cfg .model_dir
25
28
if not os .path .exists (models_dir ):
26
29
os .mkdir (models_dir )
27
30
28
- losses_dir = 'losses'
31
+ losses_dir = cfg . losses_dir
29
32
if not os .path .exists (losses_dir ):
30
33
os .mkdir (losses_dir )
31
34
@@ -61,12 +64,12 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
61
64
ax4 .set_ylabel ('batch val loss' )
62
65
ax4 .plot (running_val_loss )
63
66
64
- plt .savefig (os .path .join ('losses' ,'losses_{}.png' .format (str (epoch + 1 ).zfill (2 ))))
67
+ plt .savefig (os .path .join (losses_dir ,'losses_{}.png' .format (str (epoch + 1 ).zfill (2 ))))
65
68
66
69
transform = transforms .Compose ([transforms .ToPILImage (), transforms .ToTensor ()])
67
70
68
- train_dataset = DAE_dataset (os .path .join ('data' , 'train' ), transform = transform )
69
- val_dataset = DAE_dataset (os .path .join ('data' , 'val' ), transform = transform )
71
+ train_dataset = DAE_dataset (os .path .join (data_dir , 'train' ), transform = transform )
72
+ val_dataset = DAE_dataset (os .path .join (data_dir , 'val' ), transform = transform )
70
73
71
74
print ('\n len(train_dataset) : ' , len (train_dataset ))
72
75
print ('len(val_dataset) : ' , len (val_dataset ))
@@ -124,15 +127,16 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
124
127
for batch_idx , (imgs , noisy_imgs ) in enumerate (train_loader ):
125
128
batch_start_time = time .time ()
126
129
imgs = imgs .to (device )
127
- # print(imgs.shape)
128
- # q()
129
130
noisy_imgs = noisy_imgs .to (device )
131
+
130
132
optimizer .zero_grad ()
131
133
out = model (noisy_imgs )
134
+
132
135
loss = loss_fn (out , imgs )
133
136
running_train_loss .append (loss .item ())
134
137
loss .backward ()
135
138
optimizer .step ()
139
+
136
140
if (batch_idx + 1 )% log_interval == 0 :
137
141
batch_time = time .time () - batch_start_time
138
142
m ,s = divmod (batch_time , 60 )
@@ -153,8 +157,10 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
153
157
154
158
imgs = imgs .to (device )
155
159
noisy_imgs = noisy_imgs .to (device )
160
+
156
161
out = model (noisy_imgs )
157
162
loss = loss_fn (out , imgs )
163
+
158
164
running_val_loss .append (loss .item ())
159
165
160
166
if (batch_idx + 1 )% log_interval == 0 :
@@ -168,4 +174,4 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
168
174
print ('\n epoch val time: {} hrs {} mins {} secs' .format (int (h ), int (m ), int (s )))
169
175
170
176
plot_losses (running_train_loss , running_val_loss , train_epoch_loss , val_epoch_loss , epoch )
171
- torch .save ({'model ' : model , 'losses' : {'running_train_loss' : running_train_loss , 'running_val_loss' : running_val_loss , 'train_epoch_loss' : train_epoch_loss , 'val_epoch_loss' : val_epoch_loss }, 'epochs_till_now' : epoch + 1 }, os .path .join (models_dir , 'model{}.pth' .format (str (epoch + 1 ).zfill (2 ))))
177
+ torch .save ({'model_state_dict ' : model . state_dict () , 'losses' : {'running_train_loss' : running_train_loss , 'running_val_loss' : running_val_loss , 'train_epoch_loss' : train_epoch_loss , 'val_epoch_loss' : val_epoch_loss }, 'epochs_till_now' : epoch + 1 }, os .path .join (models_dir , 'model{}.pth' .format (str (epoch + 1 ).zfill (2 ))))
0 commit comments