1. 树的剪枝算法
输入:
ID3或C4.5的决策树
参数a
输出:
剪枝后的决策树
1.1. 递归版本
- 从树的根结点开始
- 如果该结点的孩子中存在子树(不全是叶子结点),则先对子树做prune
- 所有子树都prune之后,再判断该结点的孩子是否都是叶子
- 如果不全是叶子,对该结点的算法结束
- 如果该结点的孩子都是叶子,则尝试对该结点剪枝
5.a 计算,代表该结点split后以该结点为根结点的子树的损失,其损失的计算方式与整棵树的计算方式相同
因此在生成树的时候为每个叶子结点保存它的和
5.b 计算,代表该结点split之前作为一个叶子结点时的熵。因此在生成树时记录该split之前的熵
5.c 比较和,如果,则将该结点修改为叶子结点。即:计算该结点的输出标记,并删除它的所有孩子结点。 - 对该结点的处理结束,由于这个算法是递归调用的,如果该结点有父结点,则要继续处理它的父结点。
def isTree(Node):
return 'child' in Node.keys()
def Clip(Node):
bestNt = 0
for value in Node['child']:
if Node['child'][value]['Nt'] > bestNt:
bestNt = Node['child'][value]['Nt']
bestLabel = Node['child'][value]['label']
Node['label'] = bestLabel
Node.pop('child')
def Merge(Node, alpha):
# 计算CaTb
CT_b = 0
for value in Node['child']:
CT_b = CT_b + Node['child'][value]['Nt'] * Node['child'][value]['entropy'] + alpha
# 计算CaTa
CT_a = Node['entropy'] + alpha
# 剪枝的条件
if CT_a <= CT_b:
Clip(Node)
def prune(Node, alpha):
# 判断子结点中是否存在树
for value in Node['child']:
# 如果存在树
if isTree(Node['child'][value]):
# 先对树子结点做prune
prune(Node['child'][value], alpha)
# 对所有子树都prune之后,判断是否所有子树都是叶子
isAllLeaf = True
for value in Node['child']:
if isTree(Node['child'][value]):
isAllLeaf = False
break
# 如果所有子树都是叶子
if isAllLeaf:
# 尝试对结点做剪枝
Merge(Node, alpha)
1.2. DP版本
【?】如何通过DP实现