Skip to content

Commit e11f666

Browse files
authored
added config
1 parent f0c54fe commit e11f666

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

main.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as nn
55
from tqdm import tqdm
66
import matplotlib.pyplot as plt
7-
plt.switch_backend('agg')
7+
plt.switch_backend('agg') # for servers not supporting display
88

99
# import neccesary libraries for defining the optimizers
1010
import torch.optim as optim
@@ -13,6 +13,7 @@
1313

1414
from unet import UNet
1515
from datasets import DAE_dataset
16+
import config as cfg
1617

1718
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
1819
print('device: ', device)
@@ -21,11 +22,13 @@ def q(text = ''):
2122
print('> {}'.format(text))
2223
sys.exit()
2324

24-
models_dir = 'models'
25+
data_dir = cfg.data_dir
26+
27+
models_dir = cfg.model_dir
2528
if not os.path.exists(models_dir):
2629
os.mkdir(models_dir)
2730

28-
losses_dir = 'losses'
31+
losses_dir = cfg.losses_dir
2932
if not os.path.exists(losses_dir):
3033
os.mkdir(losses_dir)
3134

@@ -61,12 +64,12 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
6164
ax4.set_ylabel('batch val loss')
6265
ax4.plot(running_val_loss)
6366

64-
plt.savefig(os.path.join('losses','losses_{}.png'.format(str(epoch + 1).zfill(2))))
67+
plt.savefig(os.path.join(losses_dir,'losses_{}.png'.format(str(epoch + 1).zfill(2))))
6568

6669
transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
6770

68-
train_dataset = DAE_dataset(os.path.join('data', 'train'), transform = transform)
69-
val_dataset = DAE_dataset(os.path.join('data', 'val'), transform = transform)
71+
train_dataset = DAE_dataset(os.path.join(data_dir, 'train'), transform = transform)
72+
val_dataset = DAE_dataset(os.path.join(data_dir, 'val'), transform = transform)
7073

7174
print('\nlen(train_dataset) : ', len(train_dataset))
7275
print('len(val_dataset) : ', len(val_dataset))
@@ -124,15 +127,16 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
124127
for batch_idx, (imgs, noisy_imgs) in enumerate(train_loader):
125128
batch_start_time = time.time()
126129
imgs = imgs.to(device)
127-
# print(imgs.shape)
128-
# q()
129130
noisy_imgs = noisy_imgs.to(device)
131+
130132
optimizer.zero_grad()
131133
out = model(noisy_imgs)
134+
132135
loss = loss_fn(out, imgs)
133136
running_train_loss.append(loss.item())
134137
loss.backward()
135138
optimizer.step()
139+
136140
if (batch_idx + 1)%log_interval == 0:
137141
batch_time = time.time() - batch_start_time
138142
m,s = divmod(batch_time, 60)
@@ -153,8 +157,10 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
153157

154158
imgs = imgs.to(device)
155159
noisy_imgs = noisy_imgs.to(device)
160+
156161
out = model(noisy_imgs)
157162
loss = loss_fn(out, imgs)
163+
158164
running_val_loss.append(loss.item())
159165

160166
if (batch_idx + 1)%log_interval == 0:
@@ -168,4 +174,4 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
168174
print('\nepoch val time: {} hrs {} mins {} secs'.format(int(h), int(m), int(s)))
169175

170176
plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoch_loss, epoch)
171-
torch.save({'model': model, 'losses': {'running_train_loss': running_train_loss, 'running_val_loss': running_val_loss, 'train_epoch_loss': train_epoch_loss, 'val_epoch_loss': val_epoch_loss}, 'epochs_till_now': epoch+1}, os.path.join(models_dir, 'model{}.pth'.format(str(epoch + 1).zfill(2))))
177+
torch.save({'model_state_dict': model.state_dict(), 'losses': {'running_train_loss': running_train_loss, 'running_val_loss': running_val_loss, 'train_epoch_loss': train_epoch_loss, 'val_epoch_loss': val_epoch_loss}, 'epochs_till_now': epoch+1}, os.path.join(models_dir, 'model{}.pth'.format(str(epoch + 1).zfill(2))))

0 commit comments

Comments
 (0)