-
-
Notifications
You must be signed in to change notification settings - Fork 138
/
Copy pathserver.py
172 lines (157 loc) · 5.51 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import prelude
import logging
import shutil
import torch
import sys
import os
from os import path
from io import BytesIO
from typing import *
from collections import OrderedDict
from dataclasses import dataclass
from socketserver import ThreadingTCPServer, BaseRequestHandler
from threading import Lock
from common import send_msg, recv_msg, UnexpectedEOF
from config import config
@dataclass
class State:
buffer_dir: str
drain_dir: str
capacity: int
force_sequential: bool
dir_lock: Lock
param_lock: Lock
# fields below are protected by dir_lock
buffer_size: int
submission_id: int
# fields below are protected by param_lock
mortal_param: Optional[OrderedDict]
dqn_param: Optional[OrderedDict]
param_version: int
idle_param_version: int
S = None
class Handler(BaseRequestHandler):
def handle(self):
msg = self.recv_msg()
match msg['type']:
# called by workers
case 'get_param':
self.handle_get_param(msg)
case 'submit_replay':
self.handle_submit_replay(msg)
# called by trainer
case 'submit_param':
self.handle_submit_param(msg)
case 'drain':
self.handle_drain()
def handle_get_param(self, msg):
with S.dir_lock:
overflow = S.buffer_size >= S.capacity
with S.param_lock:
has_param = S.mortal_param is not None and S.dqn_param is not None
if overflow:
self.send_msg({'status': 'samples overflow'})
return
if not has_param:
self.send_msg({'status': 'empty param'})
return
client_param_version = msg['param_version']
buf = BytesIO()
with S.param_lock:
if S.force_sequential and S.idle_param_version <= client_param_version:
res = {'status': 'trainer is busy'}
else:
res = {
'status': 'ok',
'mortal': S.mortal_param,
'dqn': S.dqn_param,
'param_version': S.param_version,
}
torch.save(res, buf)
self.send_msg(buf.getbuffer(), packed=True)
def handle_submit_replay(self, msg):
with S.dir_lock:
for filename, content in msg['logs'].items():
filepath = path.join(S.buffer_dir, f'{S.submission_id}_{filename}')
with open(filepath, 'wb') as f:
f.write(content)
S.buffer_size += len(msg['logs'])
S.submission_id += 1
logging.info(f'total buffer size: {S.buffer_size}')
def handle_submit_param(self, msg):
with S.param_lock:
S.mortal_param = msg['mortal']
S.dqn_param = msg['dqn']
S.param_version += 1
if msg['is_idle']:
S.idle_param_version = S.param_version
def handle_drain(self):
drained_size = 0
with S.dir_lock:
buffer_list = os.listdir(S.buffer_dir)
raw_count = len(buffer_list)
assert raw_count == S.buffer_size
if (not S.force_sequential or raw_count >= S.capacity) and raw_count > 0:
old_drain_list = os.listdir(S.drain_dir)
for filename in old_drain_list:
filepath = path.join(S.drain_dir, filename)
os.remove(filepath)
for filename in buffer_list:
src = path.join(S.buffer_dir, filename)
dst = path.join(S.drain_dir, filename)
shutil.move(src, dst)
drained_size = raw_count
S.buffer_size = 0
logging.info(f'files transferred to trainer: {drained_size}')
logging.info(f'total buffer size: {S.buffer_size}')
self.send_msg({
'count': drained_size,
'drain_dir': S.drain_dir,
})
def send_msg(self, msg, packed=False):
return send_msg(self.request, msg, packed)
def recv_msg(self):
return recv_msg(self.request)
class Server(ThreadingTCPServer):
def handle_error(self, request, client_address):
typ, _, _ = sys.exc_info()
if typ is BrokenPipeError or typ is UnexpectedEOF:
return
return super().handle_error(request, client_address)
def main():
global S
cfg = config['online']['server']
S = State(
buffer_dir = path.abspath(cfg['buffer_dir']),
drain_dir = path.abspath(cfg['drain_dir']),
capacity = cfg['capacity'],
force_sequential = cfg['force_sequential'],
dir_lock = Lock(),
param_lock = Lock(),
buffer_size = 0,
submission_id = 0,
mortal_param = None,
dqn_param = None,
param_version = 0,
idle_param_version = 0,
)
bind_addr = (config['online']['remote']['host'], config['online']['remote']['port'])
if path.isdir(S.buffer_dir):
shutil.rmtree(S.buffer_dir)
if path.isdir(S.drain_dir):
shutil.rmtree(S.drain_dir)
os.makedirs(S.buffer_dir)
os.makedirs(S.drain_dir)
with Server(bind_addr, Handler, bind_and_activate=False) as server:
server.allow_reuse_address = True
server.daemon_threads = True
server.server_bind()
server.server_activate()
host, port = bind_addr
logging.info(f'listening on {host}:{port}')
server.serve_forever()
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
pass