forked from yysijie/st-gcn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeeder.py
86 lines (71 loc) · 2.44 KB
/
feeder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# sys
import os
import sys
import numpy as np
import random
import pickle
# torch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
# visualization
import time
# operation
from . import tools
class Feeder(torch.utils.data.Dataset):
""" Feeder for skeleton-based action recognition
Arguments:
data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M)
label_path: the path to label
random_choose: If true, randomly choose a portion of the input sequence
random_shift: If true, randomly pad zeros at the begining or end of sequence
window_size: The length of the output sequence
normalization: If true, normalize input sequence
debug: If true, only use the first 100 samples
"""
def __init__(self,
data_path,
label_path,
random_choose=False,
random_move=False,
window_size=-1,
debug=False,
mmap=True):
self.debug = debug
self.data_path = data_path
self.label_path = label_path
self.random_choose = random_choose
self.random_move = random_move
self.window_size = window_size
self.load_data(mmap)
def load_data(self, mmap):
# data: N C V T M
# load label
with open(self.label_path, 'rb') as f:
self.sample_name, self.label = pickle.load(f)
# load data
if mmap:
self.data = np.load(self.data_path, mmap_mode='r')
else:
self.data = np.load(self.data_path)
if self.debug:
self.label = self.label[0:100]
self.data = self.data[0:100]
self.sample_name = self.sample_name[0:100]
self.N, self.C, self.T, self.V, self.M = self.data.shape
def __len__(self):
return len(self.label)
def __getitem__(self, index):
# get data
data_numpy = np.array(self.data[index])
label = self.label[index]
# processing
if self.random_choose:
data_numpy = tools.random_choose(data_numpy, self.window_size)
elif self.window_size > 0:
data_numpy = tools.auto_pading(data_numpy, self.window_size)
if self.random_move:
data_numpy = tools.random_move(data_numpy)
return data_numpy, label