-
-
Notifications
You must be signed in to change notification settings - Fork 138
/
Copy pathengine.py
136 lines (121 loc) · 4.61 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import json
import traceback
import torch
import numpy as np
from torch.distributions import Normal, Categorical
from typing import *
class MortalEngine:
def __init__(
self,
brain,
dqn,
is_oracle,
version,
device = None,
stochastic_latent = False,
enable_amp = False,
enable_quick_eval = True,
enable_rule_based_agari_guard = False,
name = 'NoName',
boltzmann_epsilon = 0,
boltzmann_temp = 1,
top_p = 1,
):
self.engine_type = 'mortal'
self.device = device or torch.device('cpu')
assert isinstance(self.device, torch.device)
self.brain = brain.to(self.device).eval()
self.dqn = dqn.to(self.device).eval()
self.is_oracle = is_oracle
self.version = version
self.stochastic_latent = stochastic_latent
self.enable_amp = enable_amp
self.enable_quick_eval = enable_quick_eval
self.enable_rule_based_agari_guard = enable_rule_based_agari_guard
self.name = name
self.boltzmann_epsilon = boltzmann_epsilon
self.boltzmann_temp = boltzmann_temp
self.top_p = top_p
def react_batch(self, obs, masks, invisible_obs):
try:
with (
torch.autocast(self.device.type, enabled=self.enable_amp),
torch.inference_mode(),
):
return self._react_batch(obs, masks, invisible_obs)
except Exception as ex:
raise Exception(f'{ex}\n{traceback.format_exc()}')
def _react_batch(self, obs, masks, invisible_obs):
obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device)
masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device)
if invisible_obs is not None:
invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device)
batch_size = obs.shape[0]
match self.version:
case 1:
mu, logsig = self.brain(obs, invisible_obs)
if self.stochastic_latent:
latent = Normal(mu, logsig.exp() + 1e-6).sample()
else:
latent = mu
q_out = self.dqn(latent, masks)
case 2 | 3 | 4:
phi = self.brain(obs)
q_out = self.dqn(phi, masks)
if self.boltzmann_epsilon > 0:
is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool)
logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf)
sampled = sample_top_p(logits, self.top_p)
actions = torch.where(is_greedy, q_out.argmax(-1), sampled)
else:
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device)
actions = q_out.argmax(-1)
return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist()
def sample_top_p(logits, p):
if p >= 1:
return Categorical(logits=logits).sample()
if p <= 0:
return logits.argmax(-1)
probs = logits.softmax(-1)
probs_sort, probs_idx = probs.sort(-1, descending=True)
probs_sum = probs_sort.cumsum(-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.
sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1)
return sampled
class ExampleMjaiLogEngine:
def __init__(self, name: str):
self.engine_type = 'mjai-log'
self.name = name
self.player_ids = None
def set_player_ids(self, player_ids: List[int]):
self.player_ids = player_ids
def react_batch(self, game_states):
res = []
for game_state in game_states:
game_idx = game_state.game_index
state = game_state.state
events_json = game_state.events_json
events = json.loads(events_json)
assert events[0]['type'] == 'start_kyoku'
player_id = self.player_ids[game_idx]
cans = state.last_cans
if cans.can_discard:
tile = state.last_self_tsumo()
res.append(json.dumps({
'type': 'dahai',
'actor': player_id,
'pai': tile,
'tsumogiri': True,
}))
else:
res.append('{"type":"none"}')
return res
# They will be executed at specific events. They can be no-op but must be
# defined.
def start_game(self, game_idx: int):
pass
def end_kyoku(self, game_idx: int):
pass
def end_game(self, game_idx: int, scores: List[int]):
pass