决策树系列0:你需要一棵决策树

关于决策树

决策树是非常流行的一种机器学习算法,基于决策树的集成算法随机森林更是位于机器学习 Top10 的算法之一。决策树之所以备受推崇,很大一个原因是其实现的过程与人类思考问题的方式非常接近,即使没有任何机器学习基础的人也可以看懂决策树做出判断的过程:

相关概念

决策树的概念来自自然界中的树,树有分支、叶子、根和树干;决策树的节点根据选定的特征进行切分,形成分支,这些分支上的节点是上一层节点的子节点,如果某节点不存在子节点(即不再向下延伸),那么这个节点就是叶子节点。每个决策树都有一个根节点(root),和树的跟一样,这是一棵决策树开始生长的起点。

种一棵决策树

决策树的生长意味着对数据空间的不断分割,每个叶子节点都对应于空间中的某一部分,将所有叶子节点对应的空间相加就是原来的数据空间,以二维数据空间为例,\(X\in \mathbf{X}\),有两个维度的特征,每一次我们都选择其中一个特征进行分割,这个过程我们可以用下图形象地表示出来:

上图中叶子节点是 \(X_3,X_5,X_6,X_7,X_8\) ,这些叶子节点对应的子空间之和即为原始数据空间。在二维空间里似乎几何分割更加直观,但是我们很难画出更高维空间的几何分割,对于决策树来说,高维空间和低维空间的分割过程大同小异,无非是再增添一些节点,让树生长的更为茂盛罢了。

构造一颗决策树需要解决四个问题:

  1. 选择什么特征进行分裂?如何给定分裂的临界值?
  2. 如何评估分类的效果?
  3. 分裂到什么时候停止?
  4. 每个叶子节点都有一个对应的分类结果,怎么给?

以 CART 树的分类问题为例,看看这些问题该如何解决:

  • 选择分裂特征和分裂值

假设我们输入的的数据 \(\mathbf{x^{(i)}}=[x_1^{(i)},x_2^{(i)}…x_n^{(i)}]^T\) ,其中既有连续变量又有离散变量,对于连续变量 \(x_i\),我们会选定一个值 \(c\),当 \(x_i^{(j)}\leqslant c\) 时,该样本进入左边的子节点,否则进入右边的子节点(CART 树做二分类),连续变量的取值虽然是无穷的,但训练数据是有限的,如果训练集的数量是 \(m\),那么临界值 \(c\) 最多只有 \(m\) 种选择(考虑所有样本不能落入同一个子节点则是 \(m-1\) 种),其中一种处理思路是,取 \(c\in \{x_i^{(1)}, x_i^{(2)}, x_i^{(3)}...x_i^{(m)}\}\) ,选择分类效果最好的一种作为当前节点的切分值。

对于离散变量 \(x_j\),它的取值范围必然是有限的,假设共有 M 个不同的类别,进行分裂时,我们给定一个子集\(A\subset \{1,2,3…M\}\) ,如果 \(x_j^{(i)}\in A\) 进入左边的子节点,反之进入右边的子节点,遍历子集 \(A\),选择效果最好的一个作为切分点 。

对所有特征的最佳切分点进行评估,从中选出其中效果最好的特征和切分点\((s,c)\),用伪代码的形式表达这个流程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
初始化 当前最佳分类效果,当前最佳分裂特征和当前最佳分裂值;
for(i = 0;i < 特征数量; ++i){
选择待评估的特征 x_i;
if (x_i 是连续变量){
找出最佳切分值 c;
计算选择 x_i 进行分裂的对应的最佳效果 E_i;
}
else{
找出最佳子集 A;
计算选择 x_i 进行分裂的对应的最佳效果 E_i;
}
if (E_i > 当前最佳分类效果){
当前最佳分类效果 = E_i;
当前最佳分裂特征 = x_i;
当前最佳分裂值 = c 或 A;
}
}
  • 评估分裂效果

我们需要一个能定量评估分类效果的函数 \(\phi\) ,通常把它叫做不纯度函数(Impurity Function)。可以想象一个节点中样本类型越多、分布越均匀不纯度越高,举个例子,[+++—] 要比 [+++++-] 的不纯度高,因为分布均匀,前者传递的信息是模糊的,后者则能相对明确地代表 “+” 这个状态。我们的目的是要在子节点中获得尽可能小的不纯度,即每一次切分都让不纯度降低。假设数据总共有 \(K\) 个类型,对于当前节点 \(t\) ,类型为 \(k\) 的数据所占比例为 \(p(k|t)\) ,容易推得 \(\sum_{k=1}^Kp(k|t)=1\)

不纯度函数可以有很多选择,但必须遵循以下原则:

  1. 节点中各类型的数据均匀分布时,不纯度最高,即 \(p(k|t)=1/K,(k=1,2…K)\) 时,\(\phi\) 达到最大;
  2. 节点中只有一种类型的数据时,不纯度最低,即 \(p(k|t)=1,p(i|t)=0,(i=0,1…k-1,k+1…K)\) 时,\(\phi\) 达到最小;
  3. \(\phi\) 对于 \(p\) 是对称的,即任意对调 \(p(i|t)\)\(p(j|t)\) 的值,\(\phi\) 不变;

给定不纯度函数 \(\phi\) 后,我们用 \(i(t)\) 来表示节点 \(t\) 的不纯度: \[ i(t)=\phi(p(1|t), p(2|t)...p(K|t)) \] 有了节点的不纯度之后,我们就可以定量评估对于节点 \(t\),分裂 \(s\) (表示具体的切分特征和切分点)的效果 \(\Phi(s,t)\)\[ \Phi(s,t)=i(t)-p_Ri(t_R)-p_Li(t_L) \] 其中, \(p_R, p_L\) 分别表示节点 \(t\) 中的数据落入右边子节点和左边子节点的概率,\(i(t_R), i(t_L)\) 则为子节点的不纯度,显然 \(\Phi(s,t)\) 表示的是,经过分裂之后,子节点的不纯度和相对于父节点来说降低了多少,这个值越大表示这次分裂的效果越好。

最后我们再定义树的不纯度 \(I(T)\)\[ I(T)=\sum_{t\in \widetilde{T}}i(t)p(t) \] 注意上式只对所有的叶子节点 \(\widetilde{T}\) 进行求和,\(p(t)\) 可以用进入叶子节点 \(t\) 中的样本数量除以样本总量求得。

最后的最后给出三个常用的不纯度函数:

  1. 信息熵函数(Entropy Function):\(\sum_{j=1}^{K}p_j \text{ log }\frac{1}{p_j}\)
  2. 错误分类率(Misclassification Rate):\(1 - \mathbf{max}_j p_j\)
  3. 基尼指数(Gini Index):\(\sum_{j=1}^{K}p_j (1-p_j)=1-\sum_{j=1}^{K}p_{j}^{2}\)

读者可以尝试证明一下这三个不纯度函数是否满足我们之前所说的三个原则。

  • 给叶子节点赋值

所有的叶子节点最后都要给出一个答案,即落入这个叶子节点的数据应该属于哪一类?赋值方式简单且粗暴,少数服从多数,占多数的类即为该叶子节点的分类结果。

  • 什么时候停止分裂

这取决于我们的不纯度函数和种树的策略,可以证明如果以错误分类率作为不纯度函数,子节点的不纯度之和总是小于父节点(证明见最后),意味着分裂将持续进行直至叶子节点只有一个样本时(分无可分),这会让我们的树变得非常庞大,引起严重的过拟合问题。一种处理方法是,设置一个不纯度降低的阈值,当某次分裂带来的不纯度减小小于这个值时,可以认为此次分裂效果并不理想,应当停止分裂,这种方法的缺点在于分裂的依据过于短视,当前效果平平的分裂很有可能对之后几步影响巨大;另一种较为常用的策略是,不限制树的生长,后期进行剪枝。

回归任务

决策树处理回归任务的流程与分类任务基本一样,不同之处在于切分点的评估。在分类任务中,我们使用基尼指数等作为不纯度函数,这些不适用于回归任务(没有类别),我们使用 MSE 来评估分裂的效果。选择分裂值 \(c\) 将某个节点分为M(CART 树 M=2) 个子节点 \(R_1, R_2,…,R_M\),每个子节点的值应为常数 \(c_i\)\[ \min_{j,s}\left[\min_{c_i}\sum_{x^i\in R_i}(y^i-c_i)^2\right] \] 显然对于任一特征 \(j\) 的分裂点 \(s\),当 \(c_i=\mathrm{ave}(y^i| x^i\in R_i)\),上式取到极小值,扫描一遍数据即可确定 \(s\)。出于计算效率的考量,回归树和分类树在分裂时的选择都是贪心的,即选择对于当前步来说效果最好的,而非完成后的整体效果。

处理缺失值

决策树在处理缺失值时相比其他的算法有着天然的优势,一般来说有一下几种处理方法:

  1. 把「缺失」另做一类,或归到已有的一类中,即假设「缺失」也是对应特征的一个值;
  2. 根据概率分布到子节点中,可以全部分配到最大概率子节点,也可以根据各子节点中的数量按比例分配
  3. 替代分割法(surrogate split),当前节点的最佳分组是对于变量 \(s\) 时,如果数据中(训练或者测试数据)有个别条目缺失 \(s\) ,那么将通过其他变量进行一系列的分组,这些分组的结果非常接近 \(s\) 的最佳分组结果,从中选出最接近的分组方法来处理缺失 \(s\) 属性的条目。
  4. 计算不纯度时,基于该特征未缺失的样本进行计算,而对于有缺失的样本可以将其添加到该特征下的所有分类中

对于替代分割法,还需要说明一点,在寻找替代的分组时,不以不纯度减小作为准则,仅仅追求分类结果与目标分组最接近的(是对于当前节点分组最接近的还是所有衍生的子节点?)。替代分割法是最能体现决策树优势的一种处理缺失值的策略。

尽管如此,作为相关问题领域的专家,往往有对于缺失数据更为可靠的补全方法,如果你有更可靠的依据或方法去补全数据,还是不要太依赖决策树的自动补全。

剪枝预热

防止决策树过拟合有这几种方法:

  1. 停止准则:当某次分割的不纯度降低小于我们制定的标准时,该节点不再分裂;
  2. 剪枝:不限制树的生长,在树形成之后合并某些节点,即剪枝;
  3. 限定树的深度:当决策树层数达到限定的值深度时,停止生长;

停止准则最大的问题是视线过于短浅,有时候一次不是那么完美的分裂在当前看起来效果不好,但是可能给后续的继续分裂创造良好的基础。

实际的情况是,如果我们不限制树的大小,决策树总是倾向于尽可能长得大(层数多,叶子节点多),这带来的一个问题就是过拟合(可以证明父节点错分率必然大于子节点,决策树倾向于长大)。

在讲剪枝前(Pruning)先引入几个概念:

  • Descendant:派生节点,如果一个节点 \(t'\) 可以从另一个节点 \(t\) 沿着一条连续的路径向下派生出来,我们就说,\(t'\)\(t\) 的派生节点;
  • Ancestor:先祖节点,如果 \(t'\)\(t\) 的派生节点,那么反过来 \(t\) 就是 \(t'\) 的派生节点;
  • Branch:分支,如果 \(T_t\) 是树 \(T\) 的一个分支,那么 \(T_t\) 的根节点 \(t\) 以及 \(t\) 所包含的所有派生节点构成了这个分支 \(T_t\)
  • Pruning:剪枝的过程是从原来的树 \(T\) 中剪去一个分支 \(T_t\) ,这个过程会去掉 \(T_t\) 中的所有节点(除了 \(T_t\) 的根节点 \(t\)),可以用 \(T-T_t\) 来表示这个剪枝的过程;剪枝过后的树可以用 \(T'\) 来表示,用 \(T'<T\) 表示剪枝过后的树小于原来的树;

一个非常残酷的现实是,即使树 \(T\) 的规模并不是很大,其分支数量也是相当惊人的,穷举所有的分支进行判断并不可行,我们需要一种更为聪明的策略来进行剪枝,敬请期待决策树系列的下一篇文章:Minimal Cost-Complexity Pruning…

一个例子

例子来源于 Programming Collective Intelligence 第七章,根据相关特征预测用户的会员类型:

来源 地域 读过FAQ 浏览网页数量 会员类型
1 digg USA yes 24 Basic
2 kiwitobes France yes 23 Basic
3 (direct) UK no 21 Basic
4 digg New Zealand yes 12 Basic
5 google UK yes 18 Basic
6 kiwitobes France yes 19 Basic
7 slashdot USA yes 18 None
8 (direct) Zealand no 12 None
9 slashdot France yes 19 None
10 digg USA no 18 None
11 google UK no 18 None
12 kiwitobes UK no 19 None
13 slashdot UK no 21 None
14 google France yes 23 Premium
15 google UK no 21 Premium
16 google USA no 24 Premium

使用基尼指数作为不纯度函数,决策树的构建过程如下:

根节点基尼指数为: \[ i_{root}=\frac{6}{16}\times (1-\frac{6}{16})+\frac{7}{16}\times (1-\frac{7}{16})+\frac{3}{16}\times (1-\frac{3}{16})=0.6328 \]

  1. 选择’来源‘作为切分的特征:
  • 切分点为 google 时: \[ \begin{split} i_{google}&=p(google)\times i(t_L) + p(not\ google)\times i(t_R)\\ &=\frac{5}{16}\times \left[\frac{3}{5}\times (1-\frac{3}{5})+\frac{1}{5}\times (1-\frac{1}{5})+\frac{1}{5}\times (1-\frac{1}{5})\right]\\ &+\frac{11}{16}\times \left[\frac{6}{11}\times (1-\frac{6}{11})+\frac{5}{11}\times (1-\frac{5}{11})\right]\\ &=0.5159 \end{split} \]

  • 切分点为 digg 时: \[ \begin{split} i_{digg}&=p(digg)\times i(t_L) + p(not\ digg)\times i(t_R)\\ &=\frac{3}{16}\times \left[\frac{2}{3}\times (1-\frac{2}{3})+\frac{1}{3}\times (1-\frac{1}{3})\right]\\ &+\frac{13}{16}\times \left[\frac{6}{13}\times (1-\frac{6}{13})+\frac{3}{13}\times (1-\frac{3}{13})+\frac{4}{13}\times (1-\frac{4}{13})\right]\\ &=0.6025 \end{split} \]

  • 切分点为 kiwitobes 时: \[ \begin{split} i_{kiwitobes}&=p(kiwitobes)\times i(t_L) + p(not\ kiwitobes)\times i(t_R)\\ &=\frac{3}{16}\times \left[\frac{2}{3}\times (1-\frac{2}{3})+\frac{1}{3}\times (1-\frac{1}{3})\right]\\ &+\frac{13}{16}\times \left[\frac{6}{13}\times (1-\frac{6}{13})+\frac{3}{13}\times (1-\frac{3}{13})+\frac{4}{13}\times (1-\frac{4}{13})\right]\\ &=0.6025 \end{split} \]

  • 切分点为 slashdot 时: \[ \begin{split} i_{slashdot}&=p(slashdot)\times i(t_L) + p(not\ slashdot)\times i(t_R)\\ &=\frac{3}{16}\times \left[0\right]\\ &+\frac{13}{16}\times \left[\frac{6}{13}\times (1-\frac{6}{13})+\frac{3}{13}\times (1-\frac{3}{13})+\frac{4}{13}\times (1-\frac{4}{13})\right]\\ &=0.5192 \end{split} \]

  • 切分点为 (direct) 时: \[ \begin{split} i_{ (direct)}&=p( (direct))\times i(t_L) + p(not\ (direct))\times i(t_R)\\ &=\frac{2}{16}\times \left[\frac{1}{2}\times (1-\frac{1}{2})+\frac{1}{2}\times (1-\frac{1}{2})\right]\\ &+\frac{14}{16}\times \left[\frac{6}{14}\times (1-\frac{6}{14})+\frac{3}{14}\times (1-\frac{3}{14})+\frac{5}{14}\times (1-\frac{5}{14})\right]\\ &=0.625 \end{split} \] 可见选择 “来源” 作为分裂的特征,当切分点为 google 时不纯度下降最大=0.6328-0.5159=0.1169

  1. 选择地域作为分裂特征:
  • 切分点为 USA 时: \[ \begin{split} i_{USA}&=p(USA)\times i(t_L) + p(not\ USA)\times i(t_R)\\ &=\frac{4}{16}\times \left[\frac{2}{4}\times (1-\frac{2}{4})+\frac{1}{4}\times (1-\frac{1}{4})+\frac{1}{4}\times (1-\frac{1}{4})\right]\\ &+\frac{12}{16}\times \left[\frac{2}{12}\times (1-\frac{2}{12})+\frac{5}{12}\times (1-\frac{5}{12})+\frac{5}{12}\times (1-\frac{5}{12})\right]\\ &=0.625 \end{split} \]

  • 切分点为 France 时: \[ \begin{split} i_{France}&=p(France)\times i(t_L) + p(not\ France)\times i(t_R)\\ &=\frac{4}{16}\times \left[\frac{2}{4}\times (1-\frac{2}{4})+\frac{1}{4}\times (1-\frac{1}{4})+\frac{1}{4}\times (1-\frac{1}{4})\right]\\ &+\frac{12}{16}\times \left[\frac{2}{12}\times (1-\frac{2}{12})+\frac{6}{12}\times (1-\frac{6}{12})+\frac{4}{12}\times (1-\frac{4}{12})\right]\\ &=0.6145 \end{split} \]

  • 切分点为 UK 时: \[ \begin{split} i_{UK}&=p(UK)\times i(t_L) + p(not\ UK)\times i(t_R)\\ &=\frac{6}{16}\times \left[\frac{2}{6}\times (1-\frac{2}{6})+\frac{3}{6}\times (1-\frac{3}{6})+\frac{1}{6}\times (1-\frac{1}{6})\right]\\ &+\frac{10}{16}\times \left[\frac{2}{10}\times (1-\frac{2}{10})+\frac{4}{10}\times (1-\frac{4}{10})+\frac{4}{10}\times (1-\frac{4}{10})\right]\\ &=0.6291 \end{split} \]

  • 切分点为 New Zealand 时: \[ \begin{split} i_{New Zealand}&=p(New\ Zealand)\times i(t_L) + p(not\ New\ Zealand)\times i(t_R)\\ &=\frac{2}{16}\times \left[\frac{1}{2}\times (1-\frac{1}{2})+\frac{1}{2}\times (1-\frac{1}{2})\right]\\ &+\frac{14}{16}\times \left[\frac{6}{14}\times (1-\frac{6}{14})+\frac{3}{14}\times (1-\frac{3}{14})+\frac{5}{14}\times (1-\frac{5}{14})\right]\\ &=0.625 \end{split} \]

可见选择 “地域” 作为分裂的特征,当切分点为 France 时不纯度下降最大=0.6328-0.6145=0.0183

  1. 选择“读过 FAQ” 作为特征时:
  • 切分点为 yes 时: \[ \begin{split} i_{yes}&=p(yes)\times i(t_L) + p(no)\times i(t_R)\\ &=\frac{8}{16}\times \left[\frac{1}{8}\times (1-\frac{1}{8})+\frac{2}{8}\times (1-\frac{2}{8})+\frac{5}{8}\times (1-\frac{5}{8})\right]\\ &+\frac{8}{16}\times \left[\frac{2}{8}\times (1-\frac{2}{8})+\frac{1}{8}\times (1-\frac{1}{8})+\frac{5}{8}\times (1-\frac{5}{8})\right]\\ &=0.5312 \end{split} \]

​ “读过 FAQ” 为二元特征,仅有一个分裂点为不纯度下降=0.6328-0.5312=0.1016

  1. 选择“浏览网页数量” 作为特征时:

这是一个连续特征,根据样本中的数据依次选取切分点: \[ \begin{split} i_{浏览数量=12}&=p(浏览数量\leqslant12)\times i(t_L) + p(浏览数量>12)\times i(t_R)\\ &=\frac{2}{16}\times \left[\frac{1}{2}\times (1-\frac{1}{2})+\frac{1}{2}\times (1-\frac{1}{2})\right]\\ &+\frac{14}{16}\times \left[\frac{6}{14}\times (1-\frac{6}{14})+\frac{3}{14}\times (1-\frac{3}{14})+\frac{5}{14}\times (1-\frac{5}{14})\right]\\ &=0.625 \end{split} \]

\[ \begin{split} i_{浏览数量=18}&=p(浏览数量\leqslant18)\times i(t_L) + p(浏览数量>18)\times i(t_R)\\ &=\frac{6}{16}\times \left[\frac{4}{6}\times (1-\frac{4}{6})+\frac{2}{6}\times (1-\frac{2}{6})\right]\\ &+\frac{10}{16}\times \left[\frac{3}{10}\times (1-\frac{3}{10})+\frac{4}{10}\times (1-\frac{4}{10})+\frac{3}{10}\times (1-\frac{3}{10})\right]\\ &=0.5791 \end{split} \]

\[ \begin{split} i_{浏览数量=19}&=p(浏览数量\leqslant19)\times i(t_L) + p(浏览数量>19)\times i(t_R)\\ &=\frac{9}{16}\times \left[\frac{6}{9}\times (1-\frac{6}{9})+\frac{3}{9}\times (1-\frac{3}{9})\right]\\ &+\frac{7}{16}\times \left[\frac{3}{7}\times (1-\frac{3}{7})+\frac{3}{7}\times (1-\frac{3}{7})+\frac{1}{7}\times (1-\frac{1}{7})\right]\\ &=0.5178 \end{split} \]

\[ \begin{split} i_{浏览数量=21}&=p(浏览数量\leqslant21)\times i(t_L) + p(浏览数量>21)\times i(t_R)\\ &=\frac{12}{16}\times \left[\frac{7}{12}\times (1-\frac{7}{12})+\frac{1}{12}\times (1-\frac{1}{12})+\frac{4}{12}\times (1-\frac{4}{12})\right]\\ &+\frac{4}{16}\times \left[\frac{2}{4}\times (1-\frac{2}{4})+\frac{2}{4}\times (1-\frac{2}{4})\right]\\ &=0.5312 \end{split} \]

\[ \begin{split} i_{浏览数量=23}&=p(浏览数量\leqslant23)\times i(t_L) + p(浏览数量>23)\times i(t_R)\\ &=\frac{14}{16}\times \left[\frac{7}{14}\times (1-\frac{7}{14})+\frac{2}{14}\times (1-\frac{2}{14})+\frac{5}{14}\times (1-\frac{5}{14})\right]\\ &+\frac{2}{16}\times \left[\frac{1}{2}\times (1-\frac{1}{2})+\frac{1}{2}\times (1-\frac{1}{2})\right]\\ &=0.5892 \end{split} \]

​ 取 19 进行分裂时不纯度下降最大 = 0.6328-0.5178=0.1150

比较四个特征下的不纯度下降值,可以发现第一次分裂时选择 (来源, google) 作为分裂特征效果最好,并将样本分成以下两部分:

1
2
t_l = [5, 11, 14, 15, 16]
t_r = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13]

接下来再对左右子树重复上述的步骤直到满足停止条件即可,限于篇幅,这里不再赘述。对应的代码 👉戳这里,你也可以看 Patrick L. Lê 写的一个详细的教程,把代码和上面讲的对应起来,对于加深理解很有帮助。

附证明:

\(R(t) \geq R(t_L) + R(t_R)\)

令节点 \(t\) 的多数类是 \(j^*\)

\[ \begin {align}p(j^* |t)& = p(j^*,t_L |t) + p(j^*,t_R |t) \\ & = p(j^*|t_L) p(t_L|t)+p(j^*|t_R) p(t_R|t) \\ & = p_Lp(j^*|t_L)+p_Rp(j^*|t_R) \\ & \le p_L\underset{j}{\text{ max }}p(j|t_L)+p_R\underset{j}{\text{ max }}p(j|t_R) \\ r(t)& = 1-p(j^*|t) \\ & \ge 1-\left[ p_L\underset{j}{\text{ max }}p(j|t_L)+p_R\underset{j}{\text{ max }}p(j|t_R) \right] \\ & = p_L(1-\underset{j}{\text{ max }}p(j|t_L))+p_R(1-\underset{j}{\text{ max }}p(j|t_R)) \\ & = p_Lr(t_L)+p_Rr(t_R) \\ R(t)& = p(t)r(t) \\ & \ge p(t)p_Lr(t_L)+p(t)p_Rr(t_R) \\ & = p(t_L)r(t_L)+p(t_R)r(t_R) \\ & = R(t_L)+R(t_R) \\ \end {align} \]

图一来自 IBM 素材库

图二来自蓝鲸的网站分析笔记

文末证明过程参考:https://onlinecourses.science.psu.edu/stat857/node/53