Skip to content

Commit

Permalink
Added option to use the true game state in the search tree.
Browse files Browse the repository at this point in the history
  • Loading branch information
Skirlax committed Aug 23, 2024
1 parent f84eb80 commit 85e8664
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
6 changes: 4 additions & 2 deletions mu_alpha_zero/MuZero/MZ_MCTS/mz_search_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def search(self, network_wrapper, state: np.ndarray, current_player: int or None
game_state = state if use_state_directly else self.buffer.concat_frames(current_player)
state_ = network_wrapper.representation_forward(
game_state.permute(2, 0, 1).unsqueeze(0).to(device)).squeeze(0)
state_ = scale_hidden_state(state_)
if self.muzero_config.scale_hidden_state:
state_ = scale_hidden_state(state_)
pi, v = network_wrapper.prediction_forward(state_.unsqueeze(0), predict=True)
if self.muzero_config.dirichlet_alpha > 0:
pi = pi + np.random.dirichlet([self.muzero_config.dirichlet_alpha] * self.muzero_config.net_action_size)
Expand Down Expand Up @@ -148,7 +149,8 @@ def search(self, network_wrapper, state: np.ndarray, current_player: int or None
self.muzero_config)
next_state, reward = network_wrapper.dynamics_forward(current_node_state_with_action.unsqueeze(0),
predict=True)
next_state = scale_hidden_state(next_state)
if self.muzero_config.scale_hidden_state:
next_state = scale_hidden_state(next_state)
reward = reward[0][0]
pi, v = network_wrapper.prediction_forward(next_state.unsqueeze(0), predict=True)
if self.muzero_config.use_true_game_state_in_tree:
Expand Down
6 changes: 4 additions & 2 deletions mu_alpha_zero/MuZero/Network/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def calculate_losses(self, experience_batch, grad_scales, weights, device, muzer
else:
values = scalar_values
hidden_state = self.representation_forward(init_states)
hidden_state = scale_hidden_state(hidden_state)
if muzero_config.scale_hidden_state:
hidden_state = scale_hidden_state(hidden_state)
pred_pis, pred_vs = self.prediction_forward(hidden_state, return_support=muzero_config.loss_gets_support)
pi_loss, v_loss, r_loss = 0, 0, 0
pi_loss += self.muzero_loss(pred_pis, pis,masks=masks)
Expand All @@ -249,7 +250,8 @@ def calculate_losses(self, experience_batch, grad_scales, weights, device, muzer
hidden_state, pred_rs, pred_pis, pred_vs = self.forward_recurrent(
match_action_with_obs_batch(hidden_state, moves, muzero_config), False,
return_support=muzero_config.loss_gets_support)
hidden_state = scale_hidden_state(hidden_state)
if muzero_config.scale_hidden_state:
hidden_state = scale_hidden_state(hidden_state)
hidden_state.register_hook(lambda grad: grad * 0.5)
current_pi_loss = self.muzero_loss(pred_pis, pis,masks=masks)
current_v_loss = loss_fn(pred_vs, values)
Expand Down
1 change: 1 addition & 0 deletions mu_alpha_zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class MuZeroConfig(Config):
loss_gets_support: bool = False
frame_buffer_ignores_actions: bool = False
actions_are: Literal["columns", "rows", "board"] = "board"
scale_hidden_state: bool = False
# Will keep maximum of len(longest_path_in_tree) actual game states in memory during tree search,
# in order to deduce when the game is over and improve initial policy.
use_true_game_state_in_tree: bool = False
Expand Down

0 comments on commit 85e8664

Please sign in to comment.