Skip to content

Commit 447885e

Browse files
committed
update code
1 parent 731cce4 commit 447885e

File tree

15 files changed

+544
-393
lines changed

15 files changed

+544
-393
lines changed

Reinforcement_learning_TUT/2_Q_Learning_maze/RL_brain.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pandas as pd
1010

1111

12-
class QTable:
12+
class QLearningTable:
1313
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
1414
self.actions = actions # a list
1515
self.lr = learning_rate

Reinforcement_learning_TUT/2_Q_Learning_maze/run_this.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""
1414

1515
from maze_env import Maze
16-
from RL_brain import QTable
16+
from RL_brain import QLearningTable
1717

1818

1919
def update():
@@ -47,7 +47,7 @@ def update():
4747

4848
if __name__ == "__main__":
4949
env = Maze()
50-
RL = QTable(actions=list(range(env.n_actions)))
50+
RL = QLearningTable(actions=list(range(env.n_actions)))
5151

5252
env.after(100, update)
5353
env.mainloop()

Reinforcement_learning_TUT/3_Sarsa_maze/RL_brain.py

+12-40
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,14 @@
88
import numpy as np
99
import pandas as pd
1010

11+
1112
class RL(object):
1213
def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
1314
self.actions = action_space # a list
1415
self.lr = learning_rate
1516
self.gamma = reward_decay
1617
self.epsilon = e_greedy
1718

18-
def choose_action(self, observation):
19-
pass
20-
21-
def learn(self, *args):
22-
pass
23-
24-
25-
# can be learned offline
26-
class QTable(RL):
27-
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
28-
super(QTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
29-
3019
self.q_table = pd.DataFrame(columns=self.actions)
3120

3221
def check_state_exist(self, state):
@@ -43,7 +32,7 @@ def check_state_exist(self, state):
4332
def choose_action(self, observation):
4433
self.check_state_exist(observation)
4534
# action selection
46-
if np.random.uniform() < self.epsilon:
35+
if np.random.rand() < self.epsilon:
4736
# choose best action
4837
state_action = self.q_table.ix[observation, :]
4938
state_action = state_action.reindex(np.random.permutation(state_action.index)) # some actions have same value
@@ -53,6 +42,15 @@ def choose_action(self, observation):
5342
action = np.random.choice(self.actions)
5443
return action
5544

45+
def learn(self, *args):
46+
pass
47+
48+
49+
# off-policy
50+
class QLearningTable(RL):
51+
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
52+
super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
53+
5654
def learn(self, s, a, r, s_):
5755
self.check_state_exist(s_)
5856
q_predict = self.q_table.ix[s, a]
@@ -63,38 +61,12 @@ def learn(self, s, a, r, s_):
6361
self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # update
6462

6563

66-
# online learning
64+
# on-policy
6765
class SarsaTable(RL):
6866

6967
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
7068
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
7169

72-
self.q_table = pd.DataFrame(columns=self.actions)
73-
74-
def check_state_exist(self, state):
75-
if state not in self.q_table.index:
76-
# append new state to q table
77-
self.q_table = self.q_table.append(
78-
pd.Series(
79-
[0]*len(self.actions),
80-
index=self.q_table.columns,
81-
name=state,
82-
)
83-
)
84-
85-
def choose_action(self, observation):
86-
self.check_state_exist(observation)
87-
# action selection
88-
if np.random.rand() < self.epsilon:
89-
# choose best action
90-
state_action = self.q_table.ix[observation, :]
91-
state_action = state_action.reindex(np.random.permutation(state_action.index)) # some actions have same value
92-
action = state_action.argmax()
93-
else:
94-
# choose random action
95-
action = np.random.choice(self.actions)
96-
return action
97-
9870
def learn(self, s, a, r, s_, a_):
9971
self.check_state_exist(s_)
10072
q_predict = self.q_table.ix[s, a]

Reinforcement_learning_TUT/4_Sarsa_lambda_maze/RL_brain.py

+9-66
Original file line numberDiff line numberDiff line change
@@ -16,60 +16,6 @@ def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=
1616
self.gamma = reward_decay
1717
self.epsilon = e_greedy
1818

19-
def choose_action(self, observation):
20-
pass
21-
22-
def learn(self, *args):
23-
pass
24-
25-
26-
# off-policy
27-
class QTable(RL):
28-
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
29-
super(QTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
30-
31-
self.q_table = pd.DataFrame(columns=self.actions)
32-
33-
def check_state_exist(self, state):
34-
if state not in self.q_table.index:
35-
# append new state to q table
36-
self.q_table = self.q_table.append(
37-
pd.Series(
38-
[0]*len(self.actions),
39-
index=self.q_table.columns,
40-
name=state,
41-
)
42-
)
43-
44-
def choose_action(self, observation):
45-
self.check_state_exist(observation)
46-
# action selection
47-
if np.random.uniform() < self.epsilon:
48-
# choose best action
49-
state_action = self.q_table.ix[observation, :]
50-
state_action = state_action.reindex(np.random.permutation(state_action.index)) # some actions have same value
51-
action = state_action.argmax()
52-
else:
53-
# choose random action
54-
action = np.random.choice(self.actions)
55-
return action
56-
57-
def learn(self, s, a, r, s_):
58-
self.check_state_exist(s_)
59-
q_predict = self.q_table.ix[s, a]
60-
if s_ != 'terminal':
61-
q_target = r + self.gamma * self.q_table.ix[s_, :].max() # next state is not terminal
62-
else:
63-
q_target = r # next state is terminal
64-
self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # update
65-
66-
67-
# on-policy
68-
class SarsaTable(RL):
69-
70-
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
71-
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
72-
7319
self.q_table = pd.DataFrame(columns=self.actions)
7420

7521
def check_state_exist(self, state):
@@ -96,26 +42,18 @@ def choose_action(self, observation):
9642
action = np.random.choice(self.actions)
9743
return action
9844

99-
def learn(self, s, a, r, s_, a_):
100-
self.check_state_exist(s_)
101-
q_predict = self.q_table.ix[s, a]
102-
if s_ != 'terminal':
103-
q_target = r + self.gamma * self.q_table.ix[s_, a_] # next state is not terminal
104-
else:
105-
q_target = r # next state is terminal
106-
self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # update
45+
def learn(self, *args):
46+
pass
10747

10848

10949
# backward eligibility traces
110-
class SarsaLambdaTable(SarsaTable):
50+
class SarsaLambdaTable(RL):
11151
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
11252
super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
11353

11454
# backward view, eligibility trace.
11555
self.lambda_ = trace_decay
116-
117-
def initialize_trace(self):
118-
self.eligibility_trace = self.q_table * 0
56+
self.eligibility_trace = self.q_table.copy()
11957

12058
def check_state_exist(self, state):
12159
if state not in self.q_table.index:
@@ -140,6 +78,11 @@ def learn(self, s, a, r, s_, a_):
14078
error = q_target - q_predict
14179

14280
# increase trace amount for visited state-action pair
81+
82+
# Method 1:
83+
# self.eligibility_trace.ix[s, a] += 1
84+
85+
# Method 2:
14386
self.eligibility_trace.ix[s, :] *= 0
14487
self.eligibility_trace.ix[s, a] = 1
14588

Reinforcement_learning_TUT/4_Sarsa_lambda_maze/run_this.py

-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ def update():
1616
# initial observation
1717
observation = env.reset()
1818

19-
# initialize eligibility trace
20-
RL.initialize_trace()
21-
2219
# RL choose action based on observation
2320
action = RL.choose_action(str(observation))
2421

0 commit comments

Comments
 (0)