Skip to content

Commit

Permalink
add rendering and update README
Browse files Browse the repository at this point in the history
  • Loading branch information
vitchyr committed Apr 16, 2019
1 parent 99e080f commit 90cbd9c
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 20 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ To get started, checkout the example scripts, linked above.
## What's New
### Version 0.2

#### 04/05-15/2019
- Add rendering
- Fix SAC bug to account for future entropy (#41, #43)
- Add online algorithm mode (#42)

#### 04/05/2019

The initial release for 0.2 has the following major changes:
Expand Down
14 changes: 14 additions & 0 deletions rlkit/samplers/data_collector/path_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@ def __init__(
env,
policy,
max_num_epoch_paths_saved=None,
render=False,
render_kwargs=None,
):
if render_kwargs is None:
render_kwargs = {}
self._env = env
self._policy = policy
self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
self._render = render
self._render_kwargs = render_kwargs

self._num_steps_total = 0
self._num_paths_total = 0
Expand Down Expand Up @@ -85,12 +91,18 @@ def __init__(
env,
policy,
max_num_epoch_paths_saved=None,
render=False,
render_kwargs=None,
observation_key='observation',
desired_goal_key='desired_goal',
):
if render_kwargs is None:
render_kwargs = {}
self._env = env
self._policy = policy
self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
self._render = render
self._render_kwargs = render_kwargs
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
self._observation_key = observation_key
self._desired_goal_key = desired_goal_key
Expand All @@ -115,6 +127,8 @@ def collect_new_paths(
self._env,
self._policy,
max_path_length=max_path_length_this_loop,
render=self._render,
render_kwargs=self._render_kwargs,
observation_key=self._observation_key,
desired_goal_key=self._desired_goal_key,
return_dict_obs=True,
Expand Down
16 changes: 16 additions & 0 deletions rlkit/samplers/data_collector/step_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@ def __init__(
env,
policy,
max_num_epoch_paths_saved=None,
render=False,
render_kwargs=None,
):
if render_kwargs is None:
render_kwargs = {}
self._env = env
self._policy = policy
self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
self._render = render
self._render_kwargs = render_kwargs

self._num_steps_total = 0
self._num_paths_total = 0
Expand Down Expand Up @@ -70,6 +76,8 @@ def collect_one_step(
next_ob, reward, terminal, env_info = (
self._env.step(action)
)
if self._render:
self._env.render(**self._render_kwargs)
terminal = np.array([terminal])
reward = np.array([reward])
# store path obs
Expand Down Expand Up @@ -118,13 +126,19 @@ def __init__(
env,
policy,
max_num_epoch_paths_saved=None,
render=False,
render_kwargs=None,
observation_key='observation',
desired_goal_key='desired_goal',
):
if render_kwargs is None:
render_kwargs = {}
self._env = env
self._policy = policy
self._max_num_epoch_paths_saved = max_num_epoch_paths_saved
self._epoch_paths = deque(maxlen=self._max_num_epoch_paths_saved)
self._render = render
self._render_kwargs = render_kwargs
self._observation_key = observation_key
self._desired_goal_key = desired_goal_key

Expand Down Expand Up @@ -192,6 +206,8 @@ def collect_one_step(
next_ob, reward, terminal, env_info = (
self._env.step(action)
)
if self._render:
self._env.render(**self._render_kwargs)
terminal = np.array([terminal])
reward = np.array([reward])
# store path obs
Expand Down
31 changes: 21 additions & 10 deletions rlkit/samplers/rollout_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ def multitask_rollout(
env,
agent,
max_path_length=np.inf,
animated=False,
render=False,
render_kwargs=None,
observation_key=None,
desired_goal_key=None,
get_action_kwargs=None,
return_dict_obs=False,
):
if render_kwargs is None:
render_kwargs = {}
if get_action_kwargs is None:
get_action_kwargs = {}
dict_obs = []
Expand All @@ -25,8 +28,8 @@ def multitask_rollout(
path_length = 0
agent.reset()
o = env.reset()
if animated:
env.render()
if render:
env.render(**render_kwargs)
goal = o[desired_goal_key]
while path_length < max_path_length:
dict_obs.append(o)
Expand All @@ -35,8 +38,8 @@ def multitask_rollout(
new_obs = np.hstack((o, goal))
a, agent_info = agent.get_action(new_obs, **get_action_kwargs)
next_o, r, d, env_info = env.step(a)
if animated:
env.render()
if render:
env.render(**render_kwargs)
observations.append(o)
rewards.append(r)
terminals.append(d)
Expand Down Expand Up @@ -70,7 +73,13 @@ def multitask_rollout(
)


def rollout(env, agent, max_path_length=np.inf, animated=False):
def rollout(
env,
agent,
max_path_length=np.inf,
render=False,
render_kwargs=None,
):
"""
The following value for the following keys will be a 2D array, with the
first dimension corresponding to the time dimension.
Expand All @@ -85,6 +94,8 @@ def rollout(env, agent, max_path_length=np.inf, animated=False):
- agent_infos
- env_infos
"""
if render_kwargs is None:
render_kwargs = {}
observations = []
actions = []
rewards = []
Expand All @@ -95,8 +106,8 @@ def rollout(env, agent, max_path_length=np.inf, animated=False):
agent.reset()
next_o = None
path_length = 0
if animated:
env.render()
if render:
env.render(**render_kwargs)
while path_length < max_path_length:
a, agent_info = agent.get_action(o)
next_o, r, d, env_info = env.step(a)
Expand All @@ -110,8 +121,8 @@ def rollout(env, agent, max_path_length=np.inf, animated=False):
if d:
break
o = next_o
if animated:
env.render()
if render:
env.render(**render_kwargs)

actions = np.array(actions)
if len(actions.shape) == 1:
Expand Down
8 changes: 4 additions & 4 deletions rlkit/samplers/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np


def rollout(env, agent, max_path_length=np.inf, animated=False):
def rollout(env, agent, max_path_length=np.inf, render=False):
"""
The following value for the following keys will be a 2D array, with the
first dimension corresponding to the time dimension.
Expand All @@ -19,7 +19,7 @@ def rollout(env, agent, max_path_length=np.inf, animated=False):
:param env:
:param agent:
:param max_path_length:
:param animated:
:param render:
:return:
"""
observations = []
Expand All @@ -31,7 +31,7 @@ def rollout(env, agent, max_path_length=np.inf, animated=False):
o = env.reset()
next_o = None
path_length = 0
if animated:
if render:
env.render()
while path_length < max_path_length:
a, agent_info = agent.get_action(o)
Expand All @@ -46,7 +46,7 @@ def rollout(env, agent, max_path_length=np.inf, animated=False):
if d:
break
o = next_o
if animated:
if render:
env.render()

actions = np.array(actions)
Expand Down
2 changes: 1 addition & 1 deletion rlkit/torch/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.policies.simple import RandomPolicy
from rlkit.samplers.util import rollout
from rlkit.samplers.rollout_functions import rollout
from rlkit.torch.torch_rl_algorithm import TorchRLAlgorithm
from rlkit.torch.data_management.normalizer import TorchFixedNormalizer
from torch import nn as nn
Expand Down
2 changes: 1 addition & 1 deletion rlkit/torch/skewfit/video_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def dump_video(
env,
policy,
max_path_length=horizon,
animated=False,
render=False,
)
is_vae_env = isinstance(env, VAEWrappedEnv)
l = []
Expand Down
2 changes: 1 addition & 1 deletion rlkit/util/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def dump_video(
env,
policy,
max_path_length=horizon,
animated=False,
render=False,
)
is_vae_env = isinstance(env, VAEWrappedEnv)
l = []
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_goal_conditioned_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def simulate_policy(args):
env,
policy,
max_path_length=args.H,
animated=not args.hide,
render=not args.hide,
observation_key='observation',
desired_goal_key='desired_goal',
))
Expand Down
4 changes: 2 additions & 2 deletions scripts/run_policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from rlkit.samplers.util import rollout
from rlkit.samplers.rollout_functions import rollout
from rlkit.torch.pytorch_util import set_gpu_mode
import argparse
import joblib
Expand All @@ -21,7 +21,7 @@ def simulate_policy(args):
env,
policy,
max_path_length=args.H,
animated=True,
render=True,
)
if hasattr(env, "log_diagnostics"):
env.log_diagnostics([path])
Expand Down

0 comments on commit 90cbd9c

Please sign in to comment.