@@ -52,7 +52,7 @@ def produce_action_and_action_info(self, state):
5252 """Given the state, produces an action, the probability of the action, the log probability of the action, and
5353 the argmax action"""
5454 action_probabilities = self .actor_local (state )
55- max_probability_action = torch .argmax (action_probabilities ). unsqueeze ( 0 )
55+ max_probability_action = torch .argmax (action_probabilities , dim = 1 )
5656 action_distribution = create_actor_distribution (self .action_types , action_probabilities , self .action_size )
5757 action = action_distribution .sample ().cpu ()
5858 # Have to deal with situation of 0.0 probabilities because we can't do log 0
@@ -69,7 +69,7 @@ def calculate_critic_losses(self, state_batch, action_batch, reward_batch, next_
6969 qf1_next_target = self .critic_target (next_state_batch )
7070 qf2_next_target = self .critic_target_2 (next_state_batch )
7171 min_qf_next_target = action_probabilities * (torch .min (qf1_next_target , qf2_next_target ) - self .alpha * log_action_probabilities )
72- min_qf_next_target = min_qf_next_target .mean (dim = 1 ).unsqueeze (- 1 )
72+ min_qf_next_target = min_qf_next_target .sum (dim = 1 ).unsqueeze (- 1 )
7373 next_q_value = reward_batch + (1.0 - mask_batch ) * self .hyperparameters ["discount_rate" ] * (min_qf_next_target )
7474
7575 qf1 = self .critic_local (state_batch ).gather (1 , action_batch .long ())
@@ -85,7 +85,6 @@ def calculate_actor_loss(self, state_batch):
8585 qf2_pi = self .critic_local_2 (state_batch )
8686 min_qf_pi = torch .min (qf1_pi , qf2_pi )
8787 inside_term = self .alpha * log_action_probabilities - min_qf_pi
88- policy_loss = action_probabilities * inside_term
89- policy_loss = policy_loss .mean ()
88+ policy_loss = (action_probabilities * inside_term ).sum (dim = 1 ).mean ()
9089 log_action_probabilities = torch .sum (log_action_probabilities * action_probabilities , dim = 1 )
9190 return policy_loss , log_action_probabilities
0 commit comments