Skip to content

Commit 8cc5c59

Browse files
committed
fix errors of SAC and SAC-Discrete caused by torch>=1.4.0
1 parent bc6ee5f commit 8cc5c59

File tree

1 file changed

+10
-5
lines changed
  • agents/actor_critic_agents

1 file changed

+10
-5
lines changed

agents/actor_critic_agents/SAC.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,12 @@ def learn(self):
144144
"""Runs a learning iteration for the actor, both critics and (if specified) the temperature parameter"""
145145
state_batch, action_batch, reward_batch, next_state_batch, mask_batch = self.sample_experiences()
146146
qf1_loss, qf2_loss = self.calculate_critic_losses(state_batch, action_batch, reward_batch, next_state_batch, mask_batch)
147+
self.update_critic_parameters(qf1_loss, qf2_loss)
148+
147149
policy_loss, log_pi = self.calculate_actor_loss(state_batch)
148150
if self.automatic_entropy_tuning: alpha_loss = self.calculate_entropy_tuning_loss(log_pi)
149151
else: alpha_loss = None
150-
self.update_all_parameters(qf1_loss, qf2_loss, policy_loss, alpha_loss)
152+
self.update_actor_parameters(policy_loss, alpha_loss)
151153

152154
def sample_experiences(self):
153155
return self.memory.sample()
@@ -182,18 +184,21 @@ def calculate_entropy_tuning_loss(self, log_pi):
182184
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
183185
return alpha_loss
184186

185-
def update_all_parameters(self, critic_loss_1, critic_loss_2, actor_loss, alpha_loss):
186-
"""Updates the parameters for the actor, both critics and (if specified) the temperature parameter"""
187+
def update_critic_parameters(self, critic_loss_1, critic_loss_2):
188+
"""Updates the parameters for both critics"""
187189
self.take_optimisation_step(self.critic_optimizer, self.critic_local, critic_loss_1,
188190
self.hyperparameters["Critic"]["gradient_clipping_norm"])
189191
self.take_optimisation_step(self.critic_optimizer_2, self.critic_local_2, critic_loss_2,
190192
self.hyperparameters["Critic"]["gradient_clipping_norm"])
191-
self.take_optimisation_step(self.actor_optimizer, self.actor_local, actor_loss,
192-
self.hyperparameters["Actor"]["gradient_clipping_norm"])
193193
self.soft_update_of_target_network(self.critic_local, self.critic_target,
194194
self.hyperparameters["Critic"]["tau"])
195195
self.soft_update_of_target_network(self.critic_local_2, self.critic_target_2,
196196
self.hyperparameters["Critic"]["tau"])
197+
198+
def update_actor_parameters(self, actor_loss, alpha_loss):
199+
"""Updates the parameters for the actor and (if specified) the temperature parameter"""
200+
self.take_optimisation_step(self.actor_optimizer, self.actor_local, actor_loss,
201+
self.hyperparameters["Actor"]["gradient_clipping_norm"])
197202
if alpha_loss is not None:
198203
self.take_optimisation_step(self.alpha_optim, None, alpha_loss, None)
199204
self.alpha = self.log_alpha.exp()

0 commit comments

Comments
 (0)