-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_batch.py
35 lines (28 loc) · 1.07 KB
/
load_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""
BSEN
2019
Authors:
Wan-Ting Hsieh [email protected]
Jérémy Lefort-Besnard jlefortbesnard [at] tuta [dot] io
"""
import numpy as np
import joblib
import torch
peoID = joblib.load('data/shuffled_peoID.pkl')
def getAEBatch_centerloss(personDataList,lab_tile, batchSize=32):
while 1:
randDataIdx = np.random.choice(personDataList.__len__(), size=batchSize, replace=True)
clipData = [joblib.load(personDataList[idx]) for idx in randDataIdx]
clip_lab = [int(lab_tile[idx]) for idx in randDataIdx]
feaAll = []
for fea in clipData:
if [] in fea:
continue
tmpFea = np.concatenate(fea, axis=1).reshape(1,fea.shape[0],fea.shape[1],fea.shape[2])
if np.isnan(tmpFea).any() or np.isinf(tmpFea).any(): # prevent non-labe fea
continue
feaAll.append(tmpFea)
feaAll = np.array(feaAll)
feaAll = torch.from_numpy(feaAll).float().cuda()
labAll = torch.from_numpy(np.array(clip_lab)).long().cuda()
yield feaAll,labAll