knn最简单的实现方法是线性扫描,即计算所有实例两两之间的距离。当训练集很大时,这种计算非常耗时。 用kd树存储数据集,可以减少计算代价。
原理部分可参考这篇文章: https://zhuanlan.zhihu.com/p/23966698
kd树的构建
算法
设数据有n维(1,2,3,4,…,n)
- 构造根节点,深度为1,选择维度1上的中位数m1为切分点,比m1小的在根节点左边,比m1大的在根节点右边;
- 重复,对于深度j,选择维度(j%n)上的中位数mj为切分点,比mj小的在节点左边,比mj大的在节点右边。
源码
def __create(self, a, j):
"""
递归生成kd树
:param a: 待划分数据集
:param j: 选择划分维度
:return:
"""
if len(a) == 1:
a[0]['dimension'] = j
return Node(a[0])
else:
middle_a = self.__midian(a, j)
middle_index = len(a) / 2
middle_a['dimension'] = j
root_node = Node(middle_a)
left_a = a[:middle_index]
right_a = a[middle_index + 1:]
left_node = self.__create(left_a, (j + 1) % self.dimensions)
left_node.location = 'left'
root_node.left = left_node
left_node.parent = root_node
if len(a) > 2: # or len(right_a)>0,此处由于偶数长度序列取中位数时取的是中间两个数中右侧的一个,导致右序列可能为空
right_node = self.__create(right_a, (j + 1) % self.dimensions)
right_node.location = 'right'
root_node.right = right_node
right_node.parent = root_node
return root_node
kd树的搜索
算法
找到叶节点,向上爬到未被访问过的父节点,如果候选集不足k,则把当前节点加入候选集,若已有k, 则比较候选集中最大长度与当前距离,大于则替换,判断父节点另一侧是否可能有满足要求的节点 (test_node到此处切分线的距离大于候选集中最大距离),有的话则进入,找到叶节点重复,没有的话继续向上爬。
if candidate_point的长度小于k:
if current_node访问过:
向上爬到未被访问过的一个父节点n;
if n是根节点:
if n访问过:
结束;
else:
if max(candidate_point)>distance(test_point,n的切分线):
进入另一侧节点,找到叶节点,重复;
else:
结束;
else:
重复;
else:
将current_node加入candidate_point,标记已访问;
if max(candidate_point)>distance(test_point,current_node的切分线):
进入另一侧节点,找到叶节点,重复;
else:
向上爬到未被访问过的一个父节点n;
if n是根节点:
if n访问过:
结束;
else:
if max(candidate_point)>distance(test_point,n的切分线):
进入另一侧节点,找到叶节点,重复;
else:
结束;
else:
重复;
else:
if current_node访问过:
向上爬到未被访问过的一个父节点n;
if n是根节点:
if n访问过:
结束;
else:
if max(candidate_point)>distance(test_point,n的切分线):
进入另一侧节点,找到叶节点,重复;
else:
结束;
else:
重复;
else:
if distance(test_point,current_node)>max(candidate_point):
删除max(candidate_point);
将current_node加入candidate_point,标记已访问;
if max(candidate_point)>distance(test_point,current_node的切分线):
进入另一侧节点,找到叶节点,重复;
else:
向上爬到未被访问过的一个父节点n;
if n是根节点:
if n访问过:
结束;
else:
if max(candidate_point)>distance(test_point,n的切分线):
进入另一侧节点,找到叶节点,重复;
else:
结束;
else:
重复;
源码
candidate_point = [] #所要选择的k个数据点组成的列表
kdtree = KDTree(self.points) #由训练集生成kdTree
current_node = self.find_leaf(kdtree.root, self.test_point) #当前到达节点
while True:
while current_node.visited:
if current_node.parent is None:
break
previous_node = current_node
current_node = current_node.parent
if current_node.parent is None:
if current_node.visited:
break
if len(candidate_point) != self.k:
current_distance = KNN.__distance(current_node.value['coordinate'], self.test_point)
current_node.visited = True
candidate_point.append({'node': current_node, 'distance': current_distance})
if current_node.left is None:
continue
else:
current_distance = KNN.__distance(current_node.value['coordinate'], self.test_point)
candidate_point.sort(key=lambda x: x['distance'])
if current_distance < candidate_point[self.k - 1]['distance']:
current_node.visited = True
candidate_point.pop()
candidate_point.append({'node': current_node, 'distance': current_distance})
else:
current_node.visited = True
if current_node.left is None:
continue
# if current_node.left is not None:会执行以下部分
if self.intersect_check(candidate_point, current_node):
location = previous_node.location
if location == 'left':
if current_node.right is None:
continue
else:
current_node = self.find_leaf(current_node.right, self.test_point)
continue
if location == 'right':
if current_node.left is None:
continue
else:
current_node = self.find_leaf(current_node.left, self.test_point)
else:
continue