Skip to content

Commit a76c954

Browse files
committed
Fix amp
Signed-off-by: Onur Berk Töre <onurberk_t@hotmail.com>
1 parent b1a589d commit a76c954

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

Diff for: recipes_source/recipes/amp_recipe.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def make_model(in_size, out_size, num_layers):
119119
for epoch in range(0): # 0 epochs, this section is for illustration only
120120
for input, target in zip(data, targets):
121121
# Runs the forward pass under ``autocast``.
122-
with torch.autocast(dtype=torch.float16):
122+
with torch.autocast(device_type=device, dtype=torch.float16):
123123
output = net(input)
124124
# output is float16 because linear layers ``autocast`` to float16.
125125
assert output.dtype is torch.float16
@@ -154,7 +154,7 @@ def make_model(in_size, out_size, num_layers):
154154

155155
for epoch in range(0): # 0 epochs, this section is for illustration only
156156
for input, target in zip(data, targets):
157-
with torch.autocast(dtype=torch.float16):
157+
with torch.autocast(device_type=device, dtype=torch.float16):
158158
output = net(input)
159159
loss = loss_fn(output, target)
160160

@@ -187,7 +187,7 @@ def make_model(in_size, out_size, num_layers):
187187
start_timer()
188188
for epoch in range(epochs):
189189
for input, target in zip(data, targets):
190-
with torch.autocast(dtype=torch.float16, enabled=use_amp):
190+
with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
191191
output = net(input)
192192
loss = loss_fn(output, target)
193193
scaler.scale(loss).backward()
@@ -205,7 +205,7 @@ def make_model(in_size, out_size, num_layers):
205205

206206
for epoch in range(0): # 0 epochs, this section is for illustration only
207207
for input, target in zip(data, targets):
208-
with torch.autocast(dtype=torch.float16):
208+
with torch.autocast(device_type=device, dtype=torch.float16):
209209
output = net(input)
210210
loss = loss_fn(output, target)
211211
scaler.scale(loss).backward()

0 commit comments

Comments
 (0)