File tree 1 file changed +8
-2
lines changed
1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change 9
9
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
10
10
on the CartPole-v1 task from `Gymnasium <https://gymnasium.farama.org>`__.
11
11
12
+ You might find it helpful to read the original `Deep Q Learning (DQN) <https://arxiv.org/abs/1312.5602>`__ paper
13
+
12
14
**Task**
13
15
14
16
The agent has to decide between two actions - moving the cart left or
83
85
plt .ion ()
84
86
85
87
# if GPU is to be used
86
- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
88
+ device = torch .device (
89
+ "cuda" if torch .cuda .is_available () else
90
+ "mps" if torch .backends .mps .is_available () else
91
+ "cpu"
92
+ )
87
93
88
94
89
95
######################################################################
@@ -397,7 +403,7 @@ def optimize_model():
397
403
# can produce better results if convergence is not observed.
398
404
#
399
405
400
- if torch .cuda .is_available ():
406
+ if torch .cuda .is_available () or torch . backends . mps . is_available () :
401
407
num_episodes = 600
402
408
else :
403
409
num_episodes = 50
You can’t perform that action at this time.
0 commit comments