Skip to content

Commit

Permalink
[RLlib]: Fix FQE Policy call (ray-project#26671)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohan138 authored Jul 19, 2022
1 parent adf24bf commit 4fded80
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 14 deletions.
40 changes: 26 additions & 14 deletions rllib/offline/estimators/fqe_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ def train(self, batch: SampleBatch) -> TensorType:
A list of losses for each training iteration
"""
losses = []
if self.minibatch_size is None:
minibatch_size = batch.count
minibatch_size = self.minibatch_size or batch.count
# Copy batch for shuffling
batch = batch.copy(shallow=True)
for _ in range(self.n_iters):
minibatch_losses = []
batch.shuffle()
Expand Down Expand Up @@ -209,18 +210,29 @@ def _compute_action_probs(self, obs: TensorType) -> TensorType:
input_dict = {SampleBatch.OBS: obs}
seq_lens = torch.ones(len(obs), device=self.device, dtype=int)
state_batches = []
if self.policy.action_distribution_fn and is_overridden(
self.policy.action_distribution_fn
):
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
self.policy,
self.policy.model,
input_dict=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=False,
is_training=False,
)
if is_overridden(self.policy.action_distribution_fn):
try:
# TorchPolicyV2 function signature
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
self.policy.model,
obs_batch=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=False,
is_training=False,
)
except TypeError:
# TorchPolicyV1 function signature for compatibility with DQN
# TODO: Remove this once DQNTorchPolicy is migrated to PolicyV2
dist_inputs, dist_class, _ = self.policy.action_distribution_fn(
self.policy,
self.policy.model,
input_dict=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=False,
is_training=False,
)
else:
dist_class = self.policy.dist_class
dist_inputs, _ = self.policy.model(input_dict, state_batches, seq_lens)
Expand Down
46 changes: 46 additions & 0 deletions rllib/offline/estimators/tests/test_ope.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.policy.sample_batch import concat_samples
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.numpy import convert_to_numpy
from pathlib import Path
import os
import copy
import numpy as np
import gym
import torch


class TestOPE(unittest.TestCase):
Expand Down Expand Up @@ -162,6 +166,48 @@ def test_ope_in_algo(self):
print(*list(std_est.items()), sep="\n")
print("\n\n\n")

def test_fqe_model(self):
# Test FQETorchModel for:
# (1) Check that it does not modify the underlying batch during training
# (2) Check that the stoppign criteria from FQE are working correctly
# (3) Check that using fqe._compute_action_probs equals brute force
# iterating over all actions with policy.compute_log_likelihoods
fqe = FQETorchModel(
policy=self.algo.get_policy(),
gamma=self.gamma,
**self.q_model_config,
)
tmp_batch = copy.deepcopy(self.batch)
losses = fqe.train(self.batch)

# Make sure FQETorchModel.train() does not modify self.batch
check(tmp_batch, self.batch)

# Make sure FQE stopping criteria are respected
assert (
len(losses) == fqe.n_iters or losses[-1] < fqe.delta
), f"FQE.train() terminated early in {len(losses)} steps with final loss"
f"{losses[-1]} for n_iters: {fqe.n_iters} and delta: {fqe.delta}"

# Test fqe._compute_action_probs against "brute force" method
# of computing log_prob for each possible action individually
# using policy.compute_log_likelihoods
obs = torch.tensor(self.batch["obs"], device=fqe.device)
action_probs = fqe._compute_action_probs(obs)
action_probs = convert_to_numpy(action_probs)

tmp_probs = []
for act in range(fqe.policy.action_space.n):
tmp_actions = np.zeros_like(self.batch["actions"]) + act
log_probs = fqe.policy.compute_log_likelihoods(
actions=tmp_actions,
obs_batch=self.batch["obs"],
)
tmp_probs.append(torch.exp(log_probs))
tmp_probs = torch.stack(tmp_probs).transpose(0, 1)
tmp_probs = convert_to_numpy(tmp_probs)
check(action_probs, tmp_probs, decimals=3)

def test_multiple_inputs(self):
# TODO (Rohan138): Test with multiple input files
pass
Expand Down

0 comments on commit 4fded80

Please sign in to comment.