@@ -120,7 +120,6 @@ def __init__(self, config):
120120 self .reduce_macro_action_appearance_cutoff_throughout_training = self .hyperparameters ["reduce_macro_action_appearance_cutoff_throughout_training" ]
121121 self .add_1_macro_action_at_a_time = self .hyperparameters ["add_1_macro_action_at_a_time" ]
122122
123-
124123 self .min_num_episodes_to_play = self .hyperparameters ["min_num_episodes_to_play" ]
125124
126125
@@ -139,80 +138,56 @@ def run_n_episodes(self, num_episodes=None, show_whether_achieved_goal=True, sav
139138 self .grammar_induction_iteration = 1
140139
141140 while self .episodes_conducted < self .num_episodes :
142-
143- self .episode_actions_scores_and_exploration_status , round_of_macro_actions = \
144- self .agent .run_n_episodes (num_episodes = self .calculate_how_many_episodes_to_play (),
145- episodes_to_run_with_no_exploration = self .episodes_to_run_with_no_exploration )
146-
147-
148- self .episodes_conducted += len (self .episode_actions_scores_and_exploration_status )
149- actions_to_infer_grammar_from = self .pick_actions_to_infer_grammar_from (
150- self .episode_actions_scores_and_exploration_status )
151-
152- num_actions_before = len (self .global_action_id_to_primitive_action )
153-
154- if self .infer_new_grammar :
155- self .update_action_choices (actions_to_infer_grammar_from )
156-
157- else :
158- print ("NOT inferring new grammar because no better results found" )
159-
160- print ("New actions " , self .global_action_id_to_primitive_action )
161-
162- self .new_actions_just_added = list (range (num_actions_before , num_actions_before + len (self .global_action_id_to_primitive_action ) - num_actions_before ))
163- print ("Actions just added " , self .new_actions_just_added )
164-
165- assert len (set (self .global_action_id_to_primitive_action .values ())) == len (
166- self .global_action_id_to_primitive_action .values ()), \
167- "Not all actions are unique anymore: {}" .format (self .global_action_id_to_primitive_action )
168-
169- for key , value in self .global_action_id_to_primitive_action .items ():
170- assert max (value ) < self .action_size , "Actions should be in terms of primitive actions"
171-
172- self .grammar_induction_iteration += 1
173-
174- current_num_actions = len (self .global_action_id_to_primitive_action .keys ())
175-
176- if self .only_train_new_actions :
177- PRE_TRAINING_ITERATIONS = int (self .pre_training_learning_iterations_multiplier ) # * (len(self.new_actions_just_added) ** 1.25))
178- else :
179- PRE_TRAINING_ITERATIONS = int (self .pre_training_learning_iterations_multiplier ) # * (current_num_actions ** 1.25))
180-
181- print (" " )
182-
183- print ("PRE TRAINING ITERATIONS " , PRE_TRAINING_ITERATIONS )
184-
185- print (" " )
186-
187- self .agent .update_agent_for_new_actions (self .global_action_id_to_primitive_action ,
188- copy_over_hidden_layers = self .copy_over_hidden_layers ,
189- change_or_append_final_layer = "APPEND" )
190-
191- if num_actions_before != len (self .global_action_id_to_primitive_action ):
192- replay_buffer = self .memory_shaper .put_adapted_experiences_in_a_replay_buffer (
193- self .global_action_id_to_primitive_action )
194-
195- print (" ------ " )
196- print ("Length of buffer {} -- Actions {} -- Pre training iterations {}" .format (len (replay_buffer ),
197- current_num_actions ,
198- PRE_TRAINING_ITERATIONS ))
199- print (" ------ " )
200- self .overwrite_replay_buffer_and_pre_train_agent (replay_buffer , PRE_TRAINING_ITERATIONS ,
201- only_train_final_layer = self .only_train_final_layer , only_train_new_actions = self .only_train_new_actions )
202- print ("Now there are {} actions: {}" .format (current_num_actions , self .global_action_id_to_primitive_action ))
203-
204-
205- episode_actions = [data [1 ] for data in self .episode_actions_scores_and_exploration_status ]
206- flat_episode_actions = [ep_action for ep in episode_actions for ep_action in ep ]
207-
208- final_actions_count = Counter (round_of_macro_actions )
141+ self .play_new_episodes ()
142+ self .infer_new_grammar ()
143+ self .update_agent ()
144+ final_actions_count = Counter (self .round_of_macro_actions )
209145 print ("FINAL EPISODE SET ACTIONS COUNT " , final_actions_count )
210-
211-
212146 time_taken = time .time () - start
213-
214147 return self .agent .game_full_episode_scores [:self .num_episodes ], self .agent .rolling_results [:self .num_episodes ], time_taken
215148
149+ def play_new_episodes (self ):
150+ """Plays a new set of episodes using the recently updated agent"""
151+ self .episode_actions_scores_and_exploration_status , self .round_of_macro_actions = \
152+ self .agent .run_n_episodes (num_episodes = self .calculate_how_many_episodes_to_play (),
153+ episodes_to_run_with_no_exploration = self .episodes_to_run_with_no_exploration )
154+ self .episodes_conducted += len (self .episode_actions_scores_and_exploration_status )
155+
156+ def infer_new_grammar (self ):
157+ """Infers a new action grammar and updates the global action set"""
158+ actions_to_infer_grammar_from = self .pick_actions_to_infer_grammar_from (self .episode_actions_scores_and_exploration_status )
159+ num_actions_before = len (self .global_action_id_to_primitive_action )
160+ if self .infer_new_grammar : self .update_action_choices (actions_to_infer_grammar_from )
161+ else : print ("NOT inferring new grammar because no better results found" )
162+ self .new_actions_just_added = list (range (num_actions_before , num_actions_before + len (
163+ self .global_action_id_to_primitive_action ) - num_actions_before ))
164+ self .check_new_global_actions_valid ()
165+ self .grammar_induction_iteration += 1
166+
167+ def check_new_global_actions_valid (self ):
168+ """Checks that global_action_id_to_primitive_action still only has valid entries"""
169+ assert len (set (self .global_action_id_to_primitive_action .values ())) == len (
170+ self .global_action_id_to_primitive_action .values ()), \
171+ "Not all actions are unique anymore: {}" .format (self .global_action_id_to_primitive_action )
172+ for key , value in self .global_action_id_to_primitive_action .items ():
173+ assert max (value ) < self .action_size , "Actions should be in terms of primitive actions"
174+
175+ def update_agent (self ):
176+ """Updates the agent according to new action set by changing its action set, creating a new replay buffer
177+ and doing any pretraining"""
178+ current_num_actions = len (self .global_action_id_to_primitive_action .keys ())
179+ PRE_TRAINING_ITERATIONS = int (self .pre_training_learning_iterations_multiplier )
180+ self .agent .update_agent_for_new_actions (self .global_action_id_to_primitive_action ,
181+ copy_over_hidden_layers = self .copy_over_hidden_layers ,
182+ change_or_append_final_layer = "APPEND" )
183+ if len (self .new_actions_just_added ) > 0 :
184+ replay_buffer = self .memory_shaper .put_adapted_experiences_in_a_replay_buffer (
185+ self .global_action_id_to_primitive_action )
186+ self .overwrite_replay_buffer_and_pre_train_agent (replay_buffer , PRE_TRAINING_ITERATIONS ,
187+ only_train_final_layer = self .only_train_final_layer ,
188+ only_train_new_actions = self .only_train_new_actions )
189+ print ("Now there are {} actions: {}" .format (current_num_actions , self .global_action_id_to_primitive_action ))
190+
216191 def calculate_how_many_episodes_to_play (self ):
217192 """Calculates how many episodes the agent should play until we re-infer the grammar"""
218193 episodes_to_play = self .hyperparameters ["epsilon_decay_rate_denominator" ] / self .grammar_induction_iteration
@@ -291,38 +266,24 @@ def pick_new_macro_actions(self, rules_episode_appearance_count):
291266 """
292267 new_unflattened_actions = {}
293268 cutoff = self .num_top_results_to_use * self .action_frequency_required_in_top_results
294-
295269 if self .reduce_macro_action_appearance_cutoff_throughout_training :
296-
297270 cutoff = cutoff / (self .grammar_induction_iteration ** 0.5 )
298-
299271 print (" " )
300272 print ("Cutoff " , cutoff )
301273 print (" " )
302274 action_id = len (self .global_action_id_to_primitive_action .keys ())
303-
304-
305-
306275 counts = {}
307-
308276 for rule in rules_episode_appearance_count .keys ():
309277 count = rules_episode_appearance_count [rule ]
310-
311278 # count = count * (len(rule))**0.25
312-
313279 print ("Rule {} -- Count {}" .format (rule , count ))
314280 if count >= cutoff :
315281 new_unflattened_actions [action_id ] = rule
316282 counts [action_id ] = count
317283 action_id += 1
318-
319-
320-
321284 new_actions = flatten_action_id_to_actions (new_unflattened_actions , self .global_action_id_to_primitive_action ,
322285 self .action_size )
323-
324286 if self .add_1_macro_action_at_a_time :
325-
326287 max_count = 0
327288 best_rule = None
328289 for action_id , primitive_actions in new_actions .items ():
@@ -331,11 +292,9 @@ def pick_new_macro_actions(self, rules_episode_appearance_count):
331292 if count > max_count :
332293 max_count = count
333294 best_rule = primitive_actions
334-
335295 if best_rule is None : new_actions = {}
336296 else :
337297 new_actions = {len (self .global_action_id_to_primitive_action .keys ()): best_rule }
338-
339298 return new_actions
340299
341300 def overwrite_replay_buffer_and_pre_train_agent (self , replay_buffer , training_iterations , only_train_final_layer ,
0 commit comments