|
46 | 46 | import matplotlib.pyplot as plt
|
47 | 47 | import time
|
48 | 48 | import os
|
49 |
| -import copy |
| 49 | +from tempfile import TemporaryDirectory |
50 | 50 |
|
51 | 51 | cudnn.benchmark = True
|
52 | 52 | plt.ion() # interactive mode
|
@@ -146,67 +146,71 @@ def imshow(inp, title=None):
|
146 | 146 | def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
147 | 147 | since = time.time()
|
148 | 148 |
|
149 |
| - best_model_wts = copy.deepcopy(model.state_dict()) |
150 |
| - best_acc = 0.0 |
151 |
| - |
152 |
| - for epoch in range(num_epochs): |
153 |
| - print(f'Epoch {epoch}/{num_epochs - 1}') |
154 |
| - print('-' * 10) |
155 |
| - |
156 |
| - # Each epoch has a training and validation phase |
157 |
| - for phase in ['train', 'val']: |
158 |
| - if phase == 'train': |
159 |
| - model.train() # Set model to training mode |
160 |
| - else: |
161 |
| - model.eval() # Set model to evaluate mode |
162 |
| - |
163 |
| - running_loss = 0.0 |
164 |
| - running_corrects = 0 |
165 |
| - |
166 |
| - # Iterate over data. |
167 |
| - for inputs, labels in dataloaders[phase]: |
168 |
| - inputs = inputs.to(device) |
169 |
| - labels = labels.to(device) |
170 |
| - |
171 |
| - # zero the parameter gradients |
172 |
| - optimizer.zero_grad() |
173 |
| - |
174 |
| - # forward |
175 |
| - # track history if only in train |
176 |
| - with torch.set_grad_enabled(phase == 'train'): |
177 |
| - outputs = model(inputs) |
178 |
| - _, preds = torch.max(outputs, 1) |
179 |
| - loss = criterion(outputs, labels) |
180 |
| - |
181 |
| - # backward + optimize only if in training phase |
182 |
| - if phase == 'train': |
183 |
| - loss.backward() |
184 |
| - optimizer.step() |
185 |
| - |
186 |
| - # statistics |
187 |
| - running_loss += loss.item() * inputs.size(0) |
188 |
| - running_corrects += torch.sum(preds == labels.data) |
189 |
| - if phase == 'train': |
190 |
| - scheduler.step() |
191 |
| - |
192 |
| - epoch_loss = running_loss / dataset_sizes[phase] |
193 |
| - epoch_acc = running_corrects.double() / dataset_sizes[phase] |
194 |
| - |
195 |
| - print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') |
196 |
| - |
197 |
| - # deep copy the model |
198 |
| - if phase == 'val' and epoch_acc > best_acc: |
199 |
| - best_acc = epoch_acc |
200 |
| - best_model_wts = copy.deepcopy(model.state_dict()) |
201 |
| - |
202 |
| - print() |
203 |
| - |
204 |
| - time_elapsed = time.time() - since |
205 |
| - print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s') |
206 |
| - print(f'Best val Acc: {best_acc:4f}') |
207 |
| - |
208 |
| - # load best model weights |
209 |
| - model.load_state_dict(best_model_wts) |
| 149 | + # Create a temporary directory to save training checkpoints |
| 150 | + with TemporaryDirectory() as tempdir: |
| 151 | + best_model_params_path = os.path.join(tempdir, 'best_model_params.pt') |
| 152 | + |
| 153 | + torch.save(model.state_dict(), best_model_params_path) |
| 154 | + best_acc = 0.0 |
| 155 | + |
| 156 | + for epoch in range(num_epochs): |
| 157 | + print(f'Epoch {epoch}/{num_epochs - 1}') |
| 158 | + print('-' * 10) |
| 159 | + |
| 160 | + # Each epoch has a training and validation phase |
| 161 | + for phase in ['train', 'val']: |
| 162 | + if phase == 'train': |
| 163 | + model.train() # Set model to training mode |
| 164 | + else: |
| 165 | + model.eval() # Set model to evaluate mode |
| 166 | + |
| 167 | + running_loss = 0.0 |
| 168 | + running_corrects = 0 |
| 169 | + |
| 170 | + # Iterate over data. |
| 171 | + for inputs, labels in dataloaders[phase]: |
| 172 | + inputs = inputs.to(device) |
| 173 | + labels = labels.to(device) |
| 174 | + |
| 175 | + # zero the parameter gradients |
| 176 | + optimizer.zero_grad() |
| 177 | + |
| 178 | + # forward |
| 179 | + # track history if only in train |
| 180 | + with torch.set_grad_enabled(phase == 'train'): |
| 181 | + outputs = model(inputs) |
| 182 | + _, preds = torch.max(outputs, 1) |
| 183 | + loss = criterion(outputs, labels) |
| 184 | + |
| 185 | + # backward + optimize only if in training phase |
| 186 | + if phase == 'train': |
| 187 | + loss.backward() |
| 188 | + optimizer.step() |
| 189 | + |
| 190 | + # statistics |
| 191 | + running_loss += loss.item() * inputs.size(0) |
| 192 | + running_corrects += torch.sum(preds == labels.data) |
| 193 | + if phase == 'train': |
| 194 | + scheduler.step() |
| 195 | + |
| 196 | + epoch_loss = running_loss / dataset_sizes[phase] |
| 197 | + epoch_acc = running_corrects.double() / dataset_sizes[phase] |
| 198 | + |
| 199 | + print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') |
| 200 | + |
| 201 | + # deep copy the model |
| 202 | + if phase == 'val' and epoch_acc > best_acc: |
| 203 | + best_acc = epoch_acc |
| 204 | + torch.save(model.state_dict(), best_model_params_path) |
| 205 | + |
| 206 | + print() |
| 207 | + |
| 208 | + time_elapsed = time.time() - since |
| 209 | + print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s') |
| 210 | + print(f'Best val Acc: {best_acc:4f}') |
| 211 | + |
| 212 | + # load best model weights |
| 213 | + model.load_state_dict(torch.load(best_model_params_path)) |
210 | 214 | return model
|
211 | 215 |
|
212 | 216 |
|
|
0 commit comments