Skip to content

Commit bba204e

Browse files
authored
Merge branch 'main' into issue2349
2 parents 939c19e + 789fc09 commit bba204e

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

Diff for: beginner_source/introyt/tensorboardyt_tutorial.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,14 @@ def forward(self, x):
214214
# Check against the validation set
215215
running_vloss = 0.0
216216

217-
net.train(False) # Don't need to track gradents for validation
217+
# In evaluation mode some model specific operations can be omitted eg. dropout layer
218+
net.train(False) # Switching to evaluation mode, eg. turning off regularisation
218219
for j, vdata in enumerate(validation_loader, 0):
219220
vinputs, vlabels = vdata
220221
voutputs = net(vinputs)
221222
vloss = criterion(voutputs, vlabels)
222223
running_vloss += vloss.item()
223-
net.train(True) # Turn gradients back on for training
224+
net.train(True) # Switching back to training mode, eg. turning on regularisation
224225

225226
avg_loss = running_loss / 1000
226227
avg_vloss = running_vloss / len(validation_loader)

Diff for: intermediate_source/mario_rl_tutorial.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def __init__(self, env, shape):
199199

200200
def observation(self, observation):
201201
transforms = T.Compose(
202-
[T.Resize(self.shape), T.Normalize(0, 255)]
202+
[T.Resize(self.shape, antialias=True), T.Normalize(0, 255)]
203203
)
204204
observation = transforms(observation).squeeze(0)
205205
return observation

0 commit comments

Comments
 (0)