@@ -144,10 +144,12 @@ def learn(self):
144144 """Runs a learning iteration for the actor, both critics and (if specified) the temperature parameter"""
145145 state_batch , action_batch , reward_batch , next_state_batch , mask_batch = self .sample_experiences ()
146146 qf1_loss , qf2_loss = self .calculate_critic_losses (state_batch , action_batch , reward_batch , next_state_batch , mask_batch )
147+ self .update_critic_parameters (qf1_loss , qf2_loss )
148+
147149 policy_loss , log_pi = self .calculate_actor_loss (state_batch )
148150 if self .automatic_entropy_tuning : alpha_loss = self .calculate_entropy_tuning_loss (log_pi )
149151 else : alpha_loss = None
150- self .update_all_parameters ( qf1_loss , qf2_loss , policy_loss , alpha_loss )
152+ self .update_actor_parameters ( policy_loss , alpha_loss )
151153
152154 def sample_experiences (self ):
153155 return self .memory .sample ()
@@ -182,18 +184,21 @@ def calculate_entropy_tuning_loss(self, log_pi):
182184 alpha_loss = - (self .log_alpha * (log_pi + self .target_entropy ).detach ()).mean ()
183185 return alpha_loss
184186
185- def update_all_parameters (self , critic_loss_1 , critic_loss_2 , actor_loss , alpha_loss ):
186- """Updates the parameters for the actor, both critics and (if specified) the temperature parameter """
187+ def update_critic_parameters (self , critic_loss_1 , critic_loss_2 ):
188+ """Updates the parameters for both critics"""
187189 self .take_optimisation_step (self .critic_optimizer , self .critic_local , critic_loss_1 ,
188190 self .hyperparameters ["Critic" ]["gradient_clipping_norm" ])
189191 self .take_optimisation_step (self .critic_optimizer_2 , self .critic_local_2 , critic_loss_2 ,
190192 self .hyperparameters ["Critic" ]["gradient_clipping_norm" ])
191- self .take_optimisation_step (self .actor_optimizer , self .actor_local , actor_loss ,
192- self .hyperparameters ["Actor" ]["gradient_clipping_norm" ])
193193 self .soft_update_of_target_network (self .critic_local , self .critic_target ,
194194 self .hyperparameters ["Critic" ]["tau" ])
195195 self .soft_update_of_target_network (self .critic_local_2 , self .critic_target_2 ,
196196 self .hyperparameters ["Critic" ]["tau" ])
197+
198+ def update_actor_parameters (self , actor_loss , alpha_loss ):
199+ """Updates the parameters for the actor and (if specified) the temperature parameter"""
200+ self .take_optimisation_step (self .actor_optimizer , self .actor_local , actor_loss ,
201+ self .hyperparameters ["Actor" ]["gradient_clipping_norm" ])
197202 if alpha_loss is not None :
198203 self .take_optimisation_step (self .alpha_optim , None , alpha_loss , None )
199204 self .alpha = self .log_alpha .exp ()
0 commit comments