@@ -119,7 +119,7 @@ def make_model(in_size, out_size, num_layers):
119
119
for epoch in range (0 ): # 0 epochs, this section is for illustration only
120
120
for input , target in zip (data , targets ):
121
121
# Runs the forward pass under ``autocast``.
122
- with torch .autocast (dtype = torch .float16 ):
122
+ with torch .autocast (device_type = device , dtype = torch .float16 ):
123
123
output = net (input )
124
124
# output is float16 because linear layers ``autocast`` to float16.
125
125
assert output .dtype is torch .float16
@@ -154,7 +154,7 @@ def make_model(in_size, out_size, num_layers):
154
154
155
155
for epoch in range (0 ): # 0 epochs, this section is for illustration only
156
156
for input , target in zip (data , targets ):
157
- with torch .autocast (dtype = torch .float16 ):
157
+ with torch .autocast (device_type = device , dtype = torch .float16 ):
158
158
output = net (input )
159
159
loss = loss_fn (output , target )
160
160
@@ -187,7 +187,7 @@ def make_model(in_size, out_size, num_layers):
187
187
start_timer ()
188
188
for epoch in range (epochs ):
189
189
for input , target in zip (data , targets ):
190
- with torch .autocast (dtype = torch .float16 , enabled = use_amp ):
190
+ with torch .autocast (device_type = device , dtype = torch .float16 , enabled = use_amp ):
191
191
output = net (input )
192
192
loss = loss_fn (output , target )
193
193
scaler .scale (loss ).backward ()
@@ -205,7 +205,7 @@ def make_model(in_size, out_size, num_layers):
205
205
206
206
for epoch in range (0 ): # 0 epochs, this section is for illustration only
207
207
for input , target in zip (data , targets ):
208
- with torch .autocast (dtype = torch .float16 ):
208
+ with torch .autocast (device_type = device , dtype = torch .float16 ):
209
209
output = net (input )
210
210
loss = loss_fn (output , target )
211
211
scaler .scale (loss ).backward ()
0 commit comments