Skip to content

Commit fbc84b8

Browse files
committed
fix bugs
1 parent 31a67d3 commit fbc84b8

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

agents/actor_critic_agents/SAC_Discrete.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def produce_action_and_action_info(self, state):
5252
"""Given the state, produces an action, the probability of the action, the log probability of the action, and
5353
the argmax action"""
5454
action_probabilities = self.actor_local(state)
55-
max_probability_action = torch.argmax(action_probabilities).unsqueeze(0)
55+
max_probability_action = torch.argmax(action_probabilities, dim=1)
5656
action_distribution = create_actor_distribution(self.action_types, action_probabilities, self.action_size)
5757
action = action_distribution.sample().cpu()
5858
# Have to deal with situation of 0.0 probabilities because we can't do log 0
@@ -69,7 +69,7 @@ def calculate_critic_losses(self, state_batch, action_batch, reward_batch, next_
6969
qf1_next_target = self.critic_target(next_state_batch)
7070
qf2_next_target = self.critic_target_2(next_state_batch)
7171
min_qf_next_target = action_probabilities * (torch.min(qf1_next_target, qf2_next_target) - self.alpha * log_action_probabilities)
72-
min_qf_next_target = min_qf_next_target.mean(dim=1).unsqueeze(-1)
72+
min_qf_next_target = min_qf_next_target.sum(dim=1).unsqueeze(-1)
7373
next_q_value = reward_batch + (1.0 - mask_batch) * self.hyperparameters["discount_rate"] * (min_qf_next_target)
7474

7575
qf1 = self.critic_local(state_batch).gather(1, action_batch.long())
@@ -85,7 +85,6 @@ def calculate_actor_loss(self, state_batch):
8585
qf2_pi = self.critic_local_2(state_batch)
8686
min_qf_pi = torch.min(qf1_pi, qf2_pi)
8787
inside_term = self.alpha * log_action_probabilities - min_qf_pi
88-
policy_loss = action_probabilities * inside_term
89-
policy_loss = policy_loss.mean()
88+
policy_loss = (action_probabilities * inside_term).sum(dim=1).mean()
9089
log_action_probabilities = torch.sum(log_action_probabilities * action_probabilities, dim=1)
9190
return policy_loss, log_action_probabilities

0 commit comments

Comments
 (0)