Skip to content

Commit

Permalink
[rllib] Fix centralized critic example to use right policy (ray-proje…
Browse files Browse the repository at this point in the history
…ct#8341)

* update

* update
  • Loading branch information
ericl authored May 7, 2020
1 parent 325aec8 commit 30db920
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions rllib/examples/centralized_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def loss_with_central_critic(policy, model, dist_class, train_batch):
train_batch[OPPONENT_ACTION])

policy.loss_obj = PPOLoss(
policy.action_space,
dist_class,
model,
train_batch[Postprocessing.VALUE_TARGETS],
Expand All @@ -159,8 +158,7 @@ def loss_with_central_critic(policy, model, dist_class, train_batch):
clip_param=policy.config["clip_param"],
vf_clip_param=policy.config["vf_clip_param"],
vf_loss_coeff=policy.config["vf_loss_coeff"],
use_gae=policy.config["use_gae"],
model_config=policy.config["model"])
use_gae=policy.config["use_gae"])

return policy.loss_obj.loss

Expand Down Expand Up @@ -193,7 +191,8 @@ def central_vf_stats(policy, train_batch, grads):
CentralizedValueMixin
])

CCTrainer = PPOTrainer.with_updates(name="CCPPOTrainer", default_policy=CCPPO)
CCTrainer = PPOTrainer.with_updates(
name="CCPPOTrainer", default_policy=CCPPO, get_policy_class=None)

if __name__ == "__main__":
args = parser.parse_args()
Expand Down

0 comments on commit 30db920

Please sign in to comment.