Skip to content

Commit

Permalink
[RLlib] MAML: Add cartpole mass test for PyTorch. (ray-project#13679)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Jan 25, 2021
1 parent e9103ee commit 9423930
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 12 deletions.
3 changes: 3 additions & 0 deletions python/requirements_rllib.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ pettingzoo>=1.4.0
# For tests on RecSim and Kaggle envs.
recsim
kaggle_environments

# For MAML on PyTorch.
higher
24 changes: 15 additions & 9 deletions rllib/agents/maml/tests/test_maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,21 @@ def test_maml_compilation(self):
num_iterations = 1

# Test for tf framework (torch not implemented yet).
for _ in framework_iterator(config, frameworks=("tf")):
trainer = maml.MAMLTrainer(
config=config,
env="ray.rllib.examples.env.pendulum_mass.PendulumMassEnv")
for i in range(num_iterations):
trainer.train()
check_compute_single_action(
trainer, include_prev_action_reward=True)
trainer.stop()
for fw in framework_iterator(config, frameworks=("tf", "torch")):
for env in [
"pendulum_mass.PendulumMassEnv",
"cartpole_mass.CartPoleMassEnv"
]:
if fw == "tf" and env.startswith("cartpole"):
continue
print("env={}".format(env))
env_ = "ray.rllib.examples.env.{}".format(env)
trainer = maml.MAMLTrainer(config=config, env=env_)
for i in range(num_iterations):
trainer.train()
check_compute_single_action(
trainer, include_prev_action_reward=True)
trainer.stop()


if __name__ == "__main__":
Expand Down
31 changes: 31 additions & 0 deletions rllib/examples/env/cartpole_mass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import gym
from gym.envs.classic_control.cartpole import CartPoleEnv
from ray.rllib.env.meta_env import MetaEnv


class CartPoleMassEnv(CartPoleEnv, gym.utils.EzPickle, MetaEnv):
"""CartPoleMassEnv varies the weights of the cart and the pole.
"""

def sample_tasks(self, n_tasks):
# Sample new cart- and pole masses (random floats between 0.5 and 2.0
# (cart) and between 0.05 and 0.2 (pole)).
cart_masses = np.random.uniform(low=0.5, high=2.0, size=(n_tasks, 1))
pole_masses = np.random.uniform(low=0.05, high=0.2, size=(n_tasks, 1))
return np.concatenate([cart_masses, pole_masses], axis=-1)

def set_task(self, task):
"""
Args:
task (Tuple[float]): Masses of the cart and the pole.
"""
self.masscart = task[0]
self.masspole = task[1]

def get_task(self):
"""
Returns:
Tuple[float]: The current mass of the cart- and pole.
"""
return np.array([self.masscart, self.masspole])
9 changes: 6 additions & 3 deletions rllib/examples/env/pendulum_mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,22 @@ class PendulumMassEnv(PendulumEnv, gym.utils.EzPickle, MetaEnv):
"""

def sample_tasks(self, n_tasks):
# Mass is a random float between 0.5 and 2
# Sample new pendulum masses (random floats between 0.5 and 2).
return np.random.uniform(low=0.5, high=2.0, size=(n_tasks, ))

def set_task(self, task):
"""
Args:
task: task of the meta-learning environment
task (float): Task of the meta-learning environment (here: mass of
the pendulum).
"""
# self.m is the mass property of the pendulum.
self.m = task

def get_task(self):
"""
Returns:
task: task of the meta-learning environment
float: The current mass of the pendulum (self.m in the PendulumEnv
object).
"""
return self.m

0 comments on commit 9423930

Please sign in to comment.