-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDataLoader.py
46 lines (40 loc) · 1.68 KB
/
DataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from os.path import isfile
class FastDataset(Dataset):
def __init__(self, folder_path, actors, name_actors):
self.actors = actors
self.name_actors = name_actors
self.folder_path = folder_path
self.file_list = [os.path.join(folder_path, f) for f in sorted(os.listdir(folder_path)) if
isfile(os.path.join(folder_path, f))]
labels = []
for f in self.file_list:
label = os.path.basename(f)
if "FaceTalk" in label:
label = label[:label.find("_FaceTalk")]
else:
label = label[:label.find("_")]
if label not in labels: labels.append(label)
self.dict_emotions = {label: idx for idx, label in enumerate(labels)}
self.num_classes = len(labels)
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
animation = np.load(self.file_list[idx], allow_pickle=True)
label = os.path.basename(self.file_list[idx])
if "FaceTalk" in label:
face = label[label.find("_") + 1:label.find(".")]
label = label[:label.find("_FaceTalk")]
face = face[face.index("FaceTalk"):]
else:
face = label[label.find("_") + 1:label.find(".")]
label = label[:label.find("_")]
path = self.file_list[idx]
id_template = self.name_actors.index(face)
template = self.actors[id_template]
for i in range(animation.shape[0]):
animation[i] = animation[i] - template
return torch.Tensor(animation), self.dict_emotions[label], path