-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecision_tree.py
executable file
·66 lines (47 loc) · 1.89 KB
/
decision_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
import sys
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.metrics import accuracy_score
from normalize_frames import normalize_frames
from load_gestures import load_gestures
joints = ['head', 'neck', 'left_shoulder', 'left_elbow', 'left_hand', 'right_shoulder', 'right_elbow', 'right_hand', 'torso', 'left_hip', 'right_hip']
dims = ['x', 'y', 'z']
# remove truncation of numpy array printing
np.set_printoptions(threshold=sys.maxsize)
# properties
num_frames = 36
ratio = 0.7
# 6. FORMAT DATA
gesture_sets = load_gestures()
gesture_sets = normalize_frames(gesture_sets, num_frames)
samples, labels = [], []
for gs in gesture_sets:
for seq in gs.sequences:
sample = np.concatenate(list(map(lambda x: x.frame, seq.frames)))
samples.append(sample)
labels.append(int(gs.label))
X, Y = np.vstack(samples), np.array(labels)
X_train, X_test, y_train, y_test = train_test_split(X, Y, train_size=ratio, random_state= 10)
# 7. CREATE AND TRAIN MODEL
clf = DecisionTreeClassifier(criterion='gini')
clf.fit(X_train, y_train)
print("num features: ", clf.feature_importances_.size)
# 7. TEST MODEr
y_train_pred = clf.predict(X_train)
train_accuracy = accuracy_score(y_train, y_train_pred) * 100
print("train_accuracy", train_accuracy)
y_test_pred = clf.predict(X_test)
test_accuracy = accuracy_score(y_test,y_test_pred) * 100
print("test_accuracy", test_accuracy)
# 8. VISUALIZE MODEL
feature_names = ['frame ' + str(i)+' - '+str(joints[i%11])+' - '+str(dims[i%3]) for i in range(33*num_frames)]
dot_data = export_graphviz(
clf,
out_file='tree.dot',
feature_names=feature_names,
class_names=["pan left", "pan right", "pan up", "pan down", "zoom in", "zoom out", "rotate clockwise", "rotate counterclockwise", "point"],
filled=True,
rounded=True,
special_characters=True
)