Skip to content

Commit 8a0dd27

Browse files
authored
Merge pull request p-christ#60 from ku2482/sac-discrete/bugfix
Bug fix for SAC-discrete.
2 parents 31a67d3 + 8cc5c59 commit 8a0dd27

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

agents/actor_critic_agents/SAC.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,12 @@ def learn(self):
144144
"""Runs a learning iteration for the actor, both critics and (if specified) the temperature parameter"""
145145
state_batch, action_batch, reward_batch, next_state_batch, mask_batch = self.sample_experiences()
146146
qf1_loss, qf2_loss = self.calculate_critic_losses(state_batch, action_batch, reward_batch, next_state_batch, mask_batch)
147+
self.update_critic_parameters(qf1_loss, qf2_loss)
148+
147149
policy_loss, log_pi = self.calculate_actor_loss(state_batch)
148150
if self.automatic_entropy_tuning: alpha_loss = self.calculate_entropy_tuning_loss(log_pi)
149151
else: alpha_loss = None
150-
self.update_all_parameters(qf1_loss, qf2_loss, policy_loss, alpha_loss)
152+
self.update_actor_parameters(policy_loss, alpha_loss)
151153

152154
def sample_experiences(self):
153155
return self.memory.sample()
@@ -182,18 +184,21 @@ def calculate_entropy_tuning_loss(self, log_pi):
182184
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
183185
return alpha_loss
184186

185-
def update_all_parameters(self, critic_loss_1, critic_loss_2, actor_loss, alpha_loss):
186-
"""Updates the parameters for the actor, both critics and (if specified) the temperature parameter"""
187+
def update_critic_parameters(self, critic_loss_1, critic_loss_2):
188+
"""Updates the parameters for both critics"""
187189
self.take_optimisation_step(self.critic_optimizer, self.critic_local, critic_loss_1,
188190
self.hyperparameters["Critic"]["gradient_clipping_norm"])
189191
self.take_optimisation_step(self.critic_optimizer_2, self.critic_local_2, critic_loss_2,
190192
self.hyperparameters["Critic"]["gradient_clipping_norm"])
191-
self.take_optimisation_step(self.actor_optimizer, self.actor_local, actor_loss,
192-
self.hyperparameters["Actor"]["gradient_clipping_norm"])
193193
self.soft_update_of_target_network(self.critic_local, self.critic_target,
194194
self.hyperparameters["Critic"]["tau"])
195195
self.soft_update_of_target_network(self.critic_local_2, self.critic_target_2,
196196
self.hyperparameters["Critic"]["tau"])
197+
198+
def update_actor_parameters(self, actor_loss, alpha_loss):
199+
"""Updates the parameters for the actor and (if specified) the temperature parameter"""
200+
self.take_optimisation_step(self.actor_optimizer, self.actor_local, actor_loss,
201+
self.hyperparameters["Actor"]["gradient_clipping_norm"])
197202
if alpha_loss is not None:
198203
self.take_optimisation_step(self.alpha_optim, None, alpha_loss, None)
199204
self.alpha = self.log_alpha.exp()

agents/actor_critic_agents/SAC_Discrete.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, config):
3030
Base_Agent.copy_model_over(self.critic_local, self.critic_target)
3131
Base_Agent.copy_model_over(self.critic_local_2, self.critic_target_2)
3232
self.memory = Replay_Buffer(self.hyperparameters["Critic"]["buffer_size"], self.hyperparameters["batch_size"],
33-
self.config.seed)
33+
self.config.seed, device=self.device)
3434

3535
self.actor_local = self.create_NN(input_dim=self.state_size, output_dim=self.action_size, key_to_use="Actor")
3636
self.actor_optimizer = torch.optim.Adam(self.actor_local.parameters(),
@@ -52,7 +52,7 @@ def produce_action_and_action_info(self, state):
5252
"""Given the state, produces an action, the probability of the action, the log probability of the action, and
5353
the argmax action"""
5454
action_probabilities = self.actor_local(state)
55-
max_probability_action = torch.argmax(action_probabilities).unsqueeze(0)
55+
max_probability_action = torch.argmax(action_probabilities, dim=1)
5656
action_distribution = create_actor_distribution(self.action_types, action_probabilities, self.action_size)
5757
action = action_distribution.sample().cpu()
5858
# Have to deal with situation of 0.0 probabilities because we can't do log 0
@@ -69,7 +69,7 @@ def calculate_critic_losses(self, state_batch, action_batch, reward_batch, next_
6969
qf1_next_target = self.critic_target(next_state_batch)
7070
qf2_next_target = self.critic_target_2(next_state_batch)
7171
min_qf_next_target = action_probabilities * (torch.min(qf1_next_target, qf2_next_target) - self.alpha * log_action_probabilities)
72-
min_qf_next_target = min_qf_next_target.mean(dim=1).unsqueeze(-1)
72+
min_qf_next_target = min_qf_next_target.sum(dim=1).unsqueeze(-1)
7373
next_q_value = reward_batch + (1.0 - mask_batch) * self.hyperparameters["discount_rate"] * (min_qf_next_target)
7474

7575
qf1 = self.critic_local(state_batch).gather(1, action_batch.long())
@@ -85,7 +85,6 @@ def calculate_actor_loss(self, state_batch):
8585
qf2_pi = self.critic_local_2(state_batch)
8686
min_qf_pi = torch.min(qf1_pi, qf2_pi)
8787
inside_term = self.alpha * log_action_probabilities - min_qf_pi
88-
policy_loss = action_probabilities * inside_term
89-
policy_loss = policy_loss.mean()
88+
policy_loss = (action_probabilities * inside_term).sum(dim=1).mean()
9089
log_action_probabilities = torch.sum(log_action_probabilities * action_probabilities, dim=1)
9190
return policy_loss, log_action_probabilities

0 commit comments

Comments
 (0)