Skip to content

Commit 0625efe

Browse files
author
Mofan Zhou
committed
tf RNN example2
1 parent 2c6c18f commit 0625efe

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

RL/example1/hunter_prey.py

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import tkinter as tk
2+
import numpy as np
3+
import pandas as pd
4+
5+
6+
class TableLUQ(object):
7+
def __init__(self, actions, epsilon=0.1, alpha=0.2, gamma=0.9):
8+
self._actions = actions
9+
self._epsilon = epsilon
10+
self._alpha = alpha
11+
self._gamma = gamma
12+
13+
# state-actions table. columns include [state, a1, a2, a3...]
14+
self._q_table = pd.DataFrame(columns=['state']+self._actions)
15+
self._q_table.set_index('state', inplace=True)
16+
17+
def learn(self, s0, a0, r, s1):
18+
q_s0_a0 = self._q_table.loc[s0, a0]
19+
if s1 in self._q_table.index:
20+
q_s1_a = self._q_table.loc[s1, :]
21+
q_s1_a_max = q_s1_a.max()
22+
else:
23+
q_s1_a_max = 0
24+
self._q_table.loc[s0, a0] = q_s0_a0 + self._alpha * (
25+
r + self._gamma * q_s1_a_max - q_s0_a0
26+
)
27+
28+
def choose_action(self, state):
29+
if np.random.random() < self._epsilon:
30+
action = np.random.choice(self._actions)
31+
else:
32+
if state not in self._q_table.index:
33+
new_state_actions = pd.Series([0]*len(self._actions), index=self._actions, name=state)
34+
self._q_table = self._q_table.append(new_state_actions)
35+
action = np.random.choice(self._actions)
36+
else:
37+
state_actions_pair = self._q_table.loc[state, :]
38+
shuffled_state_actions_pair = state_actions_pair.reindex(
39+
np.random.permutation(
40+
state_actions_pair.index))
41+
action = shuffled_state_actions_pair.argmax()
42+
return action
43+
44+
@property
45+
def epsilon(self):
46+
return self._epsilon
47+
48+
@property
49+
def alpha(self):
50+
return self._alpha
51+
52+
@property
53+
def gamma(self):
54+
return self._gamma
55+
56+
57+
def get_state(h_loc, p_loc):
58+
state = str(
59+
(
60+
p_loc[0] - h_loc[0],
61+
p_loc[1] - h_loc[1]
62+
)
63+
)
64+
return state
65+
66+
67+
def get_reward(h_loc, p_loc):
68+
distance = ((p_loc[0] - h_loc[0])**2 + (p_loc[1] - h_loc[1])**2)**(1/2)
69+
reward = 1/distance
70+
return reward
71+
72+
73+
def move(h_loc, action):
74+
h_x, h_y = h_loc[0], h_loc[1]
75+
if action == 'u':
76+
h_y -= 1
77+
h_y = max([h_y, 0])
78+
elif action == 'd':
79+
h_y += 1
80+
h_y = min([h_y, 4])
81+
elif action == 'l':
82+
h_x -= 1
83+
h_x = max([h_x, 0])
84+
else:
85+
h_x += 1
86+
h_x = min([h_x, 4])
87+
x_amount = (h_x - h_loc[0])*100
88+
y_amount = (h_y - h_loc[1])*100
89+
canvas.move(hunter_icon, x_amount, y_amount)
90+
h_loc = [h_x, h_y]
91+
92+
window = tk.Tk()
93+
window.geometry('500x500')
94+
canvas = tk.Canvas(window, height=500, width=500)
95+
canvas.pack()
96+
97+
hunter_loc = [0, 0] # can move
98+
prey_loc = [4, 4] # fixed
99+
100+
hunter_icon = canvas.create_rectangle(
101+
(
102+
hunter_loc[0]*100,
103+
hunter_loc[1]*100,
104+
hunter_loc[0]*100+100,
105+
hunter_loc[1]*100+100),
106+
fill='black'
107+
)
108+
prey_icon = canvas.create_oval(
109+
(
110+
prey_loc[0]*100,
111+
prey_loc[1]*100,
112+
prey_loc[0]*100+100,
113+
prey_loc[1]*100+100),
114+
fill='red'
115+
)
116+
117+
move_up = [+1, 0]
118+
move_down = [-1, 0]
119+
move_left = [0, -1]
120+
move_right = [0, +1]
121+
122+
hunter_actions = {
123+
'u': move_up,
124+
'd': move_down,
125+
'l': move_left,
126+
'r': move_right,
127+
}
128+
129+
130+
window.mainloop()
224 KB
Loading

0 commit comments

Comments
 (0)