Skip to content

Commit

Permalink
ERROR FIXED
Browse files Browse the repository at this point in the history
  • Loading branch information
mailgyc-163 committed Mar 20, 2019
1 parent 62f8ce0 commit cfa5daa
Showing 1 changed file with 76 additions and 69 deletions.
145 changes: 76 additions & 69 deletions doudizhu/apps/game/extra/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import core.extra.card as card
from core.extra.card import action_space, Category, action_space_category
import numpy as np
from collections import Counter
import tensorflow as tf
import argparse
import time
from collections import Counter
from contextlib import contextmanager

import numpy as np
import tensorflow as tf

from apps.game.extra import card
from .card import action_space, Card, Category, action_space_category

action_space_single = action_space[1:16]
action_space_pair = action_space[16:29]
action_space_triple = action_space[29:42]
Expand All @@ -26,21 +27,21 @@ def counter_subset(list1, list2):
# map char cards to 3 - 17
def to_value(cards):
if isinstance(cards, list) or isinstance(cards, np.ndarray):
values = [card.Card.cards.index(c)+3 for c in cards]
values = [Card.cards.index(c) + 3 for c in cards]
return values
else:
return card.Card.cards.index(cards)+3
return Card.cards.index(cards) + 3


# map 3 - 17 to char cards
def to_char(cards):
if isinstance(cards, list) or isinstance(cards, np.ndarray):
if len(cards) == 0:
return []
chars = [card.Card.cards[c-3] for c in cards]
chars = [card.Card.cards[c - 3] for c in cards]
return chars
else:
return card.Card.cards[cards-3]
return card.Card.cards[cards - 3]


def get_mask(cards, action_space, last_cards=None):
Expand All @@ -58,7 +59,7 @@ def get_mask(cards, action_space, last_cards=None):
return mask
if len(last_cards) > 0:
for j in range(1, mask.size):
if mask[j] == 1 and not card.CardGroup.to_cardgroup(action_space[j]).\
if mask[j] == 1 and not card.CardGroup.to_cardgroup(action_space[j]). \
bigger_than(card.CardGroup.to_cardgroup(last_cards)):
mask[j] = 0
# else:
Expand All @@ -80,11 +81,12 @@ def get_mask_onehot60(cards, action_space, last_cards):
return mask
if len(last_cards) > 0:
for j in range(1, len(action_space)):
if np.sum(mask[j]) > 0 and not card.CardGroup.to_cardgroup(action_space[j]).\
if np.sum(mask[j]) > 0 and not card.CardGroup.to_cardgroup(action_space[j]). \
bigger_than(card.CardGroup.to_cardgroup(last_cards)):
mask[j] = np.zeros([60])
return mask


# # get char cards, return valid response
# def get_mask_category(cards, action_space, last_cards=None):
# mask = np.zeros([14]) if last_cards is None else np.zeros([15])
Expand Down Expand Up @@ -112,19 +114,22 @@ def get_seq_length(category, cards_val):
return cards_val.size // 5
return None


# get [-1, 1] minor cards target, input: value cards 3-17
def find_minor_in_three_one(cards):
if cards[0] == cards[1]:
return cards[-1]
else:
return cards[0]


def find_minor_in_three_two(cards):
if cards[1] == cards[2]:
return cards[-1]
else:
return cards[0]


def find_minor_in_three_one_line(cards):
cnt = np.zeros([18])
for i in range(len(cards)):
Expand All @@ -135,6 +140,7 @@ def find_minor_in_three_one_line(cards):
minor.append(i)
return np.array(minor)


def find_minor_in_three_two_line(cards):
cnt = np.zeros([18])
for i in range(len(cards)):
Expand All @@ -145,6 +151,7 @@ def find_minor_in_three_two_line(cards):
minor.append(i)
return np.array(minor)


def find_minor_in_four_two(cards):
cnt = np.zeros([18])
for i in range(len(cards)):
Expand All @@ -155,24 +162,25 @@ def find_minor_in_four_two(cards):
minor.append(i)
return np.array(minor)


def get_minor_cards(cards, category_idx):
minor_cards = np.ones([15])
length = 0
if category_idx == Category.THREE_ONE.value:
length = 1
minor_cards[find_minor_in_three_one(cards)-3] = -1
minor_cards[find_minor_in_three_one(cards) - 3] = -1
if category_idx == Category.THREE_TWO.value:
length = 1
minor_cards[find_minor_in_three_two(cards)-3] = -1
minor_cards[find_minor_in_three_two(cards) - 3] = -1
if category_idx == Category.THREE_ONE_LINE.value:
length = int(cards.size / 4)
minor_cards[find_minor_in_three_one_line(cards)-3] = -1
minor_cards[find_minor_in_three_one_line(cards) - 3] = -1
if category_idx == Category.THREE_TWO_LINE.value:
length = int(cards.size / 5)
minor_cards[find_minor_in_three_two_line(cards)-3] = -1
minor_cards[find_minor_in_three_two_line(cards) - 3] = -1
if category_idx == Category.FOUR_TWO.value:
length = 2
minor_cards[find_minor_in_four_two(cards)-3] = -1
minor_cards[find_minor_in_four_two(cards) - 3] = -1
return minor_cards, length


Expand All @@ -197,7 +205,7 @@ def get_feature_state(env, mask=None):
m = mask[i]
if m:
a = action_space[i]

if not a:
features[i, 1] = 1
continue
Expand Down Expand Up @@ -260,7 +268,7 @@ def train_fake_action(targets, handcards, s, sess, network, category_idx, main_c
cards_onehot = card.Card.char2onehot(main_cards_char)
# we must make the order in each 4 batch correct...
discard_onehot_from_s(s[0], cards_onehot)
assert np.amax(s) < 1.1 and np.amin(s) > -0.1
# assert np.amax(s) < 1.1 and np.amin(s) > -0.1

is_pair = False
if category_idx == Category.THREE_TWO.value or category_idx == Category.THREE_TWO_LINE.value:
Expand All @@ -270,21 +278,21 @@ def train_fake_action(targets, handcards, s, sess, network, category_idx, main_c
target_val = card.Card.char2value_3_17(target) - 3
input_single, input_pair, input_triple, input_quadric = get_masks(handcards, None)

_, response_active_output, fake_loss = sess.run([network.optimize_fake,
network.fc_response_minor_output,
network.minor_response_loss],
feed_dict = {
network.input_state: s,
network.input_single: np.reshape(input_single, [1, -1]),
network.input_pair: np.reshape(input_pair, [1, -1]),
network.input_triple: np.reshape(input_triple, [1, -1]),
network.input_quadric: np.reshape(input_quadric, [1, -1]),
network.input_single_last: np.zeros([1, 15]),
network.input_pair_last: np.zeros([1, 13]),
network.input_triple_last: np.zeros([1, 13]),
network.input_quadric_last: np.zeros([1, 13]),
network.minor_response_input: np.array([target_val]),
})
_, response_active_output, fake_loss = sess.run([network.optimize_fake,
network.fc_response_minor_output,
network.minor_response_loss],
feed_dict={
network.input_state: s,
network.input_single: np.reshape(input_single, [1, -1]),
network.input_pair: np.reshape(input_pair, [1, -1]),
network.input_triple: np.reshape(input_triple, [1, -1]),
network.input_quadric: np.reshape(input_quadric, [1, -1]),
network.input_single_last: np.zeros([1, 15]),
network.input_pair_last: np.zeros([1, 13]),
network.input_triple_last: np.zeros([1, 13]),
network.input_quadric_last: np.zeros([1, 13]),
network.minor_response_input: np.array([target_val]),
})
cards = [target]
handcards.remove(target)
if is_pair:
Expand Down Expand Up @@ -333,10 +341,10 @@ def train_fake_action_60(targets, handcards, s, sess, network, category_idx, mai
target_val = card.Card.char2value_3_17(target) - 3
_, fc_minor_response_output = sess.run([network.optimize[-1],
network.fc_minor_response_output], feed_dict={
network.input_state: s.reshape(1, -1),
network.minor_type: np.array([minor_type]),
network.minor_response_input: np.array([target_val])
})
network.input_state: s.reshape(1, -1),
network.minor_type: np.array([minor_type]),
network.minor_response_input: np.array([target_val])
})
cards = [target]
handcards.remove(target)
if is_pair:
Expand Down Expand Up @@ -374,29 +382,29 @@ def test_fake_action(targets, handcards, s, sess, network, category_idx, dup_mas
target_val = card.Card.char2value_3_17(target) - 3
input_single, input_pair, input_triple, input_quadric = get_masks(handcards, None)
response_minor_output = sess.run(network.fc_response_minor_output,
feed_dict = {
network.input_state: s,
network.input_single: np.reshape(input_single, [1, -1]),
network.input_pair: np.reshape(input_pair, [1, -1]),
network.input_triple: np.reshape(input_triple, [1, -1]),
network.input_quadric: np.reshape(input_quadric, [1, -1])
})
feed_dict={
network.input_state: s,
network.input_single: np.reshape(input_single, [1, -1]),
network.input_pair: np.reshape(input_pair, [1, -1]),
network.input_triple: np.reshape(input_triple, [1, -1]),
network.input_quadric: np.reshape(input_quadric, [1, -1])
})
# give minor cards
response_minor_output = response_minor_output[0]
response_minor_output[dup_mask == 0] = -1

if is_pair:
# fix dimension mismatch
input_pair = np.concatenate([input_pair, [0, 0]])
response_minor_output[input_pair == 0] = -1
else:
response_minor_output[input_single == 0] = -1

response_minor = np.argmax(response_minor_output)
dup_mask[response_minor] = 0

# convert network output to char cards
cards =[target]
cards = [target]
handcards.remove(target)
if is_pair:
handcards.remove(target)
Expand All @@ -421,7 +429,7 @@ def pick_minor_targets(category, cards_char):
return cards_char[-length:]
if category == Category.THREE_TWO_LINE.value:
length = len(cards_char) // 5
return cards_char[-length*2::2]
return cards_char[-length * 2::2]
if category == Category.FOUR_TWO.value:
return cards_char[-2:]
return None
Expand All @@ -437,11 +445,11 @@ def pick_main_cards(category, cards_char):
return cards_char[:-length]
if category == Category.THREE_TWO_LINE.value:
length = len(cards_char) // 5
return cards_char[:-length*2]
return cards_char[:-length * 2]
if category == Category.FOUR_TWO.value:
return cards_char[:-2]
return None


def get_mask_alter(cards, last_cards, last_cards_category):
decision_mask = None
Expand Down Expand Up @@ -486,7 +494,7 @@ def get_mask_alter(cards, last_cards, last_cards_category):
response_mask = np.zeros([15])
subspace = action_space_category[last_cards_category]
for j in range(len(subspace)):
if counter_subset(subspace[j], cards) and card.CardGroup.to_cardgroup(subspace[j]).\
if counter_subset(subspace[j], cards) and card.CardGroup.to_cardgroup(subspace[j]). \
bigger_than(card.CardGroup.to_cardgroup(last_cards)):
# diff = card.Card.to_value(subspace[j][0]) - card.Card.to_value(last_cards[0])
# assert(diff > 0)
Expand All @@ -504,7 +512,7 @@ def get_mask_alter(cards, last_cards, last_cards_category):
if no_bomb:
decision_mask[1] = 0
return decision_mask, response_mask, bomb_mask, length_mask


# return [3-17 value]
def give_cards_without_minor(response, last_cards_value, category_idx, length_output):
Expand Down Expand Up @@ -648,13 +656,13 @@ def inference_minor_util(s, handcards, sess, network, num, is_pair, dup_mask, ma
inter_masks.append([input_single, input_pair, input_triple, input_quadric])

response_minor_output = scheduled_run(sess, network.fc_minor_response_output,
(
(network.input_state, s),
(network.input_single, np.reshape(input_single, [1, -1])),
(network.input_pair, np.reshape(input_pair, [1, -1])),
(network.input_triple, np.reshape(input_triple, [1, -1])),
(network.input_quadric, np.reshape(input_quadric, [1, -1]))
))
(
(network.input_state, s),
(network.input_single, np.reshape(input_single, [1, -1])),
(network.input_pair, np.reshape(input_pair, [1, -1])),
(network.input_triple, np.reshape(input_triple, [1, -1])),
(network.input_quadric, np.reshape(input_quadric, [1, -1]))
))
# response_active_output = sess.run(network.fc_response_active_output,
# feed_dict={
# network.input_state: s,
Expand Down Expand Up @@ -725,10 +733,10 @@ def inference_minor_util60(s, handcards, sess, network, num, is_pair, dup_mask,
inter_states.append(s.copy())
input_single, input_pair, _, _ = get_masks(handcards, None)
response_minor_output = scheduled_run(sess, network.fc_minor_response_output,
(
(network.input_state, s),
(network.minor_type, np.array([minor_type]))
))
(
(network.input_state, s),
(network.minor_type, np.array([minor_type]))
))

# give minor cards
response_minor_output = response_minor_output[0]
Expand Down Expand Up @@ -797,7 +805,7 @@ def gputimeblock(label):
yield
finally:
end = time.perf_counter()
GPUTime.total_time += end-start
GPUTime.total_time += end - start


def update_params(scope_from, scope_to):
Expand All @@ -810,15 +818,14 @@ def update_params(scope_from, scope_to):
ops.append(to_var.assign(from_var))
return ops


if __name__ == '__main__':
# _, response_mask, _, _ = get_mask_alter(['A', 'A', 'A', 'J', 'J', '10', '6', '6', '5'], ['9', '9', '9', '5'], False,
# 5)
# _, response_mask, _, _ = get_mask_alter(['A', 'A', 'A', 'J', 'J', '10', '6', '6', '5'], ['9', '9', '9', '5'],
# False, 5)
pass
# for i in range(14):
# for j in range(len(action_space_category[i])):
# try:
# assert get_category_idx(np.array(action_space_category[i][j])) == i
# except AssertionError as error:
# print(i, action_space_category[i][j])

0 comments on commit cfa5daa

Please sign in to comment.