From 7111087c9505872745d833318f65901ee909f631 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Luis=20Castro=20Garc=C3=ADa?= Date: Wed, 31 May 2023 15:15:44 -0600 Subject: [PATCH 1/2] refactored train loop in trainingyt.py, resolves issue #2230 --- beginner_source/introyt/trainingyt.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/beginner_source/introyt/trainingyt.py b/beginner_source/introyt/trainingyt.py index 929e06c1b57..d9f585411e8 100644 --- a/beginner_source/introyt/trainingyt.py +++ b/beginner_source/introyt/trainingyt.py @@ -290,15 +290,19 @@ def train_one_epoch(epoch_index, tb_writer): model.train(True) avg_loss = train_one_epoch(epoch_number, writer) - # We don't need gradients on to do reporting - model.train(False) - + running_vloss = 0.0 - for i, vdata in enumerate(validation_loader): - vinputs, vlabels = vdata - voutputs = model(vinputs) - vloss = loss_fn(voutputs, vlabels) - running_vloss += vloss + # Set the model to evaluation mode, disabling dropout and using population + # statistics for batch normalization. + model.eval() + + # Disable gradient computation and reduce memory consumption. + with torch.no_grad(): + for i, vdata in enumerate(validation_loader): + vinputs, vlabels = vdata + voutputs = model(vinputs) + vloss = loss_fn(voutputs, vlabels) + running_vloss += vloss avg_vloss = running_vloss / (i + 1) print('LOSS train {} valid {}'.format(avg_loss, avg_vloss)) From c5a30960ae923e6da7d73175224e993fca09b18d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Luis=20Castro=20Garc=C3=ADa?= Date: Wed, 31 May 2023 17:44:03 -0600 Subject: [PATCH 2/2] Simplified numpy function call, resolves issue #1038 --- intermediate_source/torchvision_tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/torchvision_tutorial.rst b/intermediate_source/torchvision_tutorial.rst index 9e3d1b9655c..21d47e258f7 100644 --- a/intermediate_source/torchvision_tutorial.rst +++ b/intermediate_source/torchvision_tutorial.rst @@ -145,7 +145,7 @@ Let’s write a ``torch.utils.data.Dataset`` class for this dataset. num_objs = len(obj_ids) boxes = [] for i in range(num_objs): - pos = np.where(masks[i]) + pos = np.nonzero(masks[i]) xmin = np.min(pos[1]) xmax = np.max(pos[1]) ymin = np.min(pos[0])