-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CTestLoader.py
49 lines (41 loc) · 1.11 KB
/
CTestLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import tensorflow as tf
import numpy as np
import os, glob
from functools import lru_cache
class CTestLoader(tf.keras.utils.Sequence):
def __init__(self, testFolder):
self._folder = testFolder
self._batchesNpz = [
f for f in glob.glob(os.path.join(testFolder, 'test-*.npz'))
]
self.on_epoch_end()
return
@property
def folder(self):
return self._folder
@lru_cache(maxsize=1)
def parametersIDs(self):
batch, _ = self[0]
userId = batch['userId'][0, 0, 0]
placeId = batch['placeId'][0, 0, 0]
screenId = batch['screenId'][0, 0, 0]
return userId, placeId, screenId
def on_epoch_end(self):
return
def __len__(self):
return len(self._batchesNpz)
def __getitem__(self, idx):
with np.load(self._batchesNpz[idx]) as res:
res = {k: v for k, v in res.items()}
Y = res.pop('y')
return(res, (Y, ))
if __name__ == '__main__':
folder = os.path.dirname(__file__)
ds = CTestLoader(os.path.join(folder, 'test'))
print(len(ds))
batch, (y,) = ds[0]
for k, v in batch.items():
print(k, v.shape)
print()
print(y.shape)
pass