Skip to content

Commit

Permalink
data processing
Browse files Browse the repository at this point in the history
  • Loading branch information
yysijie committed Sep 30, 2019
1 parent 1c213d0 commit 58e752b
Show file tree
Hide file tree
Showing 75 changed files with 744 additions and 1 deletion.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
20 changes: 19 additions & 1 deletion doc/SKELETON_DATA.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
## Skeleton-based Dataset Processing

Coming Soon...
ST-GCN was evaluated on two skeleton-based action recognition datasets: **Kinetics-skeleton** and **NTU RGB+D**.
The raw data should be converted to the proper format before training and test as below steps. Or you can download
the processed data directly from [GoogleDrive](https://drive.google.com/open?id=103NOL9YYZSW1hLoWmYnv5Fs8mK-Ij7qb).

#### Kinetics-skeleton
[Kinetics](https://deepmind.com/research/open-source/open-source-datasets/kinetics/) is a video-based dataset for action recognition which only provide raw video clips without skeleton data. Kinetics dataset include To obtain the joint locations, we first resized all videos to the resolution of 340x256 and converted the frame rate to 30 fps. Then, we extracted skeletons from each frame in Kinetics by [Openpose](https://github.com/CMU-Perceptual-Computing-Lab/openpose). The extracted skeleton data we called **Kinetics-skeleton**(7.5GB) can be downloaded from [GoogleDrive](https://drive.google.com/open?id=1SPQ6FmFsjGg3f59uCWfdUWI-5HJM_YhZ) or [BaiduYun](https://pan.baidu.com/s/1dwKG2TLvG-R1qeIiE4MjeA#list/path=%2FShare%2FAAAI18%2Fkinetics-skeleton&parentPath=%2FShare).

After uncompressing, build the database for mmskeleton by this command:
```
python tools/data_processing/kinetics_gendata.py --data_path <path to kinetics-skeleton>
```

#### NTU RGB+D
NTU RGB+D can be downloaded from [their website](http://rose1.ntu.edu.sg/datasets/actionrecognition.asp).
Only the **3D skeletons**(5.8GB) modality is required in our experiments. After that, this command should be used to build the database for training or evaluation on mmskeleton:
```
python tools/data_processing/ntu_gendata.py --data_path <path to nturgbd+d_skeletons>
```
where the ```<path to nturgbd+d_skeletons>``` points to the 3D skeletons modality of NTU RGB+D dataset you download.
162 changes: 162 additions & 0 deletions mmskeleton/datasets/kinetics_feeder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# sys
import os
import sys
import numpy as np
import random
import pickle
import json
# torch
import torch
import torch.nn as nn
from torchvision import datasets, transforms

# operation
from .utils import skeleton


class KineticsFeeder(torch.utils.data.Dataset):
""" Feeder for skeleton-based action recognition in kinetics-skeleton dataset
Arguments:
data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M)
label_path: the path to label
random_choose: If true, randomly choose a portion of the input sequence
random_shift: If true, randomly pad zeros at the begining or end of sequence
random_move: If true, perform randomly but continuously changed transformation to input sequence
window_size: The length of the output sequence
pose_matching: If ture, match the pose between two frames
num_person_in: The number of people the feeder can observe in the input sequence
num_person_out: The number of people the feeder in the output sequence
debug: If true, only use the first 100 samples
"""
def __init__(self,
data_path,
label_path,
ignore_empty_sample=True,
random_choose=False,
random_shift=False,
random_move=False,
window_size=-1,
pose_matching=False,
num_person_in=5,
num_person_out=2,
debug=False):
self.debug = debug
self.data_path = data_path
self.label_path = label_path
self.random_choose = random_choose
self.random_shift = random_shift
self.random_move = random_move
self.window_size = window_size
self.num_person_in = num_person_in
self.num_person_out = num_person_out
self.pose_matching = pose_matching
self.ignore_empty_sample = ignore_empty_sample

self.load_data()

def load_data(self):
# load file list
self.sample_name = os.listdir(self.data_path)

if self.debug:
self.sample_name = self.sample_name[0:2]

# load label
label_path = self.label_path
with open(label_path) as f:
label_info = json.load(f)

sample_id = [name.split('.')[0] for name in self.sample_name]
self.label = np.array(
[label_info[id]['label_index'] for id in sample_id])
has_skeleton = np.array(
[label_info[id]['has_skeleton'] for id in sample_id])

# ignore the samples which does not has skeleton sequence
if self.ignore_empty_sample:
self.sample_name = [
s for h, s in zip(has_skeleton, self.sample_name) if h
]
self.label = self.label[has_skeleton]

# output data shape (N, C, T, V, M)
self.N = len(self.sample_name) #sample
self.C = 3 #channel
self.T = 300 #frame
self.V = 18 #joint
self.M = self.num_person_out #person

def __len__(self):
return len(self.sample_name)

def __iter__(self):
return self

def __getitem__(self, index):

# output shape (C, T, V, M)
# get data
sample_name = self.sample_name[index]
sample_path = os.path.join(self.data_path, sample_name)
with open(sample_path, 'r') as f:
video_info = json.load(f)

# fill data_numpy
data_numpy = np.zeros((self.C, self.T, self.V, self.num_person_in))
for frame_info in video_info['data']:
frame_index = frame_info['frame_index']
for m, skeleton_info in enumerate(frame_info["skeleton"]):
if m >= self.num_person_in:
break
pose = skeleton_info['pose']
score = skeleton_info['score']
data_numpy[0, frame_index, :, m] = pose[0::2]
data_numpy[1, frame_index, :, m] = pose[1::2]
data_numpy[2, frame_index, :, m] = score

# centralization
data_numpy[0:2] = data_numpy[0:2] - 0.5
data_numpy[0][data_numpy[2] == 0] = 0
data_numpy[1][data_numpy[2] == 0] = 0

# get & check label index
label = video_info['label_index']
assert (self.label[index] == label)

# data augmentation
if self.random_shift:
data_numpy = skeleton.random_shift(data_numpy)
if self.random_choose:
data_numpy = skeleton.random_choose(data_numpy, self.window_size)
elif self.window_size > 0:
data_numpy = skeleton.auto_pading(data_numpy, self.window_size)
if self.random_move:
data_numpy = skeleton.random_move(data_numpy)

# sort by score
sort_index = (-data_numpy[2, :, :, :].sum(axis=1)).argsort(axis=1)
for t, s in enumerate(sort_index):
data_numpy[:, t, :, :] = data_numpy[:, t, :, s].transpose(
(1, 2, 0))
data_numpy = data_numpy[:, :, :, 0:self.num_person_out]

# match poses between 2 frames
if self.pose_matching:
data_numpy = skeleton.openpose_match(data_numpy)

return data_numpy, label

def top_k(self, score, top_k):
assert (all(self.label >= 0))

rank = score.argsort()
hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(self.label)]
return sum(hit_top_k) * 1.0 / len(hit_top_k)

def top_k_by_category(self, score, top_k):
assert (all(self.label >= 0))
return skeleton.top_k_by_category(self.label, score, top_k)

def calculate_recall_precision(self, score):
assert (all(self.label >= 0))
return skeleton.calculate_recall_precision(self.label, score)
85 changes: 85 additions & 0 deletions tools/data_processing/kinetics_gendata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import sys
import pickle
import argparse

import numpy as np
from numpy.lib.format import open_memmap

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from mmskeleton.datasets.kinetics_feeder import KineticsFeeder
toolbar_width = 30


def print_toolbar(rate, annotation=''):
# setup toolbar
sys.stdout.write("{}[".format(annotation))
for i in range(toolbar_width):
if i * 1.0 / toolbar_width > rate:
sys.stdout.write(' ')
else:
sys.stdout.write('-')
sys.stdout.flush()
sys.stdout.write(']\r')


def end_toolbar():
sys.stdout.write("\n")


def gendata(
data_path,
label_path,
data_out_path,
label_out_path,
num_person_in=5, #observe the first 5 persons
num_person_out=2, #then choose 2 persons with the highest score
max_frame=300):

feeder = KineticsFeeder(data_path=data_path,
label_path=label_path,
num_person_in=num_person_in,
num_person_out=num_person_out,
window_size=max_frame)

sample_name = feeder.sample_name
sample_label = []

fp = open_memmap(data_out_path,
dtype='float32',
mode='w+',
shape=(len(sample_name), 3, max_frame, 18,
num_person_out))

for i, s in enumerate(sample_name):
data, label = feeder[i]
print_toolbar(
i * 1.0 / len(sample_name),
'({:>5}/{:<5}) Processing data: '.format(i + 1, len(sample_name)))
fp[i, :, 0:data.shape[1], :, :] = data
sample_label.append(label)

with open(label_out_path, 'wb') as f:
pickle.dump((sample_name, list(sample_label)), f)


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Kinetics-skeleton Data Converter.')
parser.add_argument('--data_path',
default='data/Kinetics/kinetics-skeleton')
parser.add_argument('--out_folder',
default='data/Kinetics/kinetics-skeleton')
arg = parser.parse_args()

part = ['train', 'val']
for p in part:
data_path = '{}/kinetics_{}'.format(arg.data_path, p)
label_path = '{}/kinetics_{}_label.json'.format(arg.data_path, p)
data_out_path = '{}/{}_data.npy'.format(arg.out_folder, p)
label_out_path = '{}/{}_label.pkl'.format(arg.out_folder, p)

if not os.path.exists(arg.out_folder):
os.makedirs(arg.out_folder)
gendata(data_path, label_path, data_out_path, label_out_path)
Loading

0 comments on commit 58e752b

Please sign in to comment.