Skip to content

Commit 43aa9ab

Browse files
committedJan 2, 2018
feat: implement double dqn
1 parent bcafdf7 commit 43aa9ab

5 files changed

+95
-47
lines changed
 

‎agent.py

+53-13
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,44 @@
1+
import sys
12
import numpy as np
23
import warnings
34
import utils
5+
from enum import Enum
46
from time import time, sleep
57
import matplotlib.pyplot as plt
68
from policy import EpsGreedyPolicy
79
from memory import ExperienceReplay
10+
from keras.models import Sequential
11+
from keras.layers import *
12+
from keras.optimizers import *
13+
from keras.models import load_model
814

9-
class Agent:
10-
def __init__(self, game, model, nb_epoch=10000, memory_size=1000, batch_size=50, nb_frames=4, epsilon=1., discount=.9, learning_rate=.1):
11-
12-
channels = model.input_shape[1]
15+
TEST = 0
16+
SIMPLE = 1
17+
DOUBLE = 2
1318

14-
if nb_frames != channels:
15-
warnings.warn("Dimension mismatch: Using number of channels for number of frames")
16-
nb_frames = channels
19+
class Agent:
20+
def __init__(self, game, mode=SIMPLE, nb_epoch=10000, memory_size=1000, batch_size=50, nb_frames=4, epsilon=1., discount=.9, learning_rate=.1, model=None):
1721

1822
self.game = game
19-
self.model = model
23+
self.mode = mode
24+
self.target_model = None
25+
self.rows, self.columns = game.field_shape()
2026
self.nb_epoch = nb_epoch
2127
self.nb_frames = nb_frames
2228
self.nb_actions = game.nb_actions()
2329

30+
if mode == TEST:
31+
print('Training Mode: Loading model...')
32+
self.model = load_model(model)
33+
elif mode == SIMPLE:
34+
print('Using Plain DQN: Building model...')
35+
self.model = self.build_model()
36+
elif mode == DOUBLE:
37+
print('Using Double DQN: Building primary and target model...')
38+
self.model = self.build_model()
39+
self.target_model = self.build_model()
40+
self.update_target_model()
41+
2442
# Trades off the importance of sooner versus later rewards.
2543
# A factor of 0 means it rather prefers immediate rewards
2644
# and it will mostly consider current rewards. A factor of 1
@@ -39,18 +57,34 @@ def __init__(self, game, model, nb_epoch=10000, memory_size=1000, batch_size=50,
3957
# a random action by the probability 'eps'. Without this policy the network
4058
# is greedy and it will it settles with the first effective strategy it finds.
4159
# Hence, we introduce certain randomness.
42-
# Epislon reaches its minimum at 2/3 of the games
60+
# Epislon reaches its minimum at 1/2 of the games
4361
epsilon_end = self.nb_epoch - (self.nb_epoch / 2)
4462
self.policy = EpsGreedyPolicy(self.model, epsilon_end, self.nb_actions, epsilon, .1)
4563

4664
# Create new experience replay memory. Without this optimization
4765
# the training takes extremely long even on a GPU and most
4866
# importantly the approximation of Q-values using non-linear
4967
# functions, that is used for our NN, is not very stable.
50-
self.memory = ExperienceReplay(self.model, self.nb_actions, memory_size, batch_size, self.discount, self.learning_rate)
68+
self.memory = ExperienceReplay(self.model, self.target_model, self.nb_actions, memory_size, batch_size, self.discount, self.learning_rate)
5169

5270
self.frames = None
5371

72+
def build_model(self):
73+
model = Sequential()
74+
model.add(Conv2D(32, (2, 2), activation='relu', input_shape=(self.nb_frames, self.rows, self.columns), data_format="channels_first"))
75+
model.add(Conv2D(64, (2, 2), activation='relu'))
76+
model.add(Conv2D(64, (3, 3), activation='relu'))
77+
model.add(Flatten())
78+
model.add(Dropout(0.1))
79+
model.add(Dense(512, activation='relu'))
80+
model.add(Dense(self.nb_actions))
81+
model.compile(Adam(), 'MSE')
82+
83+
return model
84+
85+
def update_target_model(self):
86+
self.target_model.set_weights(self.model.get_weights())
87+
5488
def get_frames(self):
5589
frame = self.game.get_state()
5690
if self.frames is None:
@@ -85,7 +119,8 @@ def print_stats(self, data, y_label, x_label='Epoch', marker='-'):
85119
path = './plots/{name}_{size}x{size}_{timestamp}'
86120
fig.savefig(path.format(size=self.game.grid_size, name=file_name, timestamp=int(time())))
87121

88-
def train(self, visualize=True):
122+
def train(self, update_freq=10):
123+
total_steps = 0
89124
max_steps = self.game.grid_size**2 * 3
90125
loops = 0
91126
nb_wins = 0
@@ -119,6 +154,7 @@ def train(self, visualize=True):
119154

120155
cumulative_reward += reward
121156
steps += 1
157+
total_steps += 1
122158

123159
if steps == max_steps and not done:
124160
loops += 1
@@ -145,6 +181,9 @@ def train(self, visualize=True):
145181
if done:
146182
duration = utils.get_time_difference(start_time, time())
147183

184+
if self.mode == DOUBLE and self.target_model is not None and total_steps % (update_freq) == 0:
185+
self.update_target_model()
186+
148187
current_epoch = epoch + 1
149188
reward_buffer.append([current_epoch, cumulative_reward])
150189
duration_buffer.append([current_epoch, duration])
@@ -160,8 +199,9 @@ def train(self, visualize=True):
160199
self.print_stats(steps_buffer, 'Steps per Game')
161200
self.print_stats(wins_buffer, 'Wins')
162201

163-
path = './models/model_{size}x{size}_{epochs}_{timestamp}.h5'
164-
self.model.save(path.format(size=self.game.grid_size, epochs=self.nb_epoch, timestamp=int(time())))
202+
path = './models/model_{mode}_{size}x{size}_{epochs}_{timestamp}.h5'
203+
mode = 'dqn' if self.mode == SIMPLE else 'ddqn'
204+
self.model.save(path.format(mode=mode, size=self.game.grid_size, epochs=self.nb_epoch, timestamp=int(time())))
165205

166206
def play(self, nb_games=5, interval=.7):
167207
nb_wins = 0

‎main.py

+9-25
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
from keras.models import Sequential
2-
from keras.layers import *
3-
from keras.optimizers import *
41
from games import Snake
5-
from keras.models import load_model
6-
from agent import Agent
2+
from agent import Agent, TEST
73
import argparse
84

95
boolean = lambda x: (str(x).lower() == 'true')
106

11-
# Command line arguments
7+
# Command line argumentss
128
parser = argparse.ArgumentParser()
139
parser.add_argument("--train", nargs='?', type=boolean, const=True, default=True)
1410
parser.add_argument("--model", nargs='?', const=True)
11+
parser.add_argument("--mode", nargs='?', type=int, const=True, default=1, choices=[0,1,2])
12+
parser.add_argument("--update-freq", nargs='?', type=int, const=True, default=10)
1513
parser.add_argument("--grid-size", nargs='?', type=int, const=True, default=10)
1614
parser.add_argument("--frames", nargs='?', type=int, const=True, default=4)
1715
parser.add_argument("--epochs", nargs='?', type=int, const=True, default=10000)
@@ -27,14 +25,13 @@
2725
args = parser.parse_args()
2826

2927
if not args.train and args.model is None:
30-
parser.error("Non-training mode requires a model")
28+
parser.error("Non-training mode requires a model")
3129

3230
print(args)
3331

3432
game = Snake(grid_size=args.grid_size, walls=args.walls)
3533

3634
# Hyper parameter for the neural net and the agent
37-
rows, columns = game.field_shape()
3835
nb_frames = args.frames
3936
nb_epoch = args.epochs
4037
memory_size = args.memory_size
@@ -43,25 +40,12 @@
4340
discount = args.discount
4441
learning_rate = args.learning_rate
4542
nb_actions = game.nb_actions()
43+
mode = args.mode if args.train else TEST
44+
update_freq = args.update_freq
4645

47-
model = None
46+
agent = Agent(game, mode, nb_epoch, memory_size, batch_size, nb_frames, epsilon, discount, learning_rate, model=args.model)
4847

4948
if args.train:
50-
model = Sequential()
51-
model.add(Conv2D(32, (2, 2), activation='relu', input_shape=(nb_frames, rows, columns), data_format="channels_first"))
52-
model.add(Conv2D(64, (2, 2), activation='relu'))
53-
model.add(Conv2D(64, (3, 3), activation='relu'))
54-
model.add(Flatten())
55-
model.add(Dropout(0.1))
56-
model.add(Dense(512, activation='relu'))
57-
model.add(Dense(nb_actions))
58-
model.compile(Adam(), 'MSE')
59-
else:
60-
model = load_model(args.model)
61-
62-
agent = Agent(game, model, nb_epoch, memory_size, batch_size, nb_frames, epsilon, discount, learning_rate)
63-
64-
if args.train:
65-
agent.train()
49+
agent.train(update_freq=update_freq)
6650
else:
6751
agent.play(nb_games=args.games, interval=args.interval)

‎memory.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from random import sample
33

44
class ExperienceReplay():
5-
def __init__(self, model, nb_actions, memory_size=100, batch_size=50, discount=.9, learning_rate=.1):
5+
def __init__(self, model, target_model, nb_actions, memory_size=100, batch_size=50, discount=.9, learning_rate=.1):
66
self.memory = []
77
self.model = model
8+
self.target_model = target_model
89
self.nb_actions = nb_actions
910
self.memory_size = memory_size
1011
self.batch_size = batch_size
@@ -57,7 +58,7 @@ def get_batch(self):
5758
q_t = self.model.predict(batch)
5859

5960
# q-values for the next states (states_tn)
60-
q_tn = self.get_q_next(q_t, batch_size)
61+
q_tn = self.get_q_next(q_t, states_tn, batch_size)
6162

6263
# Delta (learning rate). Determines how aggressively
6364
# the q-values should be updated. 1 means very a
@@ -68,18 +69,41 @@ def get_batch(self):
6869

6970
inputs = states_t
7071

71-
# Update q-values for states_t given the reward and the max q-value for states_tn
72+
# Update q-values based on the next states (states_tn)
7273
# q_t[:batch_size] = q-values for the current states (states_t)
7374
targets = (1 - delta) * q_t[:batch_size] + delta * (rewards + self.discount * (1 - done) * q_tn)
7475

7576
return inputs, targets
7677

77-
def get_q_next(self, q_t, batch_size):
78-
# Take max q-value from each next state (state_tn) and reshape into
79-
# [[ .5 .5 .5 .5 .5 ] | max q for state_tn[0]
80-
# [ .2 .2 .2 .2 .2 ] | max q for state_tn[1]
81-
# ... #state_tn ]
82-
return np.max(q_t[batch_size:], axis=1).repeat(self.nb_actions).reshape((batch_size, self.nb_actions))
78+
def get_q_next(self, q_t, states_tn, batch_size):
79+
if not self.target_model:
80+
# Plain DQN
81+
# A single network for action selection and generation of target q-values
82+
# Take max q-value from each next state (state_tn) and reshape into
83+
# [[ .5 .5 .5 .5 .5 ] | max q for state_tn[0]
84+
# [ .2 .2 .2 .2 .2 ] | max q for state_tn[1]
85+
# ... #state_tn ]
86+
q_next = np.max(q_t[batch_size:], axis=1)
87+
else:
88+
# Double DQN
89+
# The problem with plain DQN is that it tends to overestimate the q-values due to the
90+
# 'max' used in the formula to update the targets. The 'max' leads to a positive bias
91+
# because the highest q-value is propagated to previous states.
92+
# The solution is to have two separate networks, one primary network for determining the
93+
# action and a second (target) network to genrate the target q-values for that action.
94+
# By decoupling the action choice from the target Q-value generation, we are able to
95+
# substantially reduce the overestimation, and train faster and more reliably.
96+
97+
# Select max action from primary network (from states_tn)
98+
next_actions = np.argmax(q_t[batch_size:], axis=1)
99+
100+
# Generate target q-values with secondary (target) network
101+
target_q_values = self.target_model.predict(states_tn)
102+
103+
# Take the highest q-values
104+
q_next = target_q_values[range(batch_size), next_actions]
105+
106+
return q_next.repeat(self.nb_actions).reshape((batch_size, self.nb_actions))
83107

84108
def extract_transition(self, experience, batch_size):
85109
input_dim = self.input_dim
11.4 MB
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)
Please sign in to comment.