Skip to content

Commit aa727ed

Browse files
committed
Chapter 6: Deep Q-Learning
1 parent 6c2c6c8 commit aa727ed

8 files changed

+493
-0
lines changed

Chapter06/actor.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from collections import deque
2+
import ray
3+
import gym
4+
import numpy as np
5+
from models import get_Q_network
6+
7+
8+
@ray.remote
9+
class Actor:
10+
def __init__(self,
11+
actor_id,
12+
replay_buffer,
13+
parameter_server,
14+
config,
15+
eps,
16+
eval=False):
17+
self.actor_id = actor_id
18+
self.replay_buffer = replay_buffer
19+
self.parameter_server = parameter_server
20+
self.config = config
21+
self.eps = eps
22+
self.eval = eval
23+
self.Q = get_Q_network(config)
24+
self.env = gym.make(config["env"])
25+
self.local_buffer = []
26+
self.obs_shape = config["obs_shape"]
27+
self.n_actions = config["n_actions"]
28+
self.multi_step_n = config.get("n_step", 1)
29+
self.q_update_freq = config.get("q_update_freq", 100)
30+
self.send_experience_freq = \
31+
config.get("send_experience_freq", 100)
32+
self.continue_sampling = True
33+
self.cur_episodes = 0
34+
self.cur_steps = 0
35+
36+
def update_q_network(self):
37+
if self.eval:
38+
pid = \
39+
self.parameter_server.get_eval_weights.remote()
40+
else:
41+
pid = \
42+
self.parameter_server.get_weights.remote()
43+
new_weights = ray.get(pid)
44+
if new_weights:
45+
self.Q.set_weights(new_weights)
46+
else:
47+
print("Weights are not available yet, skipping.")
48+
49+
def get_action(self, observation):
50+
observation = observation.reshape((1, -1))
51+
q_estimates = self.Q.predict(observation)[0]
52+
if np.random.uniform() <= self.eps:
53+
action = np.random.randint(self.n_actions)
54+
else:
55+
action = np.argmax(q_estimates)
56+
return action
57+
58+
def get_n_step_trans(self, n_step_buffer):
59+
gamma = self.config['gamma']
60+
discounted_return = 0
61+
cum_gamma = 1
62+
for trans in list(n_step_buffer)[:-1]:
63+
_, _, reward, _ = trans
64+
discounted_return += cum_gamma * reward
65+
cum_gamma *= gamma
66+
observation, action, _, _ = n_step_buffer[0]
67+
last_observation, _, _, done = n_step_buffer[-1]
68+
experience = (observation, action, discounted_return,
69+
last_observation, done, cum_gamma)
70+
return experience
71+
72+
def stop(self):
73+
self.continue_sampling = False
74+
75+
def sample(self):
76+
print("Starting sampling in actor {}".format(self.actor_id))
77+
self.update_q_network()
78+
observation = self.env.reset()
79+
episode_reward = 0
80+
episode_length = 0
81+
n_step_buffer = deque(maxlen=self.multi_step_n + 1)
82+
while self.continue_sampling:
83+
action = self.get_action(observation)
84+
next_observation, reward, \
85+
done, info = self.env.step(action)
86+
n_step_buffer.append((observation, action,
87+
reward, done))
88+
if len(n_step_buffer) == self.multi_step_n + 1:
89+
self.local_buffer.append(
90+
self.get_n_step_trans(n_step_buffer))
91+
self.cur_steps += 1
92+
episode_reward += reward
93+
episode_length += 1
94+
if done:
95+
if self.eval:
96+
break
97+
next_observation = self.env.reset()
98+
if len(n_step_buffer) > 1:
99+
self.local_buffer.append(
100+
self.get_n_step_trans(n_step_buffer))
101+
self.cur_episodes += 1
102+
episode_reward = 0
103+
episode_length = 0
104+
observation = next_observation
105+
if self.cur_steps % \
106+
self.send_experience_freq == 0 and not self.eval:
107+
self.send_experience_to_replay()
108+
if self.cur_steps % \
109+
self.q_update_freq == 0 and not self.eval:
110+
self.update_q_network()
111+
return episode_reward
112+
113+
def send_experience_to_replay(self):
114+
rf = self.replay_buffer.add.remote(self.local_buffer)
115+
ray.wait([rf])
116+
self.local_buffer = []

Chapter06/learner.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import time
2+
import ray
3+
import numpy as np
4+
from models import get_trainable_model
5+
from tensorflow.keras.models import clone_model
6+
7+
8+
@ray.remote
9+
class Learner:
10+
def __init__(self, config, replay_buffer, parameter_server):
11+
self.config = config
12+
self.replay_buffer = replay_buffer
13+
self.parameter_server = parameter_server
14+
self.Q, self.trainable = get_trainable_model(config)
15+
self.target_network = clone_model(self.Q)
16+
self.train_batch_size = config["train_batch_size"]
17+
self.total_collected_samples = 0
18+
self.samples_since_last_update = 0
19+
self.send_weights_to_parameter_server()
20+
self.stopped = False
21+
22+
def send_weights_to_parameter_server(self):
23+
self.parameter_server.update_weights.remote(self.Q.get_weights())
24+
25+
def start_learning(self):
26+
print("Learning starting...")
27+
self.send_weights()
28+
while not self.stopped:
29+
sid = self.replay_buffer.get_total_env_samples.remote()
30+
total_samples = ray.get(sid)
31+
if total_samples >= self.config["learning_starts"]:
32+
self.optimize()
33+
34+
def optimize(self):
35+
samples = ray.get(self.replay_buffer
36+
.sample.remote(self.train_batch_size))
37+
if samples:
38+
N = len(samples)
39+
self.total_collected_samples += N
40+
self.samples_since_last_update += N
41+
ndim_obs = 1
42+
for s in self.config["obs_shape"]:
43+
if s:
44+
ndim_obs *= s
45+
n_actions = self.config["n_actions"]
46+
obs = np.array([sample[0] for sample \
47+
in samples]).reshape((N, ndim_obs))
48+
actions = np.array([sample[1] for sample \
49+
in samples]).reshape((N,))
50+
rewards = np.array([sample[2] for sample \
51+
in samples]).reshape((N,))
52+
last_obs = np.array([sample[3] for sample \
53+
in samples]).reshape((N, ndim_obs))
54+
done_flags = np.array([sample[4] for sample \
55+
in samples]).reshape((N,))
56+
gammas = np.array([sample[5] for sample \
57+
in samples]).reshape((N,))
58+
masks = np.zeros((N, n_actions))
59+
masks[np.arange(N), actions] = 1
60+
dummy_labels = np.zeros((N,))
61+
# double DQN
62+
maximizer_a = np.argmax(self.Q.predict(last_obs),
63+
axis=1)
64+
target_network_estimates = \
65+
self.target_network.predict(last_obs)
66+
q_value_estimates = \
67+
np.array([target_network_estimates[i,
68+
maximizer_a[i]]
69+
for i in range(N)]).reshape((N,))
70+
sampled_bellman = rewards + gammas * \
71+
q_value_estimates * \
72+
(1 - done_flags)
73+
trainable_inputs = [obs, masks,
74+
sampled_bellman]
75+
self.trainable.fit(trainable_inputs,
76+
dummy_labels, verbose=0)
77+
self.send_weights()
78+
79+
if self.samples_since_last_update > 500:
80+
self.target_network.set_weights(self.Q.get_weights())
81+
self.samples_since_last_update = 0
82+
return True
83+
else:
84+
print("No samples received from the buffer.")
85+
time.sleep(5)
86+
return False
87+
88+
def send_weights(self):
89+
id = self.parameter_server.update_weights.remote(self.Q.get_weights())
90+
ray.get(id)
91+
92+
def stop(self):
93+
self.stopped = True

Chapter06/models.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from tensorflow.keras import backend as K
2+
from tensorflow.keras import Input
3+
from tensorflow.keras.layers import Dense, Flatten, Lambda
4+
from tensorflow.keras.optimizers import Adam
5+
from tensorflow.keras.models import Model
6+
7+
def masked_loss(args):
8+
y_true, y_pred, mask = args
9+
masked_pred = K.sum(mask * y_pred, axis=1, keepdims=True)
10+
loss = K.square(y_true - masked_pred)
11+
return K.mean(loss, axis=-1)
12+
13+
14+
def get_Q_network(config):
15+
obs_input = Input(shape=config["obs_shape"],
16+
name='Q_input')
17+
18+
x = Flatten()(obs_input)
19+
for i, n_units in enumerate(config["fcnet_hiddens"]):
20+
layer_name = 'Q_' + str(i + 1)
21+
x = Dense(n_units,
22+
activation=config["fcnet_activation"],
23+
name=layer_name)(x)
24+
q_estimate_output = Dense(config["n_actions"],
25+
activation='linear',
26+
name='Q_output')(x)
27+
# Q Model
28+
Q_model = Model(inputs=obs_input,
29+
outputs=q_estimate_output)
30+
Q_model.summary()
31+
Q_model.compile(optimizer=Adam(), loss='mse')
32+
return Q_model
33+
34+
35+
def get_trainable_model(config):
36+
Q_model = get_Q_network(config)
37+
obs_input = Q_model.get_layer("Q_input").output
38+
q_estimate_output = Q_model.get_layer("Q_output").output
39+
mask_input = Input(shape=(config["n_actions"],),
40+
name='Q_mask')
41+
sampled_bellman_input = Input(shape=(1,),
42+
name='Q_sampled')
43+
44+
# Trainable model
45+
loss_output = Lambda(masked_loss,
46+
output_shape=(1,),
47+
name='Q_masked_out')\
48+
([sampled_bellman_input,
49+
q_estimate_output,
50+
mask_input])
51+
trainable_model = Model(inputs=[obs_input,
52+
mask_input,
53+
sampled_bellman_input],
54+
outputs=loss_output)
55+
trainable_model.summary()
56+
trainable_model.compile(optimizer=
57+
Adam(lr=config["lr"],
58+
clipvalue=config["grad_clip"]),
59+
loss=[lambda y_true,
60+
y_pred: y_pred])
61+
return Q_model, trainable_model
62+

Chapter06/parameter_server.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import ray
2+
from models import get_Q_network
3+
4+
5+
@ray.remote
6+
class ParameterServer:
7+
def __init__(self, config):
8+
self.weights = None
9+
self.eval_weights = None
10+
self.Q = get_Q_network(config)
11+
12+
def update_weights(self, new_parameters):
13+
self.weights = new_parameters
14+
return True
15+
16+
def get_weights(self):
17+
return self.weights
18+
19+
def get_eval_weights(self):
20+
return self.eval_weights
21+
22+
def set_eval_weights(self):
23+
self.eval_weights = self.weights
24+
return True
25+
26+
def save_eval_weights(self,
27+
filename=
28+
'checkpoints/model_checkpoint'):
29+
self.Q.set_weights(self.eval_weights)
30+
self.Q.save_weights(filename)
31+
print("Saved.")

Chapter06/ray_primer.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# These examples are taken from Ray's own documentation at
2+
# https://docs.ray.io/en/latest/index.html
3+
4+
import ray
5+
6+
# Initialize Ray
7+
ray.init()
8+
9+
# Using remote functions
10+
@ray.remote
11+
def remote_function():
12+
return 1
13+
14+
object_ids = []
15+
for _ in range(4):
16+
y_id = remote_function.remote()
17+
object_ids.append(y_id)
18+
19+
@ray.remote
20+
def remote_chain_function(value):
21+
return value + 1
22+
23+
y1_id = remote_function.remote()
24+
chained_id = remote_chain_function.remote(y1_id)
25+
26+
27+
# Using remote objects
28+
y = 1
29+
object_id = ray.put(y)
30+
31+
# Using remote classes (actors)
32+
@ray.remote
33+
class Counter(object):
34+
def __init__(self):
35+
self.value = 0
36+
37+
def increment(self):
38+
self.value += 1
39+
return self.value
40+
41+
a = Counter.remote()
42+
obj_id = a.increment.remote()
43+
ray.get(obj_id) == 1
44+

Chapter06/replay.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from collections import deque
2+
import ray
3+
import numpy as np
4+
5+
6+
@ray.remote
7+
class ReplayBuffer:
8+
def __init__(self, config):
9+
self.replay_buffer_size = config["buffer_size"]
10+
self.buffer = deque(maxlen=self.replay_buffer_size)
11+
self.total_env_samples = 0
12+
13+
def add(self, experience_list):
14+
experience_list = experience_list
15+
for e in experience_list:
16+
self.buffer.append(e)
17+
self.total_env_samples += 1
18+
return True
19+
20+
def sample(self, n):
21+
if len(self.buffer) > n:
22+
sample_ix = np.random.randint(
23+
len(self.buffer), size=n)
24+
return [self.buffer[ix] for ix in sample_ix]
25+
26+
def get_total_env_samples(self):
27+
return self.total_env_samples

Chapter06/rllib_apex_dqn.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pprint
2+
from ray import tune
3+
from ray.rllib.agents.dqn.apex import APEX_DEFAULT_CONFIG
4+
from ray.rllib.agents.dqn.apex import ApexTrainer
5+
6+
if __name__ == '__main__':
7+
config = APEX_DEFAULT_CONFIG.copy()
8+
pp = pprint.PrettyPrinter(indent=4)
9+
pp.pprint(config)
10+
config['env'] = "CartPole-v0"
11+
config['num_workers'] = 50
12+
config['evaluation_num_workers'] = 10
13+
config['evaluation_interval'] = 1
14+
config['learning_starts'] = 5000
15+
tune.run(ApexTrainer, config=config)

0 commit comments

Comments
 (0)