Skip to content

Commit ebc1a2a

Browse files
authored
Update train.py
1 parent 533d258 commit ebc1a2a

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

train.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
1919
print('device: ', device)
2020

21+
script_time = time.time()
22+
2123
def q(text = ''):
2224
print('> {}'.format(text))
2325
sys.exit()
@@ -70,7 +72,6 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
7072

7173
transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
7274

73-
7475
train_dataset = DAE_dataset(os.path.join(data_dir, train_dir), transform = transform)
7576
val_dataset = DAE_dataset(os.path.join(data_dir, val_dir), transform = transform)
7677

@@ -127,7 +128,7 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
127128
###
128129

129130
for epoch in range(epochs_till_now, epochs_till_now+epochs):
130-
print('\n===== EPOCH {}/{} ====='.format(epochs_till_now + 1, epochs_till_now + epochs))
131+
print('\n===== EPOCH {}/{} ====='.format(epoch + 1, epochs_till_now + epochs))
131132
print('\nTRAINING...')
132133
epoch_train_start_time = time.time()
133134
model.train()
@@ -189,3 +190,10 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
189190
'val_epoch_loss': val_epoch_loss},
190191
'epochs_till_now': epoch+1},
191192
os.path.join(models_dir, 'model{}.pth'.format(str(epoch + 1).zfill(2))))
193+
194+
total_script_time = time.time() - script_time
195+
m, s = divmod(total_script_time, 60)
196+
h, m = divmod(m, 60)
197+
print(f'\ntotal time taken for running this script: {int(h)} hrs {int(m)} mins {int(s)} secs')
198+
199+
print('\nFin.')

0 commit comments

Comments
 (0)