Skip to content

Commit 7b9fac7

Browse files
authored
Update DDPG.py
Fine tune
1 parent 521889b commit 7b9fac7

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

Char05 DDPG/DDPG.py

+21-27
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,28 @@
2828
parser.add_argument('--target_update_interval', default=1, type=int)
2929
parser.add_argument('--test_iteration', default=10, type=int)
3030

31-
parser.add_argument('--learning_rate', default=1e-3, type=float)
31+
parser.add_argument('--learning_rate', default=1e-4, type=float)
3232
parser.add_argument('--gamma', default=0.99, type=int) # discounted factor
33-
parser.add_argument('--capacity', default=50000, type=int) # replay buffer size
34-
parser.add_argument('--batch_size', default=64, type=int) # mini batch size
33+
parser.add_argument('--capacity', default=1000000, type=int) # replay buffer size
34+
parser.add_argument('--batch_size', default=100, type=int) # mini batch size
3535
parser.add_argument('--seed', default=False, type=bool)
3636
parser.add_argument('--random_seed', default=9527, type=int)
3737
# optional parameters
3838

39-
parser.add_argument('--sample_frequency', default=256, type=int)
39+
parser.add_argument('--sample_frequency', default=2000, type=int)
4040
parser.add_argument('--render', default=False, type=bool) # show UI or not
4141
parser.add_argument('--log_interval', default=50, type=int) #
4242
parser.add_argument('--load', default=False, type=bool) # load model
4343
parser.add_argument('--render_interval', default=100, type=int) # after render_interval, the env.render() will work
4444
parser.add_argument('--exploration_noise', default=0.1, type=float)
4545
parser.add_argument('--max_episode', default=100000, type=int) # num of games
46-
parser.add_argument('--max_length_of_trajectory', default=2000, type=int) # num of games
4746
parser.add_argument('--print_log', default=5, type=int)
48-
parser.add_argument('--update_iteration', default=10, type=int)
47+
parser.add_argument('--update_iteration', default=200, type=int)
4948
args = parser.parse_args()
5049

5150
device = 'cuda' if torch.cuda.is_available() else 'cpu'
5251
script_name = os.path.basename(__file__)
53-
env = gym.make(args.env_name).unwrapped
52+
env = gym.make(args.env_name)
5453

5554
if args.seed:
5655
env.seed(args.random_seed)
@@ -134,14 +133,15 @@ def __init__(self, state_dim, action_dim, max_action):
134133
self.actor = Actor(state_dim, action_dim, max_action).to(device)
135134
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
136135
self.actor_target.load_state_dict(self.actor.state_dict())
137-
self.actor_optimizer = optim.Adam(self.actor.parameters(), args.learning_rate)
136+
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)
138137

139138
self.critic = Critic(state_dim, action_dim).to(device)
140139
self.critic_target = Critic(state_dim, action_dim).to(device)
141140
self.critic_target.load_state_dict(self.critic.state_dict())
142-
self.critic_optimizer = optim.Adam(self.critic.parameters(), args.learning_rate)
141+
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)
143142
self.replay_buffer = Replay_buffer()
144143
self.writer = SummaryWriter(directory)
144+
145145
self.num_critic_update_iteration = 0
146146
self.num_actor_update_iteration = 0
147147
self.num_training = 0
@@ -158,12 +158,12 @@ def update(self):
158158
state = torch.FloatTensor(x).to(device)
159159
action = torch.FloatTensor(u).to(device)
160160
next_state = torch.FloatTensor(y).to(device)
161-
done = torch.FloatTensor(d).to(device)
161+
done = torch.FloatTensor(1-d).to(device)
162162
reward = torch.FloatTensor(r).to(device)
163163

164164
# Compute the target Q value
165165
target_Q = self.critic_target(next_state, self.actor_target(next_state))
166-
target_Q = reward + ((1 - done) * args.gamma * target_Q).detach()
166+
target_Q = reward + (done * args.gamma * target_Q).detach()
167167

168168
# Get current Q estimate
169169
current_Q = self.critic(state, action)
@@ -228,39 +228,33 @@ def main():
228228
state = next_state
229229

230230
elif args.mode == 'train':
231-
print("====================================")
232-
print("Collection Experience...")
233-
print("====================================")
234231
if args.load: agent.load()
232+
total_step = 0
235233
for i in range(args.max_episode):
234+
total_reward = 0
235+
step =0
236236
state = env.reset()
237237
for t in count():
238238
action = agent.select_action(state)
239-
240-
# issue 3 add noise to action
241239
action = (action + np.random.normal(0, args.exploration_noise, size=env.action_space.shape[0])).clip(
242240
env.action_space.low, env.action_space.high)
243241

244242
next_state, reward, done, info = env.step(action)
245-
ep_r += reward
246243
if args.render and i >= args.render_interval : env.render()
247244
agent.replay_buffer.push((state, next_state, action, reward, np.float(done)))
248-
# if (i+1) % 10 == 0:
249-
# print('Episode {}, The memory size is {} '.format(i, len(agent.replay_buffer.storage)))
250245

251246
state = next_state
252-
if done or t >= args.max_length_of_trajectory:
253-
agent.writer.add_scalar('ep_r', ep_r, global_step=i)
254-
if i % args.print_log == 0:
255-
print("Ep_i \t{}, the ep_r is \t{:0.2f}, the step is \t{}".format(i, ep_r, t))
256-
ep_r = 0
247+
if done:
257248
break
249+
step += 1
250+
total_reward += reward
251+
total_step += step+1
252+
print("Total T:{} Episode: \t{} Total Reward: \t{:0.2f}".format(total_step, i, total_reward))
253+
agent.update()
254+
# "Total T: %d Episode Num: %d Episode T: %d Reward: %f
258255

259256
if i % args.log_interval == 0:
260257
agent.save()
261-
if len(agent.replay_buffer.storage) >= args.capacity-1:
262-
agent.update()
263-
264258
else:
265259
raise NameError("mode wrong!!!")
266260

0 commit comments

Comments
 (0)