-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
360 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import numpy as np | ||
|
||
num_in_a_row_will_win = 4 # 几子棋 | ||
|
||
|
||
class Board: | ||
"""棋盘类""" | ||
|
||
def __init__(self, board=None, size=6, next_player=-1): | ||
self.size = size # 棋盘大小 size * size | ||
self.board = np.zeros((self.size, self.size), int) if board is None else board # 棋盘初始状态 | ||
|
||
self.next_player = next_player # 当前下棋玩家(-1:黑子,1:白子) | ||
|
||
def get_legal_pos(self): | ||
"""获取当前棋盘可落子处""" | ||
indices = np.where(self.board == 0) # 返回棋盘中未落子处的下标(一个二维数组,第一个对应行坐标,第二个对应列坐标) | ||
# zip:将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组 | ||
return list(zip(indices[0], indices[1])) | ||
|
||
def is_move_legal(self, move_pos): | ||
|
||
x, y = move_pos[0], move_pos[1] | ||
if x < 0 or x > self.size or y < 0 or y > self.size: # 检查落子坐标 | ||
return False | ||
if self.board[x, y] != 0: # 该位置是否还能落子 | ||
return False | ||
|
||
return True | ||
|
||
def move(self, move_pos): | ||
if not self.is_move_legal(move_pos): # 落子位置不合理 | ||
raise ValueError("move {0} on board {1} is not legal". format(move_pos, self.board)) | ||
# 新棋盘,准备赋予新结点使用(-self.next_player: 更新下棋选手) | ||
new_board = Board(board=np.copy(self.board), next_player=-self.next_player) | ||
new_board.board[move_pos[0], move_pos[1]] = self.next_player # 落子 | ||
|
||
return new_board # 返回新棋盘 | ||
|
||
def game_over(self, move_pos): | ||
""" | ||
判断游戏是否结束 | ||
:param move_pos: 落子下标 | ||
:param player: 落子方 | ||
:return: | ||
""" | ||
if self.board_result(move_pos): # player玩家胜利,游戏结束 | ||
return 'win' | ||
elif len(self.get_legal_pos()) == 0: # 未分胜利且无可落子点位,返回平局 | ||
return 'tie' | ||
else: # 游戏未结束 | ||
return None | ||
|
||
def board_result(self, move_pos): | ||
""" | ||
每次落子都需要判断棋盘状态,确定棋局是继续还是结束 | ||
:param move_pos: 落子下标 | ||
:return: | ||
""" | ||
x, y = move_pos[0], move_pos[1] | ||
player = self.board[x, y] # 落子方 | ||
direction = list([[self.board[i][y] for i in range(self.size)]]) # 纵向是否有五颗连子 | ||
direction.append([self.board[x][j] for j in range(self.size)]) # 横向是否有五颗连子 | ||
direction.append(self.board.diagonal(y - x)) # 该点正对角是否有五颗连子 | ||
direction.append(np.fliplr(self.board).diagonal(self.size - 1 - y - x)) # 该点反对角是否有五颗连子 | ||
for v_list in direction: | ||
count = 0 | ||
for v in v_list: | ||
if v == player: | ||
count += 1 | ||
if count == num_in_a_row_will_win: | ||
return True # 该玩家赢下游戏 | ||
else: | ||
count = 0 | ||
return False | ||
|
||
def __str__(self): | ||
return "next_player: {}\nboard:\n{}\n".format(self.next_player, self.board) | ||
|
||
|
||
if __name__ == '__main__': | ||
import random | ||
print(random.randint(0, 0)) | ||
pass | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from Board import Board | ||
from Players import Human, AI | ||
|
||
from datetime import datetime | ||
|
||
|
||
class Game: | ||
|
||
def __init__(self): | ||
self.board = Board() # 初始棋盘(黑子先手) | ||
|
||
def graphic(self): | ||
"""绘制棋盘""" | ||
width, height = self.board.size, self.board.size # 棋盘大小 | ||
|
||
print(" 黑子(-1) 用 X 表示\t\t\t白子(1) 用 O 表示\n") | ||
|
||
for x in range(width): # 打印行坐标 | ||
print("{0:8}".format(x), end='') | ||
|
||
print('\r\n') | ||
for i in range(height): | ||
print("{0:4d}".format(i), end='') | ||
for j in range(width): | ||
if self.board.board[i, j] == -1: | ||
print('X'.center(8), end='') | ||
elif self.board.board[i, j] == 1: | ||
print('O'.center(8), end='') | ||
else: | ||
print('-'.center(8), end='') | ||
print('\r\n\r\n') | ||
|
||
def start_play(self): | ||
human, ai = Human(), AI() | ||
self.graphic() | ||
|
||
while True: | ||
|
||
self.board, move_pos = human.action(self.board) | ||
game_result = self.board.game_over(move_pos) | ||
|
||
self.graphic() | ||
if game_result == 'win' or game_result == 'tie': # 游戏结束 | ||
print('黑子落棋: {}, 黑子(-1)胜利!游戏结束!'.format(move_pos)) if game_result == 'win' \ | ||
else print('黑子落棋: {}, 平局!游戏结束!'.format(move_pos)) | ||
break | ||
else: | ||
print('黑子落棋: {}, 未分胜负, 游戏继续!'.format(move_pos)) | ||
|
||
# start_time = datetime.now() | ||
self.board, move_pos = ai.action(self.board, move_pos) | ||
# print(datetime.now() - start_time) | ||
game_result = self.board.game_over(move_pos) | ||
self.graphic() | ||
if game_result == 'win' or game_result == 'tie': # 游戏结束 | ||
print('白子落棋: {}, 白子(1)胜利!游戏结束!'.format(move_pos)) if game_result == 'win' \ | ||
else print('白子落棋: {}, 平局!游戏结束!'.format(move_pos)) | ||
break | ||
else: | ||
print('白子落棋: {}, 未分胜负, 游戏继续!'.format(move_pos)) | ||
|
||
|
||
if __name__ == "__main__": | ||
game = Game() | ||
game.start_play() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import numpy as np | ||
|
||
from Node import TreeNode | ||
|
||
mcts_times = 11000 # MCTS次数 | ||
|
||
|
||
def monte_carlo_tree_search(board, pre_pos): | ||
root = TreeNode(board=board, pre_pos=pre_pos) # 根结点,根结点无父亲 | ||
for i in range(mcts_times): # 相当于(while resources_left(time, computational power):)即资源限制 | ||
leaf = traverse(root) # 选择和扩展,leaf = unvisited node(遍历根结点) | ||
simulation_result = rollout(leaf) # 模拟 | ||
backpropagate(leaf, simulation_result) # 反向传播 | ||
return best_child(root).pre_pos | ||
# return root.best_uct().pre_pos | ||
|
||
|
||
def traverse(node): | ||
""" | ||
层次遍历该结点及其子结点,遇到叶子结点,遇到未完全扩展的结点则对其进行扩展 | ||
:param node: 某一结点 | ||
:return: | ||
""" | ||
while node.fully_expanded(): # 该结点已经完全扩展, 选择一个UCT最高的孩子 | ||
node = node.best_uct() | ||
# 遇到未完成扩展的结点后退出循环,先检查是否为叶子结点 | ||
if node.non_terminal() is not None: # 是叶子结点(node is terminal) | ||
return node | ||
else: # 不是叶子结点且还没有孩子(in case no children are present) | ||
return node.pick_univisted() # 扩展访问结点 | ||
|
||
|
||
# def traverse(node): | ||
# """ | ||
# 层次遍历该结点及其子结点,遇到叶子结点,遇到未完全扩展的结点则对其进行扩展 | ||
# :param node: 某一结点 | ||
# :return: | ||
# """ | ||
# while node.non_terminal() is None: # 不是叶子结点 | ||
# if node.fully_expanded(): # 该结点已经完全扩展, 选择一个UCT最高的孩子 | ||
# node = node.best_uct() | ||
# else: | ||
# return node.pick_univisted() # 不是叶子结点且还没有孩子, 扩展访问结点(in case no children are present) | ||
# return node # 返回叶子结点(node is terminal) | ||
|
||
|
||
def rollout(node): | ||
while True: | ||
game_result = node.non_terminal() | ||
if game_result is None: # 不是叶子结点, 继续模拟 | ||
node = rollout_policy(node) | ||
else: # 是叶子结点,结束 | ||
break | ||
if game_result == 'win' and -node.board.next_player == 1: # 白子胜(测试过, 没有错误) | ||
# print(node, '模拟白子胜利!') | ||
# print('模拟白子胜利!') | ||
return 1 # 相对于白子是胜利的 | ||
elif game_result == 'win': # 黑子胜(测试过, 没有错误) | ||
# print(node.board.board, node, '模拟黑子胜利!') | ||
return -1 # 相对于白子是失败的 | ||
else: # 平局 | ||
return 0 | ||
|
||
|
||
def rollout_policy(node): | ||
return node.pick_random() # 随机选择了一个结点进行模拟 | ||
|
||
|
||
def backpropagate(node, result): | ||
node.num_of_visit += 1 | ||
node.num_of_wins[result] += 1 | ||
if node.parent: # 如果不是根结点,则继续更新其父节点 | ||
backpropagate(node.parent, result) | ||
|
||
|
||
def best_child(node): | ||
visit_num_of_children = np.array(list([child.num_of_visit for child in node.children])) | ||
best_index = np.argmax(visit_num_of_children) # 获取最大uct的下标 | ||
node = node.children[best_index] | ||
# print('root_child_node_info: ', node.num_of_visit, node.num_of_wins) | ||
return node |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import numpy as np | ||
from random import randint | ||
from collections import defaultdict | ||
|
||
|
||
class TreeNode: | ||
"""MCTS Node""" | ||
|
||
def __init__(self, parent=None, pre_pos=None, board=None): | ||
self.pre_pos = pre_pos # (0,1) # 造成这个棋盘的结点下标 | ||
|
||
self.parent = parent # 父结点 | ||
self.children = list() # 子结点 | ||
|
||
self.not_visit_pos = None # 未访问过的节点 | ||
|
||
self.board = board # 每个结点对应一个棋盘状态 | ||
|
||
self.num_of_visit = 0 # 访问次数N | ||
# self.num_of_win = 0 # 胜利次数M 需要实时更新 | ||
self.num_of_wins = defaultdict(int) # 记录该结点模拟的白子、黑子的胜利次数(defaultdict: 当字典里的key不存在但被查找时,返回0) | ||
# self.uct = 0 # 选择该点的机率:uct = (M/N) + c * sqrt(log(parent.N) / N) 需要实时更新 | ||
|
||
def fully_expanded(self): | ||
""" | ||
:return: True: 该结点已经完全扩展, False: 该结点未完全扩展 | ||
""" | ||
if self.not_visit_pos is None: # 如果未访问过的结点为None(初始化为None)则未进行扩展过 | ||
self.not_visit_pos = self.board.get_legal_pos() # 得到可作为该结点扩展结点的所有下标 | ||
# 只剩一个落子点位的叶子结点的未访问结点为0且孩子为0 | ||
# print('len(self.not_visit_pos):', len(self.not_visit_pos), 'len(self.children):', len(self.children)) | ||
# print(True if (len(self.not_visit_pos) == 0 and len(self.children) != 0) else False) | ||
return True if (len(self.not_visit_pos) == 0 and len(self.children) != 0) else False | ||
# return True if len(self.not_visit_pos) == 0 else False | ||
|
||
def pick_univisted(self): | ||
"""选择一个未访问的结点""" | ||
random_index = randint(0, len(self.not_visit_pos) - 1) # 随机选择一个未访问的结点(random.randint: 闭区间) | ||
# print(len(self.not_visit_pos)) | ||
move_pos = self.not_visit_pos.pop(random_index) # 得到一个随机的未访问结点, 并从所有的未访问结点中删除 | ||
# print(len(self.not_visit_pos)) | ||
|
||
new_board = self.board.move(move_pos) # 模拟落子并返回新棋盘 | ||
new_node = TreeNode(parent=self, pre_pos=move_pos, board=new_board) # 新棋盘绑定新结点 | ||
self.children.append(new_node) | ||
return new_node | ||
|
||
def pick_random(self): | ||
"""选择结点的孩子进行扩展""" | ||
possible_moves = self.board.get_legal_pos() # 可以落子的点位 | ||
random_index = randint(0, len(possible_moves) - 1) # 随机选择一个可以落子的点位(random.randint: 闭区间) | ||
move_pos = possible_moves[random_index] # 得到一个随机的可以落子的点位 | ||
|
||
new_board = self.board.move(move_pos) # 模拟落子并返回新棋盘 | ||
new_node = TreeNode(parent=self, pre_pos=move_pos, board=new_board) # 新棋盘绑定新结点 | ||
return new_node | ||
|
||
def non_terminal(self): | ||
""" | ||
:return: None: 不是叶子(终端)结点, 'win' or 'tie': 是叶子(终端)结点 | ||
""" | ||
game_result = self.board.game_over(self.pre_pos) | ||
return game_result | ||
|
||
def num_of_win(self): | ||
# print(self) | ||
# print(-self.board.next_player) | ||
wins = self.num_of_wins[-self.board.next_player] # 孩子结点的棋盘状态是在父节点的next_player之后形成 | ||
loses = self.num_of_wins[self.board.next_player] | ||
return wins - loses | ||
# return wins | ||
|
||
def best_uct(self, c_param=1.98): | ||
"""返回一个自己最好的孩子结点(根据UCT进行比较)""" | ||
uct_of_children = np.array(list([ | ||
(child.num_of_win() / child.num_of_visit) + c_param * np.sqrt(np.log(self.num_of_visit) / child.num_of_visit) | ||
for child in self.children | ||
])) | ||
best_index = np.argmax(uct_of_children) | ||
# max_uct = max(uct_of_children) | ||
# best_index = np.where(uct_of_children == max_uct) # 获取最大uct的下标 | ||
# best_index = np.random.choice(best_index[0]) # 随机选取一个拥有最大uct的孩子 | ||
return self.children[best_index] | ||
|
||
def __str__(self): | ||
return "pre_pos: {}\t pre_player: {}\t num_of_visit: {}\t num_of_wins: {}"\ | ||
.format(self.pre_pos, self.board.board[self.pre_pos[0], self.pre_pos[1]], | ||
self.num_of_visit, dict(self.num_of_wins)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from MCTS import monte_carlo_tree_search | ||
|
||
|
||
class Human: | ||
|
||
def __init__(self, player=-1): | ||
self.player = player | ||
|
||
def get_action_pos(self, board): | ||
"""落子""" | ||
try: | ||
location = input("Your move(please use commas to separate the two index): ") | ||
if isinstance(location, str) and len(location.split(",")) == 2: # for python3, 检测变量类型 | ||
move_pos = tuple([int(n, 10) for n in location.split(",")]) # 转成不可变的元组 | ||
else: | ||
move_pos = -1 | ||
except: | ||
move_pos = -1 | ||
|
||
if move_pos == -1 or move_pos not in board.get_legal_pos(): | ||
print("Invalid Move") | ||
move_pos = self.get_action_pos(board) | ||
return move_pos | ||
|
||
def action(self, board): | ||
move_pos = self.get_action_pos(board) | ||
board = board.move(move_pos) # 新的棋盘 | ||
return board, move_pos | ||
|
||
|
||
class AI: | ||
"""AI player""" | ||
|
||
def __init__(self, player=1): | ||
self.player = player | ||
|
||
@staticmethod | ||
def action(board, pre_pos): | ||
move_pos = monte_carlo_tree_search(board, pre_pos) | ||
board = board.move(move_pos) # 新的棋盘 | ||
return board, move_pos |