Skip to content

Commit

Permalink
[RLlib] Custom view requirements (e.g. for prev-n-obs) work with `com…
Browse files Browse the repository at this point in the history
…pute_single_action` and `compute_actions_from_input_dict`. (ray-project#18921)
  • Loading branch information
sven1977 authored Sep 30, 2021
1 parent 6dc1a6b commit 828f5d2
Show file tree
Hide file tree
Showing 7 changed files with 410 additions and 203 deletions.
10 changes: 8 additions & 2 deletions rllib/agents/ars/tests/test_ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@


class TestARS(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(num_cpus=3)

@classmethod
def tearDownClass(cls):
ray.shutdown()

def test_ars_compilation(self):
"""Test whether an ARSTrainer can be built on all frameworks."""
ray.init(num_cpus=3)
config = ars.DEFAULT_CONFIG.copy()
# Keep it simple.
config["model"]["fcnet_hiddens"] = [10]
Expand All @@ -30,7 +37,6 @@ def test_ars_compilation(self):

check_compute_single_action(trainer)
trainer.stop()
ray.shutdown()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/dqn/simple_q_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def compute_q_values(policy: Policy,
explore,
is_training=None) -> TensorType:
model_out, _ = model({
SampleBatch.CUR_OBS: obs,
SampleBatch.OBS: obs,
"is_training": is_training
if is_training is not None else policy._get_is_training_placeholder(),
}, [], None)
Expand Down
Loading

0 comments on commit 828f5d2

Please sign in to comment.