Skip to content

Commit ef5cbac

Browse files
committed
moved HRL to its own folder
1 parent f768dc0 commit ef5cbac

File tree

18 files changed

+1132
-642
lines changed

18 files changed

+1132
-642
lines changed

agents/Base_Agent.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,14 @@ def save_experience(self, memory=None, experience=None):
265265

266266
def take_optimisation_step(self, optimizer, network, loss, clipping_norm=None, retain_graph=False):
267267
"""Takes an optimisation step by calculating gradients given the loss and then updating the parameters"""
268+
if not isinstance(network, list): network = [network]
268269
optimizer.zero_grad() #reset gradients to 0
269270
loss.backward(retain_graph=retain_graph) #this calculates the gradients
270271
self.logger.info("Loss -- {}".format(loss.item()))
271272
if self.debug_mode: self.log_gradient_and_weight_information(network, optimizer)
272273
if clipping_norm is not None:
273-
torch.nn.utils.clip_grad_norm_(network.parameters(), clipping_norm) #clip gradients to help stabilise training
274+
for net in network:
275+
torch.nn.utils.clip_grad_norm_(net.parameters(), clipping_norm) #clip gradients to help stabilise training
274276
optimizer.step() #this applies the gradients
275277

276278
def log_gradient_and_weight_information(self, network, optimizer):
@@ -295,10 +297,10 @@ def soft_update_of_target_network(self, local_model, target_model, tau):
295297
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
296298
target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
297299

298-
def create_NN(self, input_dim, output_dim, key_to_use=None, override_seed=None):
300+
def create_NN(self, input_dim, output_dim, key_to_use=None, override_seed=None, hyperparameters=None):
299301
"""Creates a neural network for the agents to use"""
300-
if key_to_use: hyperparameters = self.hyperparameters[key_to_use]
301-
else: hyperparameters = self.hyperparameters
302+
if hyperparameters is None: hyperparameters = self.hyperparameters
303+
if key_to_use: hyperparameters = hyperparameters[key_to_use]
302304
if override_seed: seed = override_seed
303305
else: seed = self.config.seed
304306

agents/Trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def create_agent_to_agent_group_dictionary(self):
4141
"HIRO": "HIRO",
4242
"SAC": "Actor_Critic_Agents",
4343
"HRL": "HRL",
44+
"Model_HRL": "HRL",
4445
"DIAYN": "DIAYN",
4546
"Dueling DDQN": "DQN_Agents"
4647
}

agents/hierarchical_agents/HRL.py

Lines changed: 0 additions & 584 deletions
This file was deleted.

agents/hierarchical_agents/HRL/DDQN_HRL.py

Lines changed: 329 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import copy
2+
import torch
3+
import random
4+
import numpy as np
5+
import torch.nn.functional as F
6+
7+
from collections import Counter
8+
from torch import optim
9+
from Base_Agent import Base_Agent
10+
from Replay_Buffer import Replay_Buffer
11+
from exploration_strategies.Epsilon_Greedy_Exploration import Epsilon_Greedy_Exploration
12+
13+
14+
class DDQN_Wrapper(Base_Agent):
15+
16+
def __init__(self, config, global_action_id_to_primitive_actions, action_length_reward_bonus, end_of_episode_symbol = "/"):
17+
super().__init__(config)
18+
self.end_of_episode_symbol = end_of_episode_symbol
19+
self.global_action_id_to_primitive_actions = global_action_id_to_primitive_actions
20+
self.memory = Replay_Buffer(self.hyperparameters["buffer_size"], self.hyperparameters["batch_size"], config.seed)
21+
self.exploration_strategy = Epsilon_Greedy_Exploration(config)
22+
23+
self.oracle = self.create_oracle()
24+
self.oracle_optimizer = optim.Adam(self.oracle.parameters(), lr=self.hyperparameters["learning_rate"])
25+
26+
self.q_network_local = self.create_NN(input_dim=self.state_size + 1, output_dim=self.action_size)
27+
self.q_network_local.print_model_summary()
28+
self.q_network_optimizer = optim.Adam(self.q_network_local.parameters(), lr=self.hyperparameters["learning_rate"])
29+
self.q_network_target = self.create_NN(input_dim=self.state_size + 1, output_dim=self.action_size)
30+
Base_Agent.copy_model_over(from_model=self.q_network_local, to_model=self.q_network_target)
31+
32+
self.action_length_reward_bonus = action_length_reward_bonus
33+
self.abandon_ship = config.hyperparameters["abandon_ship"]
34+
35+
def create_oracle(self):
36+
"""Creates the network we will use to predict the next state"""
37+
oracle_hyperparameters = copy.deepcopy(self.hyperparameters)
38+
oracle_hyperparameters["columns_of_data_to_be_embedded"] = []
39+
oracle_hyperparameters["embedding_dimensions"] = []
40+
oracle_hyperparameters["linear_hidden_units"] = [5, 5]
41+
oracle_hyperparameters["final_layer_activation"] = [None, "tanh"]
42+
oracle = self.create_NN(input_dim=self.state_size + 2, output_dim=[self.state_size + 1, 1], hyperparameters=oracle_hyperparameters)
43+
oracle.print_model_summary()
44+
return oracle
45+
46+
def run_n_episodes(self, num_episodes, episodes_to_run_with_no_exploration):
47+
self.turn_on_any_epsilon_greedy_exploration()
48+
self.round_of_macro_actions = []
49+
self.episode_actions_scores_and_exploration_status = []
50+
num_episodes_to_get_to = self.episode_number + num_episodes
51+
while self.episode_number < num_episodes_to_get_to:
52+
self.reset_game()
53+
self.step()
54+
self.save_and_print_result()
55+
if num_episodes_to_get_to - self.episode_number == episodes_to_run_with_no_exploration:
56+
self.turn_off_any_epsilon_greedy_exploration()
57+
assert len(self.episode_actions_scores_and_exploration_status) == num_episodes, "{} vs. {}".format(len(self.episode_actions_scores_and_exploration_status),
58+
num_episodes)
59+
assert len(self.episode_actions_scores_and_exploration_status[0]) == 3
60+
assert self.episode_actions_scores_and_exploration_status[0][2] in [True, False]
61+
assert isinstance(self.episode_actions_scores_and_exploration_status[0][1], list)
62+
assert isinstance(self.episode_actions_scores_and_exploration_status[0][1][0], int)
63+
assert isinstance(self.episode_actions_scores_and_exploration_status[0][0], int) or isinstance(self.episode_actions_scores_and_exploration_status[0][0], float)
64+
return self.episode_actions_scores_and_exploration_status, self.round_of_macro_actions
65+
66+
def step(self):
67+
"""Runs a step within a game including a learning step if required"""
68+
step_number = 0.0
69+
self.state = np.append(self.state, step_number / 200.0) #Divide by 200 because there are 200 steps in cart pole
70+
71+
self.total_episode_score_so_far = 0
72+
episode_macro_actions = []
73+
while not self.done:
74+
surprised = False
75+
macro_action = self.pick_action()
76+
primitive_actions = self.global_action_id_to_primitive_actions[macro_action]
77+
primitive_actions_conducted = 0
78+
for ix, action in enumerate(primitive_actions):
79+
if self.abandon_ship and primitive_actions_conducted > 0:
80+
if self.abandon_macro_action(action):
81+
break
82+
83+
step_number += 1
84+
self.action = action
85+
self.next_state, self.reward, self.done, _ = self.environment.step(action)
86+
self.next_state = np.append(self.next_state, step_number / 200.0) #Divide by 200 because there are 200 steps in cart pole
87+
88+
self.total_episode_score_so_far += self.reward
89+
if self.hyperparameters["clip_rewards"]: self.reward = max(min(self.reward, 1.0), -1.0)
90+
primitive_actions_conducted += 1
91+
self.track_episodes_data()
92+
self.save_experience()
93+
94+
if len(primitive_actions) > 1:
95+
96+
surprised = self.am_i_surprised()
97+
98+
99+
self.state = self.next_state
100+
if self.time_for_q_network_to_learn():
101+
for _ in range(self.hyperparameters["learning_iterations"]):
102+
self.q_network_learn()
103+
self.oracle_learn()
104+
if self.done or surprised: break
105+
episode_macro_actions.append(macro_action)
106+
self.round_of_macro_actions.append(macro_action)
107+
if random.random() < 0.1: print(Counter(episode_macro_actions))
108+
self.save_episode_actions_with_score()
109+
self.episode_number += 1
110+
self.logger.info("END OF EPISODE")
111+
112+
def am_i_surprised(self):
113+
"""Returns boolean indicating whether the next_state was a surprise or not"""
114+
with torch.no_grad():
115+
state = torch.from_numpy(self.state).float().unsqueeze(0).to(self.device)
116+
action = torch.Tensor([[self.action]])
117+
118+
119+
states_and_actions = torch.cat((state, action), dim=1) #must change this for all games besides cart pole
120+
predictions = self.oracle(states_and_actions)
121+
predicted_next_state = predictions[0, :-1]
122+
123+
difference = F.mse_loss(predicted_next_state, torch.Tensor(self.next_state))
124+
if difference > 0.5:
125+
print("Surprise! Loss {} -- {} vs. {}".format(difference, predicted_next_state, self.next_state))
126+
return True
127+
else: return False
128+
129+
130+
def abandon_macro_action(self, action):
131+
"""Returns boolean indicating whether to abandon macro action or not"""
132+
state = torch.from_numpy(self.state).float().unsqueeze(0).to(self.device)
133+
with torch.no_grad():
134+
primitive_q_values = self.calculate_q_values(state, local=True, primitive_actions_only=True)
135+
q_value_highest = torch.max(primitive_q_values)
136+
q_values_action = primitive_q_values[:, action]
137+
if q_value_highest > 0.0: multiplier = 0.7
138+
else: multiplier = 1.3
139+
if q_values_action < multiplier * q_value_highest:
140+
print("BREAKING Action {} -- Q Values {}".format(action, primitive_q_values))
141+
return True
142+
else:
143+
return False
144+
145+
def pick_action(self, state=None):
146+
"""Uses the local Q network and an epsilon greedy policy to pick an action"""
147+
if state is None: state = self.state
148+
if isinstance(state, np.int64) or isinstance(state, int): state = np.array([state])
149+
state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
150+
if len(state.shape) < 2: state = state.unsqueeze(0)
151+
self.q_network_local.eval() #puts network in evaluation mode
152+
with torch.no_grad():
153+
action_values = self.calculate_q_values(state, local=True, primitive_actions_only=False)
154+
self.q_network_local.train() #puts network back in training mode
155+
action = self.exploration_strategy.perturb_action_for_exploration_purposes({"action_values": action_values,
156+
"turn_off_exploration": self.turn_off_exploration,
157+
"episode_number": self.episode_number})
158+
self.logger.info("Q values {} -- Action chosen {}".format(action_values, action))
159+
return action
160+
161+
def calculate_q_values(self, states, local, primitive_actions_only):
162+
"""Calculates the q values using the local q network"""
163+
if local:
164+
primitive_q_values = self.q_network_local(states)
165+
else:
166+
primitive_q_values = self.q_network_target(states)
167+
168+
num_actions = len(self.global_action_id_to_primitive_actions)
169+
if primitive_actions_only or num_actions <= self.action_size:
170+
return primitive_q_values
171+
172+
extra_q_values = self.calculate_macro_action_q_values(states, num_actions)
173+
extra_q_values = torch.Tensor([extra_q_values])
174+
all_q_values = torch.cat((primitive_q_values, extra_q_values), dim=1)
175+
176+
return all_q_values
177+
178+
def calculate_macro_action_q_values(self, state, num_actions):
179+
assert state.shape[0] == 1
180+
q_values = []
181+
for action_id in range(self.action_size, num_actions):
182+
macro_action = self.global_action_id_to_primitive_actions[action_id]
183+
predicted_next_state = state
184+
cumulated_reward = 0
185+
action_ix = 0
186+
for action in macro_action[:-1]:
187+
predictions = self.oracle(torch.cat((predicted_next_state, torch.Tensor([[action]])), dim=1))
188+
rewards = predictions[:, -1]
189+
predicted_next_state = predictions[:, :-1]
190+
cumulated_reward += (rewards.item() + self.action_length_reward_bonus) * self.hyperparameters["discount_rate"] ** (action_ix)
191+
action_ix += 1
192+
final_action = macro_action[-1]
193+
final_q_value = self.q_network_local(predicted_next_state)[0, final_action]
194+
total_q_value = cumulated_reward + final_q_value * self.hyperparameters["discount_rate"] ** (action_ix)
195+
q_values.append(total_q_value)
196+
return q_values
197+
198+
def time_for_q_network_to_learn(self):
199+
"""Returns boolean indicating whether enough steps have been taken for learning to begin and there are
200+
enough experiences in the replay buffer to learn from"""
201+
return self.right_amount_of_steps_taken() and self.enough_experiences_to_learn_from()
202+
203+
def right_amount_of_steps_taken(self):
204+
"""Returns boolean indicating whether enough steps have been taken for learning to begin"""
205+
return self.global_step_number % self.hyperparameters["update_every_n_steps"] == 0
206+
207+
def q_network_learn(self, experiences=None):
208+
"""Runs a learning iteration for the Q network"""
209+
if experiences is None: states, actions, rewards, next_states, dones = self.sample_experiences() #Sample experiences
210+
else: states, actions, rewards, next_states, dones = experiences
211+
loss = self.compute_loss(states, next_states, rewards, actions, dones)
212+
self.take_optimisation_step(self.q_network_optimizer, self.q_network_local, loss, self.hyperparameters["gradient_clipping_norm"])
213+
self.soft_update_of_target_network(self.q_network_local, self.q_network_target,
214+
self.hyperparameters["tau"])
215+
216+
def sample_experiences(self):
217+
"""Draws a random sample of experience from the memory buffer"""
218+
experiences = self.memory.sample()
219+
states, actions, rewards, next_states, dones = experiences
220+
return states, actions, rewards, next_states, dones
221+
222+
def compute_loss(self, states, next_states, rewards, actions, dones):
223+
"""Computes the loss required to train the Q network"""
224+
with torch.no_grad():
225+
max_action_indexes = self.calculate_q_values(next_states, local=True, primitive_actions_only=True).detach().argmax(1)
226+
Q_targets_next = self.calculate_q_values(next_states, local=False, primitive_actions_only=True).gather(1,max_action_indexes.unsqueeze(1))
227+
Q_targets = rewards + (self.hyperparameters["discount_rate"] * Q_targets_next * (1 - dones))
228+
Q_expected = self.calculate_q_values(states, local=True, primitive_actions_only=True).gather(1,actions.long()) # must convert actions to long so can be used as index
229+
loss = F.mse_loss(Q_expected, Q_targets)
230+
return loss
231+
232+
def save_episode_actions_with_score(self):
233+
234+
self.episode_actions_scores_and_exploration_status.append([self.total_episode_score_so_far,
235+
self.episode_actions + [self.end_of_episode_symbol],
236+
self.turn_off_exploration])
237+
238+
def oracle_learn(self):
239+
states, actions, rewards, next_states, _ = self.sample_experiences() # Sample experiences
240+
states_and_actions = torch.cat((states, actions), dim=1) #must change this for all games besides cart pole
241+
predictions = self.oracle(states_and_actions)
242+
loss = F.mse_loss(torch.cat((next_states, rewards), dim=1), predictions) / float(next_states.shape[1] + 1.0)
243+
self.take_optimisation_step(self.oracle_optimizer, self.oracle,
244+
loss, self.hyperparameters["gradient_clipping_norm"])
245+
self.logger.info("Oracle Loss {}".format(loss))
246+
247+
248+
# def create_feature_extractor(self):
249+
# """Creates the feature extractor local network and target network. This means that the q_network and oracle
250+
# only need 1 layer"""
251+
# temp_hyperparameters = copy.deepcopy(self.hyperparameters)
252+
# temp_hyperparameters["linear_hidden_units"], output_dim = temp_hyperparameters["linear_hidden_units"][:-1], temp_hyperparameters["linear_hidden_units"][-1]
253+
# temp_hyperparameters["final_layer_activation"] = "relu"
254+
# feature_extractor_local = self.create_NN(input_dim=self.state_size, output_dim=output_dim, hyperparameters=temp_hyperparameters)
255+
# feature_extractor_target = self.create_NN(input_dim=self.state_size, output_dim=output_dim,hyperparameters=temp_hyperparameters)
256+
# Base_Agent.copy_model_over(from_model=feature_extractor_local, to_model=feature_extractor_target)
257+
# feature_extractor_local.print_model_summary()
258+
# return feature_extractor_local, feature_extractor_target, output_dim

0 commit comments

Comments
 (0)