Skip to content

Commit

Permalink
[RLlib] Remove tf.py_function from all Schedule classes (not differen…
Browse files Browse the repository at this point in the history
…tiable and causes other bugs in MA setups). (ray-project#8304)

Remove tf.py_function from all Schedule classes (not differentiable and causes other bugs in MA setups). (ray-project#8304)
  • Loading branch information
sven1977 authored May 4, 2020
1 parent a00144f commit 6c2b9a4
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 36 deletions.
6 changes: 3 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ matrix:
install:
- . ./ci/travis/ci.sh build
script:
- travis_wait 90 bazel test --config=ci --test_output=streamed --build_tests_only --test_tag_filters=learning_tests_tf rllib/...
- travis_wait 90 bazel test --config=ci --test_output=errors --build_tests_only --test_tag_filters=learning_tests_tf rllib/...

# RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/regression_tests/*.yaml).
# Requested by Edi (MS): Test all learning capabilities with tf1.x
Expand All @@ -163,7 +163,7 @@ matrix:
install:
- . ./ci/travis/ci.sh build
script:
- travis_wait 90 bazel test --config=ci --test_output=streamed --build_tests_only --test_tag_filters=learning_tests_tf rllib/...
- travis_wait 90 bazel test --config=ci --test_output=errors --build_tests_only --test_tag_filters=learning_tests_tf rllib/...

# RLlib: Learning tests with torch (from rllib/tuned_examples/regression_tests/*.yaml).
- os: linux
Expand All @@ -180,7 +180,7 @@ matrix:
install:
- . ./ci/travis/ci.sh build
script:
- travis_wait 90 bazel test --config=ci --test_output=streamed --build_tests_only --test_tag_filters=learning_tests_torch rllib/...
- travis_wait 90 bazel test --config=ci --test_output=errors --build_tests_only --test_tag_filters=learning_tests_torch rllib/...

# RLlib: Quick Agent train.py runs (compilation & running, no(!) learning).
# Agent single tests (compilation, loss-funcs, etc..).
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def set_info(self, info):
self.info = info
return info

def get_info(self):
def get_info(self, sess=None):
return self.info


Expand Down
3 changes: 1 addition & 2 deletions rllib/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def __init__(self,
self._update_ops = update_ops
self._stats_fetches = {}
self._loss_input_dict = None
self.exploration_info = self.exploration.get_info()
self._timestep = timestep if timestep is not None else \
tf.placeholder(tf.int32, (), name="timestep")

Expand Down Expand Up @@ -346,7 +345,7 @@ def learn_on_batch(self, postprocessed_batch):

@override(Policy)
def get_exploration_info(self):
return self._sess.run(self.exploration_info)
return self.exploration.get_info(sess=self.get_session())

@override(Policy)
def get_weights(self):
Expand Down
8 changes: 7 additions & 1 deletion rllib/utils/exploration/epsilon_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def __init__(self,
self.last_timestep = get_variable(
0, framework=framework, tf_name="timestep")

# Build the tf-info-op.
if self.framework == "tf":
self._tf_info_op = self.get_info()

@override(Exploration)
def get_exploration_action(self,
*,
Expand Down Expand Up @@ -150,6 +154,8 @@ def _get_torch_exploration_action(self, q_values, explore, timestep):
return exploit_action, action_logp

@override(Exploration)
def get_info(self):
def get_info(self, sess=None):
if sess:
return sess.run(self._tf_info_op)
eps = self.epsilon_schedule(self.last_timestep)
return {"cur_epsilon": eps}
5 changes: 4 additions & 1 deletion rllib/utils/exploration/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,15 @@ def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
return sample_batch

@DeveloperAPI
def get_info(self):
def get_info(self, sess=None):
"""Returns a description of the current exploration state.
This is not necessarily the state itself (and cannot be used in
set_state!), but rather useful (e.g. debugging) information.
Args:
sess (Optional[tf.Session]): An optional tf Session object to use.
Returns:
dict: A description of the Exploration (not necessarily its state).
This may include tf.ops as values in graph mode.
Expand Down
8 changes: 7 additions & 1 deletion rllib/utils/exploration/gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __init__(self,
self.last_timestep = get_variable(
0, framework=self.framework, tf_name="timestep")

# Build the tf-info-op.
if self.framework == "tf":
self._tf_info_op = self.get_info()

@override(Exploration)
def get_exploration_action(self,
*,
Expand Down Expand Up @@ -157,11 +161,13 @@ def _get_torch_exploration_action(self, action_dist, explore, timestep):
return action, logp

@override(Exploration)
def get_info(self):
def get_info(self, sess=None):
"""Returns the current scale value.
Returns:
Union[float,tf.Tensor[float]]: The current scale value.
"""
if sess:
return sess.run(self._tf_info_op)
scale = self.scale_schedule(self.last_timestep)
return {"cur_scale": scale}
14 changes: 6 additions & 8 deletions rllib/utils/exploration/parameter_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,18 +260,16 @@ def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
noise_free_action_dist *
np.log(noise_free_action_dist /
(noisy_action_dist + SMALL_NUMBER)), 1))
current_epsilon = self.sub_exploration.get_info()["cur_epsilon"]
if tf_sess is not None:
current_epsilon = tf_sess.run(current_epsilon)
current_epsilon = self.sub_exploration.get_info(
sess=tf_sess)["cur_epsilon"]
delta = -np.log(1 - current_epsilon +
current_epsilon / self.action_space.n)
elif policy.dist_class in [Deterministic, TorchDeterministic]:
# Calculate MSE between noisy and non-noisy output (see [2]).
distance = np.sqrt(
np.mean(np.square(noise_free_action_dist - noisy_action_dist)))
current_scale = self.sub_exploration.get_info()["cur_scale"]
if tf_sess is not None:
current_scale = tf_sess.run(current_scale)
current_scale = self.sub_exploration.get_info(
sess=tf_sess)["cur_scale"]
delta = getattr(self.sub_exploration, "ou_sigma", 0.2) * \
current_scale

Expand Down Expand Up @@ -408,5 +406,5 @@ def _tf_remove_noise_op(self):
return tf.no_op()

@override(Exploration)
def get_info(self):
return {"cur_stddev": self.stddev}
def get_info(self, sess=None):
return {"cur_stddev": self.stddev_val}
2 changes: 2 additions & 0 deletions rllib/utils/schedules/constant_schedule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.schedules.schedule import Schedule


Expand All @@ -14,5 +15,6 @@ def __init__(self, value, framework):
super().__init__(framework=framework)
self._v = value

@override(Schedule)
def _value(self, t):
return self._v
6 changes: 3 additions & 3 deletions rllib/utils/schedules/exponential_schedule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.schedules.schedule import Schedule


Expand Down Expand Up @@ -28,10 +29,9 @@ def __init__(self,
self.initial_p = initial_p
self.decay_rate = decay_rate

@override(Schedule)
def _value(self, t):
"""
Returns the result of:
initial_p * decay_rate ** (`t`/t_max)
"""Returns the result of: initial_p * decay_rate ** (`t`/t_max)
"""
return self.initial_p * \
self.decay_rate ** (t / self.schedule_timesteps)
40 changes: 40 additions & 0 deletions rllib/utils/schedules/piecewise_schedule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.schedules.schedule import Schedule

tf = try_import_tf()


def _linear_interpolation(l, r, alpha):
return l + alpha * (r - l)
Expand Down Expand Up @@ -41,12 +45,48 @@ def __init__(self,
self.outside_value = outside_value
self.endpoints = endpoints

@override(Schedule)
def _value(self, t):
# Find t in our list of endpoints.
for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
# When found, return an interpolation (default: linear).
if l_t <= t < r_t:
alpha = float(t - l_t) / (r_t - l_t)
return self.interpolation(l, r, alpha)

# t does not belong to any of the pieces, return `self.outside_value`.
assert self.outside_value is not None
return self.outside_value

@override(Schedule)
def _tf_value_op(self, t):
assert self.outside_value is not None, \
"tf-version of PiecewiseSchedule requires `outside_value` to be " \
"provided!"

endpoints = tf.cast(
tf.stack([e[0] for e in self.endpoints] + [-1]), tf.int32)

# Create all possible interpolation results.
results_list = []
for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
alpha = tf.cast(t - l_t, tf.float32) / \
tf.cast(r_t - l_t, tf.float32)
results_list.append(self.interpolation(l, r, alpha))
# If t does not belong to any of the pieces, return `outside_value`.
results_list.append(self.outside_value)
results_list = tf.stack(results_list)

# Return correct results tensor depending on where we find t.
def _cond(i, x):
return tf.logical_not(
tf.logical_or(
tf.equal(endpoints[i + 1], -1),
tf.logical_and(endpoints[i] <= x, x < endpoints[i + 1])))

def _body(i, x):
return (i + 1, t)

idx_and_t = tf.while_loop(_cond, _body,
[tf.constant(0, dtype=tf.int32), t])
return results_list[idx_and_t[0]]
53 changes: 37 additions & 16 deletions rllib/utils/schedules/schedule.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from abc import ABCMeta, abstractmethod

from ray.rllib.utils.framework import check_framework
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.framework import check_framework, try_import_tf

tf = try_import_tf()


@DeveloperAPI
class Schedule(metaclass=ABCMeta):
"""
Schedule classes implement various time-dependent scheduling schemas, such
as:
"""Schedule classes implement various time-dependent scheduling schemas.
- Constant behavior.
- Linear decay.
- Piecewise decay.
- Exponential decay.
Useful for backend-agnostic rate/weight changes for learning rates,
exploration epsilons, beta parameters for prioritized replay, loss weights
Expand All @@ -25,6 +26,25 @@ class Schedule(metaclass=ABCMeta):
def __init__(self, framework):
self.framework = check_framework(framework)

def value(self, t):
"""Generates the value given a timestep (based on schedule's logic).
Args:
t (int): The time step. This could be a tf.Tensor.
Returns:
any: The calculated value depending on the schedule and `t`.
"""
if self.framework == "tf" and not tf.executing_eagerly():
return self._tf_value_op(t)
return self._value(t)

def __call__(self, t):
"""Simply calls self.value(t). Implemented to make Schedules callable.
"""
return self.value(t)

@DeveloperAPI
@abstractmethod
def _value(self, t):
"""
Expand All @@ -38,16 +58,17 @@ def _value(self, t):
"""
raise NotImplementedError

def value(self, t):
if self.framework == "tf":
return tf.cast(
tf.py_function(self._value, [t], tf.float64),
tf.float32,
name="schedule_value")
return self._value(t)

def __call__(self, t):
@DeveloperAPI
def _tf_value_op(self, t):
"""
Simply calls `self.value(t)`.
Returns the tf-op that calculates the value based on a time step input.
Args:
t (tf.Tensor): The time step op (int tf.Tensor).
Returns:
tf.Tensor: The calculated value depending on the schedule and `t`.
"""
return self.value(t)
# By default (most of the time), tf should work with python code.
# Override only if necessary.
return tf.constant(self._value(t))

0 comments on commit 6c2b9a4

Please sign in to comment.