Skip to content

Commit

Permalink
[RLlib] IMPALA PyTorch GPU fixes (ray-project#8397)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored May 11, 2020
1 parent fdf0e5c commit c7cb2f5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
10 changes: 6 additions & 4 deletions rllib/agents/impala/vtrace_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def multi_from_logits(behaviour_policy_logits,
target_policy_logits, device="cpu")
actions = convert_to_torch_tensor(actions, device="cpu")

# Make sure tensor ranks are as expected.
# The rest will be checked by from_action_log_probs.
for i in range(len(behaviour_policy_logits)):
# Make sure tensor ranks are as expected.
# The rest will be checked by from_action_log_probs.
assert len(behaviour_policy_logits[i].size()) == 3
assert len(target_policy_logits[i].size()) == 3

Expand All @@ -215,9 +215,11 @@ def multi_from_logits(behaviour_policy_logits,
# can't use precalculated values, recompute them. Note that
# recomputing won't work well for autoregressive action dists
# which may have variables not captured by 'logits'
behaviour_action_log_probs = (multi_log_probs_from_logits_and_actions(
behaviour_policy_logits, actions, dist_class, model))
behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
behaviour_policy_logits, actions, dist_class, model)

behaviour_action_log_probs = convert_to_torch_tensor(
behaviour_action_log_probs, device="cpu")
behaviour_action_log_probs = force_list(behaviour_action_log_probs)
log_rhos = get_log_rhos(target_action_log_probs,
behaviour_action_log_probs)
Expand Down
9 changes: 6 additions & 3 deletions rllib/agents/impala/vtrace_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self,

# Compute vtrace on the CPU for better perf
# (devices handled inside `vtrace.multi_from_logits`).
device = behaviour_action_logp[0].device
self.vtrace_returns = vtrace.multi_from_logits(
behaviour_action_log_probs=behaviour_action_logp,
behaviour_policy_logits=behaviour_logits,
Expand All @@ -90,14 +91,16 @@ def __init__(self,
model=model,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold)
self.value_targets = self.vtrace_returns.vs
# Move v-trace results back to GPU for actual loss computing.
self.value_targets = self.vtrace_returns.vs.to(device)

# The policy gradients loss
self.pi_loss = -torch.sum(
actions_logp * self.vtrace_returns.pg_advantages * valid_mask)
actions_logp * self.vtrace_returns.pg_advantages.to(device) *
valid_mask)

# The baseline loss
delta = (values - self.vtrace_returns.vs) * valid_mask
delta = (values - self.value_targets) * valid_mask
self.vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0))

# The entropy loss
Expand Down

0 comments on commit c7cb2f5

Please sign in to comment.