Skip to content

Commit

Permalink
SA_PGPolicy
Browse files Browse the repository at this point in the history
  • Loading branch information
camelop committed Jun 25, 2019
1 parent 393809b commit 403e09c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
50 changes: 50 additions & 0 deletions doudizhu/apps/game/policy/SA_PGPolicy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import sys
from .PGPolicy import LearningPolicy, PGPolicy, rule

'''
Use summary act instead of actual act......
(more) see SA_DQNPolicy.py
'''

class SA_PGPolicy(PGPolicy):

def call_score(self, state, default_action=None, learning=True):
'''
call a score from 1-3, if learning is true, use epsilon greedy
'''
vec = LearningPolicy._get_state_vec_sv2(state)
mask = self._get_call_score_mask_sa(state)
self.memory.set_state((vec, mask))
# epsilon-greedy
# generate action_idx and action
action_idx = self.model.choose(vec, mask, self._random_rand())
action = LearningPolicy.CALL_SCORE_ACTIONS[action_idx]
# record action
self.memory.set_action((action_idx, action))
# reward: turn punish - call is not involved so 0
self.memory.set_reward(0)
return action

def shot_poker(self, state, default_action=None, learning=True):
'''
shot a valid poker set, if learning is true, use epsilon greedy
'''
hand_pokers = state['hand_pokers']

vec = LearningPolicy._get_state_vec_sv2(state)
mask = self._get_shot_poker_mask_sa(state)
self.memory.set_state((vec, mask))
# generate action_idx and action
action_idx = self.model.choose(vec, mask, self._random_rand())
action = self._sa_idx_to_pokers(action_idx, state)
# record action
self.memory.set_action((action_idx, action))
# reward: turn punish
self.memory.set_reward(self.turn_reward)
return action

def __str__(self):
if self.comment is None:
self.comment = 'NOCOMMENT'
return "SA_PGPolicy-sv2[{}]-seed({}).{}.v{}".format(str(self.model), self.seed, self.comment, self.generation)
8 changes: 7 additions & 1 deletion doudizhu/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@

from apps.game.policy.DRL.REINFORCE_MLP import REINFORCE_MLP
from apps.game.policy.PGPolicy import PGPolicy
from apps.game.policy.SA_PGPolicy import SA_PGPolicy


player = "Kate"
player = "Ken"
env = 'env1'
epoch_num = 5000
display = False
save = True
'''
# test setting
epoch_num = 1
display = True
save = False
'''
Expand Down Expand Up @@ -106,6 +108,10 @@
a1_model = DQNMLP(action_dim=DQNPolicy.S_ACTION_DIM, hidden_dims=(2048, )*8, learning_rate=1e-3)
a1_policy = SA_DQNPolicy(a1_model, seed=0, comment="test_sa_"+env, e_greedy=(3e-1, -1e-4), save_every=1000)
a1 = Agent(player, a1_policy)
elif player == "Ken":
a1_model = REINFORCE_MLP(action_dim=DQNPolicy.S_ACTION_DIM, hidden_dims=(2048, )*8, learning_rate=1e-3)
a1_policy = SA_PGPolicy(a1_model, seed=0, comment="test_sa_"+env, save_every=1000)
a1 = Agent(player, a1_policy)
else:
raise NotImplementedError

Expand Down

0 comments on commit 403e09c

Please sign in to comment.