-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdriver.py
73 lines (59 loc) · 2.39 KB
/
driver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
'''
Authors: Ashwani Kashyap, Anshul Pardhi
'''
from DecisionTree import *
import pandas as pd
from sklearn import model_selection
# default data set
df = pd.read_csv('data_set/Social_Network_Ads.csv')
header = list(df.columns)
# overwrite your data set here
# header = ['SepalL', 'SepalW', 'PetalL', 'PetalW', 'Class']
# df = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None, names=['SepalL','SepalW','PetalL','PetalW','Class'])
# data-set link: https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer/
# df = pd.read_csv('data_set/breast-cancer.csv')
lst = df.values.tolist()
# splitting the data set into train and test
trainDF, testDF = model_selection.train_test_split(lst, test_size=0.2)
# building the tree
t = build_tree(trainDF, header)
# get leaf and inner nodes
print("\nLeaf nodes ****************")
leaves = getLeafNodes(t)
for leaf in leaves:
print("id = " + str(leaf.id) + " depth =" + str(leaf.depth))
print("\nNon-leaf nodes ****************")
innerNodes = getInnerNodes(t)
for inner in innerNodes:
print("id = " + str(inner.id) + " depth =" + str(inner.depth))
# print tree
maxAccuracy = computeAccuracy(testDF, t)
print("\nTree before pruning with accuracy: " + str(maxAccuracy*100) + "\n")
print_tree(t)
# TODO: You have to decide on a pruning strategy
# Pruning strategy
nodeIdToPrune = -1
for node in innerNodes:
if node.id != 0:
prune_tree(t, [node.id])
currentAccuracy = computeAccuracy(testDF, t)
print("Pruned node_id: " + str(node.id) + " to achieve accuracy: " + str(currentAccuracy*100) + "%")
# print("Pruned Tree")
# print_tree(t)
if currentAccuracy > maxAccuracy:
maxAccuracy = currentAccuracy
nodeIdToPrune = node.id
t = build_tree(trainDF, header)
if maxAccuracy == 1:
break
if nodeIdToPrune != -1:
t = build_tree(trainDF, header)
prune_tree(t, [nodeIdToPrune])
print("\nFinal node Id to prune (for max accuracy): " + str(nodeIdToPrune))
else:
t = build_tree(trainDF, header)
print("\nPruning strategy did'nt increased accuracy")
print("\n********************************************************************")
print("*********** Final Tree with accuracy: " + str(maxAccuracy*100) + "% ************")
print("********************************************************************\n")
print_tree(t)