knn之kd树

 

knn最简单的实现方法是线性扫描,即计算所有实例两两之间的距离。当训练集很大时,这种计算非常耗时。 用kd树存储数据集,可以减少计算代价。

原理部分可参考这篇文章: https://zhuanlan.zhihu.com/p/23966698

kd树的构建

算法

设数据有n维(1,2,3,4,…,n)

  1. 构造根节点,深度为1,选择维度1上的中位数m1为切分点,比m1小的在根节点左边,比m1大的在根节点右边;
  2. 重复,对于深度j,选择维度(j%n)上的中位数mj为切分点,比mj小的在节点左边,比mj大的在节点右边。

源码

def __create(self, a, j):
    """
    递归生成kd树
    :param a: 待划分数据集
    :param j: 选择划分维度
    :return:
    """
    if len(a) == 1:
        a[0]['dimension'] = j
        return Node(a[0])
    else:
        middle_a = self.__midian(a, j)
        middle_index = len(a) / 2
        middle_a['dimension'] = j
        root_node = Node(middle_a)

        left_a = a[:middle_index]
        right_a = a[middle_index + 1:]

        left_node = self.__create(left_a, (j + 1) % self.dimensions)
        left_node.location = 'left'
        root_node.left = left_node
        left_node.parent = root_node

        if len(a) > 2:  # or len(right_a)>0,此处由于偶数长度序列取中位数时取的是中间两个数中右侧的一个,导致右序列可能为空
            right_node = self.__create(right_a, (j + 1) % self.dimensions)
            right_node.location = 'right'
            root_node.right = right_node
            right_node.parent = root_node

        return root_node

kd树的搜索

算法

找到叶节点,向上爬到未被访问过的父节点,如果候选集不足k,则把当前节点加入候选集,若已有k, 则比较候选集中最大长度与当前距离,大于则替换,判断父节点另一侧是否可能有满足要求的节点 (test_node到此处切分线的距离大于候选集中最大距离),有的话则进入,找到叶节点重复,没有的话继续向上爬。

if candidate_point的长度小于k:
    if current_node访问过:
        向上爬到未被访问过的一个父节点n;
        if n是根节点:
            if n访问过:
                结束;
            else:
                if max(candidate_point)>distance(test_point,n的切分线):
                    进入另一侧节点,找到叶节点,重复;
                else:
                    结束;
        else:
            重复;
    else:
        将current_node加入candidate_point,标记已访问;
        if max(candidate_point)>distance(test_point,current_node的切分线):
            进入另一侧节点,找到叶节点,重复;
        else:
            向上爬到未被访问过的一个父节点n;
            if n是根节点:
                if n访问过:
                    结束;
                else:
                    if max(candidate_point)>distance(test_point,n的切分线):
                        进入另一侧节点,找到叶节点,重复;
                    else:
                        结束;
            else:
                重复;
else:
    if current_node访问过:
        向上爬到未被访问过的一个父节点n;
        if n是根节点:
            if n访问过:
                结束;
            else:
                if max(candidate_point)>distance(test_point,n的切分线):
                    进入另一侧节点,找到叶节点,重复;
                else:
                    结束;
        else:
            重复;
    else:
        if distance(test_point,current_node)>max(candidate_point):
            删除max(candidate_point);
            将current_node加入candidate_point,标记已访问;
        if max(candidate_point)>distance(test_point,current_node的切分线):
            进入另一侧节点,找到叶节点,重复;
        else:
            向上爬到未被访问过的一个父节点n;
            if n是根节点:
                if n访问过:
                    结束;
                else:
                    if max(candidate_point)>distance(test_point,n的切分线):
                        进入另一侧节点,找到叶节点,重复;
                    else:
                        结束;
            else:
                重复;

源码

candidate_point = [] #所要选择的k个数据点组成的列表
kdtree = KDTree(self.points) #由训练集生成kdTree
current_node = self.find_leaf(kdtree.root, self.test_point) #当前到达节点

while True:
    while current_node.visited:
        if current_node.parent is None:
            break
        previous_node = current_node
        current_node = current_node.parent

    if current_node.parent is None:
        if current_node.visited:
            break

    if len(candidate_point) != self.k:
        current_distance = KNN.__distance(current_node.value['coordinate'], self.test_point)
        current_node.visited = True
        candidate_point.append({'node': current_node, 'distance': current_distance})
        if current_node.left is None:
            continue
    else:
        current_distance = KNN.__distance(current_node.value['coordinate'], self.test_point)
        candidate_point.sort(key=lambda x: x['distance'])
        if current_distance < candidate_point[self.k - 1]['distance']:
            current_node.visited = True
            candidate_point.pop()
            candidate_point.append({'node': current_node, 'distance': current_distance})
        else:
            current_node.visited = True
        if current_node.left is None:
            continue

    # if current_node.left is not None:会执行以下部分
    if self.intersect_check(candidate_point, current_node):
        location = previous_node.location
        if location == 'left':
            if current_node.right is None:
                continue
            else:
                current_node = self.find_leaf(current_node.right, self.test_point)
                continue
        if location == 'right':
            if current_node.left is None:
                continue
            else:
                current_node = self.find_leaf(current_node.left, self.test_point)
    else:
        continue