Skip to content

Commit

Permalink
[rllib] Add back get_policy_output method for SAC model (ray-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Mar 20, 2020
1 parent 9392cdb commit 7ebc678
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 12 deletions.
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ py_test(
name = "test_rollout",
main = "tests/test_rollout.py",
tags = ["tests_dir", "tests_dir_R"],
size = "large",
size = "enormous",
data = ["train.py", "rollout.py"],
srcs = ["tests/test_rollout.py"]
)
Expand Down
15 changes: 15 additions & 0 deletions rllib/agents/sac/sac_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,21 @@ def get_twin_q_values(self, model_out, actions=None):
else:
return self.twin_q_net(model_out)

def get_policy_output(self, model_out):
"""Return the action output for the most recent forward pass.
This outputs the support for pi(s). For continuous action spaces, this
is the action directly. For discrete, is is the mean / std dev.
Arguments:
model_out (Tensor): obs embeddings from the model layers, of shape
[BATCH_SIZE, num_outputs].
Returns:
tensor of shape [BATCH_SIZE, action_out_size]
"""
return self.action_model(model_out)

def policy_variables(self):
"""Return the list of variables for the policy net."""

Expand Down
22 changes: 11 additions & 11 deletions rllib/agents/sac/sac_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.tf_action_dist import (
Categorical, SquashedGaussian, DiagGaussian)
from ray.rllib.models.tf.tf_action_dist import (Categorical, SquashedGaussian,
DiagGaussian)
from ray.rllib.utils import try_import_tf, try_import_tfp
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
Expand Down Expand Up @@ -95,9 +95,8 @@ def get_dist_class(config, action_space):
if isinstance(action_space, Discrete):
action_dist_class = Categorical
else:
action_dist_class = (
SquashedGaussian if config["normalize_actions"]
else DiagGaussian)
action_dist_class = (SquashedGaussian
if config["normalize_actions"] else DiagGaussian)
return action_dist_class


Expand All @@ -107,7 +106,7 @@ def get_log_likelihood(policy, model, actions, input_dict, obs_space,
"obs": input_dict[SampleBatch.CUR_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
distribution_inputs = model.action_model(model_out)
distribution_inputs = model.get_policy_output(model_out)
action_dist_class = get_dist_class(policy.config, action_space)
return action_dist_class(distribution_inputs, model).logp(actions)

Expand All @@ -118,7 +117,7 @@ def build_action_output(policy, model, input_dict, obs_space, action_space,
"obs": input_dict[SampleBatch.CUR_OBS],
"is_training": policy._get_is_training_placeholder(),
}, [], None)
distribution_inputs = model.action_model(model_out)
distribution_inputs = model.get_policy_output(model_out)
action_dist_class = get_dist_class(policy.config, action_space)

policy.output_actions, policy.sampled_action_logp = \
Expand Down Expand Up @@ -147,9 +146,10 @@ def actor_critic_loss(policy, model, _, train_batch):
# Discrete case.
if model.discrete:
# Get all action probs directly from pi and form their logp.
log_pis_t = tf.nn.log_softmax(model.action_model(model_out_t), -1)
log_pis_t = tf.nn.log_softmax(model.get_policy_output(model_out_t), -1)
policy_t = tf.exp(log_pis_t)
log_pis_tp1 = tf.nn.log_softmax(model.action_model(model_out_tp1), -1)
log_pis_tp1 = tf.nn.log_softmax(
model.get_policy_output(model_out_tp1), -1)
policy_tp1 = tf.exp(log_pis_tp1)
# Q-values.
q_t = model.get_q_values(model_out_t)
Expand Down Expand Up @@ -178,11 +178,11 @@ def actor_critic_loss(policy, model, _, train_batch):
# Sample simgle actions from distribution.
action_dist_class = get_dist_class(policy.config, policy.action_space)
action_dist_t = action_dist_class(
model.action_model(model_out_t), policy.model)
model.get_policy_output(model_out_t), policy.model)
policy_t = action_dist_t.sample()
log_pis_t = tf.expand_dims(action_dist_t.sampled_action_logp(), -1)
action_dist_tp1 = action_dist_class(
model.action_model(model_out_tp1), policy.model)
model.get_policy_output(model_out_tp1), policy.model)
policy_tp1 = action_dist_tp1.sample()
log_pis_tp1 = tf.expand_dims(action_dist_tp1.sampled_action_logp(), -1)

Expand Down

0 comments on commit 7ebc678

Please sign in to comment.