Skip to content

Commit ebd3d6a

Browse files
authored
import config as cfg
1 parent d5041a2 commit ebd3d6a

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

main.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def q(text = ''):
2323
sys.exit()
2424

2525
data_dir = cfg.data_dir
26+
train_dir = cfg.train_dir
27+
val_dir = cfg.val_dir
2628

2729
models_dir = cfg.model_dir
2830
if not os.path.exists(models_dir):
@@ -68,55 +70,58 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
6870

6971
transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
7072

71-
train_dataset = DAE_dataset(os.path.join(data_dir, 'train'), transform = transform)
72-
val_dataset = DAE_dataset(os.path.join(data_dir, 'val'), transform = transform)
73+
train_dataset = DAE_dataset(os.path.join(data_dir, train_dir), transform = transform)
74+
val_dataset = DAE_dataset(os.path.join(data_dir, val_dir), transform = transform)
7375

7476
print('\nlen(train_dataset) : ', len(train_dataset))
7577
print('len(val_dataset) : ', len(val_dataset))
7678

77-
batch_size = 8
79+
batch_size = cfg.batch_size
7880

7981
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
8082
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle = not True)
8183

8284
print('\nlen(train_loader): {} @bs={}'.format(len(train_loader), batch_size))
8385
print('len(val_loader) : {} @bs={}'.format(len(val_loader), batch_size))
8486

85-
resume = not False
87+
# defining the model
88+
model = UNet(n_classes = 1, depth = 3, padding = True).to(device)
89+
90+
resume = cfg.resume
8691
if not resume:
8792
print('\nfrom scratch')
88-
model = UNet(n_classes = 1, depth = 3, padding = True).to(device)
8993
train_epoch_loss = []
9094
val_epoch_loss = []
9195
running_train_loss = []
9296
running_val_loss = []
9397
epochs_till_now = 0
9498
else:
95-
ckpt_path = os.path.join('models', 'model01.pth')
99+
ckpt_path = os.path.join(models_dir, cfg.ckpt)
96100
ckpt = torch.load(ckpt_path)
97-
model = ckpt['model'].to(device)
98101
print(f'\nckpt loaded: {ckpt_path}')
102+
model_state_dict = ckpt['model_state_dict']
103+
model.load_state_dict(model_state_dict)
104+
model.to(device)
99105
losses = ckpt['losses']
100106
running_train_loss = losses['running_train_loss']
101107
running_val_loss = losses['running_val_loss']
102108
train_epoch_loss = losses['train_epoch_loss']
103109
val_epoch_loss = losses['val_epoch_loss']
104-
105110
epochs_till_now = ckpt['epochs_till_now']
106111

107-
lr = 3e-5
112+
lr = cfg.lr
108113
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)
109114
loss_fn = nn.MSELoss()
110115

111116
log_interval = 25
112-
epochs = 1
117+
epochs = cfg.epochs
113118

114119
###
115120
print('\nmodel has {} M parameters'.format(count_parameters(model)))
116-
print(f'loss_fn : {loss_fn}')
121+
print(f'\nloss_fn : {loss_fn}')
117122
print(f'lr : {lr}')
118123
print(f'epochs_till_now: {epochs_till_now}')
119-
print(f'epochs : {epochs}')
124+
print(f'epochs from now: {epochs}')
120125
###
121126

122127
for epoch in range(epochs_till_now, epochs_till_now+epochs):
@@ -174,4 +179,11 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
174179
print('\nepoch val time: {} hrs {} mins {} secs'.format(int(h), int(m), int(s)))
175180

176181
plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoch_loss, epoch)
177-
torch.save({'model_state_dict': model.state_dict(), 'losses': {'running_train_loss': running_train_loss, 'running_val_loss': running_val_loss, 'train_epoch_loss': train_epoch_loss, 'val_epoch_loss': val_epoch_loss}, 'epochs_till_now': epoch+1}, os.path.join(models_dir, 'model{}.pth'.format(str(epoch + 1).zfill(2))))
182+
183+
torch.save({'model_state_dict': model.state_dict(),
184+
'losses': {'running_train_loss': running_train_loss,
185+
'running_val_loss': running_val_loss,
186+
'train_epoch_loss': train_epoch_loss,
187+
'val_epoch_loss': val_epoch_loss},
188+
'epochs_till_now': epoch+1},
189+
os.path.join(models_dir, 'model{}.pth'.format(str(epoch + 1).zfill(2))))

0 commit comments

Comments
 (0)