Skip to content

Commit 2b86c29

Browse files
committedSep 21, 2018
done
1 parent b279694 commit 2b86c29

21 files changed

+144
-303
lines changed
 

‎.idea/workspace.xml

+44-99
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎DQN.py

+68-173
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
#import snakeClass
1+
# import snakeClass
22
import pygame
3-
from keras.models import Sequential,Model
3+
from keras.models import Sequential, Model
44
from keras.layers.core import Flatten, Dense, Dropout
55
from keras.optimizers import RMSprop, Adam
66
import random
77
import numpy as np
88
import pandas as pd
99
import pygame
10-
from keras.models import Sequential,Model
10+
from keras.models import Sequential, Model
1111
from keras.layers.core import Flatten, Dense, Dropout
1212
from keras.optimizers import RMSprop
1313
import random
@@ -18,9 +18,9 @@
1818
import copy
1919
from operator import sub, add
2020

21-
2221
pd.set_option('display.max_columns', 500)
2322

23+
2424
class DQNAgent(object):
2525

2626
def __init__(self):
@@ -30,217 +30,112 @@ def __init__(self):
3030
self.short_memory = np.array([])
3131
self.agent_target = 1
3232
self.agent_predict = 0
33-
self.learning_rate = 0.001
34-
self.model = self.network()
35-
#self.model = self.network("weights_new3.hdf5")
36-
self.epsilon = 2
33+
self.learning_rate = 0.000005
34+
35+
#self.model = self.network()
36+
self.model = self.network("weights_new17_1.hdf5")
37+
self.epsilon = 0
3738
self.actual = []
3839
self.memory = []
3940

40-
4141
def get_state(self, game, player, food):
4242

4343
state = [
44-
(list(map(add, player.position[-1], [-20,0])) in player.position and player.x_change != 20) or
45-
player.position[-1][0] - 20 < 0, # danger left
46-
(list(map(add, player.position[-1], [-40,0])) in player.position and player.x_change != 20) or
47-
player.position[-1][0] - 40 < 0, # danger 2 left
48-
(list(map(add, player.position[-1], [20,0])) in player.position and player.x_change != -20) or
49-
player.position[-1][0] + 20 > game.display_width, # danger right
50-
(list(map(add, player.position[-1], [40, 0])) in player.position and player.x_change != -20) or
51-
player.position[-1][0] + 40 > game.display_width, # danger 2 right
52-
(list(map(add, player.position[-1], [0, -20])) in player.position and player.y_change != 20) or
53-
player.position[-1][-1] - 20 < 0, # danger up
54-
(list(map(add, player.position[-1], [0, -40])) in player.position and player.y_change != 20) or
55-
player.position[-1][-1] - 40 < 0, # danger 2 up
56-
(list(map(add, player.position[-1], [0, 20])) in player.position and player.y_change != -20) or
57-
player.position[-1][-1] + 20 >= game.display_height, # danger down
58-
(list(map(add, player.position[-1], [0, 40])) in player.position and player.y_change != -20) or
59-
player.position[-1][-1] + 40 > game.display_height, # danger 2 down
60-
61-
# (player.position[-1][0] - 20 in self.get_position_x_y(player)[0] and player.x_change!=20) or player.position[-1][0] - 20 <= 0, # danger left
62-
# (player.position[-1][0] - 40 in self.get_position_x_y(player)[0] and player.x_change!=20) or player.position[-1][0] - 40 <= 0, # danger 2 left
63-
# (player.position[-1][0] + 20 in self.get_position_x_y(player)[0] and player.x_change != -20) or player.position[-1][0] + 20 >= game.display_width, # danger right
64-
# (player.position[-1][0] + 40 in self.get_position_x_y(player)[0] and player.x_change != -20) or player.position[-1][0] + 40 >= game.display_width, # danger 2 right
65-
# (player.position[-1][-1] - 20 in self.get_position_x_y(player)[1] and player.y_change != 20) or player.position[-1][-1] - 20 <= 0, # danger up
66-
# (player.position[-1][-1] - 40 in self.get_position_x_y(player)[1] and player.y_change != 20) or player.position[-1][-1] - 40 <= 0, # danger 2 up
67-
# (player.position[-1][-1] + 20 in self.get_position_x_y(player)[1] and player.y_change != -20) or player.position[-1][-1] + 20 >= game.display_height, # danger down
68-
# (player.position[-1][-1] + 40 in self.get_position_x_y(player)[1] and player.y_change != -20) or player.position[-1][-1] + 40 >= game.display_height,# danger 2 down
69-
#player.x_change == - 20 and (player.position[-1][0] - 20 < 0 or player.position[-1][0] - 20 in self.get_position_x_y(player)[0]),#danger straight
70-
player.x_change == -20, # move left
71-
player.x_change == 20, # move right
72-
player.y_change == -20, # move up
73-
player.y_change == 20, # move down
74-
food.x_food < player.x, # food left
75-
food.x_food > player.x, # food right
76-
food.y_food < player.y, # food up
77-
food.y_food > player.y # food down
44+
(player.x_change == 20 and player.y_change == 0 and ((list(map(add, player.position[-1], [20, 0])) in player.position) or
45+
player.position[-1][0] + 20 >= game.display_width)) or (player.x_change == -20 and player.y_change == 0 and ((list(map(add, player.position[-1], [-20, 0])) in player.position) or
46+
player.position[-1][0] - 20 < 0)) or (player.x_change == 0 and player.y_change == -20 and ((list(map(add, player.position[-1], [0, -20])) in player.position) or
47+
player.position[-1][-1] - 20 < 0)) or (player.x_change == 0 and player.y_change == 20 and ((list(map(add, player.position[-1], [0, 20])) in player.position) or
48+
player.position[-1][-1] + 20 >= game.display_height)), # danger straight
49+
50+
(player.x_change == 0 and player.y_change == -20 and ((list(map(add,player.position[-1],[20, 0])) in player.position) or
51+
player.position[ -1][0] + 20 > game.display_width)) or (player.x_change == 0 and player.y_change == 20 and ((list(map(add,player.position[-1],
52+
[-20,0])) in player.position) or player.position[-1][0] - 20 < 0)) or (player.x_change == -20 and player.y_change == 0 and ((list(map(
53+
add,player.position[-1],[0,-20])) in player.position) or player.position[-1][-1] - 20 < 0)) or (player.x_change == 20 and player.y_change == 0 and (
54+
(list(map(add,player.position[-1],[0,20])) in player.position) or player.position[-1][
55+
-1] + 20 >= game.display_height)), # danger right
56+
57+
(player.x_change == 0 and player.y_change == 20 and ((list(map(add,player.position[-1],[20,0])) in player.position) or
58+
player.position[-1][0] + 20 > game.display_width)) or (player.x_change == 0 and player.y_change == -20 and ((list(map(
59+
add, player.position[-1],[-20,0])) in player.position) or player.position[-1][0] - 20 < 0)) or (player.x_change == 20 and player.y_change == 0 and (
60+
(list(map(add,player.position[-1],[0,-20])) in player.position) or player.position[-1][-1] - 20 < 0)) or (
61+
player.x_change == -20 and player.y_change == 0 and ((list(map(add,player.position[-1],[0,20])) in player.position) or
62+
player.position[-1][-1] + 20 >= game.display_height)), #danger left
63+
64+
65+
player.x_change == -20, # move left
66+
player.x_change == 20, # move right
67+
player.y_change == -20, # move up
68+
player.y_change == 20, # move down
69+
food.x_food < player.x, # food left
70+
food.x_food > player.x, # food right
71+
food.y_food < player.y, # food up
72+
food.y_food > player.y # food down
7873
]
7974

8075
for i in range(len(state)):
8176
if state[i]:
8277
state[i]=1
8378
else:
8479
state[i]=0
85-
# if state[0] == 1:
86-
# print('DANGER LEFT')
87-
# if state[1] == 1:
88-
# print('DANGER 2 LEFT')
89-
# if state[2] == 1:
90-
# print('DANGER RIGHT')
91-
# if state[3] == 1:
92-
# print('DANGER 2 RIGHT')
93-
# if state[4] == 1:
94-
# print('DANGER UP')
95-
# if state[5] == 1:
96-
# print('DANGER 2 UP')
97-
# if state[6] == 1:
98-
# print('DANGER DOWN')
99-
# if state[7] == True:
100-
# print('DANGER 2 DOWN')
10180

10281
return np.asarray(state)
10382

104-
def get_position_x_y(self, player):
105-
position_x = []
106-
position_y = []
107-
for i in player.position:
108-
position_x.append(i[0])
109-
position_y.append(i[1])
110-
return position_x, position_y
111-
112-
def set_reward(self, game, player, food, crash):
83+
def set_reward(self, player, food, crash):
84+
self.reward = 0
11385
if crash:
11486
self.reward = -10
11587
return self.reward
11688
if player.eaten:
11789
self.reward = 10
118-
elif (player.x_change < 0 and food.x_food < player.x) or (player.x_change > 0 and food.x_food > player.x) or (player.y_change < 0 and food.y_food < player.y) or (player.y_change > 0 and food.y_food > player.y):
119-
self.reward = 2
120-
else:
121-
self.reward = -1
90+
# elif (player.x_change < 0 and food.x_food < player.x) or (player.x_change > 0 and food.x_food > player.x) or (player.y_change < 0 and food.y_food < player.y) or (player.y_change > 0 and food.y_food > player.y):
91+
# self.reward = 1
92+
# else:
93+
# self.reward = -1
12294
return self.reward
12395

124-
def possible_moves(self, player):
125-
if player.x_change == -20:
126-
return [0,2,3]
127-
elif player.x_change == 20:
128-
return [1,2,3]
129-
elif player.y_change == -20:
130-
return [0,1,2]
131-
elif player.y_change == 20:
132-
return [0,1,3]
133-
134-
def replay(self, game, player, food, actual):
135-
player.position = copy.deepcopy(actual[0])
136-
player.x, player.y, player.x_change, player.y_change, food.x_food, food.y_food, game.crash, player.eaten, player.food = actual[1:]
137-
138-
'''
139-
def next_state(self, game, player, food, i):
140-
actual = [player.position, player.x, player.y, player.x_change, player.y_change, food.x_food, food.y_food, game.crash, player.eaten]
141-
original_state = self.get_state(game, player, food)
142-
player.do_move(i, player.x, player.y, game, food)
143-
player.display_player(player.x, player.y,player.food,game,player)
144-
array = [original_state, i, self.set_reward(game, player), self.get_state(game, player, food)]
145-
pygame.time.wait(500)
146-
self.replay(game, player, food, actual)
147-
player.display_player(player.x, player.y, player.food, game, player)
148-
return array
149-
'''
150-
151-
def next_state(self, game, player, food, i):
152-
actual = [player.position, player.x, player.y, player.x_change, player.y_change, food.x_food, food.y_food, game.crash, player.eaten, player.food]
153-
original_state = self.get_state(game, player, food)
154-
player.do_move(i, player.x, player.y, game, food)
155-
player.display_player(player.x, player.y,player.food,game,player)
156-
array = [original_state, i, self.set_reward(game, player), self.get_state(game, player, food)]
157-
pygame.time.wait(500)
158-
self.replay(game, player, food, actual)
159-
player.display_player(player.x, player.y, player.food, game, player)
160-
return array
161-
162-
def loss(self, target, state, action):
163-
return K.mean(K.square(target - self.predict_q(self.model, state, action)), axis=-1)
164-
165-
def network(self,weights=None):
96+
def network(self, weights=None):
16697
model = Sequential()
167-
model.add(Dense(output_dim=30, activation='relu', input_dim=16))
168-
model.add(Dense(output_dim=30, activation='relu'))
98+
model.add(Dense(output_dim=120, activation='relu', input_dim=11))
99+
model.add(Dropout(0.25))
100+
model.add(Dense(output_dim=120, activation='relu'))
101+
model.add(Dropout(0.25))
102+
model.add(Dense(output_dim=120, activation='relu'))
103+
model.add(Dropout(0.25))
104+
105+
#model.add(Dense(output_dim=300, activation='relu'))
169106
model.add(Dense(output_dim=3, activation='softmax'))
170107
opt = Adam(self.learning_rate)
171108
model.compile(loss='mse', optimizer=opt)
172109

173110
if weights:
174111
model.load_weights(weights)
175-
# [self.loss(self.agent_target, self.agent_predict)]
176112
return model
177113

178-
def act(self, state):
179-
if random.random(0, 1) < self.epsilon:
180-
return random.randint(0, 4)
181-
else:
182-
return np.argmax(self.brain.predictOne(state))
183-
184-
def observe(self, sequence): # in (s, a, r, s_) format
185-
self.memory.add(sequence)
186-
187-
def q_parameter(self):
188-
q = self.reward + self.gamma * np.argmax(self.fit_q())
189-
190-
def predict_q(self, model, state, action):
191-
predictor = np.array([np.hstack(np.array([state, action]))])
192-
q = model.predict(predictor)
193-
return q
194-
195-
def train_q(self, storage, state, action):
196-
train = np.array([storage[:17]])
197-
test = np.array([storage[17]])
198-
self.model.compile(loss='mse', optimizer=RMSprop(lr=0.025))
199-
self.model.fit(train, test, epochs=1)
200-
201-
def train2_q(self,training, test):
202-
training = training.values
203-
test = test.values
204-
self.model.fit(training, test, epochs=1)
205-
206-
207-
def initialize_dataframe(self):
208-
state = [0]*12
209-
for i in range(12):
210-
state[i]= random.choice([0, 1])
211-
move = random.randint(1,4)
212-
reward = random.choice([-1, -10, 10])
213-
future_state = [0]*12
214-
for i in range(12):
215-
future_state[i] = random.choice([True, False])
216-
Q = 1
217-
array = [state, move, reward, future_state, Q]
218-
self.dataframe = self.dataframe.append([array])
219-
220-
def store_memory(self, state, action, q):
221-
self.short_memory = np.hstack(np.array([state, action, q]))
222-
#print(self.short_memory)
223-
224114
def remember(self, state, action, reward, next_state, done):
225115
self.memory.append((state, action, reward, next_state, done))
226116

227117
def replay_new(self, memory):
228-
if len(memory)>1500:
118+
if len(memory) > 1000:
229119
minibatch = random.sample(memory, 1000)
230120
else:
231121
minibatch = memory
232122
for state, action, reward, next_state, done in minibatch:
233123
target = reward
234124
if not done:
235-
target = reward + self.gamma * np.amax(self.model.predict(next_state.reshape((1, 16)))[0])
236-
#print('TARGET', target)
237-
target_f = self.model.predict(state.reshape((1, 16)))
238-
#print('TARGET_F', target_f)
239-
target_f[0][np.argmax(action)] = target
240-
#print('TARGET_F_AFTER', target_f)
241-
self.model.fit(state.reshape((1,16)), target_f, epochs=1, verbose=0)
242-
243-
125+
target = reward + self.gamma * np.amax(self.model.predict(np.array([next_state]))[0])
126+
# print('TARGET', target)
127+
target_f = self.model.predict(np.array([state]))
128+
# print('TARGET_1', target_f[0])
244129

130+
target_f[0][np.argmax(action)] = target
131+
# print('TARGET_2', target_f[0])
245132

133+
self.model.fit(np.array([state]), target_f, epochs=1, verbose=0)
246134

135+
def train_short_memory(self, state, action, reward, next_state, done):
136+
target = reward
137+
if not done:
138+
target = reward + self.gamma * np.amax(self.model.predict(next_state.reshape((1, 11)))[0])
139+
target_f = self.model.predict(state.reshape((1, 11)))
140+
target_f[0][np.argmax(action)] = target
141+
self.model.fit(state.reshape((1, 11)), target_f, epochs=1, verbose=0)

‎__pycache__/DQN.cpython-36.pyc

-2.43 KB
Binary file not shown.

‎food2.png

286 Bytes
Loading

‎snakeClass.py

+32-31
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn import linear_model
1212

1313
display_option = False
14-
speed = 0
14+
speed = 50
1515
class Game:
1616

1717
def __init__(self, display_width, display_height):
@@ -22,7 +22,6 @@ def __init__(self, display_width, display_height):
2222
self.crash = False
2323
self.player = Player(self)
2424
self.food = Food(self, self.player)
25-
self.speed = 50
2625
self.score = 0
2726

2827

@@ -72,7 +71,6 @@ def do_move(self, move, x, y, game, food,agent):
7271
self.x = x + self.x_change
7372
self.y = y + self.y_change
7473

75-
#print(self.x_change, self.y_change, self.x, self.y, self.position)
7674
if self.x < 0 or self.x == game.display_width or self.y < 0 or self.y == game.display_height or [self.x, self.y] in self.position:
7775
game.crash = True
7876
eat(self, food, game)
@@ -129,48 +127,53 @@ def display(player, food, game):
129127
def update_screen():
130128
pygame.display.update()
131129

132-
def initial_move(player, game, food,agent):
133-
player.do_move(1, player.x, player.y, game, food,agent)
130+
131+
def initialize_game(player, game, food, agent):
132+
state_init1 = agent.get_state(game, player, food) # [0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0]
133+
action = [1, 0, 0]
134+
player.do_move(action, player.x, player.y, game, food, agent)
135+
state_init2 = agent.get_state(game, player, food)
136+
reward1 = agent.set_reward(player, food, game.crash)
137+
agent.remember(state_init1, action, reward1, state_init2, game.crash)
138+
agent.replay_new(agent.memory)
139+
140+
141+
def plot_score(array_counter, array_score):
142+
fit = np.polyfit(array_counter, array_score, 1)
143+
fit_fn = np.poly1d(fit)
144+
plt.plot(array_counter, array_score, 'yo', array_counter, fit_fn(array_counter), '--k')
145+
plt.show()
146+
134147

135148
def run():
136149
pygame.init()
137150
agent = DQNAgent()
138151
counter_games = 0
139152
score_plot = []
140153
counter_plot =[]
141-
while counter_games < 100:
142-
#Initialize game
154+
while counter_games < 200:
155+
#Initialize classes
143156
game = Game(400, 400)
144157
player1 = game.player
145158
food1 = game.food
146-
#Initialize storage to train first network
147-
state_init1 = agent.get_state(game,player1,food1) #[0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0]
148-
action = [1, 0, 0]
149-
player1.do_move(action, player1.x, player1.y, game, food1, agent)
150-
state_init2 = agent.get_state(game, player1, food1)
151-
reward1 = agent.set_reward(game,player1,food1,game.crash)
152-
agent.remember(state_init1,action, reward1, state_init2, game.crash)
153-
agent.replay_new(agent.memory)
154-
#Performn first move
159+
160+
#Perform first move
161+
initialize_game(player1, game, food1, agent)
162+
155163
if display_option:
156164
display(player1, food1, game)
157165
while not game.crash:
158-
#player1.get_position_x()
159-
if counter_games < 15:
160-
agent.epsilon = 3
161-
elif counter_games < 30:
162-
agent.epsilon = 2
163-
elif counter_games >= 30:
164-
agent.epsilon = 0
166+
agent.epsilon = 0
165167
state_old = agent.get_state(game, player1, food1)
166-
if randint(0, 10) < agent.epsilon:
168+
if randint(0, 200) < agent.epsilon:
167169
final_move = to_categorical(randint(0, 2), num_classes=3)[0]
168170
else:
169-
prediction = agent.model.predict(state_old.reshape((1,16)))
171+
prediction = agent.model.predict(state_old.reshape((1,11)))
170172
final_move = to_categorical(np.argmax(prediction[0]), num_classes=3)[0]
171173
player1.do_move(final_move, player1.x, player1.y, game, food1, agent)
172174
state_new = agent.get_state(game, player1, food1)
173-
reward = agent.set_reward(game, player1, food1, game.crash)
175+
reward = agent.set_reward(player1, food1, game.crash)
176+
agent.train_short_memory(state_old, final_move, reward, state_new, game.crash)
174177
agent.remember(state_old, final_move, reward, state_new, game.crash)
175178
if display_option:
176179
display(player1, food1, game)
@@ -182,11 +185,9 @@ def run():
182185
print('Game', counter_games, ' Score:', game.score)
183186
score_plot.append(game.score)
184187
counter_plot.append(counter_games)
185-
agent.model.save_weights('weights_new5_lr0001.hdf5')
188+
agent.model.save_weights('weights_new17_1.hdf5')
189+
190+
plot_score(counter_plot, score_plot)
186191

187-
fit = np.polyfit(counter_plot, score_plot, 1)
188-
fit_fn = np.poly1d(fit)
189-
plt.plot(counter_plot, score_plot, 'yo', counter_plot, fit_fn(counter_plot), '--k')
190-
plt.show()
191192

192193
run()

‎weights_new10.hdf5

81.2 KB
Binary file not shown.

‎weights_new11.hdf5

119 KB
Binary file not shown.

‎weights_new12.hdf5

44 KB
Binary file not shown.

‎weights_new13.hdf5

44 KB
Binary file not shown.

‎weights_new14.hdf5

71.7 KB
Binary file not shown.

‎weights_new14_1.hdf5

71.7 KB
Binary file not shown.

‎weights_new14_2.hdf5

71.7 KB
Binary file not shown.

‎weights_new15.hdf5

202 KB
Binary file not shown.

‎weights_new15_1.hdf5

202 KB
Binary file not shown.

‎weights_new16.hdf5

44 KB
Binary file not shown.

‎weights_new17.hdf5

140 KB
Binary file not shown.

‎weights_new17_1.hdf5

140 KB
Binary file not shown.

‎weights_new6.hdf5

31 KB
Binary file not shown.

‎weights_new7_random.hdf5

31 KB
Binary file not shown.

‎weights_new8.hdf5

31 KB
Binary file not shown.

‎weights_new9.hdf5

60.9 KB
Binary file not shown.

0 commit comments

Comments
 (0)
Please sign in to comment.