Skip to content

Commit d686b66

Browse files
authored
Fix train loop in trainingyt.py (#2372)
* refactored train loop in trainingyt.py, resolves issue #2230 * Simplified numpy function call, resolves issue #1038
1 parent 4673b14 commit d686b66

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

beginner_source/introyt/trainingyt.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -290,15 +290,19 @@ def train_one_epoch(epoch_index, tb_writer):
290290
model.train(True)
291291
avg_loss = train_one_epoch(epoch_number, writer)
292292

293-
# We don't need gradients on to do reporting
294-
model.train(False)
295-
293+
296294
running_vloss = 0.0
297-
for i, vdata in enumerate(validation_loader):
298-
vinputs, vlabels = vdata
299-
voutputs = model(vinputs)
300-
vloss = loss_fn(voutputs, vlabels)
301-
running_vloss += vloss
295+
# Set the model to evaluation mode, disabling dropout and using population
296+
# statistics for batch normalization.
297+
model.eval()
298+
299+
# Disable gradient computation and reduce memory consumption.
300+
with torch.no_grad():
301+
for i, vdata in enumerate(validation_loader):
302+
vinputs, vlabels = vdata
303+
voutputs = model(vinputs)
304+
vloss = loss_fn(voutputs, vlabels)
305+
running_vloss += vloss
302306

303307
avg_vloss = running_vloss / (i + 1)
304308
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

intermediate_source/torchvision_tutorial.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ Let’s write a ``torch.utils.data.Dataset`` class for this dataset.
145145
num_objs = len(obj_ids)
146146
boxes = []
147147
for i in range(num_objs):
148-
pos = np.where(masks[i])
148+
pos = np.nonzero(masks[i])
149149
xmin = np.min(pos[1])
150150
xmax = np.max(pos[1])
151151
ymin = np.min(pos[0])

0 commit comments

Comments
 (0)