Skip to content

Commit

Permalink
[RLlib]: Off-Policy Evaluation fixes. (ray-project#25899)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohan138 authored Jun 21, 2022
1 parent e108766 commit 28df3f3
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 73 deletions.
10 changes: 6 additions & 4 deletions rllib/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def summarize_episodes(
if new_episodes is None:
new_episodes = episodes

episodes, estimates = _partition(episodes)
new_episodes, _ = _partition(new_episodes)
episodes, _ = _partition(episodes)
new_episodes, estimates = _partition(new_episodes)

episode_rewards = []
episode_lengths = []
Expand Down Expand Up @@ -223,9 +223,11 @@ def summarize_episodes(
for k, v in e.metrics.items():
acc[k].append(v)
for name, metrics in estimators.items():
out = {}
for k, v_list in metrics.items():
metrics[k] = np.mean(v_list)
estimators[name] = dict(metrics)
out[k + "_mean"] = np.mean(v_list)
out[k + "_std"] = np.std(v_list)
estimators[name] = out

return dict(
episode_reward_max=max_reward,
Expand Down
85 changes: 54 additions & 31 deletions rllib/offline/estimators/direct_method.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Tuple, List, Generator
import logging
from typing import Tuple, Generator, List
from ray.rllib.offline.estimators.off_policy_estimator import (
OffPolicyEstimator,
OffPolicyEstimate,
)
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import SampleBatchType
Expand All @@ -16,44 +17,58 @@

torch, nn = try_import_torch()

logger = logging.getLogger()

# TODO (rohan): replace with AIR/parallel workers
# (And find a better name than `should_train`)
@DeveloperAPI
def k_fold_cv(
batch: SampleBatchType, k: int, should_train: bool = True
) -> Generator[Tuple[List[SampleBatch]], None, None]:
"""Utility function that returns a k-fold cross validation generator
over episodes from the given batch. If the number of episodes in the
batch is less than `k` or `should_train` is set to False, yields an empty
list for train_episodes and all the episodes in test_episodes.

@ExperimentalAPI
def train_test_split(
batch: SampleBatchType,
train_test_split_val: float = 0.0,
k: int = 0,
) -> Generator[Tuple[List[SampleBatch]], None, None]:
"""Utility function that returns either a train/test split or
a k-fold cross validation generator over episodes from the given batch.
By default, `k` is set to 0.0, which sets eval_batch = batch
and train_batch to an empty SampleBatch.
Args:
batch: A SampleBatch of episodes to split
k: Number of cross-validation splits
should_train: True by default. If False, yield [], [episodes].
train_test_split_val: Split the batch into a training batch with
`train_test_split_val * n_episodes` episodes and an evaluation batch
with `(1 - train_test_split_val) * n_episodes` episodes. If not
specified, use `k` for k-fold cross validation instead.
k: k-fold cross validation for training model and evaluating OPE.
Returns:
A tuple with two lists of SampleBatches (train_episodes, test_episodes)
A tuple with two SampleBatches (eval_batch, train_batch)
"""
if not train_test_split_val and not k:
logger.log(
"`train_test_split_val` and `k` are both 0;" "not generating training batch"
)
yield [batch], [SampleBatch()]
return
episodes = batch.split_by_episode()
n_episodes = len(episodes)
if n_episodes < k or not should_train:
yield [], episodes
# Train-test split
if train_test_split_val:
train_episodes = episodes[: int(n_episodes * train_test_split_val)]
eval_episodes = episodes[int(n_episodes * train_test_split_val) :]
yield eval_episodes, train_episodes
return
# k-fold cv
assert n_episodes >= k, f"Not enough eval episodes in batch for {k}-fold cv!"
n_fold = n_episodes // k
for i in range(k):
train_episodes = episodes[: i * n_fold] + episodes[(i + 1) * n_fold :]
if i != k - 1:
test_episodes = episodes[i * n_fold : (i + 1) * n_fold]
eval_episodes = episodes[i * n_fold : (i + 1) * n_fold]
else:
# Append remaining episodes onto the last test_episodes
test_episodes = episodes[i * n_fold :]
yield train_episodes, test_episodes
# Append remaining episodes onto the last eval_episodes
eval_episodes = episodes[i * n_fold :]
yield eval_episodes, train_episodes
return


@DeveloperAPI
@ExperimentalAPI
class DirectMethod(OffPolicyEstimator):
"""The Direct Method estimator.
Expand All @@ -66,7 +81,8 @@ def __init__(
policy: Policy,
gamma: float,
q_model_type: str = "fqe",
k: int = 5,
train_test_split_val: float = 0.0,
k: int = 0,
**kwargs,
):
"""
Expand All @@ -80,8 +96,12 @@ def __init__(
or "qreg" for Q-Regression, or a custom model that implements:
- `estimate_q(states,actions)`
- `estimate_v(states, action_probs)`
k: k-fold cross validation for training model and evaluating OPE
kwargs: Optional arguments for the specified Q model
train_test_split_val: Split the batch into a training batch with
`train_test_split_val * n_episodes` episodes and an evaluation batch
with `(1 - train_test_split_val) * n_episodes` episodes. If not
specified, use `k` for k-fold cross validation instead.
k: k-fold cross validation for training model and evaluating OPE.
kwargs: Optional arguments for the specified Q model.
"""

super().__init__(name, policy, gamma)
Expand Down Expand Up @@ -117,17 +137,20 @@ def __init__(
gamma=gamma,
**kwargs,
)
self.train_test_split_val = train_test_split_val
self.k = k
self.losses = []

@override(OffPolicyEstimator)
def estimate(
self, batch: SampleBatchType, should_train: bool = True
) -> OffPolicyEstimate:
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
estimates = []
# Split data into train and test using k-fold cross validation
for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train):
# Split data into train and test batches
for train_episodes, test_episodes in train_test_split(
batch,
self.train_test_split_val,
self.k,
):

# Train Q-function
if train_episodes:
Expand Down
17 changes: 11 additions & 6 deletions rllib/offline/estimators/doubly_robust.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimate
from ray.rllib.offline.estimators.direct_method import DirectMethod, k_fold_cv
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.offline.estimators.direct_method import DirectMethod, train_test_split
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.numpy import convert_to_numpy
import numpy as np


@DeveloperAPI
@ExperimentalAPI
class DoublyRobust(DirectMethod):
"""The Doubly Robust (DR) estimator.
DR estimator described in https://arxiv.org/pdf/1511.03722.pdf"""

@override(DirectMethod)
def estimate(
self, batch: SampleBatchType, should_train: bool = True
self,
batch: SampleBatchType,
) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
estimates = []
# Split data into train and test using k-fold cross validation
for train_episodes, test_episodes in k_fold_cv(batch, self.k, should_train):
# Split data into train and test batches
for train_episodes, test_episodes in train_test_split(
batch,
self.train_test_split_val,
self.k,
):

# Train Q-function
if train_episodes:
Expand Down
8 changes: 4 additions & 4 deletions rllib/offline/estimators/fqe_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType

torch, nn = try_import_torch()


@DeveloperAPI
@ExperimentalAPI
class FQETorchModel:
"""Pytorch implementation of the Fitted Q-Evaluation (FQE) model from
https://arxiv.org/pdf/1911.06854.pdf
Expand Down Expand Up @@ -153,7 +153,7 @@ def train_q(self, batch: SampleBatch) -> TensorType:
)

q_values, _ = self.q_model({"obs": obs}, [], None)
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
q_acts = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
next_q_values, _ = self.target_q_model({"obs": next_obs}, [], None)
next_v = torch.sum(next_q_values * next_action_prob, axis=-1)
Expand Down Expand Up @@ -188,7 +188,7 @@ def estimate_q(
q_values, _ = self.q_model({"obs": obs}, [], None)
if actions is not None:
actions = torch.tensor(actions, device=self.device, dtype=int)
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze()
q_values = torch.gather(q_values, -1, actions.unsqueeze(-1)).squeeze(-1)
return q_values.detach()

def estimate_v(
Expand Down
4 changes: 2 additions & 2 deletions rllib/offline/estimators/importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
OffPolicyEstimator,
OffPolicyEstimate,
)
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.annotations import override, ExperimentalAPI
from ray.rllib.utils.typing import SampleBatchType
from typing import List
import numpy as np


@DeveloperAPI
@ExperimentalAPI
class ImportanceSampling(OffPolicyEstimator):
"""The step-wise IS estimator.
Expand Down
26 changes: 13 additions & 13 deletions rllib/offline/estimators/off_policy_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.offline.io_context import IOContext
from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.numpy import convert_to_numpy
Expand All @@ -11,16 +11,16 @@

logger = logging.getLogger(__name__)

OffPolicyEstimate = DeveloperAPI(
OffPolicyEstimate = ExperimentalAPI(
namedtuple("OffPolicyEstimate", ["estimator_name", "metrics"])
)


@DeveloperAPI
@ExperimentalAPI
class OffPolicyEstimator:
"""Interface for an off policy reward estimator."""

@DeveloperAPI
@ExperimentalAPI
def __init__(self, name: str, policy: Policy, gamma: float):
"""Initializes an OffPolicyEstimator instance.
Expand All @@ -34,7 +34,7 @@ def __init__(self, name: str, policy: Policy, gamma: float):
self.gamma = gamma
self.new_estimates = []

@DeveloperAPI
@ExperimentalAPI
def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]:
"""Returns a list of off policy estimates for the given batch of episodes.
Expand All @@ -46,7 +46,7 @@ def estimate(self, batch: SampleBatchType) -> List[OffPolicyEstimate]:
"""
raise NotImplementedError

@DeveloperAPI
@ExperimentalAPI
def train(self, batch: SampleBatchType) -> TensorType:
"""Trains an Off-Policy Estimator on a batch of experiences.
A model-based estimator should override this and train
Expand All @@ -60,7 +60,7 @@ def train(self, batch: SampleBatchType) -> TensorType:
"""
pass

@DeveloperAPI
@ExperimentalAPI
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
"""Returns log likelihood for actions in given batch for policy.
Expand Down Expand Up @@ -92,7 +92,7 @@ def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
log_likelihoods = convert_to_numpy(log_likelihoods)
return log_likelihoods

@DeveloperAPI
@ExperimentalAPI
def check_can_estimate_for(self, batch: SampleBatchType) -> None:
"""Checks if we support off policy estimation (OPE) on given batch.
Expand All @@ -119,7 +119,7 @@ def check_can_estimate_for(self, batch: SampleBatchType) -> None:
"`off_policy_estimation_methods: {}` to disable estimation."
)

@DeveloperAPI
@ExperimentalAPI
def process(self, batch: SampleBatchType) -> None:
"""Computes off policy estimates (OPE) on batch and stores results.
Thus-far collected results can be retrieved then by calling
Expand All @@ -130,7 +130,7 @@ def process(self, batch: SampleBatchType) -> None:
"""
self.new_estimates.extend(self.estimate(batch))

@DeveloperAPI
@ExperimentalAPI
def get_metrics(self, get_losses: bool = False) -> List[OffPolicyEstimate]:
"""Returns list of new episode metric estimates since the last call.
Expand All @@ -154,7 +154,7 @@ def get_metrics(self, get_losses: bool = False) -> List[OffPolicyEstimate]:

@Deprecated(help="OffPolicyEstimator.__init__(policy, gamma, config)", error=False)
@classmethod
@DeveloperAPI
@ExperimentalAPI
def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
"""Creates an off-policy estimator from an IOContext object.
Extracts Policy and gamma (discount factor) information from the
Expand All @@ -178,11 +178,11 @@ def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
return cls(policy, gamma, config)

@Deprecated(new="OffPolicyEstimator.create_from_io_context", error=True)
@DeveloperAPI
@ExperimentalAPI
def create(self, *args, **kwargs):
return self.create_from_io_context(*args, **kwargs)

@Deprecated(new="OffPolicyEstimator.compute_log_likelihoods", error=False)
@DeveloperAPI
@ExperimentalAPI
def action_prob(self, *args, **kwargs):
return self.compute_log_likelihoods(*args, **kwargs)
Loading

0 comments on commit 28df3f3

Please sign in to comment.