之前学习了k-近邻算法的实现后,参考《机器学习实战》中的例子进行了k-近邻算法的测验,主要测试了针对约会网站和手写识别系统的数据分类,这两个测试使用的是《机器学习实战》提供的数据集。
在编写函数前,需在.py文件中添加以下内容:
~~~
from numpy import *
import numpy as np
import operator
from os import listdir
~~~
**第一部分是针对约会网站的数据分类**,用于改进约会网站的配对效果。该实例的简介如下:
* 海伦一直使用在线约会网站寻找合适自己的约会对象。尽管约会网站会推荐不同的人选,但她没有从中找到喜欢的人。经过一番总结,她发现曾交往过三种类型的人:
1.不喜欢的人( 以下简称1 );
2.魅力一般的人( 以下简称2 );
3.极具魅力的人( 以下简称3 )
* 尽管发现了上述规律,但海伦依然无法将约会网站推荐的匹配对象归入恰当的分类。她觉得可以在周一到周五约会哪些魅力一般的人,而周末则更喜欢与那些极具魅力的人为伴。海伦希望我们的分类软件可以更好的帮助她将匹配对象划分到确切的分类中。此外海伦还收集了一些约会网站未曾记录的数据信息,她认为这些数据更有助于匹配对象的归类。
* 这里提取一下这个案例的目标:根据一些数据信息,对指定人选进行分类(1或2或3)。为了使用kNN算法达到这个目标,我们需要哪些信息?前面提到过,就是需要样本数据,仔细阅读我们发现,这些样本数据就是“海伦还收集了一些约会网站未曾记录的数据信息 ”。
针对以上的描述,需要进行以下步骤:
1. 收集数据
2. 准备数据
3. 设计算法分析数据
4. 测试算法
**1.收集数据**
海伦收集的数据是记录一个人的三个特征:每年获得的飞行常客里程数;玩视频游戏所消耗的时间百分比;每周消费的冰淇淋公升数。数据是txt格式文件,如下图,前三列依次是三个特征,第四列是分类(1:代表不喜欢的人,2:代表魅力一般的人,3:代表极具魅力的人),每一行数据代表一个人。
![](https://box.kancloud.cn/2016-01-05_568b38350bfd5.jpg)
**2.准备数据**
计算机需要对数据文件txt读取数据,因此需要把数据进行格式化,对于数学运算,计算机擅长把数据存放在矩阵中。以下代码中`file2matrix(filename)`函数完成了这一工作,该函数输入数据文件名(字符串),输出训练样本矩阵和类标签向量。
这一过程返回两个矩阵:一个矩阵用于存放每个人的三个特征数据,一个矩阵存放每个人对应的分类。
**3.设计算法分析数据**
k-近邻算法的思想是寻找测试数据的前k个距离最近的样本,然后根据这k个样本的分类来确定该数据的分类,**遵循“多数占优”原则**。因此,如何寻找样本成为主要的问题,**在信号处理和模式识别领域中,常常使用“距离”来度量信号或特征的相似度**。在这里,我们假定可以使用三个特征数据来代替每个人,比如第一个人的属性我们用[40920, 8.326976, 0.953952]来代替,并且他的分类是3。那么此时的距离就是点的距离。
求出测试样本与训练样本中每个点的距离,然后进行从小到大排序,前k位的就是k-近邻,然后看看这k位近邻中占得最多的分类是什么,也就获得了最终的答案。这一部分是k-近邻算法的核心,代码中`classify()`函数就实现了k-近邻算法的核心部分。
一个优化算法效果的步骤——归一化数据:
打开数据文件我们可用发现第一列代表的特征数值远远大于其他两项特征,这样在求距离的公式中就会占很大的比重,致使两个样本的距离很大程度上取决于这个特征,其他特征的特性变得可有可无,这显然有悖于实际情况。因此通常我们可用使用归一化这一数学工具对数据进行预处理,这一处理过后的各个特征既不影响相对大小又可以不失公平。`Normalize(data)`函数实现了这一功能。
**4.测试算法**
经过了对数据进行预处理后、归一化数值,可用验证kNN算法有效性,测试代码为:`WebClassTest()` 由于数据有1000条,我们设置一个比率`ratio = 0.1`也就是令 `1000 * ratio = 100` 条作为测试样本,其余`900`条作为训练样本,当然,`ratio`的值可用改变,对算法效果是有影响的。
实现代码:
~~~
def classify(data, sample, label, k):
SampleSize = sample.shape[0]
DataMat = tile(data, (SampleSize, 1))
delta = (DataMat - sample)**2
distance = (delta.sum(axis = 1))**0.5 #
sortedDist = distance.argsort()
classCount = {}
for i in range(k):
votedLabel = label[sortedDist[i]]
classCount[votedLabel] = classCount.get(votedLabel, 0) + 1
result = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
return result[0][0]
#print classify([10,0], sample, label, 3)
def file2matrix(filename):
fil = open(filename)
fileLines = fil.readlines() # Convert the contents of a file into a list
lenOfLines = len(fileLines)
Mat = zeros((lenOfLines, 3))
classLabel = []
index = 0
for line in fileLines:
line = line.strip()
listFromLine = line.split('\t')
Mat[index,: ] = listFromLine[0:3]
classLabel.append(int(listFromLine[-1])) # the last one of listFromLine is Label
index += 1
return Mat, classLabel
mat,label = file2matrix('datingTestSet2.txt')
#print mat
# draw
import matplotlib
import matplotlib.pyplot as plt
fil = open('datingTestSet2.txt')
fileLines = fil.readlines() # Convert the contents of a file into a list
lenOfLines = len(fileLines)
figure = plt.figure()
axis = figure.add_subplot(111)
lab = ['didntLike', 'smallDoses', 'largeDoses']
for i in range(3):
n = []
l = []
for j in range(lenOfLines):
if label[j] == i + 1:
n.append(list(mat[j,0:3]))
l.append(label[j])
n = np.array(n) # list to numpy.adarray
#reshape(n, (3,k))
axis.scatter(n[:,0], n[:,1], 15.0*array(l), 15.0*array(l), label = lab[i])
print type(mat)
print type(n)
plt.legend()
plt.show()
def Normalize(data):
minValue = data.min(0)
maxValue = data.max(0)
ValueRange = maxValue - minValue
norm_data = zeros(shape(data))
k = data.shape[0]
norm_data = data - tile(minValue, (k, 1))
norm_data = norm_data / tile(ValueRange, (k, 1))
return norm_data, ValueRange, minValue
def WebClassTest():
ratio = 0.1
dataMat, dataLabels = file2matrix('datingTestSet2.txt')
normMat, ValueRange, minValue = Normalize(dataMat)
k = normMat.shape[0]
num = int(k * ratio) # test sample : 10%
errorCount = 0.0
for i in range(num):
result = classify(normMat[i,:], normMat[num:k,:],\
dataLabels[num:k], 7) # k = 3
print "The classifier came back with: %d, the real answer is %d"\
% (result, dataLabels[i])
if (result != dataLabels[i]): errorCount += 1
print "The total error rate is %f " % (errorCount / float(num))
WebClassTest()
~~~
在程序设计过程中,需要注意list、array、adarray等数据结构的使用,numpy.ndarray和标准Python库类array.array功能是不相同的。以上代码中`print type(mat)`和`print type(n)` 就是为了观察各变量的类型。允许以上代码,可用画出散点图如下:
![](https://box.kancloud.cn/2016-01-05_568b38352784d.jpg)
以上散点使用数据集中第二维和第三维数据绘制而出,当然,你可用选择其他维度的数据画二维散点图,或者使用所有维度的数据画高维图(未实现),如下图所示:
![](https://box.kancloud.cn/2016-01-05_568b38354e6eb.jpg)
对约会网站分类的测试,由于分类效果依赖于参数k和测试样本占样本数目的比例,开始测试按照书中的参数进行,取`k=3`,测试样本占总样本比例为`0.1`进行测试,结果如下:
![](https://box.kancloud.cn/2016-01-05_568b38356f86a.jpg)
理论上来说,增大k的取值可以提高准确率,但实际上若k值太大,也会造成准确率下降,而且运算复杂度增大。
k = 7:
![](https://box.kancloud.cn/2016-01-05_568b383592b73.jpg)
k = 17:
![](https://box.kancloud.cn/2016-01-05_568b3835a571a.jpg)
另一方面,降低ratio值(即增大训练样本集的比率)也可以提高算法的准确率,但由于每次算法需要比较更多的样本,因此算法复杂度也会增加。
**第二部分是手写数字识别**:
首先来看看书本给出的数据集:
~~~
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits') #load the training set
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] # take off .txt
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = listdir('testDigits') # iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] # take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify(vectorUnderTest, trainingMat, hwLabels, 3) # k = 3
print "The classifier came back with: %d, the real answer is: %d"\
% (classifierResult, classNumStr)
if (classifierResult != classNumStr): errorCount = errorCount + 1.0
print "\nThe total number of errors is: %d" % errorCount
print "\nThe total error rate is: %f" % (errorCount/float(mTest))
handwritingClassTest()
~~~
一个结果(`k = 3`):
![](https://box.kancloud.cn/2016-01-05_568b3835bd4b2.jpg)
`k = 7`时的结果,正确率并不比`k = 3`时要好:
![](https://box.kancloud.cn/2016-01-05_568b3835ccafb.jpg)
在手写数字识别过程中,随着k值得增大,准确率反而降低了。k的取值并不是越大越好。
至此,完成了k-近邻算法的学习和实例验证。比起其他机器学习方法,k-近邻算法是最简单最有效的分类数据算法,使用算法时必须有接近实际数据的训练样本数据。但是如前一节所说的,该算法的缺点也很多,最大的一点是无法给出数据的内在含义。事实上k决策树是k-近邻算法的优化版本,比起前者,决策树有效减少了储存空间和计算空间的开销,后期需继续深入学习!