Skip to content

Commit

Permalink
Added parallel training to alpha zero.
Browse files Browse the repository at this point in the history
  • Loading branch information
Skirlax committed Aug 15, 2024
1 parent 43b7129 commit 10374a7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
8 changes: 4 additions & 4 deletions mu_alpha_zero/AlphaZero/MCTS/az_search_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def play_one_game(self, network: GeneralNetwork, device: th.device) -> tuple[
state = self.game_manager.reset()
current_player = 1
game_history = []
game_data = SingleGameData()
# game_data = SingleGameData()
results = {"1": 0, "-1": 0, "D": 0}
while True:
pi, _ = self.search(network, state, current_player, device)
Expand Down Expand Up @@ -88,9 +88,9 @@ def play_one_game(self, network: GeneralNetwork, device: th.device) -> tuple[
game_history = augment_experience_with_symmetries(game_history, self.game_manager.board_size)
self.hook_manager.process_hook_executes(self, self.play_one_game.__name__, __file__, HookAt.TAIL,
args=(game_history, results))
for state, pi, r, player, move_mask in game_history:
game_data.add_data_point(DataPoint(pi, r, None, None, player, state, move_mask))
return [game_data], results["1"], results["-1"], results["D"]
# for state, pi, r, player, move_mask in game_history:
# game_data.add_data_point(DataPoint(pi, r, None, None, player, state, move_mask))
return game_history, results["1"], results["-1"], results["D"]

def search(self, network, state, current_player, device, tau=None):
"""
Expand Down
2 changes: 1 addition & 1 deletion mu_alpha_zero/AlphaZero/Network/nnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def calculate_loss(self, experience_batch, muzero_alphazero_config):
device = th.device("cuda" if th.cuda.is_available() else "cpu")
states, pi, v, _, masks = experience_batch[0], experience_batch[1], experience_batch[2], experience_batch[3], \
experience_batch[4]
# pi = [[y for y in x.values()] for x in pi]
pi = [[y for y in x.values()] for x in pi]
# game = [[y.frame,y.pi,y.v,y.action_mask] for y in experience_batch.datapoints]
# states, pi, v, masks = zip(*game)
states = th.tensor(np.array(states), dtype=th.float32, device=device)
Expand Down
8 changes: 2 additions & 6 deletions mu_alpha_zero/mem_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,8 @@ def batch(self, batch_size: int, is_eval: bool = False):
indeces = np.random.choice(np.arange(len(buf)), size=min(len(buf), batch_size),
replace=False).flatten().tolist()
items = [buf[i] for i in indeces]
item_names = ["frame", "pi", "v", "player", "action_mask"]
batch = []
for item in items:
game_batch = [[getattr(y,item_names[i]) for y in item.datapoints] for i in range(5)]
batch.append(game_batch)
return batch
batch = [[items[i][x] for i in range(len(items))] for x in range(len(items[0]))]
return [batch]

def __call__(self, batch_size, is_eval: bool = False) -> list:
return self.batch(batch_size, is_eval=is_eval)
Expand Down

0 comments on commit 10374a7

Please sign in to comment.