@@ -16,60 +16,6 @@ def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=
16
16
self .gamma = reward_decay
17
17
self .epsilon = e_greedy
18
18
19
- def choose_action (self , observation ):
20
- pass
21
-
22
- def learn (self , * args ):
23
- pass
24
-
25
-
26
- # off-policy
27
- class QTable (RL ):
28
- def __init__ (self , actions , learning_rate = 0.01 , reward_decay = 0.9 , e_greedy = 0.9 ):
29
- super (QTable , self ).__init__ (actions , learning_rate , reward_decay , e_greedy )
30
-
31
- self .q_table = pd .DataFrame (columns = self .actions )
32
-
33
- def check_state_exist (self , state ):
34
- if state not in self .q_table .index :
35
- # append new state to q table
36
- self .q_table = self .q_table .append (
37
- pd .Series (
38
- [0 ]* len (self .actions ),
39
- index = self .q_table .columns ,
40
- name = state ,
41
- )
42
- )
43
-
44
- def choose_action (self , observation ):
45
- self .check_state_exist (observation )
46
- # action selection
47
- if np .random .uniform () < self .epsilon :
48
- # choose best action
49
- state_action = self .q_table .ix [observation , :]
50
- state_action = state_action .reindex (np .random .permutation (state_action .index )) # some actions have same value
51
- action = state_action .argmax ()
52
- else :
53
- # choose random action
54
- action = np .random .choice (self .actions )
55
- return action
56
-
57
- def learn (self , s , a , r , s_ ):
58
- self .check_state_exist (s_ )
59
- q_predict = self .q_table .ix [s , a ]
60
- if s_ != 'terminal' :
61
- q_target = r + self .gamma * self .q_table .ix [s_ , :].max () # next state is not terminal
62
- else :
63
- q_target = r # next state is terminal
64
- self .q_table .ix [s , a ] += self .lr * (q_target - q_predict ) # update
65
-
66
-
67
- # on-policy
68
- class SarsaTable (RL ):
69
-
70
- def __init__ (self , actions , learning_rate = 0.01 , reward_decay = 0.9 , e_greedy = 0.9 ):
71
- super (SarsaTable , self ).__init__ (actions , learning_rate , reward_decay , e_greedy )
72
-
73
19
self .q_table = pd .DataFrame (columns = self .actions )
74
20
75
21
def check_state_exist (self , state ):
@@ -96,26 +42,18 @@ def choose_action(self, observation):
96
42
action = np .random .choice (self .actions )
97
43
return action
98
44
99
- def learn (self , s , a , r , s_ , a_ ):
100
- self .check_state_exist (s_ )
101
- q_predict = self .q_table .ix [s , a ]
102
- if s_ != 'terminal' :
103
- q_target = r + self .gamma * self .q_table .ix [s_ , a_ ] # next state is not terminal
104
- else :
105
- q_target = r # next state is terminal
106
- self .q_table .ix [s , a ] += self .lr * (q_target - q_predict ) # update
45
+ def learn (self , * args ):
46
+ pass
107
47
108
48
109
49
# backward eligibility traces
110
- class SarsaLambdaTable (SarsaTable ):
50
+ class SarsaLambdaTable (RL ):
111
51
def __init__ (self , actions , learning_rate = 0.01 , reward_decay = 0.9 , e_greedy = 0.9 , trace_decay = 0.9 ):
112
52
super (SarsaLambdaTable , self ).__init__ (actions , learning_rate , reward_decay , e_greedy )
113
53
114
54
# backward view, eligibility trace.
115
55
self .lambda_ = trace_decay
116
-
117
- def initialize_trace (self ):
118
- self .eligibility_trace = self .q_table * 0
56
+ self .eligibility_trace = self .q_table .copy ()
119
57
120
58
def check_state_exist (self , state ):
121
59
if state not in self .q_table .index :
@@ -140,6 +78,11 @@ def learn(self, s, a, r, s_, a_):
140
78
error = q_target - q_predict
141
79
142
80
# increase trace amount for visited state-action pair
81
+
82
+ # Method 1:
83
+ # self.eligibility_trace.ix[s, a] += 1
84
+
85
+ # Method 2:
143
86
self .eligibility_trace .ix [s , :] *= 0
144
87
self .eligibility_trace .ix [s , a ] = 1
145
88
0 commit comments