Skip to content

Commit

Permalink
e53
Browse files Browse the repository at this point in the history
  • Loading branch information
SmirkCao committed Oct 5, 2018
1 parent f54d844 commit fcc6170
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 5 deletions.
12 changes: 11 additions & 1 deletion CH05/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ $$
> 输入:训练数据集$D$, 特征集$A$,阈值$\epsilon$
> 输出:决策树$T$
>
> 1. $D$属于同一类$C_k$,$ T$为单节点树,类$C_k$作为该节点的类标记,返回$T$
> 1. $D$属于同一类$C_k$,$T$为单节点树,类$C_k$作为该节点的类标记,返回$T$
> 1. $A$是空集->T为单节点树,实例数最多的作为该节点类标记,返回T
> 1. 计算$g$, 选择信息增益最大的特征$A_g$
> 1. 如果$A_g$的信息增益小于$\epsilon$,$T$为单节点树,$D$中实例数最大的类$C_k$作为类标记,返回$T$
Expand All @@ -97,6 +97,16 @@ $$

### 算法5.3 C4.5生成算法

> 输入:训练数据集$D$, 特征集$A$,阈值$\epsilon$
> 输出:决策树$T$
>
> 1. D属于同一类C_k -> T为单节点树,类Ck作为该节点的类标记,返回T
> 1. A是空集->T为单节点树,实例数最多的作为该节点类标记,返回T
> 1. 计算g, 选择信息增益最大的特征Ag
> 1. 如果Ag的信息增益小于ε,T为单节点树,D中实例数最大的类Ck作为类标记,返回T
> 1. Ag划分若干非空子集Di,
> 1. Di训练集,A-{Ag}为特征集,递归调用前面步骤,得到Ti,返回Ti
### 算法5.4 树的剪枝算法

### 算法5.5 最小二乘回归树生成算法
Expand Down
35 changes: 32 additions & 3 deletions CH05/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@

class dt(object):
def __init__(self,
tol=10e-3):
self.tree = None
tol=10e-3,
criterion = 'ID3'):
self.tree = dict()
self.tol = tol
self.criterion = criterion

def fit(self, X, y):
pass
self._build_tree(X, y)

def predict(self, X):
pass
Expand Down Expand Up @@ -54,6 +56,33 @@ def _gain_ratio(X, y):
def _cal_gini(X, y):
pass

def _build_tree(self, X, y):
ck, cnts = np.unique(y, return_counts=True)
# same y
if ck.shape[0] == 1:
return {ck[0]: None}
elif X.shape[1] == 0:
return {ck[np.argmax(cnts)]: None}
else:
rst = 0
cols = X.columns.tolist()
rst_col = cols[0]
for col in cols:
gain = dt._gain(X[col], y)
if gain >= rst:
rst, rst_col = gain, col
if gain < self.tol:
return self.tree

cols.remove(rst_col)
rst = dict()
X_sub = X[cols]
for x in np.unique(X[rst_col]):
mask = X[rst_col] == x
rst.update({x: self._build_tree(X_sub[mask], y[mask])})
self.tree = {rst_col: rst}
return self.tree


if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
Expand Down
11 changes: 10 additions & 1 deletion CH05/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@ def test_e52(self):
self.assertEqual(np.argmax(rst), 2) # index = 2 -> A3

def test_e53(self):
pass
raw_data = pd.read_csv("./Input/data_5-1.txt")
cols = raw_data.columns
X = raw_data[cols[1:-1]]
y = raw_data[cols[-1]]

clf = dt()
clf.fit(X, y)
logger.info(clf.tree)
rst = {'有自己的房子': {'否': {'有工作': {'否': {'否': None}, '是': {'是': None}}}, '是': {'是': None}}}
self.assertEqual(rst, clf.tree)

def test_e54(self):
pass
Expand Down

0 comments on commit fcc6170

Please sign in to comment.