Skip to content

Commit db17499

Browse files
committed
2 parents 925bd86 + 19784e0 commit db17499

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,8 @@ Results/Notebook.ipynb
2222
*.ipynb_checkpoints
2323
*.drive_access_key.json
2424
drive_access_key.json
25-
drive_access_key
25+
drive_access_key
26+
settings.json
27+
launch.json
28+
results/data_and_graphs/Cart_Pole_Results_Data.pkl
29+
results/data_and_graphs/Cart_Pole_Results_Graph.png

agents/Base_Agent.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, config):
2121
self.set_random_seeds(config.seed)
2222
self.environment = config.environment
2323
self.environment_title = self.get_environment_title()
24-
self.action_types = "DISCRETE" if self.environment.action_space.dtype == int else "CONTINUOUS"
24+
self.action_types = "DISCRETE" if self.environment.action_space.dtype == np.int64 else "CONTINUOUS"
2525
self.action_size = int(self.get_action_size())
2626
self.config.action_size = self.action_size
2727

@@ -106,14 +106,18 @@ def get_score_required_to_win(self):
106106

107107
def get_trials(self):
108108
"""Gets the number of trials to average a score over"""
109-
if self.environment_title in ["AntMaze", "FetchReach", "Hopper", "Walker2d"]: return 100
109+
if self.environment_title in ["AntMaze", "FetchReach", "Hopper", "Walker2d", "CartPole"]: return 100
110110
try: return self.environment.unwrapped.trials
111111
except AttributeError: return self.environment.spec.trials
112112

113113
def setup_logger(self):
114114
"""Sets up the logger"""
115115
filename = "Training.log"
116-
if os.path.isfile(filename): os.remove(filename)
116+
try:
117+
if os.path.isfile(filename):
118+
os.remove(filename)
119+
except: pass
120+
117121
logger = logging.getLogger(__name__)
118122
logger.setLevel(logging.INFO)
119123
# create a file handler

results/Cart_Pole.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
import os
2+
import sys
3+
from os.path import dirname, abspath
4+
sys.path.append(dirname(dirname(abspath(__file__))))
5+
16
import gym
27

3-
from A2C import A2C
4-
from Dueling_DDQN import Dueling_DDQN
5-
from SAC_Discrete import SAC_Discrete
8+
from agents.actor_critic_agents.A2C import A2C
9+
from agents.DQN_agents.Dueling_DDQN import Dueling_DDQN
10+
from agents.actor_critic_agents.SAC_Discrete import SAC_Discrete
611
from agents.actor_critic_agents.A3C import A3C
712
from agents.policy_gradient_agents.PPO import PPO
813
from agents.Trainer import Trainer
@@ -16,8 +21,8 @@
1621
config.seed = 1
1722
config.environment = gym.make("CartPole-v0")
1823
config.num_episodes_to_run = 450
19-
config.file_to_save_data_results = "data_and_graphs/Cart_Pole_Results_Data.pkl"
20-
config.file_to_save_results_graph = "data_and_graphs/Cart_Pole_Results_Graph.png"
24+
config.file_to_save_data_results = "results/data_and_graphs/Cart_Pole_Results_Data.pkl"
25+
config.file_to_save_results_graph = "results/data_and_graphs/Cart_Pole_Results_Graph.png"
2126
config.show_solution_score = False
2227
config.visualise_individual_results = False
2328
config.visualise_overall_agent_results = True

0 commit comments

Comments
 (0)