Skip to content

Commit

Permalink
dumb down random agent and short circuit ADP training
Browse files Browse the repository at this point in the history
  • Loading branch information
wboag committed Dec 8, 2014
1 parent 994623a commit a576cf8
Showing 1 changed file with 37 additions and 24 deletions.
61 changes: 37 additions & 24 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,20 @@ def setIndex(self, value):

class randomAgent(agent):

def __init__(self):
super(randomAgent, self).__init__()
self.type = "random"

def chooseAction(self, actions):
filteredActions = filter(lambda n: n == 'east' or n == 'north' or n == 'finish', actions)
return random.choice(filteredActions)
def __init__(self):
super(randomAgent, self).__init__()
self.type = "random"

def chooseAction(self, actions):
if flipCoin(0.08):
filteredActions = filter(lambda n: n == 'south' or n == 'west' or n == 'finish', actions)
if filteredActions == []: filteredActions = actions
else:
filteredActions = filter(lambda n: n == 'north' or n == 'east' or n == 'finish', actions)
return random.choice(filteredActions)

def update(self):
pass
def update(self):
pass


class adpAgent(agent):
Expand All @@ -127,6 +131,7 @@ def __init__(self, gameworld, all_qstate_results):
#exit()

# Keep track of number of completed episodes
self.converged = False
self.completed = 0
self.nextUpdate = 1

Expand All @@ -141,11 +146,20 @@ def update(self, state, terrain, action, nextState, reward):
#print 'nextState: ', nextState
#print 'reward: ', reward

#return
# If already converged, then skip update
if self.converged:
return

# update empirical MDP
self.empirical_mdp.update(state, action, nextState, reward, terrain)

# If converged AFTER most recent update, then solve MDP for final time
if self.empirical_mdp.converged():
#print str(self.completed) + ': final solving'
self.solver = self.solver = PolicyIterationAgent(self.empirical_mdp, iterations=100)
self.converged = True
return

# If finished epsiode
if action == 'finish':
#print 'finished\n\n\n'
Expand Down Expand Up @@ -185,17 +199,16 @@ def chooseAction(self, state):

class tdAgent(agent):

def __init__(self, goalPosition):
super(tdAgent, self).__init__()
self.type = "td"
self.goalPosition = goalPosition
###Your Code Here :)###

def update(self, oldState, terrainType, action, newState, reward):
###Your Code Here :)###
pass

def chooseAction(self, actions, state, terrainType):
###Your Code Here :)###
filteredActions = filter(lambda n: n == 'east' or n == 'north' or n == 'finish', actions)
return random.choice(filteredActions)
def __init__(self, goalPosition):
super(tdAgent, self).__init__()
self.type = "td"
self.goalPosition = goalPosition
###Your Code Here :)###

def update(self, oldState, terrainType, action, newState, reward):
###Your Code Here :)###
pass

def chooseAction(self, actions, state, terrainType):
filteredActions = filter(lambda n: n == 'north' or n == 'east' or n == 'finish', actions)
return random.choice(filteredActions)

0 comments on commit a576cf8

Please sign in to comment.