Skip to content

Commit

Permalink
small updates about simulation
Browse files Browse the repository at this point in the history
  • Loading branch information
camelop committed Jun 7, 2019
1 parent 7e1a1f2 commit 1002a4b
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 26 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ coverage.xml
*,cover

# Database & logs & models
doudizhu/*.db
*.db
doudizhu/*.log
doudizhu/models/*
doudizhu/models/*
10 changes: 5 additions & 5 deletions doudizhu/apps/game/policy/DQNPolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def call_score(self, state, default_action=None, learning=True):
'''
call a score from 1-3, if learning is true, use epsilon greedy
'''
vec = LearningPolicy._get_state_vec_sv1(state)
vec = LearningPolicy._get_state_vec_sv2(state)
mask = self._get_call_score_mask(state)
self.memory.set_state((vec, mask))
# epsilon-greedy
Expand All @@ -50,7 +50,7 @@ def shot_poker(self, state, default_action=None, learning=True):
'''
hand_pokers = state['hand_pokers']

vec = LearningPolicy._get_state_vec_sv1(state)
vec = LearningPolicy._get_state_vec_sv2(state)
mask = self._get_shot_poker_mask(state)
self.memory.set_state((vec, mask))
# epsilon-greedy
Expand Down Expand Up @@ -126,9 +126,9 @@ def reset(self):

def __str__(self):
if self.comment is not None:
return "DQNPolicy-sv1[{}]({}).{}.v{}".format(str(self.model), self.seed, self.comment, self.generation)
return "DQNPolicy-sv2[{}]({}).{}.v{}".format(str(self.model), self.seed, self.comment, self.generation)
else:
return "DQNPolicy-sv1[{}]({}).v{}".format(str(self.model), self.seed, self.generation)
return "DQNPolicy-sv2[{}]({}).v{}".format(str(self.model), self.seed, self.generation)

class Memory(object):

Expand Down Expand Up @@ -163,4 +163,4 @@ def generate_sars(self):
while i + 3 < len(self.memory):
ret.append((self.memory[i], self.memory[i+1], self.memory[i+2], self.memory[i+3], i+6 >= len(self.memory)))
i += 3
return ret
return ret
17 changes: 10 additions & 7 deletions doudizhu/apps/game/policy/basePolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ def shot_poker(self, state, default_action=None):

def _legal_call_score(self, state):
max_call_score = state['max_call_score']
base_score = max_call_score + 1
ret = []
if base_score <= 3:
for i in range(base_score, 4):
ret.append(i)
return ret
base_score = max(2, max_call_score + 1)
if max_call_score == 0:
return [1, 2, 3]
elif max_call_score == 1:
return [1, 2, 3]
elif max_call_score == 2:
return [1, 3]
else:
raise NotImplementedError

def _legal_shot_poker(self, state):
'''
Expand Down Expand Up @@ -67,4 +70,4 @@ def reset(self):
pass

def __str__(self):
return self.__class__.__name__
return self.__class__.__name__
3 changes: 2 additions & 1 deletion doudizhu/apps/game/policy/chooseMinPolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, seed=None):

def call_score(self, state, default_action=None):
random.setstate(self.state)
random.seed(self.seed+hash(tuple(sorted(state["hand_pokers"]))))
ret = random.sample(self._legal_call_score(state), 1)[0]
self.state = random.getstate()
return ret
Expand All @@ -49,4 +50,4 @@ def reset(self):
self.state = random.getstate()

def __str__(self):
return "ChooseMinPolicy({})".format(self.seed)
return "ChooseMinPolicy({})".format(self.seed)
3 changes: 2 additions & 1 deletion doudizhu/apps/game/policy/chooseMinWithRolePolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, seed=None):

def call_score(self, state, default_action=None):
random.setstate(self.state)
random.seed(self.seed+hash(tuple(sorted(state["hand_pokers"]))))
ret = random.sample(self._legal_call_score(state), 1)[0]
self.state = random.getstate()
return ret
Expand All @@ -72,4 +73,4 @@ def reset(self):
self.state = random.getstate()

def __str__(self):
return "ChooseMinWithRolePolicy({})".format(self.seed)
return "ChooseMinWithRolePolicy({})".format(self.seed)
3 changes: 2 additions & 1 deletion doudizhu/apps/game/policy/learningPolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _pokers_to_cnt_vec(pokers):

def _random_call_score(self, state):
random.setstate(self.state)
random.seed(self.seed+hash(tuple(sorted(state["hand_pokers"]))))
prob = random.random()
ret = random.sample(self._legal_call_score(state), 1)[0]
self.state = random.getstate()
Expand All @@ -228,4 +229,4 @@ def _random_shot_poker(self, state):
def reset(self):
random.seed(self.seed)
self.state = random.getstate()


3 changes: 2 additions & 1 deletion doudizhu/apps/game/policy/randomPolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, seed=None):

def call_score(self, state, default_action=None):
random.setstate(self.state)
random.seed(self.seed+hash(tuple(sorted(state["hand_pokers"]))))
ret = random.sample(self._legal_call_score(state), 1)[0]
self.state = random.getstate()
return ret
Expand All @@ -29,4 +30,4 @@ def reset(self):
self.state = random.getstate()

def __str__(self):
return "RandomPolicy({})".format(self.seed)
return "RandomPolicy({})".format(self.seed)
41 changes: 33 additions & 8 deletions doudizhu/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,45 @@
# a1_policy = DQNPolicy(a1_model, seed=0, comment="default")
# a1 = Agent('Apollo', a1_policy)

a1_model = MultiLevelPerceptron(action_dim=DQNPolicy.ACTION_DIM, hidden_dims=(1000, 1000), freeze=True)
a1_policy = DQNPolicy(a1_model, seed=0, comment="default", load_tag="DQNPolicy-sv1[MLP-lr(0.001)-h1(1000)-h2(1000)-e(1951)](0).default.v0")
a1 = Agent('Apollo', a1_policy)

a2 = Agent('Lannister', RandomPolicy(seed=1))
a3 = Agent('Targaryen', RandomPolicy(seed=2))
# a1_model = MultiLevelPerceptron(action_dim=DQNPolicy.ACTION_DIM, hidden_dims=(200, 200, 200, 200))
# a1_policy = DQNPolicy(a1_model, seed=0, comment="l4")
# a1 = Agent('Billy', a1_policy)

# a1_model = MultiLevelPerceptron(action_dim=DQNPolicy.ACTION_DIM, hidden_dims=(128,), learning_rate=1e-2)
# a1_policy = DQNPolicy(a1_model, seed=0, comment="negd", e_greedy=(0.0, -0.0002))
# a1 = Agent('Cathy8', a1_policy)

# a1_model = MultiLevelPerceptron(action_dim=DQNPolicy.ACTION_DIM, hidden_dims=(256, 256), learning_rate=1e-2)
# a1_policy = DQNPolicy(a1_model, seed=0, comment="egd5e-1", e_greedy=(5e-1, -5e-4))
# a1 = Agent('Dove', a1_policy)

# a1_model = MultiLevelPerceptron(action_dim=DQNPolicy.ACTION_DIM, hidden_dims=(1024, 1024, 1024, 1024), learning_rate=1e-3)
# a1_policy = DQNPolicy(a1_model, seed=0, comment="egd5e-1", e_greedy=(5e-1, -5e-4))
# a1 = Agent('Emma', a1_policy)

# a1_model = MultiLevelPerceptron(action_dim=DQNPolicy.ACTION_DIM, hidden_dims=(2048, 2048, 2048, 2048), learning_rate=1e-3)
# a1_policy = DQNPolicy(a1_model, seed=0, comment="egd5e-1", e_greedy=(5e-1, -5e-4), save_every=100)
# a1 = Agent('Fisher', a1_policy)

# a1_model = MultiLevelPerceptron(action_dim=DQNPolicy.ACTION_DIM, hidden_dims=(4096, )*8, learning_rate=1e-3)
# a1_policy = DQNPolicy(a1_model, seed=0, comment="egd4_vsChooseMinWithRole", e_greedy=(4e-1, -1e-4), save_every=500)
# a1 = Agent('Gilly', a1_policy)

a1_model = MultiLevelPerceptron(action_dim=DQNPolicy.ACTION_DIM, hidden_dims=(2048, )*4, learning_rate=1e-3)
a1_policy = DQNPolicy(a1_model, seed=0, comment="egd0_vsChooseMinWithRole", e_greedy=(0, 0), save_every=500)
a1 = Agent('Hill', a1_policy)
# a2 = Agent('Lannister', RandomPolicy(seed=1))
# a3 = Agent('Targaryen', RandomPolicy(seed=2))
a2 = Agent('Lazarus', ChooseMinWithRolePolicy(seed=1))
a3 = Agent('Tyrion', ChooseMinWithRolePolicy(seed=2))

sim = Simulator([a1, a2, a3], display=False)
# sim = Simulator([a1, a2, a3], display=True)

time_start=time.time()

# results = sim.run(seeds=list(range(100)), mirror=True, save=True)
results = sim.run(seeds=list(range(2000)), mirror=False, save=False)
results = sim.run(seeds=list(range(5000)), mirror=False, save=True)

time_end=time.time()
print('Time(s)', time_end-time_start)
print('Time(s)', time_end-time_start)
11 changes: 11 additions & 0 deletions doudizhu/view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sqlite3
conn = sqlite3.connect("result.db")
cursor = conn.cursor()
step = 200
with conn:
# cursor.execute("select * from stat")
cursor.execute("select (seed / step)*step, AVG(p1_reward), COUNT(p1_reward) FROM result WHERE p1_name=\"Hill\" GROUP BY seed / step ".replace("step", str(step)))
# cursor.execute("select AVG(p1_reward), COUNT(p1_reward) FROM result WHERE p1_name=\"Cathy8\" ")
for row in cursor:
print(row)

0 comments on commit 1002a4b

Please sign in to comment.