Skip to content

Commit 5ca53b4

Browse files
committed
Chapter 10: Machine Teaching
1 parent 3e4d6be commit 5ca53b4

File tree

4 files changed

+426
-0
lines changed

4 files changed

+426
-0
lines changed

Chapter10/custom_mcar.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import gym
2+
from gym.spaces import Box, Dict
3+
import numpy as np
4+
5+
6+
class MountainCar(gym.Env):
7+
def __init__(self, env_config={}):
8+
self.wrapped = gym.make("MountainCar-v0")
9+
self.action_space = self.wrapped.action_space
10+
self.t = 0
11+
self.reward_fun = env_config.get("reward_fun")
12+
self.lesson = env_config.get("lesson")
13+
self.use_action_masking = env_config.get("use_action_masking", False)
14+
self.action_mask = None
15+
self.reset()
16+
if self.use_action_masking:
17+
self.observation_space = Dict(
18+
{
19+
"action_mask": Box(0, 1, shape=(self.action_space.n,)),
20+
"actual_obs": self.wrapped.observation_space,
21+
}
22+
)
23+
else:
24+
self.observation_space = self.wrapped.observation_space
25+
26+
def _get_obs(self):
27+
raw_obs = np.array(self.wrapped.unwrapped.state)
28+
if self.use_action_masking:
29+
self.update_avail_actions()
30+
obs = {
31+
"action_mask": self.action_mask,
32+
"actual_obs": raw_obs,
33+
}
34+
else:
35+
obs = raw_obs
36+
return obs
37+
38+
def reset(self):
39+
self.wrapped.reset()
40+
self.t = 0
41+
self.wrapped.unwrapped.state = self._get_init_conditions()
42+
obs = self._get_obs()
43+
return obs
44+
45+
def _get_init_conditions(self):
46+
if self.lesson == 0:
47+
low = 0.1
48+
high = 0.4
49+
velocity = self.wrapped.np_random.uniform(
50+
low=0, high=self.wrapped.max_speed
51+
)
52+
elif self.lesson == 1:
53+
low = -0.4
54+
high = 0.1
55+
velocity = self.wrapped.np_random.uniform(
56+
low=0, high=self.wrapped.max_speed
57+
)
58+
elif self.lesson == 2:
59+
low = -0.6
60+
high = -0.4
61+
velocity = self.wrapped.np_random.uniform(
62+
low=-self.wrapped.max_speed, high=self.wrapped.max_speed
63+
)
64+
elif self.lesson == 3:
65+
low = -0.6
66+
high = -0.1
67+
velocity = self.wrapped.np_random.uniform(
68+
low=-self.wrapped.max_speed, high=self.wrapped.max_speed
69+
)
70+
elif self.lesson == 4 or self.lesson is None:
71+
low = -0.6
72+
high = -0.4
73+
velocity = 0
74+
else:
75+
raise ValueError
76+
obs = (self.wrapped.np_random.uniform(low=low, high=high), velocity)
77+
return obs
78+
79+
def set_lesson(self, lesson):
80+
self.lesson = lesson
81+
82+
def step(self, action):
83+
self.t += 1
84+
state, reward, done, info = self.wrapped.step(action)
85+
if self.reward_fun == "custom_reward":
86+
position, velocity = state
87+
reward += (abs(position + 0.5) ** 2) * (position > -0.5)
88+
obs = self._get_obs()
89+
if self.t >= 200:
90+
done = True
91+
return obs, reward, done, info
92+
93+
def update_avail_actions(self):
94+
self.action_mask = np.array([1.0] * self.action_space.n)
95+
pos, vel = self.wrapped.unwrapped.state
96+
# 0: left, 1: no action, 2: right
97+
if (pos < -0.3) and (pos > -0.8) and (vel < 0) and (vel > -0.05):
98+
self.action_mask[1] = 0
99+
self.action_mask[2] = 0
100+

Chapter10/masking_model.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from gym.spaces import Box
2+
from ray.rllib.agents.dqn.distributional_q_tf_model import DistributionalQTFModel
3+
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
4+
from ray.rllib.utils.framework import try_import_tf
5+
6+
tf1, tf, tfv = try_import_tf()
7+
8+
9+
class ParametricActionsModel(DistributionalQTFModel):
10+
def __init__(
11+
self,
12+
obs_space,
13+
action_space,
14+
num_outputs,
15+
model_config,
16+
name,
17+
true_obs_shape=(2,),
18+
**kw
19+
):
20+
super(ParametricActionsModel, self).__init__(
21+
obs_space, action_space, num_outputs, model_config, name, **kw
22+
)
23+
self.action_value_model = FullyConnectedNetwork(
24+
Box(-1, 1, shape=true_obs_shape),
25+
action_space,
26+
num_outputs,
27+
model_config,
28+
name + "_action_values",
29+
)
30+
self.register_variables(self.action_value_model.variables())
31+
32+
def forward(self, input_dict, state, seq_lens):
33+
action_mask = input_dict["obs"]["action_mask"]
34+
action_values, _ = self.action_value_model(
35+
{"obs": input_dict["obs"]["actual_obs"]}
36+
)
37+
inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
38+
return action_values + inf_mask, state
39+

Chapter10/mcar_demo.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#!/usr/bin/env python
2+
import os
3+
4+
import numpy as np
5+
import sys, gym, time
6+
7+
import ray.utils
8+
9+
from ray.rllib.models.preprocessors import get_preprocessor
10+
from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
11+
from ray.rllib.offline.json_writer import JsonWriter
12+
13+
from custom_mcar import MountainCar
14+
15+
DEMO_DATA_DIR = "mcar-out"
16+
17+
18+
def key_press(key, mod):
19+
global human_agent_action, human_wants_restart, human_sets_pause
20+
if key == 0xFF0D:
21+
human_wants_restart = True
22+
if key == 32:
23+
human_sets_pause = not human_sets_pause
24+
a = int(key - ord("0"))
25+
if a <= 0 or a >= ACTIONS:
26+
return
27+
human_agent_action = a
28+
29+
30+
def key_release(key, mod):
31+
global human_agent_action
32+
a = int(key - ord("0"))
33+
if a <= 0 or a >= ACTIONS:
34+
return
35+
if human_agent_action == a:
36+
human_agent_action = 0
37+
38+
39+
def rollout(env, eps_id):
40+
global human_agent_action, human_wants_restart, human_sets_pause
41+
human_wants_restart = False
42+
obs = env.reset()
43+
prev_action = np.zeros_like(env.action_space.sample())
44+
prev_reward = 0
45+
t = 0
46+
skip = 0
47+
total_reward = 0
48+
total_timesteps = 0
49+
while 1:
50+
if not skip:
51+
print("taking action {}".format(human_agent_action))
52+
a = human_agent_action
53+
total_timesteps += 1
54+
skip = SKIP_CONTROL
55+
else:
56+
skip -= 1
57+
58+
new_obs, r, done, info = env.step(a)
59+
# Build the batch
60+
batch_builder.add_values(
61+
t=t,
62+
eps_id=eps_id,
63+
agent_index=0,
64+
obs=prep.transform(obs),
65+
actions=a,
66+
action_prob=1.0, # put the true action probability here
67+
action_logp=0,
68+
action_dist_inputs=None,
69+
rewards=r,
70+
prev_actions=prev_action,
71+
prev_rewards=prev_reward,
72+
dones=done,
73+
infos=info,
74+
new_obs=prep.transform(new_obs),
75+
)
76+
obs = new_obs
77+
prev_action = a
78+
prev_reward = r
79+
80+
if r != 0:
81+
print("reward %0.3f" % r)
82+
total_reward += r
83+
window_still_open = env.wrapped.render()
84+
if window_still_open == False:
85+
return False
86+
if done:
87+
break
88+
if human_wants_restart:
89+
break
90+
while human_sets_pause:
91+
env.wrapped.render()
92+
time.sleep(0.1)
93+
time.sleep(0.1)
94+
print("timesteps %i reward %0.2f" % (total_timesteps, total_reward))
95+
writer.write(batch_builder.build_and_reset())
96+
97+
98+
if __name__ == "__main__":
99+
batch_builder = SampleBatchBuilder() # or MultiAgentSampleBatchBuilder
100+
writer = JsonWriter(DEMO_DATA_DIR)
101+
102+
env = MountainCar()
103+
104+
# RLlib uses preprocessors to implement transforms such as one-hot encoding
105+
# and flattening of tuple and dict observations. For CartPole a no-op
106+
# preprocessor is used, but this may be relevant for more complex envs.
107+
prep = get_preprocessor(env.observation_space)(env.observation_space)
108+
print("The preprocessor is", prep)
109+
110+
if not hasattr(env.action_space, "n"):
111+
raise Exception("Keyboard agent only supports discrete action spaces")
112+
ACTIONS = env.action_space.n
113+
SKIP_CONTROL = 0 # Use previous control decision SKIP_CONTROL times, that's how you
114+
# can test what skip is still usable.
115+
116+
human_agent_action = 0
117+
human_wants_restart = False
118+
human_sets_pause = False
119+
120+
env.reset()
121+
env.wrapped.render()
122+
env.wrapped.unwrapped.viewer.window.on_key_press = key_press
123+
env.wrapped.unwrapped.viewer.window.on_key_release = key_release
124+
125+
print("ACTIONS={}".format(ACTIONS))
126+
print("Press keys 1 2 3 ... to take actions 1 2 3 ...")
127+
print("No keys pressed is taking action 0")
128+
129+
for i in range(20):
130+
window_still_open = rollout(env, i)
131+
if window_still_open == False:
132+
break
133+

0 commit comments

Comments
 (0)