G+Smo  25.01.0
Geometry + Simulation Modules
 
Loading...
Searching...
No Matches
gsKDTree.h
Go to the documentation of this file.
1
18#pragma once
19
20#include <stdexcept>
21#include <cmath>
22#include <vector>
23#include <unordered_map>
24#include <utility>
25#include <algorithm>
26
28
29namespace gismo
30{
31
32template <class T>
33struct gsKDTreeTraits
34{
35 static inline std::size_t size() { return 1; }
36
37 static inline bool islhalf(const T& lhs, const T& rhs, std::size_t axis) { return lhs[axis] < rhs[axis]; }
38
39 static inline double fabs(const T& lhs, const T& rhs, std::size_t axis) { return std::abs(lhs[axis] - rhs[axis]); }
40
41 static inline double distance(const T& lhs, const T& rhs)
42 {
43 double result = 0.0;
44 for (std::size_t i = 0; i < size(); ++i) {
45 result += fabs(lhs, rhs, i);
46 }
47 return result;
48 }
49};
50
57template <class KeyType, class ValueType>
58class gsKDTree {
59public:
60
61 // Constructs an empty gsKDTree.
62 gsKDTree();
63
64 // Efficiently build a balanced kd-tree from a large set of data
65 gsKDTree(std::vector<std::pair<KeyType, ValueType> >& data);
66
67 // Frees up all the dynamically allocated resources
68 ~gsKDTree();
69
70 // Frees up all the dynamically allocated resources
71 void clear();
72
73 // Deep-copies the contents of another gsKDTree into this one.
74 gsKDTree(const gsKDTree& other);
75
76 // Deep-copies the contents of another gsKDTree into this one.
77 gsKDTree& operator=(const gsKDTree& other);
78
79 // Returns the dimension of the data stored in this gsKDTree.
80 std::size_t dimension() const;
81
82 // Returns the number of elements in the kd-tree.
83 std::size_t size() const;
84
85 // Returns true if this gsKDTree is empty and false otherwise.
86 bool empty() const;
87
88 // Returns true if the specified key is contained in the gsKDTree.
89 bool contains(const KeyType& key) const;
90
91 /*
92 * Inserts the data with the given key into the gsKDTree,
93 * associating it with the specified value. If another data element
94 * with the same key already existed in the tree, the new value will
95 * overwrite the existing one.
96 */
97 void insert(const KeyType& key, const ValueType& value=ValueType());
98
99 /*
100 * Returns a reference to the value associated with the data stored
101 * under the given key in the gsKDTree. If the key does not exist,
102 * then it is added to the gsKDTree using the default value of
103 * ValueType as its value.
104 */
105 ValueType& operator[](const KeyType& key);
106
107 /*
108 * Returns a reference to the value associated with the given
109 * key. If the key is not in the tree, this function throws an
110 * out_of_range exception.
111 */
112 ValueType& at(const KeyType& key);
113 const ValueType& at(const KeyType& key) const;
114
115 /*
116 * Given a key and an integer k, finds the k data elements in the
117 * gsKDTree nearest to the data element associated with the given
118 * key and returns the most common value associated with those data
119 * elements. In the event of a tie, one of the most frequent value
120 * will be chosen.
121 */
122 ValueType kNNValue(const KeyType& key, std::size_t k) const;
123
124 /*
125 * Given a key and an integer k, finds the k data elements in the
126 * gsKDTree nearest to the data element associated with the given
127 * key and returns a reference to the most common value associated
128 * with those data elements. In the event of a tie, one of the most
129 * frequent value will be chosen.
130 */
131 ValueType& kNNValue(const KeyType& key, std::size_t k);
132
134 void print(std::ostream &os) const;
135
137 friend std::ostream &operator<<(std::ostream &os, const gsKDTree &obj)
138 {
139 obj.print(os);
140 return os;
141 }
142
143private:
144 struct Node {
145 KeyType point;
146 Node *left;
147 Node *right;
148 int level; // level of the node in the tree, starts at 0 for the root
149 ValueType value;
150 Node(const KeyType& _key, int _level, const ValueType& _value=ValueType()):
151 point(_key), left(NULL), right(NULL), level(_level), value(_value) {}
152 };
153
154 // Root node of the gsKDTree
155 Node* root_;
156
157 // Number of points in the gsKDTree
158 std::size_t size_;
159
160 /*
161 * Recursively build a subtree that satisfies the kd-tree invariant using points in [start, end)
162 * At each level, we split points into two halves using the median of the points as pivot
163 * The root of the subtree is at level 'currLevel'
164 * O(n) time partitioning algorithm is used to locate the median element
165 */
166 Node* buildTree(typename std::vector<std::pair<KeyType, ValueType> >::iterator start,
167 typename std::vector<std::pair<KeyType, ValueType> >::iterator end,
168 int currLevel);
169
170 /*
171 * Returns the Node that contains element with given key if it is present in subtree 'currNode'
172 * Returns the Node below where key should be inserted if key is not in the subtree
173 */
174 Node* findNode(Node* currNode, const KeyType& key) const;
175
176 // Recursive helper method for kNNValue(key, k)
177 void nearestNeighborRecurse(const Node* currNode,
178 const KeyType& key,
179 gsBoundedPriorityQueue<ValueType>& bpq) const;
180
181 // Recursive helper method for kNNValue(key, k)
182 void nearestNeighborRecurse(const Node* currNode,
183 const KeyType& key,
184 gsBoundedPriorityQueue<ValueType*>& bpq) const;
185
186 /*
187 * Recursive helper method for copy constructor and assignment operator
188 * Deep copies tree 'root' and returns the root of the copied tree
189 */
190 Node* deepcopyTree(Node* root);
191
192 // Recursively free up all resources of subtree rooted at 'currNode'
193 void freeResource(Node* currNode);
194
195}; // class gsKDTree
196
197template <class KeyType, class ValueType>
198gsKDTree<KeyType, ValueType>::gsKDTree() :
199 root_(NULL), size_(0) { }
200
201template <class KeyType, class ValueType>
202typename gsKDTree<KeyType, ValueType>::Node*
203gsKDTree<KeyType, ValueType>::deepcopyTree(typename gsKDTree<KeyType, ValueType>::Node* root)
204{
205 if (root == NULL) return NULL;
206 Node* newRoot = new Node(*root);
207 newRoot->left = deepcopyTree(root->left);
208 newRoot->right = deepcopyTree(root->right);
209 return newRoot;
210}
211
212template <class KeyType, class ValueType>
213typename gsKDTree<KeyType, ValueType>::Node*
214gsKDTree<KeyType, ValueType>::buildTree(typename std::vector<std::pair<KeyType, ValueType> >::iterator start,
215 typename std::vector<std::pair<KeyType, ValueType>>::iterator end,
216 int currLevel)
217{
218 if (start >= end) return NULL; // empty tree
219
220 int axis = currLevel % gsKDTreeTraits<KeyType>::size(); // the axis to split on
221 auto cmp = [axis](const std::pair<KeyType, ValueType>& p1,
222 const std::pair<KeyType, ValueType>& p2) {
223 return p1.first[axis] < p2.first[axis];
224 };
225 std::size_t len = end - start;
226 auto mid = start + len / 2;
227 std::nth_element(start, mid, end, cmp); // linear time partition
228
229 // move left (if needed) so that all the equal points are to the right
230 // The tree will still be balanced as long as there aren't many points that are equal along each axis
231 while (mid > start && (mid - 1)->first[axis] == mid->first[axis]) {
232 --mid;
233 }
234
235 Node* newNode = new Node(mid->first, currLevel, mid->second);
236 newNode->left = buildTree(start, mid, currLevel + 1);
237 newNode->right = buildTree(mid + 1, end, currLevel + 1);
238 return newNode;
239}
240
241template <class KeyType, class ValueType>
242gsKDTree<KeyType, ValueType>::gsKDTree(std::vector<std::pair<KeyType, ValueType> >& data)
243{
244 root_ = buildTree(data.begin(), data.end(), 0);
245 size_ = data.size();
246}
247
248template <class KeyType, class ValueType>
249gsKDTree<KeyType, ValueType>::gsKDTree(const gsKDTree& rhs)
250{
251 root_ = deepcopyTree(rhs.root_);
252 size_ = rhs.size_;
253}
254
255template <class KeyType, class ValueType>
256gsKDTree<KeyType, ValueType>& gsKDTree<KeyType, ValueType>::operator=(const gsKDTree& rhs)
257{
258 if (this != &rhs) { // make sure we don't self-assign
259 freeResource(root_);
260 root_ = deepcopyTree(rhs.root_);
261 size_ = rhs.size_;
262 }
263 return *this;
264}
265
266template <class KeyType, class ValueType>
267void gsKDTree<KeyType, ValueType>::freeResource(typename gsKDTree<KeyType, ValueType>::Node* currNode)
268{
269 if (currNode == NULL) return;
270 freeResource(currNode->left);
271 freeResource(currNode->right);
272 delete currNode;
273}
274
275template <class KeyType, class ValueType>
276gsKDTree<KeyType, ValueType>::~gsKDTree()
277{
278 clear();
279}
280
281template <class KeyType, class ValueType>
282void gsKDTree<KeyType, ValueType>::clear()
283{
284 freeResource(root_);
285}
286
287template <class KeyType, class ValueType>
288std::size_t gsKDTree<KeyType, ValueType>::dimension() const
289{
290 return gsKDTreeTraits<KeyType>::size();
291}
292
293template <class KeyType, class ValueType>
294std::size_t gsKDTree<KeyType, ValueType>::size() const
295{
296 return size_;
297}
298
299template <class KeyType, class ValueType>
300bool gsKDTree<KeyType, ValueType>::empty() const
301{
302 return size_ == 0;
303}
304
305template <class KeyType, class ValueType>
306typename gsKDTree<KeyType, ValueType>::Node*
307gsKDTree<KeyType, ValueType>::findNode(typename gsKDTree<KeyType, ValueType>::Node* currNode,
308 const KeyType& key) const
309{
310 if (currNode == NULL || currNode->point == key) return currNode;
311
312 const KeyType& currPoint = currNode->point;
313 int currLevel = currNode->level;
314 if (gsKDTreeTraits<KeyType>::islhalf(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size()))
315 {
316 // recurse to the left side
317 return currNode->left == NULL ? currNode : findNode(currNode->left, key);
318 } else {
319 // recurse to the right side
320 return currNode->right == NULL ? currNode : findNode(currNode->right, key);
321 }
322}
323
324template <class KeyType, class ValueType>
325bool gsKDTree<KeyType, ValueType>::contains(const KeyType& key) const
326{
327 auto node = findNode(root_, key);
328 return node != NULL && node->point == key;
329}
330
331template <class KeyType, class ValueType>
332void gsKDTree<KeyType, ValueType>::insert(const KeyType& key, const ValueType& value)
333{
334 auto targetNode = findNode(root_, key);
335 if (targetNode == NULL) { // this means the tree is empty
336 root_ = new Node(key, 0, value);
337 size_ = 1;
338 } else {
339 if (targetNode->point == key) { // key is already in the tree, simply update its value
340 targetNode->value = value;
341 } else { // construct a new node and insert it to the right place (child of targetNode)
342 int currLevel = targetNode->level;
343 Node* newNode = new Node(key, currLevel + 1, value);
344 if (gsKDTreeTraits<KeyType>::islhalf(key, targetNode->point, currLevel%gsKDTreeTraits<KeyType>::size())) {
345 targetNode->left = newNode;
346 } else {
347 targetNode->right = newNode;
348 }
349 ++size_;
350 }
351 }
352}
353
354template <class KeyType, class ValueType>
355const ValueType& gsKDTree<KeyType, ValueType>::at(const KeyType& key) const
356{
357 auto node = findNode(root_, key);
358 if (node == NULL || node->point != key) {
359 throw std::out_of_range("Key not found in gsKDTree");
360 } else {
361 return node->value;
362 }
363}
364
365template <class KeyType, class ValueType>
366ValueType& gsKDTree<KeyType, ValueType>::at(const KeyType& key)
367{
368 const gsKDTree<KeyType, ValueType>& constThis = *this;
369 return const_cast<ValueType&>(constThis.at(key));
370}
371
372template <class KeyType, class ValueType>
373ValueType& gsKDTree<KeyType, ValueType>::operator[](const KeyType& key)
374{
375 auto node = findNode(root_, key);
376 if (node != NULL && node->point == key) { // key is already in the tree
377 return node->value;
378 } else { // insert key with default ValueType value, and return reference to the new ValueType
379 insert(key);
380 if (node == NULL) return root_->value; // the new node is the root
381 else return (node->left != NULL && node->left->point == key) ? node->left->value: node->right->value;
382 }
383}
384
385template <class KeyType, class ValueType>
386void gsKDTree<KeyType, ValueType>::nearestNeighborRecurse(const typename gsKDTree<KeyType, ValueType>::Node* currNode,
387 const KeyType& key,
388 gsBoundedPriorityQueue<ValueType>& bpq) const
389{
390 if (currNode == NULL) return;
391 const KeyType& currPoint = currNode->point;
392
393 // Add the current point to the BPQ if it is closer to 'key' that some point in the BPQ
394 bpq.enqueue(currNode->value, gsKDTreeTraits<KeyType>::distance(key, currPoint));
395
396 // Recursively search the half of the tree that contains Point 'key'
397 int currLevel = currNode->level;
398 bool isLeftTree;
399 if (gsKDTreeTraits<KeyType>::islhalf(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size())) {
400 nearestNeighborRecurse(currNode->left, key, bpq);
401 isLeftTree = true;
402 } else {
403 nearestNeighborRecurse(currNode->right, key, bpq);
404 isLeftTree = false;
405 }
406
407 if (bpq.size() < bpq.maxSize() ||
408 gsKDTreeTraits<KeyType>::fabs(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size()) < bpq.worst()) {
409 // Recursively search the other half of the tree if necessary
410 if (isLeftTree) nearestNeighborRecurse(currNode->right, key, bpq);
411 else nearestNeighborRecurse(currNode->left, key, bpq);
412 }
413}
414
415template <class KeyType, class ValueType>
416void gsKDTree<KeyType, ValueType>::nearestNeighborRecurse(const typename gsKDTree<KeyType, ValueType>::Node* currNode,
417 const KeyType& key,
418 gsBoundedPriorityQueue<ValueType*>& bpq) const
419{
420 if (currNode == NULL) return;
421 const KeyType& currPoint = currNode->point;
422
423 // Add the current point to the BPQ if it is closer to 'key' that some point in the BPQ
424 bpq.enqueue(const_cast<ValueType*>(&(currNode->value)), gsKDTreeTraits<KeyType>::distance(key, currPoint));
425
426 // Recursively search the half of the tree that contains Point 'key'
427 int currLevel = currNode->level;
428 bool isLeftTree;
429 if (gsKDTreeTraits<KeyType>::islhalf(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size())) {
430 nearestNeighborRecurse(currNode->left, key, bpq);
431 isLeftTree = true;
432 } else {
433 nearestNeighborRecurse(currNode->right, key, bpq);
434 isLeftTree = false;
435 }
436
437 if (bpq.size() < bpq.maxSize() ||
438 gsKDTreeTraits<KeyType>::fabs(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size()) < bpq.worst()) {
439 // Recursively search the other half of the tree if necessary
440 if (isLeftTree) nearestNeighborRecurse(currNode->right, key, bpq);
441 else nearestNeighborRecurse(currNode->left, key, bpq);
442 }
443}
444
445template <class KeyType, class ValueType>
446ValueType gsKDTree<KeyType, ValueType>::kNNValue(const KeyType& key, std::size_t k) const
447{
448 // BPQ with maximum size k
449 gsBoundedPriorityQueue<ValueType> bpq(k);
450 if (empty()) throw std::out_of_range("gsKDTree is empty");
451
452 // Recursively search the kd-tree with pruning
453 nearestNeighborRecurse(root_, key, bpq);
454
455 // Ensure finite values; non-standard 'distance' functions can be
456 // used to exclude data elements that are close to the given key but
457 // on the 'wrong' side of the hyperplane. This allows to exclude
458 // nearest neighbours that are, e.g., smaller than the given key.
459 if (!math::isfinite(bpq.best()))
460 throw std::out_of_range("gsKDTree does not contain finite value");
461
462 // Count occurrences of all ValueType in the kNN set
463 std::unordered_map<ValueType, int> counter;
464 while (!bpq.empty()) {
465 ++counter[bpq.dequeueMin()];
466 }
467
468 // Return the most frequent element in the kNN set
469 ValueType result;
470 int cnt = -1;
471 for (const auto &p : counter) {
472 if (p.second > cnt) {
473 result = p.first;
474 cnt = p.second;
475 }
476 }
477 return result;
478}
479
480template <class KeyType, class ValueType>
481ValueType& gsKDTree<KeyType, ValueType>::kNNValue(const KeyType& key, std::size_t k)
482{
483 // BPQ with maximum size k
484 gsBoundedPriorityQueue<ValueType*> bpq(k);
485 if (empty())
486 throw std::out_of_range("gsKDTree is empty");
487
488 // Recursively search the kd-tree with pruning
489 nearestNeighborRecurse(root_, key, bpq);
490
491 // Ensure finite values; non-standard 'distance' functions can be
492 // used to exclude data elements that are close to the given key but
493 // on the 'wrong' side of the hyperplane. This allows to exclude
494 // nearest neighbours that are, e.g., smaller than the given key.
495 if (!math::isfinite(bpq.best()))
496 throw std::out_of_range("gsKDTree does not contain finite value");
497
498 // Count occurrences of all ValueType in the kNN set
499 std::unordered_map<ValueType*, int> counter;
500 while (!bpq.empty()) {
501 ++counter[bpq.dequeueMin()];
502 }
503
504 // Return the most frequent element in the kNN set
505 ValueType* result = nullptr;
506 int cnt = -1;
507 for (const auto &p : counter) {
508 if (p.second > cnt) {
509 result = p.first;
510 cnt = p.second;
511 }
512 }
513 return *result;
514}
515
516template <class KeyType, class ValueType>
517void gsKDTree<KeyType, ValueType>::print(std::ostream& os) const
518{
519 os << "KD-tree: size= " << size() << ", dimension= " << dimension() << ".\n";
520}
521
522} //namespace gismo
An interface representing a kd-tree in some number of dimensions.
Definition gsKDTree.h:58
friend std::ostream & operator<<(std::ostream &os, const gsKDTree &obj)
Print (as string) operator.
Definition gsKDTree.h:137
void print(std::ostream &os) const
Prints the object as a string.
Definition gsKDTree.h:517
Provides declaration of bounded priority queue.
The G+Smo namespace, containing all definitions for the library.
T distance(gsMatrix< T > const &A, gsMatrix< T > const &B, index_t i=0, index_t j=0, bool cols=false)
compute a distance between the point number in the set and the point number <j> in the set ; by def...