@@ -23,6 +23,8 @@ def q(text = ''):
23
23
sys .exit ()
24
24
25
25
data_dir = cfg .data_dir
26
+ train_dir = cfg .train_dir
27
+ val_dir = cfg .val_dir
26
28
27
29
models_dir = cfg .model_dir
28
30
if not os .path .exists (models_dir ):
@@ -68,55 +70,58 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
68
70
69
71
transform = transforms .Compose ([transforms .ToPILImage (), transforms .ToTensor ()])
70
72
71
- train_dataset = DAE_dataset (os .path .join (data_dir , 'train' ), transform = transform )
72
- val_dataset = DAE_dataset (os .path .join (data_dir , 'val' ), transform = transform )
73
+ train_dataset = DAE_dataset (os .path .join (data_dir , train_dir ), transform = transform )
74
+ val_dataset = DAE_dataset (os .path .join (data_dir , val_dir ), transform = transform )
73
75
74
76
print ('\n len(train_dataset) : ' , len (train_dataset ))
75
77
print ('len(val_dataset) : ' , len (val_dataset ))
76
78
77
- batch_size = 8
79
+ batch_size = cfg . batch_size
78
80
79
81
train_loader = torch .utils .data .DataLoader (train_dataset , batch_size = batch_size , shuffle = True )
80
82
val_loader = torch .utils .data .DataLoader (val_dataset , batch_size = batch_size , shuffle = not True )
81
83
82
84
print ('\n len(train_loader): {} @bs={}' .format (len (train_loader ), batch_size ))
83
85
print ('len(val_loader) : {} @bs={}' .format (len (val_loader ), batch_size ))
84
86
85
- resume = not False
87
+ # defining the model
88
+ model = UNet (n_classes = 1 , depth = 3 , padding = True ).to (device )
89
+
90
+ resume = cfg .resume
86
91
if not resume :
87
92
print ('\n from scratch' )
88
- model = UNet (n_classes = 1 , depth = 3 , padding = True ).to (device )
89
93
train_epoch_loss = []
90
94
val_epoch_loss = []
91
95
running_train_loss = []
92
96
running_val_loss = []
93
97
epochs_till_now = 0
94
98
else :
95
- ckpt_path = os .path .join ('models' , 'model01.pth' )
99
+ ckpt_path = os .path .join (models_dir , cfg . ckpt )
96
100
ckpt = torch .load (ckpt_path )
97
- model = ckpt ['model' ].to (device )
98
101
print (f'\n ckpt loaded: { ckpt_path } ' )
102
+ model_state_dict = ckpt ['model_state_dict' ]
103
+ model .load_state_dict (model_state_dict )
104
+ model .to (device )
99
105
losses = ckpt ['losses' ]
100
106
running_train_loss = losses ['running_train_loss' ]
101
107
running_val_loss = losses ['running_val_loss' ]
102
108
train_epoch_loss = losses ['train_epoch_loss' ]
103
109
val_epoch_loss = losses ['val_epoch_loss' ]
104
-
105
110
epochs_till_now = ckpt ['epochs_till_now' ]
106
111
107
- lr = 3e-5
112
+ lr = cfg . lr
108
113
optimizer = optim .Adam (filter (lambda p : p .requires_grad , model .parameters ()), lr = lr )
109
114
loss_fn = nn .MSELoss ()
110
115
111
116
log_interval = 25
112
- epochs = 1
117
+ epochs = cfg . epochs
113
118
114
119
###
115
120
print ('\n model has {} M parameters' .format (count_parameters (model )))
116
- print (f'loss_fn : { loss_fn } ' )
121
+ print (f'\n loss_fn : { loss_fn } ' )
117
122
print (f'lr : { lr } ' )
118
123
print (f'epochs_till_now: { epochs_till_now } ' )
119
- print (f'epochs : { epochs } ' )
124
+ print (f'epochs from now : { epochs } ' )
120
125
###
121
126
122
127
for epoch in range (epochs_till_now , epochs_till_now + epochs ):
@@ -174,4 +179,11 @@ def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoc
174
179
print ('\n epoch val time: {} hrs {} mins {} secs' .format (int (h ), int (m ), int (s )))
175
180
176
181
plot_losses (running_train_loss , running_val_loss , train_epoch_loss , val_epoch_loss , epoch )
177
- torch .save ({'model_state_dict' : model .state_dict (), 'losses' : {'running_train_loss' : running_train_loss , 'running_val_loss' : running_val_loss , 'train_epoch_loss' : train_epoch_loss , 'val_epoch_loss' : val_epoch_loss }, 'epochs_till_now' : epoch + 1 }, os .path .join (models_dir , 'model{}.pth' .format (str (epoch + 1 ).zfill (2 ))))
182
+
183
+ torch .save ({'model_state_dict' : model .state_dict (),
184
+ 'losses' : {'running_train_loss' : running_train_loss ,
185
+ 'running_val_loss' : running_val_loss ,
186
+ 'train_epoch_loss' : train_epoch_loss ,
187
+ 'val_epoch_loss' : val_epoch_loss },
188
+ 'epochs_till_now' : epoch + 1 },
189
+ os .path .join (models_dir , 'model{}.pth' .format (str (epoch + 1 ).zfill (2 ))))
0 commit comments