forked from vmayoral/basic_reinforcement_learning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
1,211 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# Deep Convolutional Reinforcement Learning | ||
|
||
### Code explained | ||
Let's analyze a 2D fruit fetch example based on [@bitwise-ben](https://github.com/bitwise-ben/Fruit)'s work. Code is available [here](examples/Fruit/qlearning.py): | ||
|
||
Dependencies used by the Deep Q-learning implementation: | ||
```Python | ||
import os | ||
from random import sample as rsample | ||
import numpy as np | ||
from keras.models import Sequential | ||
from keras.layers.convolutional import Convolution2D | ||
from keras.layers.core import Dense, Flatten | ||
from keras.optimizers import SGD, RMSprop | ||
from matplotlib import pyplot as plt | ||
``` | ||
|
||
The `GRID_SIZE` determines how big the environment will be (the bigger the environment, the tougther is to train it) | ||
```Python | ||
GRID_SIZE = 15 | ||
``` | ||
The following function defines a Python coroutine that controls the generic Fruit game dynamics | ||
(read about Python coroutines [here](https://jeffknupp.com/blog/2013/04/07/improve-your-python-yield-and-generators-explained/)). The coroutine is basically instantiated into a variable that receives `.next()` and `.send()` calls. The first one gets the function code to execute until the point where there's a call to `yield`. The `.send()` call includes an action as parameters which allows the function to finish its execution (it actually never finishes since the code is wrapped in an infinite loop, luckily we control its execution through the primitives just described). | ||
|
||
```Python | ||
def episode(): | ||
""" | ||
Coroutine function for an episode. | ||
Action has to be explicitly sent (via "send") to this co-routine. | ||
""" | ||
x, y, x_basket = ( | ||
np.random.randint(0, GRID_SIZE), # X of fruit | ||
0, # Y of dot | ||
np.random.randint(1, GRID_SIZE - 1)) # X of basket | ||
|
||
while True: | ||
# Reset grid | ||
X = np.zeros((GRID_SIZE, GRID_SIZE)) | ||
# Draw the fruit in the screen | ||
X[y, x] = 1. | ||
# Draw the basket | ||
bar = range(x_basket - 1, x_basket + 2) | ||
X[-1, bar] = 1. | ||
|
||
# End of game is known when fruit is at penultimate line of grid. | ||
# End represents either the reward (a win or a loss) | ||
end = int(y >= GRID_SIZE - 2) | ||
if end and x not in bar: | ||
end *= -1 | ||
|
||
action = yield X[np.newaxis], end | ||
if end: | ||
break | ||
|
||
x_basket = min(max(x_basket + action, 1), GRID_SIZE - 2) | ||
y += 1 | ||
``` | ||
|
||
Experience replay gets implemented in the coroutine below. Within this code, one should notice that the code blocks at `yield` expecting a `.send()` call that includes a `experience=(S, action, reward, S_prime)` tuple where: | ||
|
||
- `S`: current state | ||
- `action`: action to take | ||
- `reward`: reward obtained after taking `action` | ||
- `S_prime`: next state after taking `action` | ||
|
||
```Python | ||
def experience_replay(batch_size): | ||
""" | ||
Coroutine function for implementing experience replay. | ||
Provides a new experience by calling "send", which in turn yields | ||
a random batch of previous replay experiences. | ||
""" | ||
memory = [] | ||
while True: | ||
# experience is a tuple containing (S, action, reward, S_prime) | ||
experience = yield rsample(memory, batch_size) if batch_size <= len(memory) else None | ||
memory.append(experience) | ||
``` | ||
|
||
Similar to what was described above, the images are saved using another coroutine: | ||
```Python | ||
def save_img(): | ||
""" | ||
Coroutine to store images in the "images" directory | ||
""" | ||
if 'images' not in os.listdir('.'): | ||
os.mkdir('images') | ||
frame = 0 | ||
while True: | ||
screen = (yield) | ||
plt.imshow(screen[0], interpolation='none') | ||
plt.savefig('images/%03i.png' % frame) | ||
frame += 1 | ||
``` | ||
|
||
The model and hyperparameters are defined as follows: | ||
```Python | ||
nb_epochs = 500 | ||
batch_size = 128 | ||
epsilon = .8 | ||
gamma = .8 | ||
|
||
# Recipe of deep reinforcement learning model | ||
model = Sequential() | ||
model.add(Convolution2D(16, nb_row=3, nb_col=3, input_shape=(1, GRID_SIZE, GRID_SIZE), activation='relu')) | ||
model.add(Convolution2D(16, nb_row=3, nb_col=3, activation='relu')) | ||
model.add(Flatten()) | ||
model.add(Dense(100, activation='relu')) | ||
model.add(Dense(3)) | ||
model.compile(RMSprop(), 'MSE') | ||
``` | ||
|
||
The main loop of the code implementing Deep Q-learning: | ||
```Python | ||
exp_replay = experience_replay(batch_size) | ||
exp_replay.next() # Start experience-replay coroutine | ||
|
||
for i in xrange(nb_epochs): | ||
ep = episode() | ||
S, reward = ep.next() # Start coroutine of single entire episode | ||
loss = 0. | ||
try: | ||
while True: | ||
action = np.random.randint(-1, 2) | ||
if np.random.random() > epsilon: | ||
# Get the index of the maximum q-value of the model. | ||
# Subtract one because actions are either -1, 0, or 1 | ||
action = np.argmax(model.predict(S[np.newaxis]), axis=-1)[0] - 1 | ||
|
||
S_prime, reward = ep.send(action) | ||
experience = (S, action, reward, S_prime) | ||
S = S_prime | ||
|
||
batch = exp_replay.send(experience) | ||
if batch: | ||
inputs = [] | ||
targets = [] | ||
for s, a, r, s_prime in batch: | ||
# The targets of unchosen actions are the q-values of the model, | ||
# so that the corresponding errors are 0. The targets of chosen actions | ||
# are either the rewards, in case a terminal state has been reached, | ||
# or future discounted q-values, in case episodes are still running. | ||
t = model.predict(s[np.newaxis]).flatten() | ||
t[a + 1] = r | ||
if not r: | ||
t[a + 1] = r + gamma * model.predict(s_prime[np.newaxis]).max(axis=-1) | ||
targets.append(t) | ||
inputs.append(s) | ||
|
||
loss += model.train_on_batch(np.array(inputs), np.array(targets)) | ||
|
||
except StopIteration: | ||
pass | ||
|
||
if (i + 1) % 100 == 0: | ||
print 'Epoch %i, loss: %.6f' % (i + 1, loss) | ||
``` | ||
|
||
To test the model obtained: | ||
```Python | ||
img_saver = save_img() | ||
img_saver.next() | ||
|
||
for _ in xrange(10): | ||
g = episode() | ||
S, _ = g.next() | ||
img_saver.send(S) | ||
try: | ||
while True: | ||
act = np.argmax(model.predict(S[np.newaxis]), axis=-1)[0] - 1 | ||
S, _ = g.send(act) | ||
img_saver.send(S) | ||
|
||
except StopIteration: | ||
pass | ||
|
||
img_saver.close() | ||
|
||
``` | ||
|
||
resulting in images like the ones pictured below. | ||
|
||
![](images/fruit_grid10.gif) | ||
![](images/fruit_grid15.gif) | ||
|
||
|
||
### Resources: | ||
- Toy example of deep reinforcement model playing the game of snake, https://github.com/bitwise-ben/Snake | ||
- Toy example of a deep reinforcement learning model playing a game of catching fruit, https://github.com/bitwise-ben/Fruit | ||
- Keras plays catch, a single file Reinforcement Learning example, Eder Santana, http://edersantana.github.io/articles/keras_rl/ | ||
- Keras plays catch - a single file Reinforcement Learning example | ||
Raw, https://gist.github.com/EderSantana/c7222daa328f0e885093 | ||
- [Create a GIF from static images](http://askubuntu.com/questions/648244/how-to-create-a-gif-from-the-command-line) | ||
- [Improve Your Python: 'yield' and Generators Explained](https://jeffknupp.com/blog/2013/04/07/improve-your-python-yield-and-generators-explained/) |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"{\"class_name\": \"Sequential\", \"config\": [{\"class_name\": \"Dense\", \"config\": {\"W_constraint\": null, \"b_constraint\": null, \"name\": \"dense_1\", \"output_dim\": 100, \"activity_regularizer\": null, \"trainable\": true, \"init\": \"glorot_uniform\", \"bias\": true, \"input_dtype\": \"float32\", \"input_dim\": null, \"b_regularizer\": null, \"W_regularizer\": null, \"activation\": \"relu\", \"batch_input_shape\": [null, 100]}}, {\"class_name\": \"Dense\", \"config\": {\"W_constraint\": null, \"b_constraint\": null, \"name\": \"dense_2\", \"activity_regularizer\": null, \"trainable\": true, \"init\": \"glorot_uniform\", \"bias\": true, \"input_dim\": null, \"b_regularizer\": null, \"W_regularizer\": null, \"activation\": \"relu\", \"output_dim\": 100}}, {\"class_name\": \"Dense\", \"config\": {\"W_constraint\": null, \"b_constraint\": null, \"name\": \"dense_3\", \"activity_regularizer\": null, \"trainable\": true, \"init\": \"glorot_uniform\", \"bias\": true, \"input_dim\": null, \"b_regularizer\": null, \"W_regularizer\": null, \"activation\": \"linear\", \"output_dim\": 3}}]}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import json | ||
import numpy as np | ||
from keras.models import Sequential | ||
from keras.layers.core import Dense | ||
from keras.optimizers import sgd | ||
|
||
|
||
class Catch(object): | ||
def __init__(self, grid_size=10): | ||
self.grid_size = grid_size | ||
self.reset() | ||
|
||
def _update_state(self, action): | ||
""" | ||
Input: action and states | ||
Ouput: new states and reward | ||
""" | ||
state = self.state | ||
if action == 0: # left | ||
action = -1 | ||
elif action == 1: # stay | ||
action = 0 | ||
else: | ||
action = 1 # right | ||
f0, f1, basket = state[0] | ||
new_basket = min(max(1, basket + action), self.grid_size-1) | ||
f0 += 1 | ||
out = np.asarray([f0, f1, new_basket]) | ||
out = out[np.newaxis] | ||
|
||
assert len(out.shape) == 2 | ||
self.state = out | ||
|
||
def _draw_state(self): | ||
im_size = (self.grid_size,)*2 | ||
state = self.state[0] | ||
canvas = np.zeros(im_size) | ||
canvas[state[0], state[1]] = 1 # draw fruit | ||
canvas[-1, state[2]-1:state[2] + 2] = 1 # draw basket | ||
return canvas | ||
|
||
def _get_reward(self): | ||
fruit_row, fruit_col, basket = self.state[0] | ||
if fruit_row == self.grid_size-1: | ||
if abs(fruit_col - basket) <= 1: | ||
return 1 | ||
else: | ||
return -1 | ||
else: | ||
return 0 | ||
|
||
def _is_over(self): | ||
if self.state[0, 0] == self.grid_size-1: | ||
return True | ||
else: | ||
return False | ||
|
||
def observe(self): | ||
canvas = self._draw_state() | ||
return canvas.reshape((1, -1)) | ||
|
||
def act(self, action): | ||
self._update_state(action) | ||
reward = self._get_reward() | ||
game_over = self._is_over() | ||
return self.observe(), reward, game_over | ||
|
||
def reset(self): | ||
n = np.random.randint(0, self.grid_size-1, size=1) | ||
m = np.random.randint(1, self.grid_size-2, size=1) | ||
self.state = np.asarray([0, n, m])[np.newaxis] | ||
|
||
|
||
class ExperienceReplay(object): | ||
def __init__(self, max_memory=100, discount=.9): | ||
self.max_memory = max_memory | ||
self.memory = list() | ||
self.discount = discount | ||
|
||
def remember(self, states, game_over): | ||
# memory[i] = [[state_t, action_t, reward_t, state_t+1], game_over?] | ||
self.memory.append([states, game_over]) | ||
if len(self.memory) > self.max_memory: | ||
del self.memory[0] | ||
|
||
def get_batch(self, model, batch_size=10): | ||
len_memory = len(self.memory) | ||
num_actions = model.output_shape[-1] | ||
env_dim = self.memory[0][0][0].shape[1] | ||
inputs = np.zeros((min(len_memory, batch_size), env_dim)) | ||
targets = np.zeros((inputs.shape[0], num_actions)) | ||
for i, idx in enumerate(np.random.randint(0, len_memory, | ||
size=inputs.shape[0])): | ||
state_t, action_t, reward_t, state_tp1 = self.memory[idx][0] | ||
game_over = self.memory[idx][1] | ||
|
||
inputs[i:i+1] = state_t | ||
# There should be no target values for actions not taken. | ||
# Thou shalt not correct actions not taken #deep | ||
targets[i] = model.predict(state_t)[0] | ||
Q_sa = np.max(model.predict(state_tp1)[0]) | ||
if game_over: # if game_over is True | ||
targets[i, action_t] = reward_t | ||
else: | ||
# reward_t + gamma * max_a' Q(s', a') | ||
targets[i, action_t] = reward_t + self.discount * Q_sa | ||
return inputs, targets | ||
|
||
|
||
if __name__ == "__main__": | ||
# parameters | ||
epsilon = .1 # exploration | ||
num_actions = 3 # [move_left, stay, move_right] | ||
epoch = 1000 | ||
max_memory = 500 | ||
hidden_size = 100 | ||
batch_size = 50 | ||
grid_size = 10 | ||
|
||
model = Sequential() | ||
model.add(Dense(hidden_size, input_shape=(grid_size**2,), activation='relu')) | ||
model.add(Dense(hidden_size, activation='relu')) | ||
model.add(Dense(num_actions)) | ||
model.compile(sgd(lr=.2), "mse") | ||
|
||
# If you want to continue training from a previous model, just uncomment the line bellow | ||
# model.load_weights("model.h5") | ||
|
||
# Define environment/game | ||
env = Catch(grid_size) | ||
|
||
# Initialize experience replay object | ||
exp_replay = ExperienceReplay(max_memory=max_memory) | ||
|
||
# Train | ||
win_cnt = 0 | ||
for e in range(epoch): | ||
loss = 0. | ||
env.reset() | ||
game_over = False | ||
# get initial input | ||
input_t = env.observe() | ||
|
||
while not game_over: | ||
input_tm1 = input_t | ||
# get next action | ||
if np.random.rand() <= epsilon: | ||
action = np.random.randint(0, num_actions, size=1) | ||
else: | ||
q = model.predict(input_tm1) | ||
action = np.argmax(q[0]) | ||
|
||
# apply action, get rewards and new state | ||
input_t, reward, game_over = env.act(action) | ||
if reward == 1: | ||
win_cnt += 1 | ||
|
||
# store experience | ||
exp_replay.remember([input_tm1, action, reward, input_t], game_over) | ||
|
||
# adapt model | ||
inputs, targets = exp_replay.get_batch(model, batch_size=batch_size) | ||
|
||
loss += model.train_on_batch(inputs, targets) | ||
print("Epoch {:03d}/999 | Loss {:.4f} | Win count {}".format(e, loss, win_cnt)) | ||
|
||
# Save trained model weights and architecture, this will be used by the visualization code | ||
model.save_weights("model.h5", overwrite=True) | ||
with open("model.json", "w") as outfile: | ||
json.dump(model.to_json(), outfile) |
Binary file not shown.
Oops, something went wrong.