Skip to content

Commit d5d3dec

Browse files
author
Svetlana Karslioglu
authored
Merge branch 'main' into issue_995
2 parents 0e3639d + 5b804b8 commit d5d3dec

File tree

2 files changed

+69
-65
lines changed

2 files changed

+69
-65
lines changed

beginner_source/data_loading_tutorial.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def __call__(self, sample):
266266
h, w = image.shape[:2]
267267
new_h, new_w = self.output_size
268268

269-
top = np.random.randint(0, h - new_h)
270-
left = np.random.randint(0, w - new_w)
269+
top = np.random.randint(0, h - new_h + 1)
270+
left = np.random.randint(0, w - new_w + 1)
271271

272272
image = image[top: top + new_h,
273273
left: left + new_w]
@@ -292,7 +292,7 @@ def __call__(self, sample):
292292

293293
######################################################################
294294
# .. note::
295-
# In the example above, `RandomCrop` uses an external library's random number generator
295+
# In the example above, `RandomCrop` uses an external library's random number generator
296296
# (in this case, Numpy's `np.random.int`). This can result in unexpected behavior with `DataLoader`
297297
# (see `here <https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers>`_).
298298
# In practice, it is safer to stick to PyTorch's random number generator, e.g. by using `torch.randint` instead.

beginner_source/transfer_learning_tutorial.py

+66-62
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
import matplotlib.pyplot as plt
4747
import time
4848
import os
49-
import copy
49+
from tempfile import TemporaryDirectory
5050

5151
cudnn.benchmark = True
5252
plt.ion() # interactive mode
@@ -146,67 +146,71 @@ def imshow(inp, title=None):
146146
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
147147
since = time.time()
148148

149-
best_model_wts = copy.deepcopy(model.state_dict())
150-
best_acc = 0.0
151-
152-
for epoch in range(num_epochs):
153-
print(f'Epoch {epoch}/{num_epochs - 1}')
154-
print('-' * 10)
155-
156-
# Each epoch has a training and validation phase
157-
for phase in ['train', 'val']:
158-
if phase == 'train':
159-
model.train() # Set model to training mode
160-
else:
161-
model.eval() # Set model to evaluate mode
162-
163-
running_loss = 0.0
164-
running_corrects = 0
165-
166-
# Iterate over data.
167-
for inputs, labels in dataloaders[phase]:
168-
inputs = inputs.to(device)
169-
labels = labels.to(device)
170-
171-
# zero the parameter gradients
172-
optimizer.zero_grad()
173-
174-
# forward
175-
# track history if only in train
176-
with torch.set_grad_enabled(phase == 'train'):
177-
outputs = model(inputs)
178-
_, preds = torch.max(outputs, 1)
179-
loss = criterion(outputs, labels)
180-
181-
# backward + optimize only if in training phase
182-
if phase == 'train':
183-
loss.backward()
184-
optimizer.step()
185-
186-
# statistics
187-
running_loss += loss.item() * inputs.size(0)
188-
running_corrects += torch.sum(preds == labels.data)
189-
if phase == 'train':
190-
scheduler.step()
191-
192-
epoch_loss = running_loss / dataset_sizes[phase]
193-
epoch_acc = running_corrects.double() / dataset_sizes[phase]
194-
195-
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
196-
197-
# deep copy the model
198-
if phase == 'val' and epoch_acc > best_acc:
199-
best_acc = epoch_acc
200-
best_model_wts = copy.deepcopy(model.state_dict())
201-
202-
print()
203-
204-
time_elapsed = time.time() - since
205-
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
206-
print(f'Best val Acc: {best_acc:4f}')
207-
208-
# load best model weights
209-
model.load_state_dict(best_model_wts)
149+
# Create a temporary directory to save training checkpoints
150+
with TemporaryDirectory() as tempdir:
151+
best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
152+
153+
torch.save(model.state_dict(), best_model_params_path)
154+
best_acc = 0.0
155+
156+
for epoch in range(num_epochs):
157+
print(f'Epoch {epoch}/{num_epochs - 1}')
158+
print('-' * 10)
159+
160+
# Each epoch has a training and validation phase
161+
for phase in ['train', 'val']:
162+
if phase == 'train':
163+
model.train() # Set model to training mode
164+
else:
165+
model.eval() # Set model to evaluate mode
166+
167+
running_loss = 0.0
168+
running_corrects = 0
169+
170+
# Iterate over data.
171+
for inputs, labels in dataloaders[phase]:
172+
inputs = inputs.to(device)
173+
labels = labels.to(device)
174+
175+
# zero the parameter gradients
176+
optimizer.zero_grad()
177+
178+
# forward
179+
# track history if only in train
180+
with torch.set_grad_enabled(phase == 'train'):
181+
outputs = model(inputs)
182+
_, preds = torch.max(outputs, 1)
183+
loss = criterion(outputs, labels)
184+
185+
# backward + optimize only if in training phase
186+
if phase == 'train':
187+
loss.backward()
188+
optimizer.step()
189+
190+
# statistics
191+
running_loss += loss.item() * inputs.size(0)
192+
running_corrects += torch.sum(preds == labels.data)
193+
if phase == 'train':
194+
scheduler.step()
195+
196+
epoch_loss = running_loss / dataset_sizes[phase]
197+
epoch_acc = running_corrects.double() / dataset_sizes[phase]
198+
199+
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
200+
201+
# deep copy the model
202+
if phase == 'val' and epoch_acc > best_acc:
203+
best_acc = epoch_acc
204+
torch.save(model.state_dict(), best_model_params_path)
205+
206+
print()
207+
208+
time_elapsed = time.time() - since
209+
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
210+
print(f'Best val Acc: {best_acc:4f}')
211+
212+
# load best model weights
213+
model.load_state_dict(torch.load(best_model_params_path))
210214
return model
211215

212216

0 commit comments

Comments
 (0)