Skip to content

Commit

Permalink
predict
Browse files Browse the repository at this point in the history
  • Loading branch information
SmirkCao committed Oct 6, 2018
1 parent 026945f commit 407c5c5
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 3 deletions.
55 changes: 54 additions & 1 deletion CH05/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ $$

## 算法

这部分内容,原始的[5.1数据](./Input/data_5-1.txt)中最后的标签也是是和否,表示树模型的时候,叶结点不是很明显,所以简单改了下[数据标签](./Input/mdata_5-1.txt)。对应同样的树结构,输出的结果如下

```python
# data_5-1.txt
{'有自己的房子': {'': {'有工作': {'': {'': None}, '': {'': None}}}, '': {'': None}}}
# mdata_5-1.txt
{'有自己的房子': {'': {'有工作': {'': {'拒绝': None}, '': {'批准': None}}}, '': {'批准': None}}}
```

### 算法5.1 信息增益

> 输入:训练数据集$D$和特征$A$
Expand Down Expand Up @@ -111,11 +120,53 @@ ID3和C4.5在生成上,差异只在准则的差异。

### 算法5.4 树的剪枝算法

决策树损失函数摘录如下:

> 树$T$的叶结点个数为$|T|$,$t$是树$T$的叶结点,该结点有$N_t$个样本点,其中$k$类的样本点有$N_{tk}$个,$H_t(T)$为叶结点$t$上的经验熵, $\alpha\geqslant 0$为参数,决策树学习的损失函数可以定义为
> $$
> C_\alpha(T)=\sum_{i=1}^{|T|}N_tH_t(T)+\alpha|T|
> $$
> 其中
> $$
> H_t(T)=-\sum_k\color{red}\frac{N_{tk}}{N_t}\color{black}\log \frac{N_{tk}}{N_t}
> $$
>
> $$
> C(T)=\sum_{t=1}^{|T|}\color{red}N_tH_t(T)\color{black}=-\sum_{t=1}^{|T|}\sum_{k=1}^K\color{red}N_{tk}\color{black}\log\frac{N_{tk}}{N_t}
> $$
>
> 这时有
> $$
> C_\alpha(T)=C(T)+\alpha|T|
> $$
> 其中$C(T)$表示模型对训练数据的误差,$|T|$表示模型复杂度,参数$\alpha \geqslant 0$控制两者之间的影响。
上面这组公式中,注意红色部分,下面插入一个图

![熵与概率的关系](assets/熵与概率的关系.png)

这里面没有直接对$H_t(T)$求和,系数$N_t$使得$C(T)$和$|T|$的大小可比拟。这个地方再理解下。

> 输入:生成算法生成的整个树$T$,参数$\alpha$
>
> 输出:修剪后的子树$T_\alpha$
>
> 1. 计算每个节点的经验熵
> 1. 递归的从树的叶结点向上回缩
> 假设一组叶结点回缩到其父结点之前与之后的整体树分别是$T_B$和$T_A$,其对应的损失函数分别是$C_\alpha(T_A)$和$C_\alpha(T_B)$,如果$C_\alpha(T_A)\leqslant C_\alpha(T_B)$则进行剪枝,即将父结点变为新的叶结点
> 1. 返回2,直至不能继续为止,得到损失函数最小的子树$T_\alpha$


这里面没有具体的实现例子,给出的参考文献是李航老师在CL上的文章,文章介绍的MDL是模型选择的一种具体框架,里面有介绍KL散度,这部分可以参考下。

### 算法5.5 最小二乘回归树生成算法

### 算法5.6 CART生成算法
### 算法5.6 CART分类树生成算法

这个算法用到的策略是基尼系数,所以是分类树的生成算法。



### 算法5.7 CART剪枝算法

Expand All @@ -131,6 +182,8 @@ ID3和C4.5在生成上,差异只在准则的差异。

根据信息增益准则选择最优特征

习题的5.1是让用信息增益比生成树,和这个基本一样,换一个准则就可以了。在单元测试里面实现了这部分代码。



## 参考
Expand Down
Binary file added CH05/assets/熵与概率的关系.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 23 additions & 2 deletions CH05/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,31 @@ def __init__(self,
self.criteria = {"gain": self._gain,
"gain_ratio": self._gain_ratio}

def fit(self, X, y):
def fit(self,
X,
y):
return self._build_tree(X, y)

def predict(self, X):
def _search(self,
X,
parent=None):
if parent is None:
parent = self.tree
key_x = list(parent.keys())[0]
# is leaf
if parent[key_x] is None:
# {key_x: None} is leaf node
return key_x
else:
key_child = X[key_x].values[0]
# print("\n%s|%s|%s|%s\n" % (parent, key_x, key_child, parent[key_x][key_child].keys()))
return self._search(X, parent=parent[key_x][key_child])

def predict(self,
X):
return self._search(X)

def _cal_loss(self, X, y):
pass

@staticmethod
Expand Down
34 changes: 34 additions & 0 deletions CH05/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@


class TestDT(unittest.TestCase):
DEBUG = True

@unittest.skipIf(DEBUG, "debug")
def test_e51(self):
raw_data = pd.read_csv("./Input/data_5-1.txt")
logger.info(raw_data)

@unittest.skipIf(DEBUG, "debug")
def test_e52(self):
raw_data = pd.read_csv("./Input/data_5-1.txt")
hd = dt._cal_entropy(raw_data[raw_data.columns[-1]])
Expand All @@ -33,6 +37,7 @@ def test_e52(self):
logger.info(hd)
self.assertEqual(np.argmax(rst), 2) # index = 2 -> A3

@unittest.skipIf(DEBUG, "debug")
def test_e53(self):
raw_data = pd.read_csv("./Input/data_5-1.txt")
cols = raw_data.columns
Expand All @@ -46,6 +51,12 @@ def test_e53(self):
self.assertEqual(rst, clf.tree)
logger.info(clf.tree)

@unittest.skipIf(DEBUG, "debug")
def test_q51(self):
raw_data = pd.read_csv("./Input/data_5-1.txt")
cols = raw_data.columns
X = raw_data[cols[1:-1]]
y = raw_data[cols[-1]]
# criterion: gain_ratio
clf = dt(criterion="gain_ratio")
clf.fit(X, y)
Expand All @@ -54,7 +65,30 @@ def test_e53(self):
self.assertEqual(rst, clf.tree)
logger.info(clf.tree)

@unittest.skipIf(DEBUG, "debug")
def test_e54(self):
raw_data = pd.read_csv("./Input/mdata_5-1.txt")
cols = raw_data.columns
X = raw_data[cols[1:-1]]
y = raw_data[cols[-1]]

clf = dt()
clf.fit(X, y)
logger.info(clf.tree)

def test_predict(self):
raw_data = pd.read_csv("./Input/mdata_5-1.txt")
cols = raw_data.columns
X = raw_data[cols[1:-1]]
y = raw_data[cols[-1]]

clf = dt(criterion="gain_ratio")
clf.fit(X, y)
rst = clf.predict(X[:1])
self.assertEqual(rst, y[:1].values)
print("predict: ", rst, "label: ", y[:1])

def test_pruning(self):
pass


Expand Down

0 comments on commit 407c5c5

Please sign in to comment.