Skip to content

Commit

Permalink
[RLlib] OPE (off policy estimator) API. (ray-project#24384)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored May 2, 2022
1 parent 0c5ac3b commit 7cca778
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 124 deletions.
16 changes: 12 additions & 4 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer,
)
from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.offline.estimators.weighted_importance_sampling import (
WeightedImportanceSampling,
)
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
from ray.rllib.execution.common import WORKER_UPDATE_TIMER
from ray.rllib.execution.rollout_ops import (
Expand Down Expand Up @@ -558,11 +562,15 @@
# Specify how to evaluate the current policy. This only has an effect when
# reading offline experiences ("input" is not "sampler").
# Available options:
# - "wis": the weighted step-wise importance sampling estimator.
# - "is": the step-wise importance sampling estimator.
# - "simulation": run the environment in the background, but use
# - "simulation": Run the environment in the background, but use
# this data for evaluation only and not for learning.
"input_evaluation": ["is", "wis"],
# - Any subclass of OffPolicyEstimator, e.g.
# ray.rllib.offline.estimators.is::ImportanceSampling or your own custom
# subclass.
"input_evaluation": [
ImportanceSampling,
WeightedImportanceSampling,
],
# Whether to run postprocess_trajectory() on the trajectory fragments from
# offline inputs. Note that postprocessing will be done using the *current*
# policy, not the *behavior* policy, which is typically undesirable for
Expand Down
20 changes: 14 additions & 6 deletions rllib/agents/trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.offline.estimators.weighted_importance_sampling import (
WeightedImportanceSampling,
)
from ray.rllib.utils.typing import (
EnvConfigDict,
EnvType,
Expand Down Expand Up @@ -170,7 +174,10 @@ def __init__(self, trainer_class=None):
self.input_ = "sampler"
self.input_config = {}
self.actions_in_input_normalized = False
self.input_evaluation = ["is", "wis"]
self.input_evaluation = [
ImportanceSampling,
WeightedImportanceSampling,
]
self.postprocess_inputs = False
self.shuffle_buffer_size = 0
self.output = None
Expand Down Expand Up @@ -863,13 +870,14 @@ def offline_data(
are already normalized (between -1.0 and 1.0). This is usually the case
when the offline file has been generated by another RLlib algorithm
(e.g. PPO or SAC), while "normalize_actions" was set to True.
input_evaluation: Specify how to evaluate the current policy. This only has
an effect when reading offline experiences ("input" is not "sampler").
input_evaluation: How to evaluate the policy performance. Setting this only
makes sense when the input is reading offline data.
Available options:
- "wis": the weighted step-wise importance sampling estimator.
- "is": the step-wise importance sampling estimator.
- "simulation": run the environment in the background, but use
- "simulation" (str): Run the environment in the background, but use
this data for evaluation only and not for learning.
- Any subclass (type) of the OffPolicyEstimator API class, e.g.
`ray.rllib.offline.estimators.importance_sampling::ImportanceSampling`
or your own custom subclass.
postprocess_inputs: Whether to run postprocess_trajectory() on the
trajectory fragments from offline inputs. Note that postprocessing will
be done using the *current* policy, not the *behavior* policy, which
Expand Down
58 changes: 39 additions & 19 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
from ray.rllib.models.preprocessors import Preprocessor
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.offline.estimators.weighted_importance_sampling import (
WeightedImportanceSampling,
)
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.policy_map import PolicyMap
Expand Down Expand Up @@ -343,13 +345,14 @@ def __init__(
DefaultCallbacks for training/policy/rollout-worker callbacks.
input_creator: Function that returns an InputReader object for
loading previous generated experiences.
input_evaluation: How to evaluate the policy
performance. This only makes sense to set when the input is
reading offline data. The possible values include:
- "is": the step-wise importance sampling estimator.
- "wis": the weighted step-wise is estimator.
- "simulation": run the environment in the background, but
use this data for evaluation only and never for learning.
input_evaluation: How to evaluate the policy performance. Setting this only
makes sense when the input is reading offline data.
Available options:
- "simulation" (str): Run the environment in the background, but use
this data for evaluation only and not for learning.
- Any subclass (type) of the OffPolicyEstimator API class, e.g.
`ray.rllib.offline.estimators.importance_sampling::ImportanceSampling`
or your own custom subclass.
output_creator: Function that returns an OutputWriter object for
saving generated experiences.
remote_worker_envs: If using num_envs_per_worker > 1,
Expand Down Expand Up @@ -710,24 +713,41 @@ def wrap(env):
)
self.reward_estimators: List[OffPolicyEstimator] = []
for method in input_evaluation:
if method == "is":
method = ImportanceSampling
deprecation_warning(
old="config.input_evaluation=[is]",
new="from ray.rllib.offline.is_estimator import "
f"{method.__name__}; config.input_evaluation="
f"[{method.__name__}]",
error=False,
)
elif method == "wis":
method = WeightedImportanceSampling
deprecation_warning(
old="config.input_evaluation=[is]",
new="from ray.rllib.offline.wis_estimator import "
f"{method.__name__}; config.input_evaluation="
f"[{method.__name__}]",
error=False,
)

if method == "simulation":
logger.warning(
"Requested 'simulation' input evaluation method: "
"will discard all sampler outputs and keep only metrics."
)
sample_async = True
elif method == "is":
ise = ImportanceSamplingEstimator.create_from_io_context(
self.io_context
elif isinstance(method, type) and issubclass(method, OffPolicyEstimator):
self.reward_estimators.append(
method.create_from_io_context(self.io_context)
)
self.reward_estimators.append(ise)
elif method == "wis":
wise = WeightedImportanceSamplingEstimator.create_from_io_context(
self.io_context
)
self.reward_estimators.append(wise)
else:
raise ValueError("Unknown evaluation method: {}".format(method))
raise ValueError(
f"Unknown evaluation method: {method}! Must be "
"either `simulation` or a sub-class of ray.rllib.offline."
"off_policy_estimator::OffPolicyEstimator"
)

render = False
if policy_config.get("render_env") is True and (
Expand Down
Empty file.
41 changes: 41 additions & 0 deletions rllib/offline/estimators/importance_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import SampleBatchType


class ImportanceSampling(OffPolicyEstimator):
"""The step-wise IS estimator.
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf"""

@override(OffPolicyEstimator)
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)

rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_log_likelihood(batch)

# calculate importance ratios
p = []
for t in range(batch.count):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])

# calculate stepwise IS estimate
V_prev, V_step_IS = 0.0, 0.0
for t in range(batch.count):
V_prev += rewards[t] * self.gamma ** t
V_step_IS += p[t] * rewards[t] * self.gamma ** t

estimation = OffPolicyEstimate(
"importance_sampling",
{
"V_prev": V_prev,
"V_step_IS": V_step_IS,
"V_gain_est": V_step_IS / max(1e-8, V_prev),
},
)
return estimation
55 changes: 55 additions & 0 deletions rllib/offline/estimators/weighted_importance_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import SampleBatchType


class WeightedImportanceSampling(OffPolicyEstimator):
"""The weighted step-wise IS estimator.
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""

def __init__(self, policy: Policy, gamma: float):
super().__init__(policy, gamma)
self.filter_values = []
self.filter_counts = []

@override(OffPolicyEstimator)
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)

rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_log_likelihood(batch)

# calculate importance ratios
p = []
for t in range(batch.count):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])
for t, v in enumerate(p):
if t >= len(self.filter_values):
self.filter_values.append(v)
self.filter_counts.append(1.0)
else:
self.filter_values[t] += v
self.filter_counts[t] += 1.0

# calculate stepwise weighted IS estimate
V_prev, V_step_WIS = 0.0, 0.0
for t in range(batch.count):
V_prev += rewards[t] * self.gamma ** t
w_t = self.filter_values[t] / self.filter_counts[t]
V_step_WIS += p[t] / w_t * rewards[t] * self.gamma ** t

estimation = OffPolicyEstimate(
"weighted_importance_sampling",
{
"V_prev": V_prev,
"V_step_WIS": V_step_WIS,
"V_gain_est": V_step_WIS / max(1e-8, V_prev),
},
)
return estimation
47 changes: 8 additions & 39 deletions rllib/offline/is_estimator.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,10 @@
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling
from ray.rllib.utils.deprecation import Deprecated


class ImportanceSamplingEstimator(OffPolicyEstimator):
"""The step-wise IS estimator.
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf"""

@override(OffPolicyEstimator)
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)

rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_log_likelihood(batch)

# calculate importance ratios
p = []
for t in range(batch.count):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])

# calculate stepwise IS estimate
V_prev, V_step_IS = 0.0, 0.0
for t in range(batch.count):
V_prev += rewards[t] * self.gamma ** t
V_step_IS += p[t] * rewards[t] * self.gamma ** t

estimation = OffPolicyEstimate(
"is",
{
"V_prev": V_prev,
"V_step_IS": V_step_IS,
"V_gain_est": V_step_IS / max(1e-8, V_prev),
},
)
return estimation
@Deprecated(
new="ray.rllib.offline.estimators.importance_sampling::ImportanceSampling",
error=False,
)
class ImportanceSamplingEstimator(ImportanceSampling):
pass
2 changes: 1 addition & 1 deletion rllib/offline/off_policy_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
class OffPolicyEstimator:
"""Interface for an off policy reward estimator."""

@DeveloperAPI
def __init__(self, policy: Policy, gamma: float):
"""Initializes an OffPolicyEstimator instance.
Expand All @@ -33,6 +32,7 @@ def __init__(self, policy: Policy, gamma: float):
self.gamma = gamma
self.new_estimates = []

@DeveloperAPI
@classmethod
def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
"""Creates an off-policy estimator from an IOContext object.
Expand Down
68 changes: 13 additions & 55 deletions rllib/offline/wis_estimator.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,13 @@
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import SampleBatchType


class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
"""The weighted step-wise IS estimator.
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""

def __init__(self, policy: Policy, gamma: float):
super().__init__(policy, gamma)
self.filter_values = []
self.filter_counts = []

@override(OffPolicyEstimator)
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)

rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_log_likelihood(batch)

# calculate importance ratios
p = []
for t in range(batch.count):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])
for t, v in enumerate(p):
if t >= len(self.filter_values):
self.filter_values.append(v)
self.filter_counts.append(1.0)
else:
self.filter_values[t] += v
self.filter_counts[t] += 1.0

# calculate stepwise weighted IS estimate
V_prev, V_step_WIS = 0.0, 0.0
for t in range(batch.count):
V_prev += rewards[t] * self.gamma ** t
w_t = self.filter_values[t] / self.filter_counts[t]
V_step_WIS += p[t] / w_t * rewards[t] * self.gamma ** t

estimation = OffPolicyEstimate(
"wis",
{
"V_prev": V_prev,
"V_step_WIS": V_step_WIS,
"V_gain_est": V_step_WIS / max(1e-8, V_prev),
},
)
return estimation
from ray.rllib.offline.estimators.weighted_importance_sampling import (
WeightedImportanceSampling,
)
from ray.rllib.utils.deprecation import Deprecated


@Deprecated(
new="ray.rllib.offline.estimators.weighted_importance_sampling::"
"WeightedImportanceSampling",
error=False,
)
class WeightedImportanceSamplingEstimator(WeightedImportanceSampling):
pass

0 comments on commit 7cca778

Please sign in to comment.