Skip to content

Commit

Permalink
[RLlib] Better PolicyServer example (w/ or w/o tune) and add printing…
Browse files Browse the repository at this point in the history
… out actual listen port address in log-level=INFO. (ray-project#18254)
  • Loading branch information
sven1977 authored Aug 31, 2021
1 parent a3123b6 commit 82465f9
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 71 deletions.
14 changes: 3 additions & 11 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_tf, TensorStructType
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.multi_agent import check_multi_agent
from ray.rllib.utils.spaces import space_utils
from ray.rllib.utils.typing import AgentID, EnvInfoDict, EnvType, EpisodeID, \
PartialTrainerConfigDict, PolicyID, ResultDict, TrainerConfigDict
Expand Down Expand Up @@ -1391,16 +1392,7 @@ def _validate_config(config: PartialTrainerConfigDict,
deprecation_warning(old="simple_optimizer", error=False)

# Loop through all policy definitions in multi-agent policies.
multiagent_config = config["multiagent"]
policies = multiagent_config.get("policies")
if not policies:
policies = {DEFAULT_POLICY_ID}
if isinstance(policies, set):
policies = multiagent_config["policies"] = {
pid: PolicySpec()
for pid in policies
}
is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies
policies, is_multi_agent = check_multi_agent(config)

for pid, policy_spec in policies.copy().items():
# Policy IDs must be strings.
Expand Down Expand Up @@ -1448,7 +1440,7 @@ def _validate_config(config: PartialTrainerConfigDict,
config["simple_optimizer"] = True
# Multi-agent case: Try using MultiGPU optimizer (only
# if all policies used are DynamicTFPolicies or TorchPolicies).
elif is_multiagent:
elif is_multi_agent:
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.policy.torch_policy import TorchPolicy
default_policy_cls = None if trainer_obj_or_none is None else \
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
while len(results[0]) == 0:
self.external_env._results_avail_condition.wait()
results = self._poll()
if not self.external_env.isAlive():
if not self.external_env.is_alive():
raise Exception("Serving thread has stopped.")
limit = self.external_env._max_concurrent_episodes
assert len(results[0]) < limit, \
Expand Down
3 changes: 2 additions & 1 deletion rllib/env/policy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.multi_agent import check_multi_agent
from ray.rllib.utils.typing import MultiAgentDict, EnvInfoDict, EnvObsType, \
EnvActionType

Expand Down Expand Up @@ -354,7 +355,7 @@ def _create_embedded_rollout_worker(kwargs, send_fn):
"action_space": kwargs["policy_config"]["action_space"],
"observation_space": kwargs["policy_config"]["observation_space"],
}
is_ma = kwargs["policy_config"]["multiagent"].get("policies")
_, is_ma = check_multi_agent(kwargs["policy_config"])
kwargs["env_creator"] = _auto_wrap_external(
lambda _: (RandomMultiAgentEnv if is_ma else RandomEnv)(config))
kwargs["policy_config"]["env"] = True
Expand Down
3 changes: 2 additions & 1 deletion rllib/env/policy_server_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def get_metrics():
self.metrics_queue)
HTTPServer.__init__(self, (address, port), handler)

logger.info("Starting connector server at {}:{}".format(address, port))
logger.info("Starting connector server at {}:{}".format(
self.server_name, self.server_port))

# Start the serving thread, listening on socket and handling commands.
serving_thread = threading.Thread(
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/attention_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def get_cli_args():
"episode_reward_mean": args.stop_reward,
}

# training loop
# Manual training loop (no Ray tune).
if args.no_tune:
# manual training loop using PPO and manually keeping track of state
if args.run != "PPO":
Expand Down
195 changes: 139 additions & 56 deletions rllib/examples/serving/cartpole_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import os

import ray
from ray import tune
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.env.policy_server_input import PolicyServerInput
Expand All @@ -41,32 +42,84 @@

CHECKPOINT_FILE = "last_checkpoint_{}.out"

parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, choices=["DQN", "PPO"], default="DQN")
parser.add_argument(
"--framework",
choices=["tf", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument(
"--no-restore",
action="store_true",
help="Do not restore from a previously saved checkpoint (location of "
"which is saved in `last_checkpoint_[algo-name].out`).")
parser.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of workers to use. Each worker will create "
"its own listening socket for incoming experiences.")
parser.add_argument(
"--chatty-callbacks",
action="store_true",
help="Activates info-messages for different events on "
"server/client (episode steps, postprocessing, etc..).")

if __name__ == "__main__":
def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()

# Example-specific args.
parser.add_argument(
"--port",
type=int,
default=SERVER_BASE_PORT,
help="The base-port to use (on localhost). "
f"Default is {SERVER_BASE_PORT}.")
parser.add_argument(
"--callbacks-verbose",
action="store_true",
help="Activates info-messages for different events on "
"server/client (episode steps, postprocessing, etc..).")
parser.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of workers to use. Each worker will create "
"its own listening socket for incoming experiences.")
parser.add_argument(
"--no-restore",
action="store_true",
help="Do not restore from a previously saved checkpoint (location of "
"which is saved in `last_checkpoint_[algo-name].out`).")

# General args.
parser.add_argument(
"--run",
default="PPO",
choices=["DQN", "PPO"],
help="The RLlib-registered algorithm to use.")
parser.add_argument("--num-cpus", type=int, default=3)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument(
"--stop-iters",
type=int,
default=200,
help="Number of iterations to train.")
parser.add_argument(
"--stop-timesteps",
type=int,
default=500000,
help="Number of timesteps to train.")
parser.add_argument(
"--stop-reward",
type=float,
default=80.0,
help="Reward at which we stop training.")
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a test: --stop-reward must "
"be achieved within --stop-timesteps AND --stop-iters.")
parser.add_argument(
"--no-tune",
action="store_true",
help="Run without Tune using a manual train loop instead. Here,"
"there is no TensorBoard support.")
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.")

args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args


if __name__ == "__main__":
args = get_cli_args()
ray.init()

# `InputReader` generator (returns None if no input reader is needed on
Expand All @@ -76,7 +129,7 @@ def _input(ioctx):
# Create a PolicyServerInput.
if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
return PolicyServerInput(
ioctx, SERVER_ADDRESS, SERVER_BASE_PORT + ioctx.worker_index -
ioctx, SERVER_ADDRESS, args.port + ioctx.worker_index -
(1 if ioctx.worker_index > 0 else 0))
# No InputReader (PolicyServerInput) needed.
else:
Expand All @@ -102,47 +155,77 @@ def _input(ioctx):
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
# Create a "chatty" client/server or not.
"callbacks": MyCallbacks if args.chatty_callbacks else None,
"callbacks": MyCallbacks if args.callbacks_verbose else None,
# DL framework to use.
"framework": args.framework,
# Set to INFO so we'll see the server's actual address:port.
"log_level": "INFO",
}

# DQN.
if args.run == "DQN":
# Example of using DQN (supports off-policy actions).
trainer = DQNTrainer(
config=dict(
config, **{
"learning_starts": 100,
"timesteps_per_iteration": 200,
"model": {
"fcnet_hiddens": [64],
"fcnet_activation": "linear",
},
"n_step": 3,
"framework": args.framework,
}))
config.update({
"learning_starts": 100,
"timesteps_per_iteration": 200,
"n_step": 3,
})
config["model"] = {
"fcnet_hiddens": [64],
"fcnet_activation": "linear",
}

# PPO.
else:
# Example of using PPO (does NOT support off-policy actions).
trainer = PPOTrainer(
config=dict(
config, **{
"rollout_fragment_length": 1000,
"train_batch_size": 4000,
"framework": args.framework,
}))
config.update({
"rollout_fragment_length": 1000,
"train_batch_size": 4000,
})

checkpoint_path = CHECKPOINT_FILE.format(args.run)

# Attempt to restore from checkpoint, if possible.
if not args.no_restore and os.path.exists(checkpoint_path):
checkpoint_path = open(checkpoint_path).read()
print("Restoring from checkpoint path", checkpoint_path)
trainer.restore(checkpoint_path)

# Serving and training loop.
while True:
print(pretty_print(trainer.train()))
checkpoint = trainer.save()
print("Last checkpoint", checkpoint)
with open(checkpoint_path, "w") as f:
f.write(checkpoint)
else:
checkpoint_path = None

# Manual training loop (no Ray tune).
if args.no_tune:
if args.run == "DQN":
trainer = DQNTrainer(config=config)
else:
trainer = PPOTrainer(config=config)

if checkpoint_path:
print("Restoring from checkpoint path", checkpoint_path)
trainer.restore(checkpoint_path)

# Serving and training loop.
ts = 0
for _ in range(args.stop_iters):
results = trainer.train()
print(pretty_print(results))
checkpoint = trainer.save()
print("Last checkpoint", checkpoint)
with open(checkpoint_path, "w") as f:
f.write(checkpoint)
if results["episode_reward_mean"] >= args.stop_reward or \
ts >= args.stop_timesteps:
break
ts += results["timesteps_total"]

# Run with Tune for auto env and trainer creation and TensorBoard.
else:
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}

tune.run(
args.run,
config=config,
stop=stop,
verbose=2,
restore=checkpoint_path)
28 changes: 28 additions & 0 deletions rllib/utils/multi_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.typing import PartialTrainerConfigDict


def check_multi_agent(config: PartialTrainerConfigDict):
"""Checks, whether a (partial) config defines a multi-agent setup.
Args:
config (PartialTrainerConfigDict): The user/Trainer/Policy config
to check for multi-agent.
Returns:
Tuple[MultiAgentPolicyConfigDict, bool]: The resulting (all
fixed) multi-agent policy dict and whether we have a
multi-agent setup or not.
"""
multiagent_config = config["multiagent"]
policies = multiagent_config.get("policies")
if not policies:
policies = {DEFAULT_POLICY_ID}
if isinstance(policies, set):
policies = multiagent_config["policies"] = {
pid: PolicySpec()
for pid in policies
}
is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies
return policies, is_multiagent

0 comments on commit 82465f9

Please sign in to comment.