Skip to content

Commit

Permalink
Added Done to MultiAgentExternalEnv. (ray-project#8478)
Browse files Browse the repository at this point in the history
Co-authored-by: devanderhoff <[email protected]>
  • Loading branch information
devanderhoff and devanderhoff authored May 17, 2020
1 parent 87cbf2a commit be1f158
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 13 deletions.
8 changes: 6 additions & 2 deletions rllib/env/external_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def log_returns(self, episode_id, reward, info=None):

episode = self._get(episode_id)
episode.cur_reward += reward

if info:
episode.cur_info = info or {}

Expand Down Expand Up @@ -238,6 +239,9 @@ def done(self, observation):

def _send(self):
if self.multiagent:
if not self.training_enabled:
for agent_id in self.cur_info_dict:
self.cur_info_dict[agent_id]["training_enabled"] = False
item = {
"obs": self.new_observation_dict,
"reward": self.cur_reward_dict,
Expand All @@ -261,8 +265,8 @@ def _send(self):
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
if not self.training_enabled:
item["info"]["training_enabled"] = False
if not self.training_enabled:
item["info"]["training_enabled"] = False
with self.results_avail_condition:
self.data_queue.put_nowait(item)
self.results_avail_condition.notify()
17 changes: 15 additions & 2 deletions rllib/env/external_multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def log_action(self, episode_id, observation_dict, action_dict):

@PublicAPI
@override(ExternalEnv)
def log_returns(self, episode_id, reward_dict, info_dict=None):
def log_returns(self,
episode_id,
reward_dict,
info_dict=None,
multiagent_done_dict=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
Expand All @@ -115,7 +119,8 @@ def log_returns(self, episode_id, reward_dict, info_dict=None):
Arguments:
episode_id (str): Episode id returned from start_episode().
reward_dict (dict): Reward from the environment agents.
info (dict): Optional info dict.
info_dict (dict): Optional info dict.
multiagent_done_dict (dict): Optional done dict for agents.
"""

episode = self._get(episode_id)
Expand All @@ -127,6 +132,14 @@ def log_returns(self, episode_id, reward_dict, info_dict=None):
episode.cur_reward_dict[agent] += rew
else:
episode.cur_reward_dict[agent] = rew

if multiagent_done_dict:
for agent, done in multiagent_done_dict.items():
if agent in episode.cur_done_dict:
episode.cur_done_dict[agent] = done
else:
episode.cur_done_dict[agent] = done

if info_dict:
episode.cur_info_dict = info_dict or {}

Expand Down
34 changes: 27 additions & 7 deletions rllib/env/policy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def __init__(self, address, inference_mode="local", update_interval=10.0):
address (str): Server to connect to (e.g., "localhost:9090").
inference_mode (str): Whether to use 'local' or 'remote' policy
inference for computing actions.
update_interval (float): If using 'local' inference mode, the
policy is refreshed after this many seconds have passed.
update_interval (float or None): If using 'local' inference mode,
the policy is refreshed after this many seconds have passed,
or None for manual control via client.
"""
self.address = address
if inference_mode == "local":
Expand Down Expand Up @@ -130,7 +131,11 @@ def log_action(self, episode_id, observation, action):
})

@PublicAPI
def log_returns(self, episode_id, reward, info=None):
def log_returns(self,
episode_id,
reward,
info=None,
multiagent_done_dict=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
Expand All @@ -140,17 +145,24 @@ def log_returns(self, episode_id, reward, info=None):
Arguments:
episode_id (str): Episode id returned from start_episode().
reward (float): Reward from the environment.
info (dict): Extra info dict.
multiagent_done_dict (dict): Multi-agent done information.
"""

if self.local:
self._update_local_policy()
return self.env.log_returns(episode_id, reward, info)
if multiagent_done_dict:
return self.env.log_returns(episode_id, reward, info,
multiagent_done_dict)
else:
return self.env.log_returns(episode_id, reward, info)

self._send({
"command": PolicyClient.LOG_RETURNS,
"reward": reward,
"info": info,
"episode_id": episode_id,
"done": multiagent_done_dict,
})

@PublicAPI
Expand All @@ -172,6 +184,12 @@ def end_episode(self, episode_id, observation):
"episode_id": episode_id,
})

@PublicAPI
def update_policy_weights(self):
"""Query the server for new policy weights, if local inference is enabled.
"""
self._update_local_policy(force=True)

def _send(self, data):
payload = pickle.dumps(data)
response = requests.post(self.address, data=payload)
Expand All @@ -195,9 +213,10 @@ def _setup_local_rollout_worker(self, update_interval):
kwargs, self._send)
self.env = self.rollout_worker.env

def _update_local_policy(self):
def _update_local_policy(self, force=False):
assert self.inference_thread.is_alive()
if time.time() - self.last_updated > self.update_interval:
if (self.update_interval and time.time() - self.last_updated >
self.update_interval) or force:
logger.info("Querying server for new policy weights.")
resp = self._send({
"command": PolicyClient.GET_WEIGHTS,
Expand Down Expand Up @@ -253,7 +272,7 @@ def wrapped_creator(env_config):
"Attempting to convert it automatically to ExternalEnv.")

if isinstance(real_env, MultiAgentEnv):
external_cls = MultiAgentEnv
external_cls = ExternalMultiAgentEnv
else:
external_cls = ExternalEnv

Expand All @@ -268,6 +287,7 @@ def run(self):
time.sleep(999999)

return ExternalEnvWrapper(real_env)
return real_env

return wrapped_creator

Expand Down
9 changes: 7 additions & 2 deletions rllib/env/policy_server_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,13 @@ def execute_command(self, args):
args["episode_id"], args["observation"], args["action"])
elif command == PolicyClient.LOG_RETURNS:
assert inference_thread.is_alive()
child_rollout_worker.env.log_returns(
args["episode_id"], args["reward"], args["info"])
if args["done"]:
child_rollout_worker.env.log_returns(
args["episode_id"], args["reward"], args["info"],
args["done"])
else:
child_rollout_worker.env.log_returns(
args["episode_id"], args["reward"], args["info"])
elif command == PolicyClient.END_EPISODE:
assert inference_thread.is_alive()
child_rollout_worker.env.end_episode(args["episode_id"],
Expand Down

0 comments on commit be1f158

Please sign in to comment.