-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsarsa.py
40 lines (33 loc) · 1.23 KB
/
sarsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import random
class Sarsa:
def __init__(self, actions, epsilon, alpha, gamma):
self.q = {}
self.epsilon = epsilon
self.alpha = alpha
self.gamma = gamma
self.actions = actions
def getQ(self, state, action):
return self.q.get((state, action), 0.0)
def learnQ(self, state, action, reward, value):
oldv = self.q.get((state, action), None)
if oldv is None:
self.q[(state, action)] = reward
else:
self.q[(state, action)] = oldv + self.alpha * (value - oldv)
def chooseAction(self, state):
if random.random() < self.epsilon:
action = random.choice(self.actions)
else:
q = [self.getQ(state, a) for a in self.actions]
maxQ = max(q)
count = q.count(maxQ)
if count > 1:
best = [i for i in range(len(self.actions)) if q[i] == maxQ]
i = random.choice(best)
else:
i = q.index(maxQ)
action = self.actions[i]
return action
def learn(self, state1, action1, reward, state2, action2):
qnext = self.getQ(state2, action2)
self.learnQ(state1, action1, reward, reward + self.gamma * qnext)