diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 9afc91effbe3..652f3765be45 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -45,6 +45,10 @@ from ray.rllib.execution.buffers.multi_agent_replay_buffer import ( MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer, ) +from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling +from ray.rllib.offline.estimators.weighted_importance_sampling import ( + WeightedImportanceSampling, +) from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer from ray.rllib.execution.common import WORKER_UPDATE_TIMER from ray.rllib.execution.rollout_ops import ( @@ -558,11 +562,15 @@ # Specify how to evaluate the current policy. This only has an effect when # reading offline experiences ("input" is not "sampler"). # Available options: - # - "wis": the weighted step-wise importance sampling estimator. - # - "is": the step-wise importance sampling estimator. - # - "simulation": run the environment in the background, but use + # - "simulation": Run the environment in the background, but use # this data for evaluation only and not for learning. - "input_evaluation": ["is", "wis"], + # - Any subclass of OffPolicyEstimator, e.g. + # ray.rllib.offline.estimators.is::ImportanceSampling or your own custom + # subclass. + "input_evaluation": [ + ImportanceSampling, + WeightedImportanceSampling, + ], # Whether to run postprocess_trajectory() on the trajectory fragments from # offline inputs. Note that postprocessing will be done using the *current* # policy, not the *behavior* policy, which is typically undesirable for diff --git a/rllib/agents/trainer_config.py b/rllib/agents/trainer_config.py index 61e1bf322087..c56a2950e000 100644 --- a/rllib/agents/trainer_config.py +++ b/rllib/agents/trainer_config.py @@ -15,6 +15,10 @@ from ray.rllib.evaluation.collectors.sample_collector import SampleCollector from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector from ray.rllib.models import MODEL_DEFAULTS +from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling +from ray.rllib.offline.estimators.weighted_importance_sampling import ( + WeightedImportanceSampling, +) from ray.rllib.utils.typing import ( EnvConfigDict, EnvType, @@ -170,7 +174,10 @@ def __init__(self, trainer_class=None): self.input_ = "sampler" self.input_config = {} self.actions_in_input_normalized = False - self.input_evaluation = ["is", "wis"] + self.input_evaluation = [ + ImportanceSampling, + WeightedImportanceSampling, + ] self.postprocess_inputs = False self.shuffle_buffer_size = 0 self.output = None @@ -863,13 +870,14 @@ def offline_data( are already normalized (between -1.0 and 1.0). This is usually the case when the offline file has been generated by another RLlib algorithm (e.g. PPO or SAC), while "normalize_actions" was set to True. - input_evaluation: Specify how to evaluate the current policy. This only has - an effect when reading offline experiences ("input" is not "sampler"). + input_evaluation: How to evaluate the policy performance. Setting this only + makes sense when the input is reading offline data. Available options: - - "wis": the weighted step-wise importance sampling estimator. - - "is": the step-wise importance sampling estimator. - - "simulation": run the environment in the background, but use + - "simulation" (str): Run the environment in the background, but use this data for evaluation only and not for learning. + - Any subclass (type) of the OffPolicyEstimator API class, e.g. + `ray.rllib.offline.estimators.importance_sampling::ImportanceSampling` + or your own custom subclass. postprocess_inputs: Whether to run postprocess_trajectory() on the trajectory fragments from offline inputs. Note that postprocessing will be done using the *current* policy, not the *behavior* policy, which diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 77700740b9f3..beee873d6b57 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -35,8 +35,10 @@ from ray.rllib.models.preprocessors import Preprocessor from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate -from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator -from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator +from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling +from ray.rllib.offline.estimators.weighted_importance_sampling import ( + WeightedImportanceSampling, +) from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.policy_map import PolicyMap @@ -343,13 +345,14 @@ def __init__( DefaultCallbacks for training/policy/rollout-worker callbacks. input_creator: Function that returns an InputReader object for loading previous generated experiences. - input_evaluation: How to evaluate the policy - performance. This only makes sense to set when the input is - reading offline data. The possible values include: - - "is": the step-wise importance sampling estimator. - - "wis": the weighted step-wise is estimator. - - "simulation": run the environment in the background, but - use this data for evaluation only and never for learning. + input_evaluation: How to evaluate the policy performance. Setting this only + makes sense when the input is reading offline data. + Available options: + - "simulation" (str): Run the environment in the background, but use + this data for evaluation only and not for learning. + - Any subclass (type) of the OffPolicyEstimator API class, e.g. + `ray.rllib.offline.estimators.importance_sampling::ImportanceSampling` + or your own custom subclass. output_creator: Function that returns an OutputWriter object for saving generated experiences. remote_worker_envs: If using num_envs_per_worker > 1, @@ -710,24 +713,41 @@ def wrap(env): ) self.reward_estimators: List[OffPolicyEstimator] = [] for method in input_evaluation: + if method == "is": + method = ImportanceSampling + deprecation_warning( + old="config.input_evaluation=[is]", + new="from ray.rllib.offline.is_estimator import " + f"{method.__name__}; config.input_evaluation=" + f"[{method.__name__}]", + error=False, + ) + elif method == "wis": + method = WeightedImportanceSampling + deprecation_warning( + old="config.input_evaluation=[is]", + new="from ray.rllib.offline.wis_estimator import " + f"{method.__name__}; config.input_evaluation=" + f"[{method.__name__}]", + error=False, + ) + if method == "simulation": logger.warning( "Requested 'simulation' input evaluation method: " "will discard all sampler outputs and keep only metrics." ) sample_async = True - elif method == "is": - ise = ImportanceSamplingEstimator.create_from_io_context( - self.io_context + elif isinstance(method, type) and issubclass(method, OffPolicyEstimator): + self.reward_estimators.append( + method.create_from_io_context(self.io_context) ) - self.reward_estimators.append(ise) - elif method == "wis": - wise = WeightedImportanceSamplingEstimator.create_from_io_context( - self.io_context - ) - self.reward_estimators.append(wise) else: - raise ValueError("Unknown evaluation method: {}".format(method)) + raise ValueError( + f"Unknown evaluation method: {method}! Must be " + "either `simulation` or a sub-class of ray.rllib.offline." + "off_policy_estimator::OffPolicyEstimator" + ) render = False if policy_config.get("render_env") is True and ( diff --git a/rllib/offline/estimators/__init__.py b/rllib/offline/estimators/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/offline/estimators/importance_sampling.py b/rllib/offline/estimators/importance_sampling.py new file mode 100644 index 000000000000..f00cad42ac0f --- /dev/null +++ b/rllib/offline/estimators/importance_sampling.py @@ -0,0 +1,41 @@ +from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import SampleBatchType + + +class ImportanceSampling(OffPolicyEstimator): + """The step-wise IS estimator. + + Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf""" + + @override(OffPolicyEstimator) + def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: + self.check_can_estimate_for(batch) + + rewards, old_prob = batch["rewards"], batch["action_prob"] + new_prob = self.action_log_likelihood(batch) + + # calculate importance ratios + p = [] + for t in range(batch.count): + if t == 0: + pt_prev = 1.0 + else: + pt_prev = p[t - 1] + p.append(pt_prev * new_prob[t] / old_prob[t]) + + # calculate stepwise IS estimate + V_prev, V_step_IS = 0.0, 0.0 + for t in range(batch.count): + V_prev += rewards[t] * self.gamma ** t + V_step_IS += p[t] * rewards[t] * self.gamma ** t + + estimation = OffPolicyEstimate( + "importance_sampling", + { + "V_prev": V_prev, + "V_step_IS": V_step_IS, + "V_gain_est": V_step_IS / max(1e-8, V_prev), + }, + ) + return estimation diff --git a/rllib/offline/estimators/weighted_importance_sampling.py b/rllib/offline/estimators/weighted_importance_sampling.py new file mode 100644 index 000000000000..7c6876b8a574 --- /dev/null +++ b/rllib/offline/estimators/weighted_importance_sampling.py @@ -0,0 +1,55 @@ +from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate +from ray.rllib.policy import Policy +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import SampleBatchType + + +class WeightedImportanceSampling(OffPolicyEstimator): + """The weighted step-wise IS estimator. + + Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf""" + + def __init__(self, policy: Policy, gamma: float): + super().__init__(policy, gamma) + self.filter_values = [] + self.filter_counts = [] + + @override(OffPolicyEstimator) + def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: + self.check_can_estimate_for(batch) + + rewards, old_prob = batch["rewards"], batch["action_prob"] + new_prob = self.action_log_likelihood(batch) + + # calculate importance ratios + p = [] + for t in range(batch.count): + if t == 0: + pt_prev = 1.0 + else: + pt_prev = p[t - 1] + p.append(pt_prev * new_prob[t] / old_prob[t]) + for t, v in enumerate(p): + if t >= len(self.filter_values): + self.filter_values.append(v) + self.filter_counts.append(1.0) + else: + self.filter_values[t] += v + self.filter_counts[t] += 1.0 + + # calculate stepwise weighted IS estimate + V_prev, V_step_WIS = 0.0, 0.0 + for t in range(batch.count): + V_prev += rewards[t] * self.gamma ** t + w_t = self.filter_values[t] / self.filter_counts[t] + V_step_WIS += p[t] / w_t * rewards[t] * self.gamma ** t + + estimation = OffPolicyEstimate( + "weighted_importance_sampling", + { + "V_prev": V_prev, + "V_step_WIS": V_step_WIS, + "V_gain_est": V_step_WIS / max(1e-8, V_prev), + }, + ) + return estimation diff --git a/rllib/offline/is_estimator.py b/rllib/offline/is_estimator.py index cb179aa477ff..793e9612e2ae 100644 --- a/rllib/offline/is_estimator.py +++ b/rllib/offline/is_estimator.py @@ -1,41 +1,10 @@ -from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate -from ray.rllib.utils.annotations import override -from ray.rllib.utils.typing import SampleBatchType +from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling +from ray.rllib.utils.deprecation import Deprecated -class ImportanceSamplingEstimator(OffPolicyEstimator): - """The step-wise IS estimator. - - Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf""" - - @override(OffPolicyEstimator) - def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: - self.check_can_estimate_for(batch) - - rewards, old_prob = batch["rewards"], batch["action_prob"] - new_prob = self.action_log_likelihood(batch) - - # calculate importance ratios - p = [] - for t in range(batch.count): - if t == 0: - pt_prev = 1.0 - else: - pt_prev = p[t - 1] - p.append(pt_prev * new_prob[t] / old_prob[t]) - - # calculate stepwise IS estimate - V_prev, V_step_IS = 0.0, 0.0 - for t in range(batch.count): - V_prev += rewards[t] * self.gamma ** t - V_step_IS += p[t] * rewards[t] * self.gamma ** t - - estimation = OffPolicyEstimate( - "is", - { - "V_prev": V_prev, - "V_step_IS": V_step_IS, - "V_gain_est": V_step_IS / max(1e-8, V_prev), - }, - ) - return estimation +@Deprecated( + new="ray.rllib.offline.estimators.importance_sampling::ImportanceSampling", + error=False, +) +class ImportanceSamplingEstimator(ImportanceSampling): + pass diff --git a/rllib/offline/off_policy_estimator.py b/rllib/offline/off_policy_estimator.py index 512969422946..38e8b4a684f9 100644 --- a/rllib/offline/off_policy_estimator.py +++ b/rllib/offline/off_policy_estimator.py @@ -21,7 +21,6 @@ class OffPolicyEstimator: """Interface for an off policy reward estimator.""" - @DeveloperAPI def __init__(self, policy: Policy, gamma: float): """Initializes an OffPolicyEstimator instance. @@ -33,6 +32,7 @@ def __init__(self, policy: Policy, gamma: float): self.gamma = gamma self.new_estimates = [] + @DeveloperAPI @classmethod def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator": """Creates an off-policy estimator from an IOContext object. diff --git a/rllib/offline/wis_estimator.py b/rllib/offline/wis_estimator.py index 70fe0b819f75..f0d4596f5e41 100644 --- a/rllib/offline/wis_estimator.py +++ b/rllib/offline/wis_estimator.py @@ -1,55 +1,13 @@ -from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, OffPolicyEstimate -from ray.rllib.policy import Policy -from ray.rllib.utils.annotations import override -from ray.rllib.utils.typing import SampleBatchType - - -class WeightedImportanceSamplingEstimator(OffPolicyEstimator): - """The weighted step-wise IS estimator. - - Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf""" - - def __init__(self, policy: Policy, gamma: float): - super().__init__(policy, gamma) - self.filter_values = [] - self.filter_counts = [] - - @override(OffPolicyEstimator) - def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: - self.check_can_estimate_for(batch) - - rewards, old_prob = batch["rewards"], batch["action_prob"] - new_prob = self.action_log_likelihood(batch) - - # calculate importance ratios - p = [] - for t in range(batch.count): - if t == 0: - pt_prev = 1.0 - else: - pt_prev = p[t - 1] - p.append(pt_prev * new_prob[t] / old_prob[t]) - for t, v in enumerate(p): - if t >= len(self.filter_values): - self.filter_values.append(v) - self.filter_counts.append(1.0) - else: - self.filter_values[t] += v - self.filter_counts[t] += 1.0 - - # calculate stepwise weighted IS estimate - V_prev, V_step_WIS = 0.0, 0.0 - for t in range(batch.count): - V_prev += rewards[t] * self.gamma ** t - w_t = self.filter_values[t] / self.filter_counts[t] - V_step_WIS += p[t] / w_t * rewards[t] * self.gamma ** t - - estimation = OffPolicyEstimate( - "wis", - { - "V_prev": V_prev, - "V_step_WIS": V_step_WIS, - "V_gain_est": V_step_WIS / max(1e-8, V_prev), - }, - ) - return estimation +from ray.rllib.offline.estimators.weighted_importance_sampling import ( + WeightedImportanceSampling, +) +from ray.rllib.utils.deprecation import Deprecated + + +@Deprecated( + new="ray.rllib.offline.estimators.weighted_importance_sampling::" + "WeightedImportanceSampling", + error=False, +) +class WeightedImportanceSamplingEstimator(WeightedImportanceSampling): + pass