forked from TsingZ0/TLSAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinput.py
84 lines (64 loc) · 1.95 KB
/
input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
class DataInput:
def __init__(self, data, batch_size):
self.batch_size = batch_size
self.data = data
self.epoch_size = len(self.data) // self.batch_size
if self.epoch_size * self.batch_size < len(self.data):
self.epoch_size += 1
self.i = 0
def __iter__(self):
return self
def __next__(self):
if self.i == self.epoch_size:
raise StopIteration
ts = self.data[self.i * self.batch_size : min((self.i+1) * self.batch_size, len(self.data))]
self.i += 1
u, i, y, sl = [], [], [], []
for t in ts:
u.append(t[0])
i.append(t[3])
y.append(t[4])
sl.append(len(t[1]))
max_sl = max(sl)
hist_i = np.zeros([len(ts), max_sl], np.int64)
hist_t = np.zeros([len(ts), max_sl], np.float32)
k = 0
for t in ts:
for l in range(len(t[1])):
hist_i[k][l] = t[1][l]
hist_t[k][l] = t[2][l]
k += 1
return self.i, (u, i, y, hist_i, hist_t, sl)
class DataInputTest:
def __init__(self, data, batch_size):
self.batch_size = batch_size
self.data = data
self.epoch_size = len(self.data) // self.batch_size
if self.epoch_size * self.batch_size < len(self.data):
self.epoch_size += 1
self.i = 0
def __iter__(self):
return self
def __next__(self):
if self.i == self.epoch_size:
raise StopIteration
ts = self.data[self.i * self.batch_size : min((self.i+1) * self.batch_size,
len(self.data))]
self.i += 1
u, i, j, sl = [], [], [], []
for t in ts:
u.append(t[0])
i.append(t[3][0])
j.append(t[3][1])
sl.append(len(t[1]))
max_sl = max(sl)
hist_i = np.zeros([len(ts), max_sl], np.int64)
hist_t = np.zeros([len(ts), max_sl], np.float32)
k = 0
for t in ts:
for l in range(len(t[1])):
hist_i[k][l] = t[1][l]
hist_t[k][l] = t[2][l]
k += 1
return self.i, (u, i, j, hist_i, hist_t, sl)