-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlearner.py
93 lines (85 loc) · 3.59 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import time
import ray
import numpy as np
from models import get_trainable_model
from tensorflow.keras.models import clone_model
@ray.remote
class Learner:
def __init__(self, config, replay_buffer, parameter_server):
self.config = config
self.replay_buffer = replay_buffer
self.parameter_server = parameter_server
self.Q, self.trainable = get_trainable_model(config)
self.target_network = clone_model(self.Q)
self.train_batch_size = config["train_batch_size"]
self.total_collected_samples = 0
self.samples_since_last_update = 0
self.send_weights_to_parameter_server()
self.stopped = False
def send_weights_to_parameter_server(self):
self.parameter_server.update_weights.remote(self.Q.get_weights())
def start_learning(self):
print("Learning starting...")
self.send_weights()
while not self.stopped:
sid = self.replay_buffer.get_total_env_samples.remote()
total_samples = ray.get(sid)
if total_samples >= self.config["learning_starts"]:
self.optimize()
def optimize(self):
samples = ray.get(self.replay_buffer
.sample.remote(self.train_batch_size))
if samples:
N = len(samples)
self.total_collected_samples += N
self.samples_since_last_update += N
ndim_obs = 1
for s in self.config["obs_shape"]:
if s:
ndim_obs *= s
n_actions = self.config["n_actions"]
obs = np.array([sample[0] for sample \
in samples]).reshape((N, ndim_obs))
actions = np.array([sample[1] for sample \
in samples]).reshape((N,))
rewards = np.array([sample[2] for sample \
in samples]).reshape((N,))
last_obs = np.array([sample[3] for sample \
in samples]).reshape((N, ndim_obs))
done_flags = np.array([sample[4] for sample \
in samples]).reshape((N,))
gammas = np.array([sample[5] for sample \
in samples]).reshape((N,))
masks = np.zeros((N, n_actions))
masks[np.arange(N), actions] = 1
dummy_labels = np.zeros((N,))
# double DQN
maximizer_a = np.argmax(self.Q.predict(last_obs),
axis=1)
target_network_estimates = \
self.target_network.predict(last_obs)
q_value_estimates = \
np.array([target_network_estimates[i,
maximizer_a[i]]
for i in range(N)]).reshape((N,))
sampled_bellman = rewards + gammas * \
q_value_estimates * \
(1 - done_flags)
trainable_inputs = [obs, masks,
sampled_bellman]
self.trainable.fit(trainable_inputs,
dummy_labels, verbose=0)
self.send_weights()
if self.samples_since_last_update > 500:
self.target_network.set_weights(self.Q.get_weights())
self.samples_since_last_update = 0
return True
else:
print("No samples received from the buffer.")
time.sleep(5)
return False
def send_weights(self):
id = self.parameter_server.update_weights.remote(self.Q.get_weights())
ray.get(id)
def stop(self):
self.stopped = True