1500字范文,内容丰富有趣,写作好帮手!
1500字范文 > 手写数字图像识别-SVM算法投票法实现多分类

手写数字图像识别-SVM算法投票法实现多分类

时间:2021-07-27 17:49:47

相关推荐

手写数字图像识别-SVM算法投票法实现多分类

z机器学习课程中svm算法的应用,比起对算法的学习,更重要的收获应该是把算法理论和实际的应用结合起来了。记录一下这个简单的项目。

我没有系统得学习过python,基本上是够用就行,代码能跑就行。程序中有很多漏洞,请多指点。程序使用投票法实现0-9的手写数字图像识别。参考材料是《Python机器学习算法:原理、实现与案例》,参考使用了其中的svm算法和训练集、测试集文件。

1.准备函数

要实现多分类,就需要训练多个二分类器,0-9的数字识别,就需要训练45个二分类器。每次预测都要去训练45个分类器太浪费时间了!调用pickle包把训练的二分类器保存下来,之后使用分类器从文件读取变量就可以了,效率起飞!

# 保存变量函数def save_variable(v,filename):f=open(filename,'wb')pickle.dump(v,f)f.close()return filename# 读取变量函数def load_variable(filename):f=open(filename,'rb')r=pickle.load(f)f.close()return r

第一步,加载数据。为了训练分类器的效率着想,果断标准化。否则代码跑到天荒地老。这里使用的是sklearn的标准化函数,后面计算混淆矩阵、准确率等操作,都是使用sklearn。

# 加载数据集def load_set(train_set,test_set):# train setdata_train = np.genfromtxt(train_set, delimiter=',', dtype=float)X_train, y_train = data_train[:, :-1], data_train[:, -1]ss = StandardScaler()ss.fit(X_train)X_train_std = ss.transform(X_train)# test setdata_test = np.genfromtxt(test_set, delimiter=',', dtype=float)X_test, y_test = data_test[:, :-1], data_test[:, -1]X_test_std = ss.transform(X_test)return X_train,X_train_std,y_train,X_test,X_test_std,y_test

二分类器嘛,要么正例,要么负例,一次只能判断两个值,所以需要对原始的数据集进行处理,提取正例和负例标签。

#标签转换def extract(x,y,pos,neg):y_train = y[(y == pos) | (y == neg)]y_train = np.where(y_train == pos, +1, -1)X_train = x[(y == pos) | (y == neg)]return X_train,y_train

保存图像的函数,我使用的cv2进行保存。每行数据是一个8*8的图像,这里解释一下为什么每个数值都要乘256,我的代码里调用这个函数,输入的是标准化之后的数值,直接输出图片就是一团黑了。如果使用未标准化的数据,就没必要乘以任何数了。

#保存图像def genSamplePic(X,y,path):for num in range(X.shape[0]):img = X[num].reshape(8, 8) * 256if y[num] == 1:filename = path+"\正例_" + str(num) + ".png"else:filename = path+"\负例_" + str(num) + ".png"cv2.imwrite(filename, img)

二分类器预测完结果肯定也是要么正例要么负例,把结果还原成原始标签。

#还原标签def transLabel(clf,res):if res==1:return clf.posLabelelse:return clf.negLabel

训练二分类器的函数,0-9每两个数字都要有一个对应的分类器,一共需要45个分类器,那还是封装成一个函数调用吧。

#获取二分类器def genClassifier(X_train,y_train,pos,neg):data=extract(X_train,y_train,pos,neg)X_trained=data[0]y_trained=data[1]clf = SMO(C=1, tol=0.01, posLabel=pos, negLabel=neg, kernel='rbf', gamma=0.01)clf.train(X_trained, y_trained)return clf

这个函数是实现多分类的关键操作。读取45个分类器,生成一个分类器列表,每个分类器都预测一下这行数据的原始标签,并保存到结果列表里。出现次数最多的结果就是这一行数据最终的预测结果。

#投票法多分类实现def predictMC(X_test_k):#获取分类器clf = [[object for l in range(10)] for h in range(9)]for pos in range(9):for neg in range(pos + 1, 10):filename = "res\分类器" + str(pos) + str(neg) + ".txt"clf[pos][neg] = load_variable(filename)y_predk = list()for pos in range(9):for neg in range(pos + 1, 10):pred = clf[pos][neg].predict(X_test_k.reshape(1,-1))pred_raw = transLabel(clf[pos][neg], pred)y_predk.append(pred_raw)y_bincount = np.bincount(y_predk)result = np.argmax(y_bincount)return result

2.main

#!/usr/bin/python# -*- coding:utf-8 -*-import matplotlib.pyplot as pltfrom sklearn.metrics import confusion_matrixfrom sklearn.metrics import accuracy_scorefrom sklearn.metrics import precision_scorefrom sklearn.metrics import recall_scorefrom case26 import load_set, predictMC# 加载数据集loadset=load_set('optdigits.tra','optdigits.tes')X_train=loadset[0]X_train_std=loadset[1]y_train=loadset[2]X_test=loadset[3]X_test_std=loadset[4]y_test=loadset[5]print("X_train:\n",X_train)print("X_train_std:\n",X_train_std)print("X_train_std shape:",X_train_std.shape)print("y_train:\n",y_train)print("y_train shape:",y_train.shape)#训练分类器并保存(后面通过读取分类器实现预测,所以这一段运行过一次就不需要了)'''clf=[[object for l in range(10)] for h in range(9)]for pos in range(9):for neg in range(pos+1,10):#图形data_tra = extract(X_train_std, y_train, pos, neg)X_tra = data_tra[0]y_tra = data_tra[1]genSamplePic(X_tra, y_tra, "tra")data_tes = extract(X_test_std, y_test, pos, neg)X_tes = data_tes[0]y_tes = data_tes[1]genSamplePic(X_tes, y_tes, "tes")#训练分类器保存clf[pos][neg]=genClassifier(X_train_std,y_train,pos,neg)clf_filename = "res\分类器" + str(pos) + str(neg) + ".txt"save_variable(clf[pos][neg], clf_filename)'''#预测结果y_pred=list()for num in range(X_test_std.shape[0]):y_pred.append(predictMC(X_test_std[num]))print("y_pred:\n",y_pred)print("y_test:\n",y_test)matrix_tp=confusion_matrix(y_test,y_pred)print("matrix_tp:\n",matrix_tp)# 混淆矩阵plt.imshow(matrix_tp,cmap='Greens')plt.xticks(range(len(matrix_tp)), range(0,10))plt.yticks(range(len(matrix_tp)), range(0,10))for first_index in range(len(matrix_tp)):for second_index in range(len(matrix_tp[first_index])):plt.text(first_index, second_index, matrix_tp[first_index][second_index])plt.show()#显示混淆矩阵的图形结果accuracy_tp=accuracy_score(y_test,y_pred)print("accuracy_tp:",accuracy_tp)# 准确率precision_tp=precision_score(y_test,y_pred,average="macro")print("precision_tp:",precision_tp)# 精确率recall_tp=recall_score(y_test,y_pred,average="macro")print("recall_tp:",recall_tp)# 召回率

学习率和步长取0.01,准确率、精确率、召回率大概都在0.97左右,混淆矩阵如下图所示。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。