Skip to content

Commit

Permalink
[RLlib] Fix use_lstm flag for ModelV2 (w/o ModelV1 wrapping) and ad…
Browse files Browse the repository at this point in the history
…d it for PyTorch. (ray-project#8734)
  • Loading branch information
sven1977 authored Jun 5, 2020
1 parent d787576 commit c74dc58
Show file tree
Hide file tree
Showing 21 changed files with 331 additions and 85 deletions.
2 changes: 1 addition & 1 deletion ci/travis/ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ build_sphinx_docs() {
if [ "${OSTYPE}" = msys ]; then
echo "WARNING: Documentation not built on Windows due to currently-unresolved issues"
else
sphinx-build -q -W -E -T -b html source _build/html
sphinx-build -q -E -T -b html source _build/html
fi
)
}
Expand Down
2 changes: 1 addition & 1 deletion doc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ open _build/html/index.html
To test if there are any build errors with the documentation, do the following.

```
sphinx-build -W -b html -d _build/doctrees source _build/html
sphinx-build -b html -d _build/doctrees source _build/html
```
2 changes: 1 addition & 1 deletion doc/source/rllib-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ Once implemented, the model can then be registered and used in place of a built-
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
class CustomTorchModel(nn.Module, TorchModelV2):
class CustomTorchModel(TorchModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name): ...
def forward(self, input_dict, state, seq_lens): ...
def value_function(self): ...
Expand Down
36 changes: 28 additions & 8 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,26 @@ py_test(
args = ["--torch", "--yaml-dir=tuned_examples/ppo"]
)

py_test(
name = "run_regression_tests_repeat_after_me_tf",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_tf"],
size = "medium",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/ppo/repeatafterme-ppo-lstm.yaml"],
args = ["--yaml-dir=tuned_examples/ppo"]
)

py_test(
name = "run_regression_tests_repeat_after_me_torch",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_tf"],
size = "medium",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/ppo/repeatafterme-ppo-lstm.yaml"],
args = ["--torch", "--yaml-dir=tuned_examples/ppo"]
)

# SAC
py_test(
name = "run_regression_tests_cartpole_sac_tf",
Expand Down Expand Up @@ -1285,7 +1305,7 @@ py_test(
name = "test_rollout",
main = "tests/test_rollout.py",
tags = ["tests_dir", "tests_dir_R"],
size = "enormous",
size = "large",
data = ["train.py", "rollout.py"],
srcs = ["tests/test_rollout.py"]
)
Expand Down Expand Up @@ -1417,7 +1437,7 @@ py_test(
name = "examples/cartpole_lstm_impala_tf",
main = "examples/cartpole_lstm.py",
tags = ["examples", "examples_C"],
size = "medium",
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
)
Expand All @@ -1426,7 +1446,7 @@ py_test(
name = "examples/cartpole_lstm_impala_torch",
main = "examples/cartpole_lstm.py",
tags = ["examples", "examples_C"],
size = "medium",
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--torch", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
)
Expand All @@ -1435,7 +1455,7 @@ py_test(
name = "examples/cartpole_lstm_ppo_tf",
main = "examples/cartpole_lstm.py",
tags = ["examples", "examples_C"],
size = "medium",
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
)
Expand All @@ -1444,7 +1464,7 @@ py_test(
name = "examples/cartpole_lstm_ppo_torch",
main = "examples/cartpole_lstm.py",
tags = ["examples", "examples_C"],
size = "medium",
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--torch", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
)
Expand All @@ -1453,7 +1473,7 @@ py_test(
name = "examples/cartpole_lstm_ppo_tf_with_prev_a_and_r",
main = "examples/cartpole_lstm.py",
tags = ["examples", "examples_C"],
size = "medium",
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--run=PPO", "--stop-reward=40", "--use-prev-action-reward", "--num-cpus=4"]
)
Expand All @@ -1462,7 +1482,7 @@ py_test(
name = "examples/centralized_critic_tf",
main = "examples/centralized_critic.py",
tags = ["examples", "examples_C"],
size = "medium",
size = "large",
srcs = ["examples/centralized_critic.py"],
args = ["--as-test", "--stop-reward=7.2"]
)
Expand All @@ -1471,7 +1491,7 @@ py_test(
name = "examples/centralized_critic_torch",
main = "examples/centralized_critic.py",
tags = ["examples", "examples_C"],
size = "medium",
size = "large",
srcs = ["examples/centralized_critic.py"],
args = ["--as-test", "--torch", "--stop-reward=7.2"]
)
Expand Down
7 changes: 4 additions & 3 deletions rllib/agents/impala/tests/test_impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ def test_impala_compilation(self):
local_cfg = config.copy()
for env in ["Pendulum-v0", "CartPole-v0"]:
print("Env={}".format(env))
print("w/ LSTM")
print("w/o LSTM")
# Test w/o LSTM.
local_cfg["model"]["use_lstm"] = False
local_cfg["num_aggregation_workers"] = 0
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
for i in range(num_iterations):
Expand All @@ -36,13 +37,13 @@ def test_impala_compilation(self):
trainer.stop()

# Test w/ LSTM.
print("w/o LSTM")
print("w/ LSTM")
local_cfg["model"]["use_lstm"] = True
local_cfg["num_aggregation_workers"] = 2
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
for i in range(num_iterations):
print(trainer.train())
check_compute_action(trainer)
check_compute_action(trainer, include_state=True)
trainer.stop()


Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/impala/vtrace_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def make_time_major(*args, **kw):
values = model.value_function()

if policy.is_recurrent():
max_seq_len = tf.reduce_max(train_batch["seq_lens"]) - 1
max_seq_len = tf.reduce_max(train_batch["seq_lens"])
mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
mask = tf.reshape(mask, [-1])
else:
Expand Down
6 changes: 3 additions & 3 deletions rllib/agents/impala/vtrace_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def _make_time_major(*args, **kw):
values = model.value_function()

if policy.is_recurrent():
max_seq_len = torch.max(train_batch["seq_lens"]) - 1
mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
mask = torch.reshape(mask, [-1])
max_seq_len = torch.max(train_batch["seq_lens"])
mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len)
mask = torch.reshape(mask_orig, [-1])
else:
mask = torch.ones_like(rewards)

Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/env/repeat_after_me_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class RepeatAfterMeEnv(gym.Env):
def __init__(self, config):
self.observation_space = Discrete(2)
self.action_space = Discrete(2)
self.delay = config["repeat_delay"]
self.delay = config.get("repeat_delay", 1)
assert self.delay >= 1, "`repeat_delay` must be at least 1!"
self.history = []

Expand Down
37 changes: 28 additions & 9 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork
from ray.rllib.models.tf.lstm_v1 import LSTM
from ray.rllib.models.tf.modelv1_compat import make_v1_wrapper
from ray.rllib.models.tf.recurrent_net import LSTMWrapper
from ray.rllib.models.tf.tf_action_dist import Categorical, \
Deterministic, DiagGaussian, Dirichlet, \
MultiActionDistribution, MultiCategorical
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.visionnet_v1 import VisionNetwork
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.recurrent_net import LSTMWrapper as \
TorchLSTMWrapper
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchDeterministic, TorchDiagGaussian, \
TorchMultiActionDistribution, TorchMultiCategorical
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils import try_import_tree
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
Expand Down Expand Up @@ -57,13 +60,13 @@
"vf_share_layers": True,

# == LSTM ==
# Whether to wrap the model with a LSTM
# Whether to wrap the model with an LSTM.
"use_lstm": False,
# Max seq len for training the LSTM, defaults to 20
# Max seq len for training the LSTM, defaults to 20.
"max_seq_len": 20,
# Size of the LSTM cell
# Size of the LSTM cell.
"lstm_cell_size": 256,
# Whether to feed a_{t-1}, r_{t-1} to LSTM
# Whether to feed a_{t-1}, r_{t-1} to LSTM.
"lstm_use_prev_action_reward": False,
# When using modelv1 models with a modelv2 algorithm, you may have to
# define the state shape here (e.g., [256, 256]).
Expand Down Expand Up @@ -107,8 +110,9 @@ class ModelCatalog:
>>> observation = prep.transform(raw_observation)
>>> dist_class, dist_dim = ModelCatalog.get_action_dist(
env.action_space, {})
>>> model = ModelCatalog.get_model(inputs, dist_dim, options)
... env.action_space, {})
>>> model = ModelCatalog.get_model_v2(
... obs_space, action_space, num_outputs, options)
>>> dist = dist_class(model.outputs, model)
>>> action = dist.sample()
"""
Expand Down Expand Up @@ -307,6 +311,7 @@ def get_model_v2(obs_space,
else:
model_cls = _global_registry.get(RLLIB_MODEL,
model_config["custom_model"])

# TODO(sven): Hard-deprecate Model(V1).
if issubclass(model_cls, ModelV2):
logger.info("Wrapping {} as {}".format(model_cls,
Expand Down Expand Up @@ -374,10 +379,18 @@ def track_var_creation(next_creator, **kw):

if framework in ["tf", "tfe"]:
v2_class = None
# try to get a default v2 model
# Try to get a default v2 model.
if not model_config.get("custom_model"):
v2_class = default_model or ModelCatalog._get_v2_model_class(
obs_space, model_config, framework=framework)

if model_config.get("use_lstm"):
wrapped_cls = v2_class
forward = wrapped_cls.forward
v2_class = ModelCatalog._wrap_if_needed(
wrapped_cls, LSTMWrapper)
v2_class._wrapped_forward = forward

# fallback to a default v1 model
if v2_class is None:
if tf.executing_eagerly():
Expand All @@ -387,14 +400,20 @@ def track_var_creation(next_creator, **kw):
"observation space: {}, use_lstm={}".format(
obs_space, model_config.get("use_lstm")))
v2_class = make_v1_wrapper(ModelCatalog.get_model)
# wrap in the requested interface
# Wrap in the requested interface.
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
return wrapper(obs_space, action_space, num_outputs, model_config,
name, **model_kwargs)
elif framework == "torch":
v2_class = \
default_model or ModelCatalog._get_v2_model_class(
obs_space, model_config, framework=framework)
if model_config.get("use_lstm"):
wrapped_cls = v2_class
forward = wrapped_cls.forward
v2_class = ModelCatalog._wrap_if_needed(
wrapped_cls, TorchLSTMWrapper)
v2_class._wrapped_forward = forward
# Wrap in the requested interface.
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
return wrapper(obs_space, action_space, num_outputs, model_config,
Expand Down
35 changes: 21 additions & 14 deletions rllib/models/tf/fcnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
vf_share_layers = model_config.get("vf_share_layers")
free_log_std = model_config.get("free_log_std")

# Maybe generate free-floating bias variables for the second half of
# Generate free-floating bias variables for the second half of
# the outputs.
if free_log_std:
assert num_outputs % 2 == 0, (
Expand All @@ -34,7 +34,10 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
# We are using obs_flat, so take the flattened shape as input.
inputs = tf.keras.layers.Input(
shape=(np.product(obs_space.shape), ), name="observations")
last_layer = layer_out = inputs
# Last hidden layer output (before logits outputs).
last_layer = inputs
# The action distribution outputs.
logits_out = None
i = 1

# Create layers 0 to second-last.
Expand All @@ -49,7 +52,7 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
# The last layer is adjusted to be of size num_outputs, but it's a
# layer with activation.
if no_final_linear and num_outputs:
layer_out = tf.keras.layers.Dense(
logits_out = tf.keras.layers.Dense(
num_outputs,
name="fc_out",
activation=activation,
Expand All @@ -64,46 +67,50 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
if num_outputs:
layer_out = tf.keras.layers.Dense(
logits_out = tf.keras.layers.Dense(
num_outputs,
name="fc_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
# Adjust num_outputs to be the number of nodes in the last layer.
else:
self.num_outputs = (
[np.product(obs_space.shape)] + hiddens[-1:-1])[-1]
[np.product(obs_space.shape)] + hiddens[-1:])[-1]

# Concat the log std vars to the end of the state-dependent means.
if free_log_std:
if free_log_std and logits_out is not None:

def tiled_log_std(x):
return tf.tile(
tf.expand_dims(self.log_std_var, 0), [tf.shape(x)[0], 1])

log_std_out = tf.keras.layers.Lambda(tiled_log_std)(inputs)
layer_out = tf.keras.layers.Concatenate(axis=1)(
[layer_out, log_std_out])
logits_out = tf.keras.layers.Concatenate(axis=1)(
[logits_out, log_std_out])

last_vf_layer = None
if not vf_share_layers:
# build a parallel set of hidden layers for the value net
last_layer = inputs
# Build a parallel set of hidden layers for the value net.
last_vf_layer = inputs
i = 1
for size in hiddens:
last_layer = tf.keras.layers.Dense(
last_vf_layer = tf.keras.layers.Dense(
size,
name="fc_value_{}".format(i),
activation=activation,
kernel_initializer=normc_initializer(1.0))(last_layer)
kernel_initializer=normc_initializer(1.0))(last_vf_layer)
i += 1

value_out = tf.keras.layers.Dense(
1,
name="value_out",
activation=None,
kernel_initializer=normc_initializer(0.01))(last_layer)
kernel_initializer=normc_initializer(0.01))(
last_vf_layer if last_vf_layer is not None else last_layer)

self.base_model = tf.keras.Model(inputs, [layer_out, value_out])
self.base_model = tf.keras.Model(
inputs, [(logits_out
if logits_out is not None else last_layer), value_out])
self.register_variables(self.base_model.variables)

def forward(self, input_dict, state, seq_lens):
Expand Down
Loading

0 comments on commit c74dc58

Please sign in to comment.