Skip to content

Commit d4bdedf

Browse files
committed
fixed alpha SAC discrete error
1 parent 77c47fa commit d4bdedf

File tree

6 files changed

+93
-7
lines changed

6 files changed

+93
-7
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,7 @@ Random_Junkyard/
1919
*to_do_list
2020
Notebook.ipynb
2121
Results/Notebook.ipynb
22-
*.ipynb_checkpoints
22+
*.ipynb_checkpoints
23+
*.drive_access_key.json
24+
drive_access_key.json
25+
drive_access_key

agents/Base_Agent.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,22 @@ def log_game_info(self):
134134

135135
def set_random_seeds(self, random_seed):
136136
"""Sets all possible random seeds so results can be reproduced"""
137+
os.environ['PYTHONHASHSEED'] = str(random_seed)
137138
torch.backends.cudnn.deterministic = True
139+
torch.backends.cudnn.benchmark = False
138140
torch.manual_seed(random_seed)
141+
tf.set_random_seed(random_seed)
139142
random.seed(random_seed)
140143
np.random.seed(random_seed)
141-
if torch.cuda.is_available(): torch.cuda.manual_seed_all(random_seed)
142-
self.config.seed = random_seed
144+
if torch.cuda.is_available():
145+
torch.cuda.manual_seed_all(random_seed)
146+
torch.cuda.manual_seed(random_seed)
147+
if hasattr(gym.spaces, 'prng'):
148+
gym.spaces.prng.seed(random_seed)
143149

144150
def reset_game(self):
145151
"""Resets the game information so we are ready to play a new episode"""
152+
self.environment.seed(self.config.seed)
146153
self.state = self.environment.reset()
147154
self.next_state = None
148155
self.action = None

agents/actor_critic_agents/SAC_Discrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, config):
3636
lr=self.hyperparameters["Actor"]["learning_rate"])
3737
self.automatic_entropy_tuning = self.hyperparameters["automatically_tune_entropy_hyperparameter"]
3838
if self.automatic_entropy_tuning:
39-
self.target_entropy = -torch.prod(torch.Tensor(self.environment.action_space.shape).to(self.device)).item() # heuristic value from the paper
39+
self.target_entropy = - self.environment.unwrapped.action_space.n / 4.0 # heuristic value from the paper
4040
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
4141
self.alpha = self.log_alpha.exp()
4242
self.alpha_optim = Adam([self.log_alpha], lr=self.hyperparameters["Actor"]["learning_rate"])
@@ -80,7 +80,7 @@ def calculate_actor_loss(self, state_batch):
8080
qf1_pi = self.critic_local(state_batch)
8181
qf2_pi = self.critic_local_2(state_batch)
8282
min_qf_pi = torch.min(qf1_pi, qf2_pi)
83-
inside_term = log_action_probabilities - min_qf_pi
83+
inside_term = self.alpha * log_action_probabilities - min_qf_pi
8484
policy_loss = torch.sum(action_probabilities * inside_term)
8585
policy_loss = policy_loss.mean()
8686
log_action_probabilities = log_action_probabilities.gather(1, action.unsqueeze(-1).long())

exploration_strategies/Epsilon_Greedy_Exploration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def perturb_action_for_exploration_purposes(self, action_info):
3535

3636
if (random.random() > epsilon or turn_off_exploration) and (episode_number >= self.random_episodes_to_run):
3737
return torch.argmax(action_values).item()
38-
return random.randint(0, action_values.shape[1] - 1)
38+
return np.random.randint(0, action_values.shape[1])
3939

4040
def get_updated_epsilon_exploration(self, action_info, epsilon=1.0):
4141
"""Gets the probability that we just pick a random action. This probability decays the more episodes we have seen"""

utilities/Deepmind_RMS_Prop.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch
2+
from torch.optim import Optimizer
3+
4+
5+
class DM_RMSprop(Optimizer):
6+
"""Implements the form of RMSProp used in DM 2015 Atari paper.
7+
Inspired by https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/updates.py"""
8+
9+
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
10+
if not 0.0 <= lr:
11+
raise ValueError("Invalid learning rate: {}".format(lr))
12+
if not 0.0 <= eps:
13+
raise ValueError("Invalid epsilon value: {}".format(eps))
14+
if not 0.0 <= momentum:
15+
raise ValueError("Invalid momentum value: {}".format(momentum))
16+
if not 0.0 <= weight_decay:
17+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
18+
if not 0.0 <= alpha:
19+
raise ValueError("Invalid alpha value: {}".format(alpha))
20+
21+
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
22+
super(DM_RMSprop, self).__init__(params, defaults)
23+
24+
def __setstate__(self, state):
25+
super(DM_RMSprop, self).__setstate__(state)
26+
for group in self.param_groups:
27+
group.setdefault('momentum', 0)
28+
group.setdefault('centered', False)
29+
30+
def step(self, closure=None):
31+
"""Performs a single optimization step.
32+
33+
Arguments:
34+
closure (callable, optional): A closure that reevaluates the model
35+
and returns the loss.
36+
"""
37+
loss = None
38+
if closure is not None:
39+
loss = closure()
40+
for group in self.param_groups:
41+
momentum = group['momentum']
42+
sq_momentum = group['alpha']
43+
epsilon = group['eps']
44+
45+
for p in group['params']:
46+
if p.grad is None:
47+
continue
48+
grad = p.grad.data
49+
if grad.is_sparse:
50+
raise RuntimeError('RMSprop does not support sparse gradients')
51+
state = self.state[p]
52+
53+
# State initialization
54+
if len(state) == 0:
55+
state['step'] = 0
56+
state['square_avg'] = torch.zeros_like(p.data)
57+
if momentum > 0:
58+
state['momentum_buffer'] = torch.zeros_like(p.data)
59+
60+
mom_buffer = state['momentum_buffer']
61+
square_avg = state['square_avg']
62+
63+
64+
state['step'] += 1
65+
66+
mom_buffer.mul_(momentum)
67+
mom_buffer.add_((1 - momentum) * grad)
68+
69+
square_avg.mul_(sq_momentum).addcmul_(1 - sq_momentum, grad, grad)
70+
71+
avg = (square_avg - mom_buffer**2 + epsilon).sqrt()
72+
73+
p.data.addcdiv_(-group['lr'], grad, avg)
74+
75+
return loss
76+

utilities/data_structures/Replay_Buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def separate_out_data_types(self, experiences):
4646
return states, actions, rewards, next_states, dones
4747

4848
def pick_experiences(self, num_experiences=None):
49-
if num_experiences: batch_size = num_experiences
49+
if num_experiences is not None: batch_size = num_experiences
5050
else: batch_size = self.batch_size
5151
return random.sample(self.memory, k=batch_size)
5252

0 commit comments

Comments
 (0)