23#include <unordered_map>
35 static inline std::size_t size() {
return 1; }
37 static inline bool islhalf(
const T& lhs,
const T& rhs, std::size_t axis) {
return lhs[axis] < rhs[axis]; }
39 static inline double fabs(
const T& lhs,
const T& rhs, std::size_t axis) {
return std::abs(lhs[axis] - rhs[axis]); }
41 static inline double distance(
const T& lhs,
const T& rhs)
44 for (std::size_t i = 0; i < size(); ++i) {
45 result += fabs(lhs, rhs, i);
57template <
class KeyType,
class ValueType>
65 gsKDTree(std::vector<std::pair<KeyType, ValueType> >& data);
80 std::size_t dimension()
const;
83 std::size_t size()
const;
89 bool contains(
const KeyType& key)
const;
97 void insert(
const KeyType& key,
const ValueType& value=ValueType());
105 ValueType& operator[](
const KeyType& key);
112 ValueType& at(
const KeyType& key);
113 const ValueType& at(
const KeyType& key)
const;
122 ValueType kNNValue(
const KeyType& key, std::size_t k)
const;
131 ValueType& kNNValue(
const KeyType& key, std::size_t k);
134 void print(std::ostream &os)
const;
150 Node(
const KeyType& _key,
int _level,
const ValueType& _value=ValueType()):
151 point(_key), left(NULL), right(NULL), level(_level), value(_value) {}
166 Node* buildTree(
typename std::vector<std::pair<KeyType, ValueType> >::iterator start,
167 typename std::vector<std::pair<KeyType, ValueType> >::iterator end,
174 Node* findNode(Node* currNode,
const KeyType& key)
const;
177 void nearestNeighborRecurse(
const Node* currNode,
179 gsBoundedPriorityQueue<ValueType>& bpq)
const;
182 void nearestNeighborRecurse(
const Node* currNode,
184 gsBoundedPriorityQueue<ValueType*>& bpq)
const;
190 Node* deepcopyTree(Node* root);
193 void freeResource(Node* currNode);
197template <
class KeyType,
class ValueType>
198gsKDTree<KeyType, ValueType>::gsKDTree() :
199 root_(NULL), size_(0) { }
201template <
class KeyType,
class ValueType>
202typename gsKDTree<KeyType, ValueType>::Node*
203gsKDTree<KeyType, ValueType>::deepcopyTree(
typename gsKDTree<KeyType, ValueType>::Node* root)
205 if (root == NULL)
return NULL;
206 Node* newRoot =
new Node(*root);
207 newRoot->left = deepcopyTree(root->left);
208 newRoot->right = deepcopyTree(root->right);
212template <
class KeyType,
class ValueType>
213typename gsKDTree<KeyType, ValueType>::Node*
214gsKDTree<KeyType, ValueType>::buildTree(
typename std::vector<std::pair<KeyType, ValueType> >::iterator start,
215 typename std::vector<std::pair<KeyType, ValueType>>::iterator end,
218 if (start >= end)
return NULL;
220 int axis = currLevel % gsKDTreeTraits<KeyType>::size();
221 auto cmp = [axis](
const std::pair<KeyType, ValueType>& p1,
222 const std::pair<KeyType, ValueType>& p2) {
223 return p1.first[axis] < p2.first[axis];
225 std::size_t len = end - start;
226 auto mid = start + len / 2;
227 std::nth_element(start, mid, end, cmp);
231 while (mid > start && (mid - 1)->first[axis] == mid->first[axis]) {
235 Node* newNode =
new Node(mid->first, currLevel, mid->second);
236 newNode->left = buildTree(start, mid, currLevel + 1);
237 newNode->right = buildTree(mid + 1, end, currLevel + 1);
241template <
class KeyType,
class ValueType>
242gsKDTree<KeyType, ValueType>::gsKDTree(std::vector<std::pair<KeyType, ValueType> >& data)
244 root_ = buildTree(data.begin(), data.end(), 0);
248template <
class KeyType,
class ValueType>
249gsKDTree<KeyType, ValueType>::gsKDTree(
const gsKDTree& rhs)
251 root_ = deepcopyTree(rhs.root_);
255template <
class KeyType,
class ValueType>
256gsKDTree<KeyType, ValueType>& gsKDTree<KeyType, ValueType>::operator=(
const gsKDTree& rhs)
260 root_ = deepcopyTree(rhs.root_);
266template <
class KeyType,
class ValueType>
267void gsKDTree<KeyType, ValueType>::freeResource(
typename gsKDTree<KeyType, ValueType>::Node* currNode)
269 if (currNode == NULL)
return;
270 freeResource(currNode->left);
271 freeResource(currNode->right);
275template <
class KeyType,
class ValueType>
276gsKDTree<KeyType, ValueType>::~gsKDTree()
281template <
class KeyType,
class ValueType>
282void gsKDTree<KeyType, ValueType>::clear()
287template <
class KeyType,
class ValueType>
288std::size_t gsKDTree<KeyType, ValueType>::dimension()
const
290 return gsKDTreeTraits<KeyType>::size();
293template <
class KeyType,
class ValueType>
294std::size_t gsKDTree<KeyType, ValueType>::size()
const
299template <
class KeyType,
class ValueType>
300bool gsKDTree<KeyType, ValueType>::empty()
const
305template <
class KeyType,
class ValueType>
306typename gsKDTree<KeyType, ValueType>::Node*
307gsKDTree<KeyType, ValueType>::findNode(
typename gsKDTree<KeyType, ValueType>::Node* currNode,
308 const KeyType& key)
const
310 if (currNode == NULL || currNode->point == key)
return currNode;
312 const KeyType& currPoint = currNode->point;
313 int currLevel = currNode->level;
314 if (gsKDTreeTraits<KeyType>::islhalf(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size()))
317 return currNode->left == NULL ? currNode : findNode(currNode->left, key);
320 return currNode->right == NULL ? currNode : findNode(currNode->right, key);
324template <
class KeyType,
class ValueType>
325bool gsKDTree<KeyType, ValueType>::contains(
const KeyType& key)
const
327 auto node = findNode(root_, key);
328 return node != NULL && node->point == key;
331template <
class KeyType,
class ValueType>
332void gsKDTree<KeyType, ValueType>::insert(
const KeyType& key,
const ValueType& value)
334 auto targetNode = findNode(root_, key);
335 if (targetNode == NULL) {
336 root_ =
new Node(key, 0, value);
339 if (targetNode->point == key) {
340 targetNode->value = value;
342 int currLevel = targetNode->level;
343 Node* newNode =
new Node(key, currLevel + 1, value);
344 if (gsKDTreeTraits<KeyType>::islhalf(key, targetNode->point, currLevel%gsKDTreeTraits<KeyType>::size())) {
345 targetNode->left = newNode;
347 targetNode->right = newNode;
354template <
class KeyType,
class ValueType>
355const ValueType& gsKDTree<KeyType, ValueType>::at(
const KeyType& key)
const
357 auto node = findNode(root_, key);
358 if (node == NULL || node->point != key) {
359 throw std::out_of_range(
"Key not found in gsKDTree");
365template <
class KeyType,
class ValueType>
366ValueType& gsKDTree<KeyType, ValueType>::at(
const KeyType& key)
368 const gsKDTree<KeyType, ValueType>& constThis = *
this;
369 return const_cast<ValueType&
>(constThis.at(key));
372template <
class KeyType,
class ValueType>
373ValueType& gsKDTree<KeyType, ValueType>::operator[](
const KeyType& key)
375 auto node = findNode(root_, key);
376 if (node != NULL && node->point == key) {
380 if (node == NULL)
return root_->value;
381 else return (node->left != NULL && node->left->point == key) ? node->left->value: node->right->value;
385template <
class KeyType,
class ValueType>
386void gsKDTree<KeyType, ValueType>::nearestNeighborRecurse(
const typename gsKDTree<KeyType, ValueType>::Node* currNode,
388 gsBoundedPriorityQueue<ValueType>& bpq)
const
390 if (currNode == NULL)
return;
391 const KeyType& currPoint = currNode->point;
394 bpq.enqueue(currNode->value, gsKDTreeTraits<KeyType>::distance(key, currPoint));
397 int currLevel = currNode->level;
399 if (gsKDTreeTraits<KeyType>::islhalf(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size())) {
400 nearestNeighborRecurse(currNode->left, key, bpq);
403 nearestNeighborRecurse(currNode->right, key, bpq);
407 if (bpq.size() < bpq.maxSize() ||
408 gsKDTreeTraits<KeyType>::fabs(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size()) < bpq.worst()) {
410 if (isLeftTree) nearestNeighborRecurse(currNode->right, key, bpq);
411 else nearestNeighborRecurse(currNode->left, key, bpq);
415template <
class KeyType,
class ValueType>
416void gsKDTree<KeyType, ValueType>::nearestNeighborRecurse(
const typename gsKDTree<KeyType, ValueType>::Node* currNode,
418 gsBoundedPriorityQueue<ValueType*>& bpq)
const
420 if (currNode == NULL)
return;
421 const KeyType& currPoint = currNode->point;
424 bpq.enqueue(
const_cast<ValueType*
>(&(currNode->value)), gsKDTreeTraits<KeyType>::distance(key, currPoint));
427 int currLevel = currNode->level;
429 if (gsKDTreeTraits<KeyType>::islhalf(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size())) {
430 nearestNeighborRecurse(currNode->left, key, bpq);
433 nearestNeighborRecurse(currNode->right, key, bpq);
437 if (bpq.size() < bpq.maxSize() ||
438 gsKDTreeTraits<KeyType>::fabs(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size()) < bpq.worst()) {
440 if (isLeftTree) nearestNeighborRecurse(currNode->right, key, bpq);
441 else nearestNeighborRecurse(currNode->left, key, bpq);
445template <
class KeyType,
class ValueType>
446ValueType gsKDTree<KeyType, ValueType>::kNNValue(
const KeyType& key, std::size_t k)
const
449 gsBoundedPriorityQueue<ValueType> bpq(k);
450 if (empty())
throw std::out_of_range(
"gsKDTree is empty");
453 nearestNeighborRecurse(root_, key, bpq);
459 if (!math::isfinite(bpq.best()))
460 throw std::out_of_range(
"gsKDTree does not contain finite value");
463 std::unordered_map<ValueType, int> counter;
464 while (!bpq.empty()) {
465 ++counter[bpq.dequeueMin()];
471 for (
const auto &p : counter) {
472 if (p.second > cnt) {
480template <
class KeyType,
class ValueType>
481ValueType& gsKDTree<KeyType, ValueType>::kNNValue(
const KeyType& key, std::size_t k)
484 gsBoundedPriorityQueue<ValueType*> bpq(k);
486 throw std::out_of_range(
"gsKDTree is empty");
489 nearestNeighborRecurse(root_, key, bpq);
495 if (!math::isfinite(bpq.best()))
496 throw std::out_of_range(
"gsKDTree does not contain finite value");
499 std::unordered_map<ValueType*, int> counter;
500 while (!bpq.empty()) {
501 ++counter[bpq.dequeueMin()];
505 ValueType* result =
nullptr;
507 for (
const auto &p : counter) {
508 if (p.second > cnt) {
516template <
class KeyType,
class ValueType>
519 os <<
"KD-tree: size= " << size() <<
", dimension= " << dimension() <<
".\n";
An interface representing a kd-tree in some number of dimensions.
Definition gsKDTree.h:58
friend std::ostream & operator<<(std::ostream &os, const gsKDTree &obj)
Print (as string) operator.
Definition gsKDTree.h:137
void print(std::ostream &os) const
Prints the object as a string.
Definition gsKDTree.h:517
Provides declaration of bounded priority queue.
The G+Smo namespace, containing all definitions for the library.
T distance(gsMatrix< T > const &A, gsMatrix< T > const &B, index_t i=0, index_t j=0, bool cols=false)
compute a distance between the point number in the set and the point number <j> in the set ; by def...