-
-
Notifications
You must be signed in to change notification settings - Fork 138
/
Copy pathclient.py
88 lines (76 loc) · 2.82 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import prelude
import logging
import socket
import torch
import numpy as np
import time
import gc
from os import path
from model import Brain, DQN
from player import TrainPlayer
from common import send_msg, recv_msg
from config import config
def main():
remote = (config['online']['remote']['host'], config['online']['remote']['port'])
device = torch.device(config['control']['device'])
version = config['control']['version']
num_blocks = config['resnet']['num_blocks']
conv_channels = config['resnet']['conv_channels']
mortal = Brain(version=version, num_blocks=num_blocks, conv_channels=conv_channels).to(device).eval()
dqn = DQN(version=version).to(device)
if config['online']['enable_compile']:
mortal.compile()
dqn.compile()
train_player = TrainPlayer()
param_version = -1
pts = np.array([90, 45, 0, -135])
history_window = config['online']['history_window']
history = []
while True:
while True:
with socket.socket() as conn:
conn.connect(remote)
msg = {
'type': 'get_param',
'param_version': param_version,
}
send_msg(conn, msg)
rsp = recv_msg(conn, map_location=device)
if rsp['status'] == 'ok':
param_version = rsp['param_version']
break
time.sleep(3)
mortal.load_state_dict(rsp['mortal'])
dqn.load_state_dict(rsp['dqn'])
logging.info('param has been updated')
rankings, file_list = train_player.train_play(mortal, dqn, device)
avg_rank = rankings @ np.arange(1, 5) / rankings.sum()
avg_pt = rankings @ pts / rankings.sum()
history.append(np.array(rankings))
if len(history) > history_window:
del history[0]
sum_rankings = np.sum(history, axis=0)
ma_avg_rank = sum_rankings @ np.arange(1, 5) / sum_rankings.sum()
ma_avg_pt = sum_rankings @ pts / sum_rankings.sum()
logging.info(f'trainee rankings: {rankings} ({avg_rank:.6}, {avg_pt:.6}pt)')
logging.info(f'last {len(history)} sessions: {sum_rankings} ({ma_avg_rank:.6}, {ma_avg_pt:.6}pt)')
logs = {}
for filename in file_list:
with open(filename, 'rb') as f:
logs[path.basename(filename)] = f.read()
with socket.socket() as conn:
conn.connect(remote)
send_msg(conn, {
'type': 'submit_replay',
'logs': logs,
'param_version': param_version,
})
logging.info('logs have been submitted')
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
pass