-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathpostprocessor.py
85 lines (72 loc) · 2.26 KB
/
postprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from . import utils
def decoder(val):
return val.decode()
class Processor:
@staticmethod
def modelget(res):
resdict = utils.list2dict(res)
utils.recursive_bytetransform(resdict["inputs"], lambda x: x.decode())
utils.recursive_bytetransform(resdict["outputs"], lambda x: x.decode())
return resdict
@staticmethod
def modelscan(res):
return utils.recursive_bytetransform(res, lambda x: x.decode())
@staticmethod
def tensorget(res, as_numpy, as_numpy_mutable, meta_only):
"""Process the tensorget output.
If ``as_numpy`` is True, it'll be converted to a numpy array. The required
information such as datatype and shape must be in ``rai_result`` itself.
"""
rai_result = utils.list2dict(res)
if meta_only is True:
return rai_result
if as_numpy_mutable is True:
return utils.blob2numpy(
rai_result["blob"],
rai_result["shape"],
rai_result["dtype"],
mutable=True,
)
if as_numpy is True:
return utils.blob2numpy(
rai_result["blob"],
rai_result["shape"],
rai_result["dtype"],
mutable=False,
)
if rai_result["dtype"] == "STRING":
def target(b):
return b.decode()
else:
target = float if rai_result["dtype"] in ("FLOAT", "DOUBLE") else int
utils.recursive_bytetransform(rai_result["values"], target)
return rai_result
@staticmethod
def scriptget(res):
return utils.list2dict(res)
@staticmethod
def scriptscan(res):
return utils.recursive_bytetransform(res, lambda x: x.decode())
@staticmethod
def infoget(res):
return utils.list2dict(res)
# These functions are only doing decoding on the output from redis
decoder = staticmethod(decoder)
decoding_functions = (
"config",
"inforeset",
"loadbackend",
"modeldel",
"modelexecute",
"modelrun",
"modelset",
"modelstore",
"scriptdel",
"scriptexecute",
"scriptrun",
"scriptset",
"scriptstore",
"tensorset",
)
for fn in decoding_functions:
setattr(Processor, fn, decoder)