-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CLearnablePredictor.py
71 lines (60 loc) · 1.87 KB
/
CLearnablePredictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import threading
import Core.Utils as Utils
import numpy
class CLearnablePredictor:
def __init__(self, model, fps=30):
self._lock = threading.Lock()
self._done = threading.Event()
self._inferData = None
self._inferResults = None
self._model = model
self._timesteps = self._model.timesteps
self._prevSteps = []
self._fps = fps
return
def __enter__(self):
self._thread = threading.Thread(target=self._loop)
self._thread.start()
return self
def __exit__(self, type, value, traceback):
self._done.set()
self._thread.join()
return
def async_infer(self, data):
with self._lock:
if not(data is None):
if self._timesteps:
arr = self._prevSteps + [data,]
self._prevSteps = list(arr[-self._timesteps:]) # COPY of list
self._inferData = self._prevSteps # same as self._prevSteps
else:
self._inferData = data
pass
res = self._inferResults
self._inferResults = None
return res
def _loop(self):
while not self._done.wait(1.0 / self._fps):
self._infer()
continue
return
def _infer(self):
with self._lock:
data = self._inferData
self._inferData = None
if data is None: return
if not(len(data) == self._timesteps): return
samples = [Utils.tracked2sample(x) for x in data]
samples = Utils.samples2inputs(samples)
T = numpy.diff(samples['time'], 1)
T = numpy.insert(T, 0, 0.0)
samples['time'] = T.reshape((self._timesteps, 1))
X = {k: x[None] for k, x in samples.items()} # (timesteps, ...) => (1, timesteps, ...)
data = data[-1] # last step as current
res = self._model(X)
with self._lock:
self._inferResults = (res, data, {})
return
@property
def canPredict(self):
return True