Skip to content

Commit f719b11

Browse files
committed
updated memory shaper tests
1 parent ca4a80c commit f719b11

File tree

4 files changed

+61
-94
lines changed

4 files changed

+61
-94
lines changed

agents/hierarchical_agents/HRL.py

Lines changed: 46 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def __init__(self, config):
120120
self.reduce_macro_action_appearance_cutoff_throughout_training = self.hyperparameters["reduce_macro_action_appearance_cutoff_throughout_training"]
121121
self.add_1_macro_action_at_a_time = self.hyperparameters["add_1_macro_action_at_a_time"]
122122

123-
124123
self.min_num_episodes_to_play = self.hyperparameters["min_num_episodes_to_play"]
125124

126125

@@ -139,80 +138,56 @@ def run_n_episodes(self, num_episodes=None, show_whether_achieved_goal=True, sav
139138
self.grammar_induction_iteration = 1
140139

141140
while self.episodes_conducted < self.num_episodes:
142-
143-
self.episode_actions_scores_and_exploration_status, round_of_macro_actions = \
144-
self.agent.run_n_episodes(num_episodes=self.calculate_how_many_episodes_to_play(),
145-
episodes_to_run_with_no_exploration=self.episodes_to_run_with_no_exploration)
146-
147-
148-
self.episodes_conducted += len(self.episode_actions_scores_and_exploration_status)
149-
actions_to_infer_grammar_from = self.pick_actions_to_infer_grammar_from(
150-
self.episode_actions_scores_and_exploration_status)
151-
152-
num_actions_before = len(self.global_action_id_to_primitive_action)
153-
154-
if self.infer_new_grammar:
155-
self.update_action_choices(actions_to_infer_grammar_from)
156-
157-
else:
158-
print("NOT inferring new grammar because no better results found")
159-
160-
print("New actions ", self.global_action_id_to_primitive_action)
161-
162-
self.new_actions_just_added = list(range(num_actions_before, num_actions_before + len(self.global_action_id_to_primitive_action) - num_actions_before))
163-
print("Actions just added ", self.new_actions_just_added)
164-
165-
assert len(set(self.global_action_id_to_primitive_action.values())) == len(
166-
self.global_action_id_to_primitive_action.values()), \
167-
"Not all actions are unique anymore: {}".format(self.global_action_id_to_primitive_action)
168-
169-
for key, value in self.global_action_id_to_primitive_action.items():
170-
assert max(value) < self.action_size, "Actions should be in terms of primitive actions"
171-
172-
self.grammar_induction_iteration += 1
173-
174-
current_num_actions = len(self.global_action_id_to_primitive_action.keys())
175-
176-
if self.only_train_new_actions:
177-
PRE_TRAINING_ITERATIONS = int(self.pre_training_learning_iterations_multiplier) # * (len(self.new_actions_just_added) ** 1.25))
178-
else:
179-
PRE_TRAINING_ITERATIONS = int(self.pre_training_learning_iterations_multiplier) # * (current_num_actions ** 1.25))
180-
181-
print(" ")
182-
183-
print("PRE TRAINING ITERATIONS ", PRE_TRAINING_ITERATIONS)
184-
185-
print(" ")
186-
187-
self.agent.update_agent_for_new_actions(self.global_action_id_to_primitive_action,
188-
copy_over_hidden_layers=self.copy_over_hidden_layers,
189-
change_or_append_final_layer="APPEND")
190-
191-
if num_actions_before != len(self.global_action_id_to_primitive_action):
192-
replay_buffer = self.memory_shaper.put_adapted_experiences_in_a_replay_buffer(
193-
self.global_action_id_to_primitive_action)
194-
195-
print(" ------ ")
196-
print("Length of buffer {} -- Actions {} -- Pre training iterations {}".format(len(replay_buffer),
197-
current_num_actions,
198-
PRE_TRAINING_ITERATIONS))
199-
print(" ------ ")
200-
self.overwrite_replay_buffer_and_pre_train_agent(replay_buffer, PRE_TRAINING_ITERATIONS,
201-
only_train_final_layer=self.only_train_final_layer, only_train_new_actions=self.only_train_new_actions)
202-
print("Now there are {} actions: {}".format(current_num_actions, self.global_action_id_to_primitive_action))
203-
204-
205-
episode_actions = [data[1] for data in self.episode_actions_scores_and_exploration_status]
206-
flat_episode_actions = [ep_action for ep in episode_actions for ep_action in ep]
207-
208-
final_actions_count = Counter(round_of_macro_actions)
141+
self.play_new_episodes()
142+
self.infer_new_grammar()
143+
self.update_agent()
144+
final_actions_count = Counter(self.round_of_macro_actions)
209145
print("FINAL EPISODE SET ACTIONS COUNT ", final_actions_count)
210-
211-
212146
time_taken = time.time() - start
213-
214147
return self.agent.game_full_episode_scores[:self.num_episodes], self.agent.rolling_results[:self.num_episodes], time_taken
215148

149+
def play_new_episodes(self):
150+
"""Plays a new set of episodes using the recently updated agent"""
151+
self.episode_actions_scores_and_exploration_status, self.round_of_macro_actions = \
152+
self.agent.run_n_episodes(num_episodes=self.calculate_how_many_episodes_to_play(),
153+
episodes_to_run_with_no_exploration=self.episodes_to_run_with_no_exploration)
154+
self.episodes_conducted += len(self.episode_actions_scores_and_exploration_status)
155+
156+
def infer_new_grammar(self):
157+
"""Infers a new action grammar and updates the global action set"""
158+
actions_to_infer_grammar_from = self.pick_actions_to_infer_grammar_from(self.episode_actions_scores_and_exploration_status)
159+
num_actions_before = len(self.global_action_id_to_primitive_action)
160+
if self.infer_new_grammar: self.update_action_choices(actions_to_infer_grammar_from)
161+
else: print("NOT inferring new grammar because no better results found")
162+
self.new_actions_just_added = list(range(num_actions_before, num_actions_before + len(
163+
self.global_action_id_to_primitive_action) - num_actions_before))
164+
self.check_new_global_actions_valid()
165+
self.grammar_induction_iteration += 1
166+
167+
def check_new_global_actions_valid(self):
168+
"""Checks that global_action_id_to_primitive_action still only has valid entries"""
169+
assert len(set(self.global_action_id_to_primitive_action.values())) == len(
170+
self.global_action_id_to_primitive_action.values()), \
171+
"Not all actions are unique anymore: {}".format(self.global_action_id_to_primitive_action)
172+
for key, value in self.global_action_id_to_primitive_action.items():
173+
assert max(value) < self.action_size, "Actions should be in terms of primitive actions"
174+
175+
def update_agent(self):
176+
"""Updates the agent according to new action set by changing its action set, creating a new replay buffer
177+
and doing any pretraining"""
178+
current_num_actions = len(self.global_action_id_to_primitive_action.keys())
179+
PRE_TRAINING_ITERATIONS = int(self.pre_training_learning_iterations_multiplier)
180+
self.agent.update_agent_for_new_actions(self.global_action_id_to_primitive_action,
181+
copy_over_hidden_layers=self.copy_over_hidden_layers,
182+
change_or_append_final_layer="APPEND")
183+
if len(self.new_actions_just_added) > 0:
184+
replay_buffer = self.memory_shaper.put_adapted_experiences_in_a_replay_buffer(
185+
self.global_action_id_to_primitive_action)
186+
self.overwrite_replay_buffer_and_pre_train_agent(replay_buffer, PRE_TRAINING_ITERATIONS,
187+
only_train_final_layer=self.only_train_final_layer,
188+
only_train_new_actions=self.only_train_new_actions)
189+
print("Now there are {} actions: {}".format(current_num_actions, self.global_action_id_to_primitive_action))
190+
216191
def calculate_how_many_episodes_to_play(self):
217192
"""Calculates how many episodes the agent should play until we re-infer the grammar"""
218193
episodes_to_play = self.hyperparameters["epsilon_decay_rate_denominator"] / self.grammar_induction_iteration
@@ -291,38 +266,24 @@ def pick_new_macro_actions(self, rules_episode_appearance_count):
291266
"""
292267
new_unflattened_actions = {}
293268
cutoff = self.num_top_results_to_use * self.action_frequency_required_in_top_results
294-
295269
if self.reduce_macro_action_appearance_cutoff_throughout_training:
296-
297270
cutoff = cutoff / (self.grammar_induction_iteration**0.5)
298-
299271
print(" ")
300272
print("Cutoff ", cutoff)
301273
print(" ")
302274
action_id = len(self.global_action_id_to_primitive_action.keys())
303-
304-
305-
306275
counts = {}
307-
308276
for rule in rules_episode_appearance_count.keys():
309277
count = rules_episode_appearance_count[rule]
310-
311278
# count = count * (len(rule))**0.25
312-
313279
print("Rule {} -- Count {}".format(rule, count))
314280
if count >= cutoff:
315281
new_unflattened_actions[action_id] = rule
316282
counts[action_id] = count
317283
action_id += 1
318-
319-
320-
321284
new_actions = flatten_action_id_to_actions(new_unflattened_actions, self.global_action_id_to_primitive_action,
322285
self.action_size)
323-
324286
if self.add_1_macro_action_at_a_time:
325-
326287
max_count = 0
327288
best_rule = None
328289
for action_id, primitive_actions in new_actions.items():
@@ -331,11 +292,9 @@ def pick_new_macro_actions(self, rules_episode_appearance_count):
331292
if count > max_count:
332293
max_count = count
333294
best_rule = primitive_actions
334-
335295
if best_rule is None: new_actions = {}
336296
else:
337297
new_actions = {len(self.global_action_id_to_primitive_action.keys()): best_rule}
338-
339298
return new_actions
340299

341300
def overwrite_replay_buffer_and_pre_train_agent(self, replay_buffer, training_iterations, only_train_final_layer,

tests/Test_Action_Balanced_Replay_Buffer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ def test_sample_statistics_correct():
7979
tries = 5
8080
for random_seed in range(tries):
8181
for num_actions in range(1, 7):
82-
for buffer_size in [random.randint(55, 9999) for _ in range(20)]:
83-
for batch_size in [random.randint(8, 200) for _ in range(20)]:
84-
buffer = Action_Balanced_Replay_Buffer(buffer_size, batch_size, random_seed, num_actions)
85-
for _ in range(2000):
82+
for buffer_size in [random.randint(55, 9999) for _ in range(10)]:
83+
for batch_size in [random.randint(8, 200) for _ in range(10)]:
84+
buffer = Action_Balanced_Replay_Buffer(buffer_size, batch_size, random.randint(0, 2000000), num_actions)
85+
for _ in range(500):
8686
random_action = random.randint(0, num_actions - 1)
8787
buffer.add_experience(1, random_action, 1, 0, 0)
8888
states, actions, rewards, next_states, dones = buffer.sample()
@@ -95,3 +95,4 @@ def test_sample_statistics_correct():
9595

9696

9797

98+

tests/gesg.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import random
2+
from collections import Counter
3+
4+
import pytest
5+
6+
from Action_Balanced_Replay_Buffer import Action_Balanced_Replay_Buffer
7+
8+
# def test_sample_statistics_correct():
9+
"""Tests that sampled experiences correspond to expected statistics"""
10+

utilities/data_structures/Action_Balanced_Replay_Buffer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,16 @@ def __init__(self, buffer_size, batch_size, seed, num_actions):
1919

2020
def add_experience(self, states, actions, rewards, next_states, dones):
2121
"""Adds experience or list of experiences into the replay buffer"""
22-
print("MEMORY ", self.memories)
2322
if type(dones) == list:
2423
assert type(dones[0]) != list, "A done shouldn't be a list"
2524
experiences = [self.experience(state, action, reward, next_state, done)
2625
for state, action, reward, next_state, done in
2726
zip(states, actions, rewards, next_states, dones)]
2827
for experience in experiences:
2928
action = experience.action
30-
print("MEMORY ADDING ACTION {} -- Experience {}".format(action, experience))
3129
self.memories[action].append(experience)
3230
else:
3331
experience = self.experience(states, actions, rewards, next_states, dones)
34-
print("MEMORY ADDING ACTION {} -- Experience {}".format(actions, experience))
3532
self.memories[actions].append(experience)
3633

3734
def pick_experiences(self, num_experiences=None):

0 commit comments

Comments
 (0)