44import numpy as np
55from agents .Base_Agent import Base_Agent
66from utilities .data_structures .Replay_Buffer import Replay_Buffer
7- from .SAC import SAC
7+ from agents . actor_critic_agents .SAC import SAC
88from utilities .Utility_Functions import create_actor_distribution
99
1010class SAC_Discrete (SAC ):
@@ -20,9 +20,9 @@ def __init__(self, config):
2020 self .critic_local_2 = self .create_NN (input_dim = self .state_size , output_dim = self .action_size ,
2121 key_to_use = "Critic" , override_seed = self .config .seed + 1 )
2222 self .critic_optimizer = torch .optim .Adam (self .critic_local .parameters (),
23- lr = self .hyperparameters ["Critic" ]["learning_rate" ])
23+ lr = self .hyperparameters ["Critic" ]["learning_rate" ], eps = 1e-4 )
2424 self .critic_optimizer_2 = torch .optim .Adam (self .critic_local_2 .parameters (),
25- lr = self .hyperparameters ["Critic" ]["learning_rate" ])
25+ lr = self .hyperparameters ["Critic" ]["learning_rate" ], eps = 1e-4 )
2626 self .critic_target = self .create_NN (input_dim = self .state_size , output_dim = self .action_size ,
2727 key_to_use = "Critic" )
2828 self .critic_target_2 = self .create_NN (input_dim = self .state_size , output_dim = self .action_size ,
@@ -34,14 +34,14 @@ def __init__(self, config):
3434
3535 self .actor_local = self .create_NN (input_dim = self .state_size , output_dim = self .action_size , key_to_use = "Actor" )
3636 self .actor_optimizer = torch .optim .Adam (self .actor_local .parameters (),
37- lr = self .hyperparameters ["Actor" ]["learning_rate" ])
37+ lr = self .hyperparameters ["Actor" ]["learning_rate" ], eps = 1e-4 )
3838 self .automatic_entropy_tuning = self .hyperparameters ["automatically_tune_entropy_hyperparameter" ]
3939 if self .automatic_entropy_tuning :
4040 # we set the max possible entropy as the target entropy
4141 self .target_entropy = - np .log ((1.0 / self .action_size )) * 0.98
4242 self .log_alpha = torch .zeros (1 , requires_grad = True , device = self .device )
4343 self .alpha = self .log_alpha .exp ()
44- self .alpha_optim = Adam ([self .log_alpha ], lr = self .hyperparameters ["Actor" ]["learning_rate" ])
44+ self .alpha_optim = Adam ([self .log_alpha ], lr = self .hyperparameters ["Actor" ]["learning_rate" ], eps = 1e-4 )
4545 else :
4646 self .alpha = self .hyperparameters ["entropy_term_weight" ]
4747 assert not self .hyperparameters ["add_extra_noise" ], "There is no add extra noise option for the discrete version of SAC at moment"
@@ -65,11 +65,11 @@ def calculate_critic_losses(self, state_batch, action_batch, reward_batch, next_
6565 """Calculates the losses for the two critics. This is the ordinary Q-learning loss except the additional entropy
6666 term is taken into account"""
6767 with torch .no_grad ():
68- next_state_action , (_ , log_action_probabilities ), _ = self .produce_action_and_action_info (next_state_batch )
69- next_state_log_pi = log_action_probabilities . gather ( 1 , next_state_action . unsqueeze ( - 1 ). long () )
70- qf1_next_target = self .critic_target (next_state_batch ). gather ( 1 , next_state_action . unsqueeze ( - 1 ). long () )
71- qf2_next_target = self . critic_target_2 ( next_state_batch ). gather ( 1 , next_state_action . unsqueeze ( - 1 ). long () )
72- min_qf_next_target = torch . min ( qf1_next_target , qf2_next_target ) - self . alpha * next_state_log_pi
68+ next_state_action , (action_probabilities , log_action_probabilities ), _ = self .produce_action_and_action_info (next_state_batch )
69+ qf1_next_target = self . critic_target ( next_state_batch )
70+ qf2_next_target = self .critic_target_2 (next_state_batch )
71+ min_qf_next_target = action_probabilities * ( torch . min ( qf1_next_target , qf2_next_target ) - self . alpha * log_action_probabilities )
72+ min_qf_next_target = min_qf_next_target . mean ( dim = 1 ). unsqueeze ( - 1 )
7373 next_q_value = reward_batch + (1.0 - mask_batch ) * self .hyperparameters ["discount_rate" ] * (min_qf_next_target )
7474 self .critic_target (next_state_batch ).gather (1 , next_state_action .unsqueeze (- 1 ).long ())
7575
0 commit comments