Skip to content

Commit

Permalink
[RLlib] Policy.compute_log_likelihoods() and SAC refactor. (issue ray…
Browse files Browse the repository at this point in the history
…-project#7107) (ray-project#7124)

* Exploration API (+EpsilonGreedy sub-class).

* Exploration API (+EpsilonGreedy sub-class).

* Cleanup/LINT.

* Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents).

* Add `error` option to deprecation_warning().

* WIP.

* Bug fix: Get exploration-info for tf framework.
Bug fix: Properly deprecate some DQN config keys.

* WIP.

* LINT.

* WIP.

* Split PerWorkerEpsilonGreedy out of EpsilonGreedy.
Docstrings.

* Fix bug in sampler.py in case Policy has self.exploration = None

* Update rllib/agents/dqn/dqn.py

Co-Authored-By: Eric Liang <[email protected]>

* WIP.

* Update rllib/agents/trainer.py

Co-Authored-By: Eric Liang <[email protected]>

* WIP.

* Change requests.

* LINT

* In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set

* Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps).

* Update rllib/evaluation/worker_set.py

Co-Authored-By: Eric Liang <[email protected]>

* Review fixes.

* Fix default value for DQN's exploration spec.

* LINT

* Fix recursion bug (wrong parent c'tor).

* Do not pass timestep to get_exploration_info.

* Update tf_policy.py

* Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs.

* Bug fix tf-action-dist

* DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG).

* Switch off exploration when getting action probs from off-policy-estimator's policy.

* LINT

* Fix test_checkpoint_restore.py.

* Deprecate all SAC exploration (unused) configs.

* Properly use `model.last_output()` everywhere. Instead of `model._last_output`.

* WIP.

* Take out set_epsilon from multi-agent-env test (not needed, decays anyway).

* WIP.

* Trigger re-test (flaky checkpoint-restore test).

* WIP.

* WIP.

* Add test case for deterministic action sampling in PPO.

* bug fix.

* Added deterministic test cases for different Agents.

* Fix problem with TupleActions in dynamic-tf-policy.

* Separate supported_spaces tests so they can be run separately for easier debugging.

* LINT.

* Fix autoregressive_action_dist.py test case.

* Re-test.

* Fix.

* Remove duplicate py_test rule from bazel.

* LINT.

* WIP.

* WIP.

* SAC fix.

* SAC fix.

* WIP.

* WIP.

* WIP.

* FIX 2 examples tests.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* Fix.

* LINT.

* Renamed test file.

* WIP.

* Add unittest.main.

* Make action_dist_class mandatory.

* fix

* FIX.

* WIP.

* WIP.

* Fix.

* Fix.

* Fix explorations test case (contextlib cannot find its own nullcontext??).

* Force torch to be installed for QMIX.

* LINT.

* Fix determine_tests_to_run.py.

* Fix determine_tests_to_run.py.

* WIP

* Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function).

* Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function).

* Rename some stuff.

* Rename some stuff.

* WIP.

* WIP.

* Fix SAC.

* Fix SAC.

* Fix strange tf-error in ray core tests.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix test_io.py.

* LINT.

* Update SAC yaml files' config.

Co-authored-by: Eric Liang <[email protected]>
  • Loading branch information
sven1977 and ericl authored Feb 22, 2020
1 parent 4c2de7b commit 0db2046
Show file tree
Hide file tree
Showing 35 changed files with 767 additions and 231 deletions.
2 changes: 1 addition & 1 deletion ci/travis/install-dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ fi
if [[ "$RLLIB_TESTING" == "1" ]]; then
pip install -q tensorflow-probability==$tfp_version gast==0.2.2 \
torch==$torch_version torchvision \
gym[atari] atari_py smart_open
gym[atari] atari_py smart_open lz4
fi

if [[ "$PYTHON" == "3.6" ]] || [[ "$MAC_WHEELS" == "1" ]]; then
Expand Down
1 change: 0 additions & 1 deletion python/ray/tune/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import yaml
import numbers

import numpy as np

import ray.cloudpickle as cloudpickle
Expand Down
25 changes: 15 additions & 10 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,6 @@ py_test(
]
)


# --------------------------------------------------------------------
# Models and Distributions
# rllib/models/
Expand All @@ -811,6 +810,20 @@ py_test(
srcs = ["models/tests/test_distributions.py"]
)

# --------------------------------------------------------------------
# Policies
# rllib/policy/
#
# Tag: policy
# --------------------------------------------------------------------

py_test(
name = "policy/tests/test_compute_log_likelihoods",
tags = ["policy"],
size = "small",
srcs = ["policy/tests/test_compute_log_likelihoods.py"]
)

# --------------------------------------------------------------------
# Utils:
# rllib/utils/
Expand Down Expand Up @@ -880,14 +893,6 @@ py_test(
srcs = ["tests/test_dependency.py"]
)

# PR 7086
#py_test(
# name = "tests/test_deterministic_support",
# tags = ["tests_dir", "tests_dir_D"],
# size = "small",
# srcs = ["tests/test_deterministic_support.py"]
#)

py_test(
name = "tests/test_eager_support",
tags = ["tests_dir", "tests_dir_E"],
Expand All @@ -912,7 +917,7 @@ py_test(

py_test(
name = "tests/test_explorations",
tags = ["tests_dir", "tests_dir_E"],
tags = ["tests_dir", "tests_dir_E", "explorations"],
size = "medium",
srcs = ["tests/test_explorations.py"]
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ddpg/ddpg_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __init__(self, observation_space, action_space, config):
self.config,
self.sess,
obs_input=self.cur_observations,
action_sampler=self.output_actions,
sampled_action=self.output_actions,
loss=self.actor_loss + self.critic_loss,
loss_inputs=self.loss_inputs,
update_ops=q_batchnorm_update_ops + policy_batchnorm_update_ops)
Expand Down
20 changes: 15 additions & 5 deletions rllib/agents/dqn/dqn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,20 +202,29 @@ def build_q_model(policy, obs_space, action_space, config):
return policy.q_model


def get_log_likelihood(policy, q_model, actions, input_dict, obs_space,
action_space, config):
# Action Q network.
q_vals = _compute_q_values(policy, q_model,
input_dict[SampleBatch.CUR_OBS], obs_space,
action_space)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
action_dist = Categorical(q_vals, q_model)
return action_dist.logp(actions)


def sample_action_from_q_network(policy, q_model, input_dict, obs_space,
action_space, explore, config, timestep):

# Action Q network.
q_vals = _compute_q_values(policy, q_model,
input_dict[SampleBatch.CUR_OBS], obs_space,
action_space)

policy.q_values = q_vals[0] if isinstance(q_vals, tuple) else q_vals
policy.q_func_vars = q_model.variables()

policy.output_actions, policy.action_logp = \
policy.output_actions, policy.sampled_action_logp = \
policy.exploration.get_exploration_action(
policy.q_values, q_model, Categorical, explore, timestep)
policy.q_values, Categorical, q_model, explore, timestep)

# Noise vars for Q network except for layer normalization vars.
if config["parameter_noise"]:
Expand All @@ -224,7 +233,7 @@ def sample_action_from_q_network(policy, q_model, input_dict, obs_space,
[var for var in policy.q_func_vars if "LayerNorm" not in var.name])
policy.action_probs = tf.nn.softmax(policy.q_values)

return policy.output_actions, policy.action_logp
return policy.output_actions, policy.sampled_action_logp


def _build_parameter_noise(policy, pnet_params):
Expand Down Expand Up @@ -448,6 +457,7 @@ def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None):
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
make_model=build_q_model,
action_sampler_fn=sample_action_from_q_network,
log_likelihood_fn=get_log_likelihood,
loss_fn=build_q_losses,
stats_fn=build_q_stats,
postprocess_fn=postprocess_nstep_and_prio,
Expand Down
24 changes: 16 additions & 8 deletions rllib/agents/dqn/simple_q_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,32 @@ def build_q_models(policy, obs_space, action_space, config):
return policy.q_model


def get_log_likelihood(policy, q_model, actions, input_dict, obs_space,
action_space, config):
# Action Q network.
q_vals = _compute_q_values(policy, q_model,
input_dict[SampleBatch.CUR_OBS], obs_space,
action_space)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
action_dist = Categorical(q_vals, q_model)
return action_dist.logp(actions)


def simple_sample_action_from_q_network(policy, q_model, input_dict, obs_space,
action_space, explore, config,
timestep):
# Action Q network.
q_vals = _compute_q_values(policy, q_model,
input_dict[SampleBatch.CUR_OBS], obs_space,
action_space)

policy.q_values = q_vals[0] if isinstance(q_vals, tuple) else q_vals
policy.q_func_vars = q_model.variables()

policy.output_actions, policy.action_logp = \
policy.output_actions, policy.sampled_action_logp = \
policy.exploration.get_exploration_action(
policy.q_values, q_model, Categorical, explore, timestep)
policy.q_values, Categorical, q_model, explore, timestep)

return policy.output_actions, policy.action_logp
return policy.output_actions, policy.sampled_action_logp


def build_q_losses(policy, model, dist_class, train_batch):
Expand Down Expand Up @@ -167,13 +177,11 @@ def setup_late_mixins(policy, obs_space, action_space, config):
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
make_model=build_q_models,
action_sampler_fn=simple_sample_action_from_q_network,
log_likelihood_fn=get_log_likelihood,
loss_fn=build_q_losses,
extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values},
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
before_init=setup_early_mixins,
after_init=setup_late_mixins,
obs_include_prev_action_reward=False,
mixins=[
ParameterNoiseMixin,
TargetNetworkMixin,
])
mixins=[ParameterNoiseMixin, TargetNetworkMixin])
76 changes: 45 additions & 31 deletions rllib/agents/qmix/qmix_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from gym.spaces import Tuple, Discrete, Dict
import logging
import numpy as np
import torch as th
import torch.nn as nn
from torch.optim import RMSprop
from torch.distributions import Categorical

Expand All @@ -16,9 +14,13 @@
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.model import _unpack_obs
from ray.rllib.env.constants import GROUP_REWARDS
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.tuple_actions import TupleActions

# Torch must be installed.
torch, nn = try_import_torch(error=True)

logger = logging.getLogger(__name__)

# if the obs space is Dict type, look for the global state under this key
Expand Down Expand Up @@ -85,7 +87,7 @@ def forward(self,
mac_out = _unroll_mac(self.model, obs)

# Pick the Q-Values for the actions taken -> [B * n_agents, T]
chosen_action_qvals = th.gather(
chosen_action_qvals = torch.gather(
mac_out, dim=3, index=actions.unsqueeze(3)).squeeze(3)

# Calculate the Q-Values necessary for the target
Expand Down Expand Up @@ -114,8 +116,8 @@ def forward(self,

# use the target network to estimate the Q-values of policy
# network's selected actions
target_max_qvals = th.gather(target_mac_out, 3,
cur_max_actions).squeeze(3)
target_max_qvals = torch.gather(target_mac_out, 3,
cur_max_actions).squeeze(3)
else:
target_max_qvals = target_mac_out.max(dim=3)[0]

Expand Down Expand Up @@ -167,8 +169,8 @@ def __init__(self, obs_space, action_space, config):
self.h_size = config["model"]["lstm_cell_size"]
self.has_env_global_state = False
self.has_action_mask = False
self.device = (th.device("cuda")
if th.cuda.is_available() else th.device("cpu"))
self.device = (torch.device("cuda")
if torch.cuda.is_available() else torch.device("cpu"))

agent_obs_space = obs_space.original_space.spaces[0]
if isinstance(agent_obs_space, Dict):
Expand Down Expand Up @@ -262,20 +264,21 @@ def compute_actions(self,
# to compute actions

# Compute actions
with th.no_grad():
with torch.no_grad():
q_values, hiddens = _mac(
self.model,
th.as_tensor(obs_batch, dtype=th.float, device=self.device), [
th.as_tensor(
np.array(s), dtype=th.float, device=self.device)
for s in state_batches
])
avail = th.as_tensor(
action_mask, dtype=th.float, device=self.device)
torch.as_tensor(
obs_batch, dtype=torch.float, device=self.device), [
torch.as_tensor(
np.array(s), dtype=torch.float, device=self.device)
for s in state_batches
])
avail = torch.as_tensor(
action_mask, dtype=torch.float, device=self.device)
masked_q_values = q_values.clone()
masked_q_values[avail == 0.0] = -float("inf")
# epsilon-greedy action selector
random_numbers = th.rand_like(q_values[:, :, 0])
random_numbers = torch.rand_like(q_values[:, :, 0])
pick_random = (random_numbers < (self.cur_epsilon
if explore else 0.0)).long()
random_actions = Categorical(avail).sample().long()
Expand All @@ -286,6 +289,16 @@ def compute_actions(self,

return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}

@override(Policy)
def compute_log_likelihoods(self,
actions,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None):
obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
return np.zeros(obs_batch.size()[0])

@override(Policy)
def learn_on_batch(self, samples):
obs_batch, action_mask, env_global_state = self._unpack_observation(
Expand Down Expand Up @@ -323,31 +336,32 @@ def learn_on_batch(self, samples):

def to_batches(arr, dtype):
new_shape = [B, T] + list(arr.shape[1:])
return th.as_tensor(
return torch.as_tensor(
np.reshape(arr, new_shape), dtype=dtype, device=self.device)

rewards = to_batches(rew, th.float)
actions = to_batches(act, th.long)
obs = to_batches(obs, th.float).reshape(
rewards = to_batches(rew, torch.float)
actions = to_batches(act, torch.long)
obs = to_batches(obs, torch.float).reshape(
[B, T, self.n_agents, self.obs_size])
action_mask = to_batches(action_mask, th.float)
next_obs = to_batches(next_obs, th.float).reshape(
action_mask = to_batches(action_mask, torch.float)
next_obs = to_batches(next_obs, torch.float).reshape(
[B, T, self.n_agents, self.obs_size])
next_action_mask = to_batches(next_action_mask, th.float)
next_action_mask = to_batches(next_action_mask, torch.float)
if self.has_env_global_state:
env_global_state = to_batches(env_global_state, th.float)
next_env_global_state = to_batches(next_env_global_state, th.float)
env_global_state = to_batches(env_global_state, torch.float)
next_env_global_state = to_batches(next_env_global_state,
torch.float)

# TODO(ekl) this treats group termination as individual termination
terminated = to_batches(dones, th.float).unsqueeze(2).expand(
terminated = to_batches(dones, torch.float).unsqueeze(2).expand(
B, T, self.n_agents)

# Create mask for where index is < unpadded sequence length
filled = np.reshape(
np.tile(np.arange(T, dtype=np.float32), B),
[B, T]) < np.expand_dims(seq_lens, 1)
mask = th.as_tensor(
filled, dtype=th.float, device=self.device).unsqueeze(2).expand(
mask = torch.as_tensor(
filled, dtype=torch.float, device=self.device).unsqueeze(2).expand(
B, T, self.n_agents)

# Compute loss
Expand All @@ -359,7 +373,7 @@ def to_batches(arr, dtype):
# Optimise
self.optimiser.zero_grad()
loss_out.backward()
grad_norm = th.nn.utils.clip_grad_norm_(
grad_norm = torch.nn.utils.clip_grad_norm_(
self.params, self.config["grad_norm_clipping"])
self.optimiser.step()

Expand Down Expand Up @@ -432,7 +446,7 @@ def _get_group_rewards(self, info_batch):

def _device_dict(self, state_dict):
return {
k: th.as_tensor(v, device=self.device)
k: torch.as_tensor(v, device=self.device)
for k, v in state_dict.items()
}

Expand Down Expand Up @@ -539,7 +553,7 @@ def _unroll_mac(model, obs_tensor):
for t in range(T):
q, h = _mac(model, obs_tensor[:, t], h)
mac_out.append(q)
mac_out = th.stack(mac_out, dim=1) # Concat over time
mac_out = torch.stack(mac_out, dim=1) # Concat over time

return mac_out

Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/sac/sac_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self,
shift_and_log_scale_diag = tf.keras.Sequential([
tf.keras.layers.Dense(
units=hidden,
activation=getattr(tf.nn, actor_hidden_activation),
activation=getattr(tf.nn, actor_hidden_activation, None),
name="action_hidden_{}".format(i))
for i, hidden in enumerate(actor_hiddens)
] + [
Expand Down
Loading

0 comments on commit 0db2046

Please sign in to comment.