Skip to content

Commit

Permalink
Fix actor_critic example (pytorch#301)
Browse files Browse the repository at this point in the history
* smooth_l1_loss now requires shapes to match
 * once scalars are enabled we must torch.stack() instead of
   torch.cat() a list of scalars
  • Loading branch information
colesbury authored Feb 8, 2018
1 parent 963f7d1 commit 4ef2d4d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions reinforcement_learning/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def forward(self, x):
x = F.relu(self.affine1(x))
action_scores = self.action_head(x)
state_values = self.value_head(x)
return F.softmax(action_scores, dim=1), state_values
return F.softmax(action_scores, dim=-1), state_values


model = Policy()
optimizer = optim.Adam(model.parameters(), lr=3e-2)


def select_action(state):
state = torch.from_numpy(state).float().unsqueeze(0)
state = torch.from_numpy(state).float()
probs, state_value = model(Variable(state))
m = Categorical(probs)
action = m.sample()
Expand All @@ -74,11 +74,11 @@ def finish_episode():
rewards = torch.Tensor(rewards)
rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)
for (log_prob, value), r in zip(saved_actions, rewards):
reward = r - value.data[0, 0]
reward = r - value.data[0]
policy_losses.append(-log_prob * reward)
value_losses.append(F.smooth_l1_loss(value, Variable(torch.Tensor([r]))))
optimizer.zero_grad()
loss = torch.cat(policy_losses).sum() + torch.cat(value_losses).sum()
loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
loss.backward()
optimizer.step()
del model.rewards[:]
Expand Down

0 comments on commit 4ef2d4d

Please sign in to comment.