1500字范文,内容丰富有趣,写作好帮手!
1500字范文 > 机器学习实战(二)决策树DT(Decision Tree ID3算法)

机器学习实战(二)决策树DT(Decision Tree ID3算法)

时间:2022-10-15 22:20:32

相关推荐

机器学习实战(二)决策树DT(Decision Tree ID3算法)

目录

0.前言

1.信息增益(ID3)

2.决策树(Decision Tree)

3.实战案例

3.1. 隐形眼镜案例

3.2. 存储决策树

3.3. 决策树画图表示

学习完机器学习实战的决策树,简单的做个笔记。文中部分描述属于个人消化后的理解,仅供参考。

所有代码和数据可以访问我的 github

如果这篇文章对你有一点小小的帮助,请给个关注喔~我会非常开心的~

0.前言

决策树(Decision Tree)的执行流程很好理解,如下图所示(图源:西瓜书),在树上的每一个结点进行判断,选择分支,直到走到叶子结点,得出分类:

优点:计算复杂度不高、输出结果易于理解、对缺失值不敏感缺点:可能会产生过拟合适用数据类型:数值型和标称型(数值型数据需要离散化)

决策树构建中,目标就是找到当前哪个特征在划分数据时起到决定性作用,划分数据有多种办法,如信息增益(ID3)、信息增益率(C4.5)、基尼系数(CART),本篇主要介绍信息增益(ID3算法)。

1.信息增益(ID3)

首先,介绍香农熵(entropy),熵定义为信息的期望值,熵越高,说明信息的混乱程度越高

其中,表示数据集,表示数据集中的每一个类别,表示这个属于类别的数据占所有数据的比例。

信息增益(information gain)定义为原始的熵减去当前的熵,增益越大,说明当前熵越小,说明数据混乱程度越小

其中,表示按照此特征划分的子集数量,表示第个子集,表示子集的信息熵,表示子集数据占所有数据的比例。

注:信息增益更偏向于选择取值较多的特征,这是它的缺点。

2.决策树(Decision Tree)

算法流程可简单表示为:

遍历当前数据所有的特征,计算信息增益最大的特征,作为当前划分数据的结点,并去除此特征对划分后每个分支上的子集继续进行步骤如果当前子集内的数据都是同一类型,则停止划分,标记叶子结点如果子集内数据还未统一类型,而已经没有特征,则采用多数表决原则

3.实战案例

以下将展示书中的三个案例的代码段,所有代码和数据可以在github中下载:

3.1. 隐形眼镜案例

# coding:utf-8from math import logimport operatorimport pickle"""隐形眼镜案例"""# 计算香农熵def calcShannonEnt(dataSet):numEntries = len(dataSet)labelCounts = {}for featVec in dataSet:currentLabel = featVec[-1]if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key]) / numEntriesshannonEnt -= prob * log(prob, 2)return shannonEnt# 按照给定特征划分数据集def splitDataSet(dataSet, axis, value):retDataSet = []# 只选择第 axis 列的值为 value 的数据# 去除这个特征,取数据[:axis] 和 [axis+1:] 段for featVec in dataSet:if featVec[axis] == value:reducedFeatVec = featVec[:axis]reducedFeatVec.extend(featVec[axis + 1:])retDataSet.append(reducedFeatVec)return retDataSet# 选择最好的数据集划分方式def chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0bestFeature = -1# 遍历每一个特征for i in range(numFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList)newEntropy = 0.0# 遍历这个特征的所有特征值for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)# 判断这个子集占所有数据集的比例prob = len(subDataSet) / float(len(dataSet))# 新的信息熵 = 所有子集的信息熵乘以比例再求和newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropyif (infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeature# 多数表决原则def majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys():classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]# 创建决策树# labels 为特征的标签def createTree(dataSet, labels):# 获取当前数据集最后一列的类别信息classList = [example[-1] for example in dataSet]# 如果最后一列都是一种类别if classList.count(classList[0]) == len(classList):return classList[0]# 如果当前数据集没有可划分的特征if len(dataSet[0]) == 1:return majorityCnt(classList)# 获取最好的划分数据集特征bestFeat = chooseBestFeatureToSplit(dataSet)bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel: {}}# 在特征标签中删除当前特征del (labels[bestFeat])# 获取这个特征的列,遍历此特征的所有特征值featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels[:]# 特征有几个取值,这个结点就有几个分支# 每个取值,都划分出子集,递归建树myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)return myTree# 分类函数def classify(inputTree, featLabels, testVec):# 获取第一个特征firstStr = list(inputTree.keys())[0]# 获取这个特征下的键值对的值secondDict = inputTree[firstStr]# 获取这个特征的索引featIndex = featLabels.index(firstStr)# 遍历每一个分支for key in secondDict.keys():if testVec[featIndex] == key:# 判断当前分支下是否还有分支if type(secondDict[key]).__name__ == 'dict':classLabel = classify(secondDict[key], featLabels, testVec)else:classLabel = secondDict[key]return classLabelif __name__ == '__main__':fr = open('lenses.txt')lenses = [inst.strip().split('\t') for inst in fr.readlines()]lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']lensesTree = createTree(lenses, lensesLabels)print(lensesTree)

3.2. 存储决策树

# 存储树def storeTree(inputTree, filename):fw = open(filename, 'wb')pickle.dump(inputTree, fw)fw.close()# 取出存储的树def grabTree(filename):fr = open(filename, 'rb')return pickle.load(fr)

3.3. 决策树画图表示

# coding:utf-8import matplotlib.pyplot as plt# 解决显示中文问题from pylab import *mpl.rcParams['font.sans-serif'] = ['SimHei']"""决策树画图"""# 创建树的字典def retrieveTree(i):listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},{'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'young': 'hard', 'presbyopic': 'no lenses'}},'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'young': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}}}}}}}]return listOfTrees[i]# 获取叶节点的数目def getNumLeafs(myTree):numLeafs = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumLeafs(secondDict[key])else:numLeafs += 1return numLeafs# 获取树的层数def getTreeDepth(myTree):maxDepth = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = 1 + getTreeDepth(secondDict[key])else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepth# 使用文本注解绘制树节点decisionNode = dict(boxstyle='sawtooth', fc='0.8')leafNode = dict(boxstyle='round4', fc='0.8')arrow_args = dict(arrowstyle='<-')# 画节点def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va='center', ha='center', bbox=nodeType,arrowprops=arrow_args)# 在父子节点间填充文本信息def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString)# 画树def plotTree(myTree, parentPt, nodeTxt):numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = list(myTree.keys())[0]cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':plotTree(secondDict[key], cntrPt, str(key))else:plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# 主要画图函数def createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)plotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5 / plotTree.totalWplotTree.yOff = 1.0plotTree(inTree, (0.5, 1.0), '')plt.show()if __name__ == '__main__':myTree = retrieveTree(1)createPlot(myTree)

如果这篇文章对你有一点小小的帮助,请给个关注喔~我会非常开心的~

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。