|
| 1 | +import copy |
| 2 | +import torch |
| 3 | +import random |
| 4 | +import numpy as np |
| 5 | +import torch.nn.functional as F |
| 6 | + |
| 7 | +from collections import Counter |
| 8 | +from torch import optim |
| 9 | +from Base_Agent import Base_Agent |
| 10 | +from Replay_Buffer import Replay_Buffer |
| 11 | +from exploration_strategies.Epsilon_Greedy_Exploration import Epsilon_Greedy_Exploration |
| 12 | + |
| 13 | + |
| 14 | +class DDQN_Wrapper(Base_Agent): |
| 15 | + |
| 16 | + def __init__(self, config, global_action_id_to_primitive_actions, action_length_reward_bonus, end_of_episode_symbol = "/"): |
| 17 | + super().__init__(config) |
| 18 | + self.end_of_episode_symbol = end_of_episode_symbol |
| 19 | + self.global_action_id_to_primitive_actions = global_action_id_to_primitive_actions |
| 20 | + self.memory = Replay_Buffer(self.hyperparameters["buffer_size"], self.hyperparameters["batch_size"], config.seed) |
| 21 | + self.exploration_strategy = Epsilon_Greedy_Exploration(config) |
| 22 | + |
| 23 | + self.oracle = self.create_oracle() |
| 24 | + self.oracle_optimizer = optim.Adam(self.oracle.parameters(), lr=self.hyperparameters["learning_rate"]) |
| 25 | + |
| 26 | + self.q_network_local = self.create_NN(input_dim=self.state_size + 1, output_dim=self.action_size) |
| 27 | + self.q_network_local.print_model_summary() |
| 28 | + self.q_network_optimizer = optim.Adam(self.q_network_local.parameters(), lr=self.hyperparameters["learning_rate"]) |
| 29 | + self.q_network_target = self.create_NN(input_dim=self.state_size + 1, output_dim=self.action_size) |
| 30 | + Base_Agent.copy_model_over(from_model=self.q_network_local, to_model=self.q_network_target) |
| 31 | + |
| 32 | + self.action_length_reward_bonus = action_length_reward_bonus |
| 33 | + self.abandon_ship = config.hyperparameters["abandon_ship"] |
| 34 | + |
| 35 | + def create_oracle(self): |
| 36 | + """Creates the network we will use to predict the next state""" |
| 37 | + oracle_hyperparameters = copy.deepcopy(self.hyperparameters) |
| 38 | + oracle_hyperparameters["columns_of_data_to_be_embedded"] = [] |
| 39 | + oracle_hyperparameters["embedding_dimensions"] = [] |
| 40 | + oracle_hyperparameters["linear_hidden_units"] = [5, 5] |
| 41 | + oracle_hyperparameters["final_layer_activation"] = [None, "tanh"] |
| 42 | + oracle = self.create_NN(input_dim=self.state_size + 2, output_dim=[self.state_size + 1, 1], hyperparameters=oracle_hyperparameters) |
| 43 | + oracle.print_model_summary() |
| 44 | + return oracle |
| 45 | + |
| 46 | + def run_n_episodes(self, num_episodes, episodes_to_run_with_no_exploration): |
| 47 | + self.turn_on_any_epsilon_greedy_exploration() |
| 48 | + self.round_of_macro_actions = [] |
| 49 | + self.episode_actions_scores_and_exploration_status = [] |
| 50 | + num_episodes_to_get_to = self.episode_number + num_episodes |
| 51 | + while self.episode_number < num_episodes_to_get_to: |
| 52 | + self.reset_game() |
| 53 | + self.step() |
| 54 | + self.save_and_print_result() |
| 55 | + if num_episodes_to_get_to - self.episode_number == episodes_to_run_with_no_exploration: |
| 56 | + self.turn_off_any_epsilon_greedy_exploration() |
| 57 | + assert len(self.episode_actions_scores_and_exploration_status) == num_episodes, "{} vs. {}".format(len(self.episode_actions_scores_and_exploration_status), |
| 58 | + num_episodes) |
| 59 | + assert len(self.episode_actions_scores_and_exploration_status[0]) == 3 |
| 60 | + assert self.episode_actions_scores_and_exploration_status[0][2] in [True, False] |
| 61 | + assert isinstance(self.episode_actions_scores_and_exploration_status[0][1], list) |
| 62 | + assert isinstance(self.episode_actions_scores_and_exploration_status[0][1][0], int) |
| 63 | + assert isinstance(self.episode_actions_scores_and_exploration_status[0][0], int) or isinstance(self.episode_actions_scores_and_exploration_status[0][0], float) |
| 64 | + return self.episode_actions_scores_and_exploration_status, self.round_of_macro_actions |
| 65 | + |
| 66 | + def step(self): |
| 67 | + """Runs a step within a game including a learning step if required""" |
| 68 | + step_number = 0.0 |
| 69 | + self.state = np.append(self.state, step_number / 200.0) #Divide by 200 because there are 200 steps in cart pole |
| 70 | + |
| 71 | + self.total_episode_score_so_far = 0 |
| 72 | + episode_macro_actions = [] |
| 73 | + while not self.done: |
| 74 | + surprised = False |
| 75 | + macro_action = self.pick_action() |
| 76 | + primitive_actions = self.global_action_id_to_primitive_actions[macro_action] |
| 77 | + primitive_actions_conducted = 0 |
| 78 | + for ix, action in enumerate(primitive_actions): |
| 79 | + if self.abandon_ship and primitive_actions_conducted > 0: |
| 80 | + if self.abandon_macro_action(action): |
| 81 | + break |
| 82 | + |
| 83 | + step_number += 1 |
| 84 | + self.action = action |
| 85 | + self.next_state, self.reward, self.done, _ = self.environment.step(action) |
| 86 | + self.next_state = np.append(self.next_state, step_number / 200.0) #Divide by 200 because there are 200 steps in cart pole |
| 87 | + |
| 88 | + self.total_episode_score_so_far += self.reward |
| 89 | + if self.hyperparameters["clip_rewards"]: self.reward = max(min(self.reward, 1.0), -1.0) |
| 90 | + primitive_actions_conducted += 1 |
| 91 | + self.track_episodes_data() |
| 92 | + self.save_experience() |
| 93 | + |
| 94 | + if len(primitive_actions) > 1: |
| 95 | + |
| 96 | + surprised = self.am_i_surprised() |
| 97 | + |
| 98 | + |
| 99 | + self.state = self.next_state |
| 100 | + if self.time_for_q_network_to_learn(): |
| 101 | + for _ in range(self.hyperparameters["learning_iterations"]): |
| 102 | + self.q_network_learn() |
| 103 | + self.oracle_learn() |
| 104 | + if self.done or surprised: break |
| 105 | + episode_macro_actions.append(macro_action) |
| 106 | + self.round_of_macro_actions.append(macro_action) |
| 107 | + if random.random() < 0.1: print(Counter(episode_macro_actions)) |
| 108 | + self.save_episode_actions_with_score() |
| 109 | + self.episode_number += 1 |
| 110 | + self.logger.info("END OF EPISODE") |
| 111 | + |
| 112 | + def am_i_surprised(self): |
| 113 | + """Returns boolean indicating whether the next_state was a surprise or not""" |
| 114 | + with torch.no_grad(): |
| 115 | + state = torch.from_numpy(self.state).float().unsqueeze(0).to(self.device) |
| 116 | + action = torch.Tensor([[self.action]]) |
| 117 | + |
| 118 | + |
| 119 | + states_and_actions = torch.cat((state, action), dim=1) #must change this for all games besides cart pole |
| 120 | + predictions = self.oracle(states_and_actions) |
| 121 | + predicted_next_state = predictions[0, :-1] |
| 122 | + |
| 123 | + difference = F.mse_loss(predicted_next_state, torch.Tensor(self.next_state)) |
| 124 | + if difference > 0.5: |
| 125 | + print("Surprise! Loss {} -- {} vs. {}".format(difference, predicted_next_state, self.next_state)) |
| 126 | + return True |
| 127 | + else: return False |
| 128 | + |
| 129 | + |
| 130 | + def abandon_macro_action(self, action): |
| 131 | + """Returns boolean indicating whether to abandon macro action or not""" |
| 132 | + state = torch.from_numpy(self.state).float().unsqueeze(0).to(self.device) |
| 133 | + with torch.no_grad(): |
| 134 | + primitive_q_values = self.calculate_q_values(state, local=True, primitive_actions_only=True) |
| 135 | + q_value_highest = torch.max(primitive_q_values) |
| 136 | + q_values_action = primitive_q_values[:, action] |
| 137 | + if q_value_highest > 0.0: multiplier = 0.7 |
| 138 | + else: multiplier = 1.3 |
| 139 | + if q_values_action < multiplier * q_value_highest: |
| 140 | + print("BREAKING Action {} -- Q Values {}".format(action, primitive_q_values)) |
| 141 | + return True |
| 142 | + else: |
| 143 | + return False |
| 144 | + |
| 145 | + def pick_action(self, state=None): |
| 146 | + """Uses the local Q network and an epsilon greedy policy to pick an action""" |
| 147 | + if state is None: state = self.state |
| 148 | + if isinstance(state, np.int64) or isinstance(state, int): state = np.array([state]) |
| 149 | + state = torch.from_numpy(state).float().unsqueeze(0).to(self.device) |
| 150 | + if len(state.shape) < 2: state = state.unsqueeze(0) |
| 151 | + self.q_network_local.eval() #puts network in evaluation mode |
| 152 | + with torch.no_grad(): |
| 153 | + action_values = self.calculate_q_values(state, local=True, primitive_actions_only=False) |
| 154 | + self.q_network_local.train() #puts network back in training mode |
| 155 | + action = self.exploration_strategy.perturb_action_for_exploration_purposes({"action_values": action_values, |
| 156 | + "turn_off_exploration": self.turn_off_exploration, |
| 157 | + "episode_number": self.episode_number}) |
| 158 | + self.logger.info("Q values {} -- Action chosen {}".format(action_values, action)) |
| 159 | + return action |
| 160 | + |
| 161 | + def calculate_q_values(self, states, local, primitive_actions_only): |
| 162 | + """Calculates the q values using the local q network""" |
| 163 | + if local: |
| 164 | + primitive_q_values = self.q_network_local(states) |
| 165 | + else: |
| 166 | + primitive_q_values = self.q_network_target(states) |
| 167 | + |
| 168 | + num_actions = len(self.global_action_id_to_primitive_actions) |
| 169 | + if primitive_actions_only or num_actions <= self.action_size: |
| 170 | + return primitive_q_values |
| 171 | + |
| 172 | + extra_q_values = self.calculate_macro_action_q_values(states, num_actions) |
| 173 | + extra_q_values = torch.Tensor([extra_q_values]) |
| 174 | + all_q_values = torch.cat((primitive_q_values, extra_q_values), dim=1) |
| 175 | + |
| 176 | + return all_q_values |
| 177 | + |
| 178 | + def calculate_macro_action_q_values(self, state, num_actions): |
| 179 | + assert state.shape[0] == 1 |
| 180 | + q_values = [] |
| 181 | + for action_id in range(self.action_size, num_actions): |
| 182 | + macro_action = self.global_action_id_to_primitive_actions[action_id] |
| 183 | + predicted_next_state = state |
| 184 | + cumulated_reward = 0 |
| 185 | + action_ix = 0 |
| 186 | + for action in macro_action[:-1]: |
| 187 | + predictions = self.oracle(torch.cat((predicted_next_state, torch.Tensor([[action]])), dim=1)) |
| 188 | + rewards = predictions[:, -1] |
| 189 | + predicted_next_state = predictions[:, :-1] |
| 190 | + cumulated_reward += (rewards.item() + self.action_length_reward_bonus) * self.hyperparameters["discount_rate"] ** (action_ix) |
| 191 | + action_ix += 1 |
| 192 | + final_action = macro_action[-1] |
| 193 | + final_q_value = self.q_network_local(predicted_next_state)[0, final_action] |
| 194 | + total_q_value = cumulated_reward + final_q_value * self.hyperparameters["discount_rate"] ** (action_ix) |
| 195 | + q_values.append(total_q_value) |
| 196 | + return q_values |
| 197 | + |
| 198 | + def time_for_q_network_to_learn(self): |
| 199 | + """Returns boolean indicating whether enough steps have been taken for learning to begin and there are |
| 200 | + enough experiences in the replay buffer to learn from""" |
| 201 | + return self.right_amount_of_steps_taken() and self.enough_experiences_to_learn_from() |
| 202 | + |
| 203 | + def right_amount_of_steps_taken(self): |
| 204 | + """Returns boolean indicating whether enough steps have been taken for learning to begin""" |
| 205 | + return self.global_step_number % self.hyperparameters["update_every_n_steps"] == 0 |
| 206 | + |
| 207 | + def q_network_learn(self, experiences=None): |
| 208 | + """Runs a learning iteration for the Q network""" |
| 209 | + if experiences is None: states, actions, rewards, next_states, dones = self.sample_experiences() #Sample experiences |
| 210 | + else: states, actions, rewards, next_states, dones = experiences |
| 211 | + loss = self.compute_loss(states, next_states, rewards, actions, dones) |
| 212 | + self.take_optimisation_step(self.q_network_optimizer, self.q_network_local, loss, self.hyperparameters["gradient_clipping_norm"]) |
| 213 | + self.soft_update_of_target_network(self.q_network_local, self.q_network_target, |
| 214 | + self.hyperparameters["tau"]) |
| 215 | + |
| 216 | + def sample_experiences(self): |
| 217 | + """Draws a random sample of experience from the memory buffer""" |
| 218 | + experiences = self.memory.sample() |
| 219 | + states, actions, rewards, next_states, dones = experiences |
| 220 | + return states, actions, rewards, next_states, dones |
| 221 | + |
| 222 | + def compute_loss(self, states, next_states, rewards, actions, dones): |
| 223 | + """Computes the loss required to train the Q network""" |
| 224 | + with torch.no_grad(): |
| 225 | + max_action_indexes = self.calculate_q_values(next_states, local=True, primitive_actions_only=True).detach().argmax(1) |
| 226 | + Q_targets_next = self.calculate_q_values(next_states, local=False, primitive_actions_only=True).gather(1,max_action_indexes.unsqueeze(1)) |
| 227 | + Q_targets = rewards + (self.hyperparameters["discount_rate"] * Q_targets_next * (1 - dones)) |
| 228 | + Q_expected = self.calculate_q_values(states, local=True, primitive_actions_only=True).gather(1,actions.long()) # must convert actions to long so can be used as index |
| 229 | + loss = F.mse_loss(Q_expected, Q_targets) |
| 230 | + return loss |
| 231 | + |
| 232 | + def save_episode_actions_with_score(self): |
| 233 | + |
| 234 | + self.episode_actions_scores_and_exploration_status.append([self.total_episode_score_so_far, |
| 235 | + self.episode_actions + [self.end_of_episode_symbol], |
| 236 | + self.turn_off_exploration]) |
| 237 | + |
| 238 | + def oracle_learn(self): |
| 239 | + states, actions, rewards, next_states, _ = self.sample_experiences() # Sample experiences |
| 240 | + states_and_actions = torch.cat((states, actions), dim=1) #must change this for all games besides cart pole |
| 241 | + predictions = self.oracle(states_and_actions) |
| 242 | + loss = F.mse_loss(torch.cat((next_states, rewards), dim=1), predictions) / float(next_states.shape[1] + 1.0) |
| 243 | + self.take_optimisation_step(self.oracle_optimizer, self.oracle, |
| 244 | + loss, self.hyperparameters["gradient_clipping_norm"]) |
| 245 | + self.logger.info("Oracle Loss {}".format(loss)) |
| 246 | + |
| 247 | + |
| 248 | + # def create_feature_extractor(self): |
| 249 | + # """Creates the feature extractor local network and target network. This means that the q_network and oracle |
| 250 | + # only need 1 layer""" |
| 251 | + # temp_hyperparameters = copy.deepcopy(self.hyperparameters) |
| 252 | + # temp_hyperparameters["linear_hidden_units"], output_dim = temp_hyperparameters["linear_hidden_units"][:-1], temp_hyperparameters["linear_hidden_units"][-1] |
| 253 | + # temp_hyperparameters["final_layer_activation"] = "relu" |
| 254 | + # feature_extractor_local = self.create_NN(input_dim=self.state_size, output_dim=output_dim, hyperparameters=temp_hyperparameters) |
| 255 | + # feature_extractor_target = self.create_NN(input_dim=self.state_size, output_dim=output_dim,hyperparameters=temp_hyperparameters) |
| 256 | + # Base_Agent.copy_model_over(from_model=feature_extractor_local, to_model=feature_extractor_target) |
| 257 | + # feature_extractor_local.print_model_summary() |
| 258 | + # return feature_extractor_local, feature_extractor_target, output_dim |
0 commit comments