1500字范文,内容丰富有趣,写作好帮手!
1500字范文 > 机器学习:决策树(Decision Tree)--ID3算法

机器学习:决策树(Decision Tree)--ID3算法

时间:2018-08-14 12:21:01

相关推荐

机器学习:决策树(Decision Tree)--ID3算法

决策树的主要算法

构建决策树的关键:按照什么样的次序来选择变量(属性/特征)作为分类依据。

根据不同的目标函数,建立决策树主要有以下三种算法

ID3(J.Ross Quinlan-1975) 核心:信息熵,信息增益C4.5——ID3的改进,核心:信息增益比/增益率CART(Breiman-1984),核心:基尼系数

ID3算法

由Ross Quinlan在1986年提出,ID3决策树可以有多个分支,但是不能处理特征值为连续的情况,根据“最大信息熵增益”选取当前最佳的特征来分割数据。开始之前我们先引入几个概念。

香农熵(Shannon Entropy)

熵用来表示事件的不确定性,熵越小代表事件发生的概率就越大。若熵为0,表示事件百分百发生。

对于常见分类系统来说,假设类别D是变量,D可能取值D1,D2…Dn,每个类别出现的概率为P(D1),P(D2)…P(Dn),共n类。

分类系统的熵为:

当我们的概率是由数据统计得到时,此时的熵称为经验熵(empirical entropy)。

举个例子:

我们先不考虑特征,该样本最终分为结果只有买与不买两类,根据统计可知在1024个样本中有641个数据结果为买,383个数据结果为不买。显然买的概率为(641/1024),不买的概率为(383/1024)。显然这两个概率是我们统计得到的,通过这两个统计得到的概率计算的熵称为经验熵,为:

这就是我们第一步要做的,计算决策属性的熵。

该列中买与不买就是决策属性。

条件熵

条件熵H(Y|X)表示在已知随机变量X的条件下随机变量Y的不确定性。

同样的,如果公式中的概率是由统计估计出来的话,这时的条件熵称为经验条件熵。

我们用H(D|A)表示在给定的特征A的条件下D的经验条件熵。假设特征A将D划分为v个子集{D1,D2,…Dv},例:年龄特征(A),将我们的数据集分成了三份{青年,中年,老年}.

此时条件熵为:

信息增益(information gain)

计算完决策属性的熵以后,我们要根据信息增益来排序特征,那什么是信息增益呢?

信息增益是相对于特征而言的,表示得知特征A的信息而使得类Y的信息的不确定性减少的程度。信息增益越大,特征对最终的分类结果影响也就越大,因此此特征就被选上作为我们的分类特征。

简单地将,可以理解为一个特征对最终结果相关程度,信息增益大说明该特征与分类结果的关联性很强。特征的信息增益小,则说明该特征对分类的结果影响很小。

特征A信息增益( g(D,A) ) = 决策属性的熵 - 特征A的平均信息期望

定义为:

上述已经计算了决策属性的熵,现在我们计算某个特征的平均熵期望.

年龄(A1)将数据集为三个组:

青年(D1),中年(D2),老年(D3)

我们把年龄为青年(D1)的样本全都提取出来:

总共D1=(64*4+128)384个样本,做以下定义

D11(买)=128

D12(不买)=256

青年中买与不买的概率:

P(D11) = 128/384

P(D12) = 256/384

根据分类熵的定义得:

同理得到中年H(D2)=0,老年H(D3)=0.9175

青年组所占全部数据集比例(D1) = 384/1024 = 0.375

中年组所占全部数据集比例(D2) = 256/1024 = 0.25

老年组所占全部数据集比例(D3) = 384/1024 = 0.375

得到年龄(A1)的平均熵期望:

年龄信息增益为

上述讲过,信息增益为一个特征对最终分类结果的影响情况,信息增益越大代表该特征对最终分类结果的影响越大.所以我们还需要计算出收入,学生,信誉三个特征的信息增益信息增益越大的特征就要在决策树的前面

经过相同计算原理的我们得到

年龄特征的信息增益=0.2660

收入特征的信息增益=0.0176

学生特征的信息增益=0.1726

信誉特征的信息增益=0.0453

所以要将年龄特征作为我们决策树的根节点,之后采用递归的思想,直到叶子节点(决策属性)。

我们以西瓜书中西瓜数据集2.0举例

先约定:

色泽 :青绿:0 乌黑:1 浅白:2

根蒂:蜷缩:0 稍蜷:1 硬挺:2

敲声:浊响:0 沉闷:1 清脆:2

纹理:清晰:0 模糊:1 稍糊:2

脐部:凹陷:0 平坦:1 稍凹:2

触感:硬滑:0 软粘:1

代码来自大牛jack-cui,

仅修改一小部分代码。

# !/usr/bin/env python# -*- coding:utf-8 -*-from matplotlib.font_manager import FontPropertiesimport matplotlib.pyplot as pltfrom math import logimport operatordef createDataSet():dataSet = [[0, 0, 0, 0, 0, 0,'yes'], #数据集[1, 0, 1, 0, 0, 0,'yes'],[1, 0, 0, 0, 0, 0,'yes'],[0, 0, 1, 0, 0, 0,'yes'],[2, 0, 0, 0, 0, 0,'yes'],[0, 1, 0, 0, 2, 1,'yes'],[1, 1, 0, 2, 2, 1,'yes'],[1, 1, 0, 0, 2, 0,'yes'],[1, 1, 1, 2, 2, 0,'no'],[0, 2, 2, 0, 1, 1,'no'],[2, 2, 2, 1, 1, 0,'no'],[2, 0, 0, 1, 1, 1,'no'],[0, 1, 0, 2, 0, 0,'no'],[2, 1, 1, 2, 0, 0,'no'],[1, 1, 0, 0, 2, 1,'no'],[2, 0, 0, 1, 1, 0,'no'],[0, 0, 1, 2, 2, 0,'no']]labels=['色泽', '根蒂', '敲声', '纹理', '脐部', '触感', '好瓜']#分类属性return dataSet, labels#返回数据集和分类属性 #返回数据集和分类属性def calcShannonEnt(dataSet):numEntires = len(dataSet) #返回数据集的行数labelCounts = {} #保存每个标签(Label)出现次数的字典for featVec in dataSet: #对每组特征向量进行统计currentLabel = featVec[-1]#提取标签(Label)信息if currentLabel not in labelCounts.keys(): #如果标签(Label)没有放入统计次数的字典,添加进去labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1#Label计数shannonEnt = 0.0 #经验熵(香农熵)for key in labelCounts: #计算香农熵prob = float(labelCounts[key]) / numEntires #选择该标签(Label)的概率shannonEnt -= prob * log(prob, 2) #利用公式计算return shannonEnt #返回经验熵(香农熵)def splitDataSet(dataSet, axis, value):retDataSet = []#创建返回的数据集列表for featVec in dataSet: #遍历数据集if featVec[axis] == value:reducedFeatVec = featVec[:axis]#去掉axis特征reducedFeatVec.extend(featVec[axis+1:])#将符合条件的添加到返回的数据集retDataSet.append(reducedFeatVec)return retDataSet #返回划分后的数据集def chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1#特征数量baseEntropy = calcShannonEnt(dataSet) #计算数据集的香农熵bestInfoGain = 0.0 #信息增益bestFeature = -1#最优特征的索引值for i in range(numFeatures):#遍历所有特征#获取dataSet的第i个所有特征featList = [example[i] for example in dataSet]uniqueVals = set(featList)#创建set集合{},元素不可重复newEntropy = 0.0 #经验条件熵for value in uniqueVals:#计算信息增益subDataSet = splitDataSet(dataSet, i, value) #subDataSet划分后的子集prob = len(subDataSet) / float(len(dataSet)) #计算子集的概率newEntropy += prob * calcShannonEnt(subDataSet)#根据公式计算经验条件熵infoGain = baseEntropy - newEntropy #信息增益# print("第%d个特征的增益为%.3f" % (i, infoGain)) #打印每个特征的信息增益if (infoGain > bestInfoGain): #计算信息增益bestInfoGain = infoGain #更新信息增益,找到最大的信息增益bestFeature = i #记录信息增益最大的特征的索引值return bestFeature#返回信息增益最大的特征的索引值def majorityCnt(classList):classCount = {}for vote in classList:#统计classList中每个元素出现的次数if vote not in classCount.keys():classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True) #根据字典的值降序排序return sortedClassCount[0][0] #返回classList中出现次数最多的元素def createTree(dataSet, labels, featLabels):classList = [example[-1] for example in dataSet] #取分类标签(是否放贷:yes or no)if classList.count(classList[0]) == len(classList): #如果类别完全相同则停止继续划分return classList[0]if len(dataSet[0]) == 1: #加标签共7列 #遍历完所有特征时返回出现次数最多的类标签return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)#选择最优特征bestFeatLabel = labels[bestFeat] #最优特征的标签featLabels.append(bestFeatLabel)myTree = {bestFeatLabel:{}}#根据最优特征的标签生成树#del(labels[bestFeat])#删除已经使用特征标签featValues = [example[bestFeat] for example in dataSet] #得到训练集中所有最优特征的属性值uniqueVals = set(featValues) #去掉重复的属性值for value in uniqueVals:#遍历特征,创建决策树。del_bestFeat = bestFeatdel_labels = labels[bestFeat]del (labels[bestFeat])myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)labels.insert(del_bestFeat, del_labels)return myTreedef getNumLeafs(myTree):numLeafs = 0 #初始化叶子firstStr = next(iter(myTree)) #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]secondDict = myTree[firstStr] #获取下一组字典for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':#测试该结点是否为字典,如果不是字典,代表此结点为叶子结点numLeafs += getNumLeafs(secondDict[key])else: numLeafs +=1return numLeafsdef getTreeDepth(myTree):maxDepth = 0 #初始化决策树深度firstStr = next(iter(myTree)) #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]secondDict = myTree[firstStr] #获取下一个字典for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':#测试该结点是否为字典,如果不是字典,代表此结点为叶子结点thisDepth = 1 + getTreeDepth(secondDict[key])else: thisDepth = 1if thisDepth > maxDepth: maxDepth = thisDepth #更新层数return maxDepthdef plotNode(nodeTxt, centerPt, parentPt, nodeType):arrow_args = dict(arrowstyle="<-") #定义箭头格式font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14) #设置中文字体createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', #绘制结点xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] #计算标注位置yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):decisionNode = dict(boxstyle="sawtooth", fc="0.8")#设置结点格式leafNode = dict(boxstyle="round4", fc="0.8") #设置叶结点格式numLeafs = getNumLeafs(myTree) #获取决策树叶结点数目,决定了树的宽度depth = getTreeDepth(myTree)#获取决策树层数firstStr = next(iter(myTree))#下个字典cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #中心位置plotMidText(cntrPt, parentPt, nodeTxt) #标注有向边属性值plotNode(firstStr, cntrPt, parentPt, decisionNode)#绘制结点secondDict = myTree[firstStr]#下一个字典,也就是继续绘制子结点plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD#y偏移for key in secondDict.keys():if type(secondDict[key]).__name__=='dict': #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点plotTree(secondDict[key],cntrPt,str(key))#不是叶结点,递归调用继续绘制else:#如果是叶结点,绘制叶结点,并标注有向边属性值plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef createPlot(inTree):fig = plt.figure(1, facecolor='white') #创建figfig.clf()#清空figaxprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #去掉x、y轴plotTree.totalW = float(getNumLeafs(inTree)) #获取决策树叶结点数目plotTree.totalD = float(getTreeDepth(inTree)) #获取决策树层数plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; #x偏移plotTree(inTree, (0.5,1.0), '')#绘制决策树plt.show() #显示绘制结果def classify(inputTree, featLabels, testVec):firstStr = next(iter(inputTree))#获取决策树结点secondDict = inputTree[firstStr]#下一个字典featIndex = featLabels.index(firstStr)for key in secondDict.keys():if testVec[featIndex] == key:if type(secondDict[key]).__name__ == 'dict':classLabel = classify(secondDict[key], featLabels, testVec)else: classLabel = secondDict[key]return classLabelif __name__ == '__main__':dataSet, labels = createDataSet()featLabels = []myTree = createTree(dataSet, labels, featLabels)testVec = [0,1,1,0]#测试纹理清晰0,根蒂稍蜷1,色泽青绿1,敲声浊响0,是不是好瓜result = classify(myTree, featLabels, testVec)if result == 'yes':print('好瓜')if result == 'no':print('坏瓜')createPlot(myTree)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。