Skip to content

Commit 0740801

Browse files
alperenunlusvekars
andauthored
Improve DQN Tutorial (#2934)
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent be898cb commit 0740801

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

intermediate_source/reinforcement_q_learning.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
1010
on the CartPole-v1 task from `Gymnasium <https://gymnasium.farama.org>`__.
1111
12+
You might find it helpful to read the original `Deep Q Learning (DQN) <https://arxiv.org/abs/1312.5602>`__ paper
13+
1214
**Task**
1315
1416
The agent has to decide between two actions - moving the cart left or
@@ -83,7 +85,11 @@
8385
plt.ion()
8486

8587
# if GPU is to be used
86-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88+
device = torch.device(
89+
"cuda" if torch.cuda.is_available() else
90+
"mps" if torch.backends.mps.is_available() else
91+
"cpu"
92+
)
8793

8894

8995
######################################################################
@@ -397,7 +403,7 @@ def optimize_model():
397403
# can produce better results if convergence is not observed.
398404
#
399405

400-
if torch.cuda.is_available():
406+
if torch.cuda.is_available() or torch.backends.mps.is_available():
401407
num_episodes = 600
402408
else:
403409
num_episodes = 50

0 commit comments

Comments
 (0)