Skip to content

Commit

Permalink
update ddpg
Browse files Browse the repository at this point in the history
  • Loading branch information
liber145 committed Dec 18, 2022
1 parent bca1b2d commit 3c7d561
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 320 deletions.
20 changes: 10 additions & 10 deletions 06_doubledqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,15 @@ def train(args, env, agent):
agent.target_model.train()
agent.model.zero_grad()
agent.target_model.zero_grad()
state = env.reset()
state, _ = env.reset()
for i in range(args.max_steps):
if np.random.rand() < epsilon or i < args.warmup_steps:
action = env.action_space.sample()
else:
action = agent.get_action(torch.from_numpy(state))
action = action.item()
next_state, reward, done, info = env.step(action)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
episode_reward += reward

# 修改奖励,加速训练。
Expand All @@ -136,7 +137,7 @@ def train(args, env, agent):

epsilon = max(epsilon * args.epsilon_decay, 1e-3)

print(f"{i=}, reward={episode_reward:.0f}, length={episode_length}, max_reward={max_episode_reward}, loss={log_losses[-1]:.1e}, {epsilon=:.3f}")
print(f"i={i}, reward={episode_reward:.0f}, length={episode_length}, max_reward={max_episode_reward}, loss={log_losses[-1]:.1e}, epsilon={epsilon:.3f}")

if episode_length < 180 and episode_reward > max_episode_reward:
save_path = os.path.join(args.output_dir, "model.bin")
Expand All @@ -145,7 +146,7 @@ def train(args, env, agent):

episode_reward = 0
episode_length = 0
state = env.reset()
state, _ = env.reset()

if i > args.warmup_steps:
bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size)
Expand Down Expand Up @@ -184,18 +185,18 @@ def eval(args, env, agent):
episode_reward = 0

agent.model.eval()
state = env.reset()
state, _ = env.reset()
for i in range(5000):
episode_length += 1
action = agent.get_action(torch.from_numpy(state)).item()
next_state, reward, done, info = env.step(action)
env.render()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
episode_reward += reward

state = next_state
if done is True:
print(f"{episode_reward=}, {episode_length=}")
state = env.reset()
print(f"episode reward={episode_reward}, episode length={episode_length}")
state, _ = env.reset()
episode_length = 0
episode_reward = 0

Expand All @@ -222,7 +223,6 @@ def main():
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

env = gym.make(args.env)
env.seed(args.seed)
set_seed(args)
agent = DoubleDQN(dim_obs=args.dim_obs, num_act=args.num_act, discount=args.discount)
agent.model.to(args.device)
Expand Down
Loading

0 comments on commit 3c7d561

Please sign in to comment.