Skip to content

Commit

Permalink
[rllib] bug fix: merging --config params with params.pkl (ray-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
AmeerHajAli authored and ericl committed Mar 13, 2019
1 parent 87bfa1c commit 8a6403c
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions python/ray/rllib/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import gym
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.tune.util import merge_dicts

EXAMPLE_USAGE = """
Example Usage via RLlib CLI:
Expand Down Expand Up @@ -69,22 +70,23 @@ def create_parser(parser_creator=None):


def run(args, parser):
config = args.config
if not config:
# Load configuration from file
config_dir = os.path.dirname(args.checkpoint)
config_path = os.path.join(config_dir, "params.pkl")
if not os.path.exists(config_path):
config_path = os.path.join(config_dir, "../params.pkl")
if not os.path.exists(config_path):
config = {}
# Load configuration from file
config_dir = os.path.dirname(args.checkpoint)
config_path = os.path.join(config_dir, "params.pkl")
if not os.path.exists(config_path):
config_path = os.path.join(config_dir, "../params.pkl")
if not os.path.exists(config_path):
if not args.config:
raise ValueError(
"Could not find params.pkl in either the checkpoint dir or "
"its parent directory.")
else:
with open(config_path, 'rb') as f:
config = pickle.load(f)
if "num_workers" in config:
config["num_workers"] = min(2, config["num_workers"])

if "num_workers" in config:
config["num_workers"] = min(2, config["num_workers"])
config = merge_dicts(config, args.config)
if not args.env:
if not config.get("env"):
parser.error("the following arguments are required: --env")
Expand Down

0 comments on commit 8a6403c

Please sign in to comment.