Skip to content

Commit

Permalink
Back out "add async_run_episode to gymrunner to support envs with asy…
Browse files Browse the repository at this point in the history
…nc step methods" (facebookresearch#382)

Summary:
Pull Request resolved: facebookresearch#382

Original commit changeset: 25e0c9171ca0

Reviewed By: kittipatv

Differential Revision: D26143054

fbshipit-source-id: 52856f848957c1c0b3304be318bcdc31404e133e
  • Loading branch information
alexnikulkov authored and facebook-github-bot committed Jan 29, 2021
1 parent 037b5c0 commit ae031c5
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 159 deletions.
36 changes: 2 additions & 34 deletions reagent/gym/runners/gymrunner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import asyncio
import logging
import pickle
from typing import Optional, Sequence
Expand All @@ -22,33 +21,9 @@


def run_episode(
env: EnvWrapper,
agent: Agent,
mdp_id: int = 0,
max_steps: Optional[int] = None,
fill_info: bool = False,
) -> Trajectory:
return asyncio.run(
async_run_episode(
env=env,
agent=agent,
mdp_id=mdp_id,
max_steps=max_steps,
fill_info=fill_info,
)
)


async def async_run_episode(
env: EnvWrapper,
agent: Agent,
mdp_id: int = 0,
max_steps: Optional[int] = None,
fill_info: bool = False,
env: EnvWrapper, agent: Agent, mdp_id: int = 0, max_steps: Optional[int] = None
) -> Trajectory:
"""
NOTE: this funciton is an async coroutine in order to support async env.step(). If you are using
it with regular env.step() method, use non-async run_episode(), which wraps this function.
Return sum of rewards from episode.
After max_steps (if specified), the environment is assumed to be terminal.
Can also specify the mdp_id and gamma of episode.
Expand All @@ -58,15 +33,9 @@ async def async_run_episode(
possible_actions_mask = env.possible_actions_mask
terminal = False
num_steps = 0
step_is_coroutine = asyncio.iscoroutinefunction(env.step)
while not terminal:
action, log_prob = agent.act(obs, possible_actions_mask)
if step_is_coroutine:
next_obs, reward, terminal, info = await env.step(action)
else:
next_obs, reward, terminal, info = env.step(action)
if not fill_info:
info = None
next_obs, reward, terminal, _ = env.step(action)
next_possible_actions_mask = env.possible_actions_mask
if max_steps is not None and num_steps >= max_steps:
terminal = True
Expand All @@ -81,7 +50,6 @@ async def async_run_episode(
terminal=bool(terminal),
log_prob=log_prob,
possible_actions_mask=possible_actions_mask,
info=info,
)
agent.post_step(transition)
trajectory.add_transition(transition)
Expand Down
1 change: 0 additions & 1 deletion reagent/gym/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class Transition(rlt.BaseDataClass):
terminal: bool
log_prob: Optional[float] = None
possible_actions_mask: Optional[np.ndarray] = None
info: Optional[Dict] = None

# Same as asdict but filters out none values.
def asdict(self):
Expand Down
Loading

0 comments on commit ae031c5

Please sign in to comment.