Skip to content

Commit

Permalink
[RLlib] Fix broken tune tests in master due to framework=auto errors. (
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored May 29, 2020
1 parent c64b694 commit d483ed2
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 7 deletions.
4 changes: 4 additions & 0 deletions python/ray/tests/test_memory_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def testTuneDriverHeapLimit(self):
config={
"env": "CartPole-v0",
"memory": 100 * 1024 * 1024, # too little
"framework": "tf",
},
raise_on_failed_trial=False)
self.assertEqual(result.trials[0].status, "ERROR")
Expand All @@ -99,6 +100,7 @@ def testTuneDriverStoreLimit(self):
"env": "CartPole-v0",
# too large
"object_store_memory": 10000 * 1024 * 1024,
"framework": "tf",
}))
finally:
ray.shutdown()
Expand All @@ -113,6 +115,7 @@ def testTuneWorkerHeapLimit(self):
"env": "CartPole-v0",
"num_workers": 1,
"memory_per_worker": 100 * 1024 * 1024, # too little
"framework": "tf",
},
raise_on_failed_trial=False)
self.assertEqual(result.trials[0].status, "ERROR")
Expand All @@ -134,6 +137,7 @@ def testTuneWorkerStoreLimit(self):
"num_workers": 1,
# too large
"object_store_memory_per_worker": 10000 * 1024 * 1024,
"framework": "tf",
}))
finally:
ray.shutdown()
Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def test_cluster_rllib_restore(start_connected_cluster, tmpdir):
tune.run(
"PG",
name="experiment",
config=dict(env="CartPole-v1"),
config=dict(env="CartPole-v1", framework="tf"),
stop=dict(training_iteration=10),
local_dir="{checkpoint_dir}",
checkpoint_freq=1,
Expand Down Expand Up @@ -593,7 +593,7 @@ def test_cluster_rllib_restore(start_connected_cluster, tmpdir):
"experiment": {
"run": "PG",
"checkpoint_freq": 1,
"local_dir": dirpath
"local_dir": dirpath,
}
},
resume=True)
Expand Down
3 changes: 3 additions & 0 deletions python/ray/tune/tests/test_tune_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def setUp(self):
local_dir=tmpdir,
config={
"env": "CartPole-v0",
"framework": "tf",
},
)

Expand All @@ -58,6 +59,7 @@ def testTuneRestore(self):
restore=self.checkpoint_path, # Restore the checkpoint
config={
"env": "CartPole-v0",
"framework": "tf",
},
)

Expand All @@ -73,6 +75,7 @@ def testPostRestoreCheckpointExistence(self):
restore=self.checkpoint_path,
config={
"env": "CartPole-v0",
"framework": "tf",
},
)
self.assertTrue(os.path.isfile(self.checkpoint_path))
Expand Down
9 changes: 6 additions & 3 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,8 @@ py_test(

py_test(
name = "test_impala_cartpole_v0_buffers_2_lstm",
main = "train.py", srcs = ["train.py"],
main = "train.py",
srcs = ["train.py"],
tags = ["quick_train"],
args = [
"--env", "CartPole-v0",
Expand All @@ -730,12 +731,14 @@ py_test(

py_test(
name = "test_impala_pong_deterministic_v4_40k_ts_1G_obj_store",
main = "train.py", srcs = ["train.py"],
main = "train.py",
srcs = ["train.py"],
tags = ["quick_train"],
size = "medium",
args = [
"--env", "PongDeterministic-v4",
"--run", "IMPALA",
"--stop", "'{\"timesteps_total\": 40000}'",
"--stop", "'{\"timesteps_total\": 30000}'",
"--ray-object-store-memory=1000000000",
"--config", "'{\"framework\": \"tf\", \"num_workers\": 1, \"num_gpus\": 0, \"num_envs_per_worker\": 32, \"rollout_fragment_length\": 50, \"train_batch_size\": 50, \"learner_queue_size\": 1}'"
]
Expand Down
1 change: 1 addition & 0 deletions rllib/agents/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class _MockTrainer(Trainer):
"user_checkpoint_freq": 0,
"object_store_memory_per_worker": 0,
"object_store_memory": 0,
"framework": "tf",
})

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions rllib/tests/test_supported_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_ars(self):
check_support(
"ARS", {
"num_workers": 1,
"noise_size": 100000,
"noise_size": 1500000,
"num_rollouts": 1,
"rollouts_used": 1
})
Expand All @@ -147,7 +147,7 @@ def test_es(self):
check_support(
"ES", {
"num_workers": 1,
"noise_size": 100000,
"noise_size": 1500000,
"episodes_per_batch": 1,
"train_batch_size": 1
})
Expand Down

0 comments on commit d483ed2

Please sign in to comment.