博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
k-d tree代码解析 【转】
阅读量:7052 次
发布时间:2019-06-28

本文共 13234 字,大约阅读时间需要 44 分钟。

上一篇较详细地介绍了k-d树算法。本文来讲解具体的实现代码。

  首先是一些数据结构的定义。我们先来定义单个数据,代码如下:

//单个数据向量结构定义 struct _Examplar {
public: _Examplar():dom_dims(0){} //数据维度初始化为0   //带有完整的两个参数的constructor,这里const是为了保护原数据不被修改 _Examplar(const std::vector
elt, int dims) { if(dims > 0) {
dom_elt = elt; dom_dims = dims; } else {
dom_dims = 0; } }
(一些重载的构造函数和运算符,元素的访问控制函数等)
_Examplar(int dims)    //只含有维度信息的constructor     {
if(dims > 0) {
dom_elt.resize(dims); dom_dims = dims; } else {
dom_dims = 0; } } _Examplar(const _Examplar& rhs) //copy-constructor {
if(rhs.dom_dims > 0) {
dom_elt = rhs.dom_elt; dom_dims = rhs.dom_dims; } else {
dom_dims = 0; } } _Examplar& operator=(const _Examplar& rhs) //重载"="运算符 {
if(this == &rhs) return *this; releaseExamplarMem(); if(rhs.dom_dims > 0) {
dom_elt = rhs.dom_elt; dom_dims = rhs.dom_dims; } return *this; } ~_Examplar() {
} double& dataAt(int dim) //定义访问控制函数 {
assert(dim < dom_dims); return dom_elt[dim]; } double& operator[](int dim) //重载"[]"运算符,实现下标访问 {
return dataAt(dim); } const double& dataAt(int dim) const //定义只读访问函数 {
assert(dim < dom_dims); return dom_elt[dim]; } const double& operator[](int dim) const //重载"[]"运算符,实现下标只读访问 {
return dataAt(dim); } void create(int dims) //创建数据向量 {
releaseExamplarMem(); if(dims > 0) {
dom_elt.resize(dims); //控制数据向量维度 dom_dims = dims; } } int getDomDims() const //获得数据向量维度信息 {
return dom_dims; } void setTo(double val) //数据向量初始化设置 {
if(dom_dims > 0) {
for(int i=0;i
private:     std::vector
dom_elt; //每个数据定义为一个double类型的向量 int dom_dims; //数据向量的维度 };

  结构_Examplar定义了单个数据节点的结构,主要包含的信息有:1.数据向量本身;2.数据向量的维度。接下来定义一整个数据集的结构,代码如下:

//数据集结构定义 class ExamplarSet : public TrainData      //整个数据集类,由一个抽象类TrainData派生 {
private: //_Examplar *_ex_set; std::vector<_Examplar> _ex_set; //定义含有若干个_Examplar类数据向量的数据集 int _size; //数据集大小 int _dims; //数据集中每个数据向量的维度 public:
(一些重载的构造函数运算符,元素访问控制函数等)
ExamplarSet():_size(0), _dims(0){}     ExamplarSet(std::vector<_Examplar> ex_set, int size, int dims);     ExamplarSet(int size, int dims);     ExamplarSet(const ExamplarSet& rhs);     ExamplarSet& operator=(const ExamplarSet& rhs);     ~ExamplarSet(){}     _Examplar& examplarAt(int idx)     {         assert(idx < _size); return _ex_set[idx];     }     _Examplar& operator[](int idx)     {
return examplarAt(idx); } const _Examplar& examplarAt(int idx) const {
assert(idx < _size); return _ex_set[idx]; } void create(int size, int dims); int getDims() const { return _dims;} int getSize() const { return _size;} _HyperRectangle calculateRange(); bool empty() const {
return (_size == 0); }
void sortByDim(int dim);     //按某个方向维的排序函数     bool remove(int idx);        //去除数据集中排序后指定位置的数据向量     void push_back(const _Examplar& ex)    //添加某个数据向量至数据集末尾     {
_ex_set.push_back(ex); _size++; } int readData(char *strFilePath); //从文件读取数据集 private: void releaseExamplarSetMem() //清除现有数据集 {
_ex_set.clear(); _size = 0; } };

  类ExamplarSet定义了整个数据集的结构,其包含的主要信息有:1.含有若干个_Examplar类数据向量的数据集;2.数据集的大小;3.每个数据向量的维度。以上两个结构是整个算法两个基本的数据结构,这里的代码只是展示其主要包含的结构信息,详细的定义及函数实现代码请参看附件。

  接下来就要定义k-d tree的结构。同样采用上述由点定义到集定义的思路,我们先来定义k-d tree中一个节点结构,代码如下:

//k-d tree节点结构定义 class KDTreeNode    {
private: int _split_dim; //该节点的最大区分度方向维 _Examplar _dom_elt; //该节点的数据向量 _HyperRectangle _range_hr; //表示数据范围的超矩形结构 public: KDTreeNode *_left_child, *_right_child, *_parent; //该节点的左右子树和父节点
(一些重载的构造函数,元素访问控制函数等)
public:     KDTreeNode():_left_child(0), _right_child(0), _parent(0),         _split_dim(0){}     KDTreeNode(KDTreeNode *left_child, KDTreeNode *right_child,         KDTreeNode *parent, int split_dim, _Examplar dom_elt, _HyperRectangle range_hr):     _left_child(left_child), _right_child(right_child), _parent(parent),         _split_dim(split_dim), _dom_elt(dom_elt), _range_hr(range_hr){}     KDTreeNode(const KDTreeNode &rhs);     KDTreeNode& operator=(const KDTreeNode &rhs);     _Examplar& getDomElt() { return _dom_elt; }     _HyperRectangle& getHyperRectangle(){ return _range_hr; } int& splitDim(){ return _split_dim; } void create(KDTreeNode *left_child, KDTreeNode *right_child,         KDTreeNode *parent, int split_dim, _Examplar dom_elt,  _HyperRectangle range_hr);
};

  类KDTreeNode就是按照前一篇表1所述定义的。需要注意的是_HyperRectangle这一结构,它表示的就是这一节点所代表的空间范围Range,其定义如下:

struct _HyperRectangle    //定义表示数据范围的超矩形结构 {
_Examplar min; //统计数据集中所有数据向量每个维度上最小值组成的一个数据向量 _Examplar max; //统计数据集中所有数据向量每个维度上最大值组成的一个数据向量
(一些重载的构造函数)
_HyperRectangle() {}     _HyperRectangle(_Examplar mx, _Examplar mn)     {
assert (mx.getDomDims() == mn.getDomDims()); min = mn; max = mx; } _HyperRectangle(const _HyperRectangle& rhs) {
min = rhs.min; max = rhs.max; } _HyperRectangle& operator= (const _HyperRectangle& rhs) {
if(this == &rhs) return *this; min = rhs.min; max = rhs.max; return *this; } void create(_Examplar mx, _Examplar mn) {
assert (mx.getDomDims() == mn.getDomDims()); min = mn; max = mx; }
};

  对于整个数据集来说_HyperRectangle表示的就是对全体的统计范围信息,对部分数据集来说其表示的就是对部分数据的统计范围信息。还是以上篇中实例中的数据{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}为例,_HyperRectangle表示的统计范围如图1所示:

图1  _HyperRectangle表示的统计范围

  • 对于根节点(7,2),其所对应的空间范围是整个数据集,所以根节点(7,2)的_range_hr就是对整个数据集所有维度方向(此例即x,y方向)的数据范围统计得min = {dom_elt = (2,1),dom_dims = 2},max = {dom_elt = (9,7),dom_dims = 2};
  • 对于中间节点(5,4),其所对应的空间范围是根节点的左子树,所以节点(5,4)的_range_hr就是对整个数据集所有维度方向(此例即x,y方向)的数据范围统计得min = {dom_elt = (2,3),dom_dims = 2},max = {dom_elt = (5,7),dom_dims = 2};
  • 对于叶子节点(4,7),其所对应的空间范围是节点本身,所以节点(4,7)的_range_hr就是对整个数据集所有维度方向(此例即x,y方向)的 数据范围统计得min = {dom_elt = (4,7),dom_dims = 2},max = {dom_elt = (4,7),dom_dims = 2};

  最后再进行整个k-d tree结构的定义。代码如下:

class KDTree    //k-d tree结构定义 {
public: KDTreeNode *_root; //k-d tree的根节点 public: KDTree():_root(NULL){} void create(const ExamplarSet &exm_set); //创建k-d tree,实际上调用createKDTree void destroy(); //销毁k-d tree,实际上调用destroyKDTree ~KDTree(){ destroyKDTree(_root); } std::pair<_Examplar, double> findNearest(_Examplar target); //查找最近邻点函数,返回值是pair类型 //实际是调用findNearest_i //查找距离在range范围内的近邻点,返回这样近邻点的个数,实际是调用findNearest_range int findNearest(_Examplar target, double range, std::vector
<_Examplar, double>> &res_nearest); private: KDTreeNode* createKDTree(const ExamplarSet &exm_set); void destroyKDTree(KDTreeNode *root); std::pair<_Examplar, double> findNearest_i(KDTreeNode *root, _Examplar target); int findNearest_range(KDTreeNode *root, _Examplar target, double range, std::vector
<_Examplar, double>> &res_nearest);

  可见,整个k-d tree结构是由一系列KDTreeNode类的节点构成。整个k-d树的构建算法和基于k-d树的最邻近查找算法主要就是由createKDTree,findNearest_i以及findNearest_range这三个函数完成。代码分别如下:

  • createKDTree
//KDTree::是由于定义了KDTree的namespace KDTree::KDTreeNode* KDTree::KDTree::createKDTree( const ExamplarSet &exm_set ) {
if(exm_set.empty()) return NULL; ExamplarSet exm_set_copy(exm_set); int dims = exm_set_copy.getDims(); int size = exm_set_copy.getSize(); //计算每个维的方差,选出方差值最大的维 double var_max = -0.1; double avg, var; int dim_max_var = -1; for(int i=0;i
var_max) {
var_max = var; dim_max_var = i; } } //确定节点的数据矢量 _HyperRectangle hr = exm_set_copy.calculateRange(); //统计节点空间范围 exm_set_copy.sortByDim(dim_max_var); //将所有数据向量按最大区分度方向排序 int mid = size / 2; _Examplar exm_split = exm_set_copy.examplarAt(mid); //取出排序结果的中间节点 exm_set_copy.remove(mid); //将中间节点作为父(根)节点,所有将其从数据集中去除 //确定左右节点 ExamplarSet exm_set_left(0, exm_set_copy.getDims()); ExamplarSet exm_set_right(0, exm_set_copy.getDims()); exm_set_right.remove(0); int size_new = exm_set_copy.getSize(); //获得子数据空间大小 for(int i=0;i
_left_child = createKDTree(exm_set_left); //递归调用生成左子树 if(pNewNode->_left_child != NULL) //确认左子树父节点 pNewNode->_left_child->_parent = pNewNode; pNewNode->_right_child = createKDTree(exm_set_right); //递归调用生成右子树 if(pNewNode->_right_child != NULL) //确认右子树父节点 pNewNode->_right_child->_parent = pNewNode; return pNewNode; //最终返回k-d tree的根节点 }

  整个createKDTree函数完全符合上篇中表2所述。注意其中统计节点空间范围calculateRange这一函数,其定义如下:

KDTree::_HyperRectangle KDTree::ExamplarSet::calculateRange() {
assert(_size > 0); assert(_dims > 0); _Examplar mn(_dims); _Examplar mx(_dims); for(int j=0;j<_dims;j++) {
mn.dataAt(j) = (*this)[0][j]; //初始化最小范围向量 mx.dataAt(j) = (*this)[0][j]; //初始化最大范围向量 } for(int i=1;i<_size;i++) //统计数据集中每一个数据向量 {
for(int j=0;j<_dims;j++) {
if( (*this)[i][j] < mn[j] ) //比较每一维,寻找最小值 mn[j] = (*this)[i][j]; if( (*this)[i][j] > mx[j] ) //比较每一维,寻找最大值 mx[j] = (*this)[i][j]; } } _HyperRectangle hr(mx, mn); return hr; //返回一个_HyperRectangle结构 }
  • findNearest_i
std::pair
KDTree::KDTree::findNearest_i( KDTreeNode *root, _Examplar target ) {
KDTreeNode *pSearch = root; //堆栈用于保存搜索路径 std::vector
search_path; _Examplar nearest; double max_dist; while(pSearch != NULL) //首先通过二叉查找得到搜索路径 {
search_path.push_back(pSearch); int s = pSearch->splitDim(); if(target[s] <= pSearch->getDomElt()[s]) {
pSearch = pSearch->_left_child; } else {
pSearch = pSearch->_right_child; } } nearest = search_path.back()->getDomElt(); //取路径中最后的叶子节点为回溯前的最邻近点 max_dist = Distance_exm(nearest, target); search_path.pop_back(); //回溯搜索路径 while(!search_path.empty()) {
KDTreeNode *pBack = search_path.back(); search_path.pop_back(); if( pBack->_left_child == NULL && pBack->_right_child == NULL) //如果是叶子节点,就直接比较距离的大小 {
if( Distance_exm(nearest, target) > Distance_exm(pBack->getDomElt(), target) ) {
nearest = pBack->getDomElt(); max_dist = Distance_exm(pBack->getDomElt(), target); } } else {
int s = pBack->splitDim(); if( abs(pBack->getDomElt()[s] - target[s]) < max_dist) //以target为圆心,max_dist为半径的圆和分割面如果 { //有交割,则需要进入另一边子空间搜索 if( Distance_exm(nearest, target) > Distance_exm(pBack->getDomElt(), target) ) {
nearest = pBack->getDomElt(); max_dist = Distance_exm(pBack->getDomElt(), target); } if(target[s] <= pBack->getDomElt()[s]) //如果target位于左子空间,就应进入右子空间 pSearch = pBack->_right_child; else pSearch = pBack->_left_child; //如果target位于右子空间,就应进入左子空间 if(pSearch != NULL) search_path.push_back(pSearch); //将新的节点加入search_path中 } } } std::pair<_Examplar, double> res(nearest, max_dist); return res; //返回包含最邻近点和最近距离的pair }
  • findNearest_range
int KDTree::KDTree::findNearest_range( KDTreeNode *root, _Examplar target, double range,     std::vector
<_Examplar, double>> &res_nearest ) {
if(root == NULL) return 0; double dist_sq, dx; int ret, added_res = 0; dist_sq = 0; dist_sq = Distance_exm(root->getDomElt(), target); //计算搜索路径中每个节点和target的距离 if(dist_sq <= range) {                   //将范围内的近邻添加到结果向量res_nearest中 std::pair<_Examplar,double> temp(root->getDomElt(), dist_sq); res_nearest.push_back(temp); //结果个数+1 added_res = 1; } dx = target[root->splitDim()] - root->getDomElt()[root->splitDim()]; //左子树或右子树递归的查找 ret = findNearest_range(dx <= 0.0 ? root->_left_child : root->_right_child, target, range, res_nearest); //当另外一边可能存在范围内的近邻 if(ret >= 0 && fabs(dx) < range) {
added_res += ret; ret = findNearest_range(dx <= 0.0 ? root->_right_child : root->_left_child, target, range, res_nearest); } added_res += ret; return added_res; //最终返回范围内的近邻个数 }

  依然利用前述实例的数据来做测试,查找(2.1,3.1)和(2,4.5)两点的最近邻,并查找距离在4以内的所有近邻。程序运行结果如下:

                     

           图2  查找(2.1,3.1)的结果                                                       图3  查找(2,4.5)的结果

 

附件:

转载请注明:

转载于:https://www.cnblogs.com/retrieval/archive/2012/04/09/2439134.html

你可能感兴趣的文章
字符串匹配算法之KMP&Boyer-Moore
查看>>
iOS WKWebView的使用
查看>>
开始记录 Windows Phone 生涯
查看>>
django中TypeError: __init__() missing 1 required positional argument: 'app_module'
查看>>
C#窗体之-->窗口外观属性...
查看>>
搜索框auto_complete
查看>>
Java之戳中痛点 - (6)避免类型自动转换,例如两个整数相除得浮点数遇坑
查看>>
安装 libmagic in Mac OS (for Python-magic)
查看>>
设置作者信息等设置
查看>>
【OpenGL学习】使用Shader做图像处理
查看>>
subversion linux使用方法
查看>>
顶部滚动菜单栏
查看>>
java调用sqlldr oracle 安装的bin目录
查看>>
Eclipse 打开已存在的Android项目的问题
查看>>
Activity获取Fragment的值
查看>>
怎么搞差分约束?
查看>>
git基本的使用原理
查看>>
Ubuntu机器学习python实战(一)k-近邻算法
查看>>
Reachability(判断网络是否连接)
查看>>
sql奇特的语句
查看>>