Skip to content

Commit

Permalink
修补bug
Browse files Browse the repository at this point in the history
  • Loading branch information
acezsq committed May 30, 2022
1 parent 075ad9b commit 56d9788
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions eight.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ def update(self, transition_dict):

q_values = self.q_net(states).gather(1, actions) # Q值
# 下个状态的最大Q值
max_next_q_values = self.target_q_net(next_states).max(1)[0].view(
-1, 1)
q_targets = rewards + self.gamma * max_next_q_values * (1 - dones
) # TD误差目标
if self.dqn_type == 'DoubleDQN': # DQN与Double DQN的区别
max_action = self.q_net(next_states).max(1)[1].view(-1, 1)
max_next_q_values = self.target_q_net(next_states).gather(1, max_action)
else: # DQN的情况
max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)
q_targets = rewards + self.gamma * max_next_q_values * (1 - dones) # TD误差目标
dqn_loss = torch.mean(F.mse_loss(q_values, q_targets)) # 均方误差损失函数
self.optimizer.zero_grad() # PyTorch中默认梯度会累积,这里需要显式将梯度置为0
dqn_loss.backward() # 反向传播更新参数
Expand All @@ -84,6 +86,8 @@ def update(self, transition_dict):
self.q_net.state_dict()) # 更新目标网络
self.count += 1

lr = 1e-2

lr = 1e-2
num_episodes = 200
hidden_dim = 128
Expand Down

0 comments on commit 56d9788

Please sign in to comment.