Skip to content

Commit

Permalink
add gain_ratio parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
SmirkCao committed Oct 5, 2018
1 parent fcc6170 commit 026945f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 15 deletions.
20 changes: 12 additions & 8 deletions CH05/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ $$
> 输入:训练数据集$D$, 特征集$A$,阈值$\epsilon$
> 输出:决策树$T$
>
> 1. $D$属于同一类$C_k$,$T$为单节点树,类$C_k$作为该节点的类标记,返回$T$
> 1. $A$是空集->T为单节点树,实例数最多的作为该节点类标记,返回T
> 1. 如果$D$属于同一类$C_k$,$T$为单节点树,类$C_k$作为该节点的类标记,返回$T$
> 1. 如果$A$是空集,置$T$为单节点树,实例数最多的类作为该节点类标记,返回$T$
> 1. 计算$g$, 选择信息增益最大的特征$A_g$
> 1. 如果$A_g$的信息增益小于$\epsilon$,$T$为单节点树,$D$中实例数最大的类$C_k$作为类标记,返回$T$
> 1. $A_g$划分若干非空子集$D_i$,
Expand All @@ -100,15 +100,19 @@ $$
> 输入:训练数据集$D$, 特征集$A$,阈值$\epsilon$
> 输出:决策树$T$
>
> 1. D属于同一类C_k -> T为单节点树,类Ck作为该节点的类标记,返回T
> 1. A是空集->T为单节点树,实例数最多的作为该节点类标记,返回T
> 1. 计算g, 选择信息增益最大的特征Ag
> 1. 如果Ag的信息增益小于ε,T为单节点树,D中实例数最大的类Ck作为类标记,返回T
> 1. Ag划分若干非空子集Di,
> 1. Di训练集,A-{Ag}为特征集,递归调用前面步骤,得到Ti,返回Ti
> 1. 如果$D$属于同一类$C_k$,$T$为单节点树,类$C_k$作为该节点的类标记,返回$T$
> 1. 如果$A$是空集, 置$T$为单节点树,实例数最多的作为该节点类标记,返回$T$
> 1. 计算$g$, 选择**信息增益比**最大的特征$A_g$
> 1. 如果$A_g$的**信息增益比**小于$\epsilon$,$T$为单节点树,$D$中实例数最大的类$C_k$作为类标记,返回$T$
> 1. $A_g$划分若干非空子集$D_i$,
> 1. $D_i$训练集,$A-A_g$为特征集,递归调用前面步骤,得到$T_i$,返回$T_i$
ID3和C4.5在生成上,差异只在准则的差异。


### 算法5.4 树的剪枝算法



### 算法5.5 最小二乘回归树生成算法

### 算法5.6 CART生成算法
Expand Down
15 changes: 9 additions & 6 deletions CH05/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@


class dt(object):

def __init__(self,
tol=10e-3,
criterion = 'ID3'):
criterion='gain'):
self.tree = dict()
self.tol = tol
self.criterion = criterion
self.criteria = {"gain": self._gain,
"gain_ratio": self._gain_ratio}

def fit(self, X, y):
self._build_tree(X, y)
return self._build_tree(X, y)

def predict(self, X):
pass
Expand Down Expand Up @@ -68,10 +71,10 @@ def _build_tree(self, X, y):
cols = X.columns.tolist()
rst_col = cols[0]
for col in cols:
gain = dt._gain(X[col], y)
if gain >= rst:
rst, rst_col = gain, col
if gain < self.tol:
criterion = self.criteria[self.criterion](X[col], y)
if criterion >= rst:
rst, rst_col = criterion, col
if criterion < self.tol:
return self.tree

cols.remove(rst_col)
Expand Down
11 changes: 10 additions & 1 deletion CH05/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,21 @@ def test_e53(self):
cols = raw_data.columns
X = raw_data[cols[1:-1]]
y = raw_data[cols[-1]]

# default criterion: gain
clf = dt()
clf.fit(X, y)
logger.info("gain")
rst = {'有自己的房子': {'否': {'有工作': {'否': {'否': None}, '是': {'是': None}}}, '是': {'是': None}}}
self.assertEqual(rst, clf.tree)
logger.info(clf.tree)

# criterion: gain_ratio
clf = dt(criterion="gain_ratio")
clf.fit(X, y)
logger.info("gain_ratio")
rst = {'有自己的房子': {'否': {'有工作': {'否': {'否': None}, '是': {'是': None}}}, '是': {'是': None}}}
self.assertEqual(rst, clf.tree)
logger.info(clf.tree)

def test_e54(self):
pass
Expand Down

0 comments on commit 026945f

Please sign in to comment.