forked from ellisk42/ec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprotonet_server.py
75 lines (56 loc) · 1.92 KB
/
protonet_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import socket
import _thread
import sys
import os
from protonet_score import PretrainedProtonetDistScore, \
load_image_path, load_image
cache = {}
model = PretrainedProtonetDistScore(os.path.dirname(os.path.realpath(__file__))
+ "/results-OM/best_model.pt")
def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
sys.stderr.flush()
def compute_score(idRef, img):
if (idRef, img) in cache:
return cache[(idRef, img)]
else:
x = load_image_path(idRef)
y = load_image(img)
score = model.score(x, y)
cache[(idRef, img)] = score
return score
def handle_client(connection):
try:
# eprint("-> Client connected")
while True:
l1 = int.from_bytes(connection.recv(4), byteorder='big')
data = connection.recv(l1)
idRef = data.decode("utf8")
l2 = int.from_bytes(connection.recv(4), byteorder='big')
img = connection.recv(l2)
if idRef != "DONE":
score = compute_score(idRef, img)
loss = str(1000000 * score['dist'][0][0]).encode("utf8")
# loss = str(score['loss']).encode("utf8")
connection.sendall(len(loss).to_bytes(4, byteorder='big'))
connection.sendall(loss)
else:
break
except (BrokenPipeError, ConnectionResetError) as e:
# eprint("Client died too fast for me to answer 😢")
pass
finally:
connection.close()
if __name__ == "__main__":
server_address = "./protonet_socket"
try:
os.unlink(server_address)
except OSError:
if os.path.exists(server_address):
raise
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.bind(server_address)
sock.listen(1)
while True:
c, _ = sock.accept()
_thread.start_new_thread(handle_client, (c,))