G+Smo  24.08.0
Geometry + Simulation Modules
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
gsKDTree.h
Go to the documentation of this file.
1 
18 #pragma once
19 
20 #include <stdexcept>
21 #include <cmath>
22 #include <vector>
23 #include <unordered_map>
24 #include <utility>
25 #include <algorithm>
26 
28 
29 namespace gismo
30 {
31 
32 template <class T>
33 struct gsKDTreeTraits
34 {
35  static inline std::size_t size() { return 1; }
36 
37  static inline bool islhalf(const T& lhs, const T& rhs, std::size_t axis) { return lhs[axis] < rhs[axis]; }
38 
39  static inline double fabs(const T& lhs, const T& rhs, std::size_t axis) { return std::abs(lhs[axis] - rhs[axis]); }
40 
41  static inline double distance(const T& lhs, const T& rhs)
42  {
43  double result = 0.0;
44  for (std::size_t i = 0; i < size(); ++i) {
45  result += fabs(lhs, rhs, i);
46  }
47  return result;
48  }
49 };
50 
57 template <class KeyType, class ValueType>
58 class gsKDTree {
59 public:
60 
61  // Constructs an empty gsKDTree.
62  gsKDTree();
63 
64  // Efficiently build a balanced kd-tree from a large set of data
65  gsKDTree(std::vector<std::pair<KeyType, ValueType> >& data);
66 
67  // Frees up all the dynamically allocated resources
68  ~gsKDTree();
69 
70  // Frees up all the dynamically allocated resources
71  void clear();
72 
73  // Deep-copies the contents of another gsKDTree into this one.
74  gsKDTree(const gsKDTree& other);
75 
76  // Deep-copies the contents of another gsKDTree into this one.
77  gsKDTree& operator=(const gsKDTree& other);
78 
79  // Returns the dimension of the data stored in this gsKDTree.
80  std::size_t dimension() const;
81 
82  // Returns the number of elements in the kd-tree.
83  std::size_t size() const;
84 
85  // Returns true if this gsKDTree is empty and false otherwise.
86  bool empty() const;
87 
88  // Returns true if the specified key is contained in the gsKDTree.
89  bool contains(const KeyType& key) const;
90 
91  /*
92  * Inserts the data with the given key into the gsKDTree,
93  * associating it with the specified value. If another data element
94  * with the same key already existed in the tree, the new value will
95  * overwrite the existing one.
96  */
97  void insert(const KeyType& key, const ValueType& value=ValueType());
98 
99  /*
100  * Returns a reference to the value associated with the data stored
101  * under the given key in the gsKDTree. If the key does not exist,
102  * then it is added to the gsKDTree using the default value of
103  * ValueType as its value.
104  */
105  ValueType& operator[](const KeyType& key);
106 
107  /*
108  * Returns a reference to the value associated with the given
109  * key. If the key is not in the tree, this function throws an
110  * out_of_range exception.
111  */
112  ValueType& at(const KeyType& key);
113  const ValueType& at(const KeyType& key) const;
114 
115  /*
116  * Given a key and an integer k, finds the k data elements in the
117  * gsKDTree nearest to the data element associated with the given
118  * key and returns the most common value associated with those data
119  * elements. In the event of a tie, one of the most frequent value
120  * will be chosen.
121  */
122  ValueType kNNValue(const KeyType& key, std::size_t k) const;
123 
124  /*
125  * Given a key and an integer k, finds the k data elements in the
126  * gsKDTree nearest to the data element associated with the given
127  * key and returns a reference to the most common value associated
128  * with those data elements. In the event of a tie, one of the most
129  * frequent value will be chosen.
130  */
131  ValueType& kNNValue(const KeyType& key, std::size_t k);
132 
134  void print(std::ostream &os) const;
135 
137  friend std::ostream &operator<<(std::ostream &os, const gsKDTree &obj)
138  {
139  obj.print(os);
140  return os;
141  }
142 
143 private:
144  struct Node {
145  KeyType point;
146  Node *left;
147  Node *right;
148  int level; // level of the node in the tree, starts at 0 for the root
149  ValueType value;
150  Node(const KeyType& _key, int _level, const ValueType& _value=ValueType()):
151  point(_key), left(NULL), right(NULL), level(_level), value(_value) {}
152  };
153 
154  // Root node of the gsKDTree
155  Node* root_;
156 
157  // Number of points in the gsKDTree
158  std::size_t size_;
159 
160  /*
161  * Recursively build a subtree that satisfies the kd-tree invariant using points in [start, end)
162  * At each level, we split points into two halves using the median of the points as pivot
163  * The root of the subtree is at level 'currLevel'
164  * O(n) time partitioning algorithm is used to locate the median element
165  */
166  Node* buildTree(typename std::vector<std::pair<KeyType, ValueType> >::iterator start,
167  typename std::vector<std::pair<KeyType, ValueType> >::iterator end,
168  int currLevel);
169 
170  /*
171  * Returns the Node that contains element with given key if it is present in subtree 'currNode'
172  * Returns the Node below where key should be inserted if key is not in the subtree
173  */
174  Node* findNode(Node* currNode, const KeyType& key) const;
175 
176  // Recursive helper method for kNNValue(key, k)
177  void nearestNeighborRecurse(const Node* currNode,
178  const KeyType& key,
179  gsBoundedPriorityQueue<ValueType>& bpq) const;
180 
181  // Recursive helper method for kNNValue(key, k)
182  void nearestNeighborRecurse(const Node* currNode,
183  const KeyType& key,
184  gsBoundedPriorityQueue<ValueType*>& bpq) const;
185 
186  /*
187  * Recursive helper method for copy constructor and assignment operator
188  * Deep copies tree 'root' and returns the root of the copied tree
189  */
190  Node* deepcopyTree(Node* root);
191 
192  // Recursively free up all resources of subtree rooted at 'currNode'
193  void freeResource(Node* currNode);
194 
195 }; // class gsKDTree
196 
197 template <class KeyType, class ValueType>
198 gsKDTree<KeyType, ValueType>::gsKDTree() :
199  root_(NULL), size_(0) { }
200 
201 template <class KeyType, class ValueType>
202 typename gsKDTree<KeyType, ValueType>::Node*
203 gsKDTree<KeyType, ValueType>::deepcopyTree(typename gsKDTree<KeyType, ValueType>::Node* root)
204 {
205  if (root == NULL) return NULL;
206  Node* newRoot = new Node(*root);
207  newRoot->left = deepcopyTree(root->left);
208  newRoot->right = deepcopyTree(root->right);
209  return newRoot;
210 }
211 
212 template <class KeyType, class ValueType>
213 typename gsKDTree<KeyType, ValueType>::Node*
214 gsKDTree<KeyType, ValueType>::buildTree(typename std::vector<std::pair<KeyType, ValueType> >::iterator start,
215  typename std::vector<std::pair<KeyType, ValueType>>::iterator end,
216  int currLevel)
217 {
218  if (start >= end) return NULL; // empty tree
219 
220  int axis = currLevel % gsKDTreeTraits<KeyType>::size(); // the axis to split on
221  auto cmp = [axis](const std::pair<KeyType, ValueType>& p1,
222  const std::pair<KeyType, ValueType>& p2) {
223  return p1.first[axis] < p2.first[axis];
224  };
225  std::size_t len = end - start;
226  auto mid = start + len / 2;
227  std::nth_element(start, mid, end, cmp); // linear time partition
228 
229  // move left (if needed) so that all the equal points are to the right
230  // The tree will still be balanced as long as there aren't many points that are equal along each axis
231  while (mid > start && (mid - 1)->first[axis] == mid->first[axis]) {
232  --mid;
233  }
234 
235  Node* newNode = new Node(mid->first, currLevel, mid->second);
236  newNode->left = buildTree(start, mid, currLevel + 1);
237  newNode->right = buildTree(mid + 1, end, currLevel + 1);
238  return newNode;
239 }
240 
241 template <class KeyType, class ValueType>
242 gsKDTree<KeyType, ValueType>::gsKDTree(std::vector<std::pair<KeyType, ValueType> >& data)
243 {
244  root_ = buildTree(data.begin(), data.end(), 0);
245  size_ = data.size();
246 }
247 
248 template <class KeyType, class ValueType>
249 gsKDTree<KeyType, ValueType>::gsKDTree(const gsKDTree& rhs)
250 {
251  root_ = deepcopyTree(rhs.root_);
252  size_ = rhs.size_;
253 }
254 
255 template <class KeyType, class ValueType>
256 gsKDTree<KeyType, ValueType>& gsKDTree<KeyType, ValueType>::operator=(const gsKDTree& rhs)
257 {
258  if (this != &rhs) { // make sure we don't self-assign
259  freeResource(root_);
260  root_ = deepcopyTree(rhs.root_);
261  size_ = rhs.size_;
262  }
263  return *this;
264 }
265 
266 template <class KeyType, class ValueType>
267 void gsKDTree<KeyType, ValueType>::freeResource(typename gsKDTree<KeyType, ValueType>::Node* currNode)
268 {
269  if (currNode == NULL) return;
270  freeResource(currNode->left);
271  freeResource(currNode->right);
272  delete currNode;
273 }
274 
275 template <class KeyType, class ValueType>
276 gsKDTree<KeyType, ValueType>::~gsKDTree()
277 {
278  clear();
279 }
280 
281 template <class KeyType, class ValueType>
282 void gsKDTree<KeyType, ValueType>::clear()
283 {
284  freeResource(root_);
285 }
286 
287 template <class KeyType, class ValueType>
288 std::size_t gsKDTree<KeyType, ValueType>::dimension() const
289 {
290  return gsKDTreeTraits<KeyType>::size();
291 }
292 
293 template <class KeyType, class ValueType>
294 std::size_t gsKDTree<KeyType, ValueType>::size() const
295 {
296  return size_;
297 }
298 
299 template <class KeyType, class ValueType>
300 bool gsKDTree<KeyType, ValueType>::empty() const
301 {
302  return size_ == 0;
303 }
304 
305 template <class KeyType, class ValueType>
306 typename gsKDTree<KeyType, ValueType>::Node*
307 gsKDTree<KeyType, ValueType>::findNode(typename gsKDTree<KeyType, ValueType>::Node* currNode,
308  const KeyType& key) const
309 {
310  if (currNode == NULL || currNode->point == key) return currNode;
311 
312  const KeyType& currPoint = currNode->point;
313  int currLevel = currNode->level;
314  if (gsKDTreeTraits<KeyType>::islhalf(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size()))
315  {
316  // recurse to the left side
317  return currNode->left == NULL ? currNode : findNode(currNode->left, key);
318  } else {
319  // recurse to the right side
320  return currNode->right == NULL ? currNode : findNode(currNode->right, key);
321  }
322 }
323 
324 template <class KeyType, class ValueType>
325 bool gsKDTree<KeyType, ValueType>::contains(const KeyType& key) const
326 {
327  auto node = findNode(root_, key);
328  return node != NULL && node->point == key;
329 }
330 
331 template <class KeyType, class ValueType>
332 void gsKDTree<KeyType, ValueType>::insert(const KeyType& key, const ValueType& value)
333 {
334  auto targetNode = findNode(root_, key);
335  if (targetNode == NULL) { // this means the tree is empty
336  root_ = new Node(key, 0, value);
337  size_ = 1;
338  } else {
339  if (targetNode->point == key) { // key is already in the tree, simply update its value
340  targetNode->value = value;
341  } else { // construct a new node and insert it to the right place (child of targetNode)
342  int currLevel = targetNode->level;
343  Node* newNode = new Node(key, currLevel + 1, value);
344  if (gsKDTreeTraits<KeyType>::islhalf(key, targetNode->point, currLevel%gsKDTreeTraits<KeyType>::size())) {
345  targetNode->left = newNode;
346  } else {
347  targetNode->right = newNode;
348  }
349  ++size_;
350  }
351  }
352 }
353 
354 template <class KeyType, class ValueType>
355 const ValueType& gsKDTree<KeyType, ValueType>::at(const KeyType& key) const
356 {
357  auto node = findNode(root_, key);
358  if (node == NULL || node->point != key) {
359  throw std::out_of_range("Key not found in gsKDTree");
360  } else {
361  return node->value;
362  }
363 }
364 
365 template <class KeyType, class ValueType>
366 ValueType& gsKDTree<KeyType, ValueType>::at(const KeyType& key)
367 {
368  const gsKDTree<KeyType, ValueType>& constThis = *this;
369  return const_cast<ValueType&>(constThis.at(key));
370 }
371 
372 template <class KeyType, class ValueType>
373 ValueType& gsKDTree<KeyType, ValueType>::operator[](const KeyType& key)
374 {
375  auto node = findNode(root_, key);
376  if (node != NULL && node->point == key) { // key is already in the tree
377  return node->value;
378  } else { // insert key with default ValueType value, and return reference to the new ValueType
379  insert(key);
380  if (node == NULL) return root_->value; // the new node is the root
381  else return (node->left != NULL && node->left->point == key) ? node->left->value: node->right->value;
382  }
383 }
384 
385 template <class KeyType, class ValueType>
386 void gsKDTree<KeyType, ValueType>::nearestNeighborRecurse(const typename gsKDTree<KeyType, ValueType>::Node* currNode,
387  const KeyType& key,
388  gsBoundedPriorityQueue<ValueType>& bpq) const
389 {
390  if (currNode == NULL) return;
391  const KeyType& currPoint = currNode->point;
392 
393  // Add the current point to the BPQ if it is closer to 'key' that some point in the BPQ
394  bpq.enqueue(currNode->value, gsKDTreeTraits<KeyType>::distance(key, currPoint));
395 
396  // Recursively search the half of the tree that contains Point 'key'
397  int currLevel = currNode->level;
398  bool isLeftTree;
399  if (gsKDTreeTraits<KeyType>::islhalf(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size())) {
400  nearestNeighborRecurse(currNode->left, key, bpq);
401  isLeftTree = true;
402  } else {
403  nearestNeighborRecurse(currNode->right, key, bpq);
404  isLeftTree = false;
405  }
406 
407  if (bpq.size() < bpq.maxSize() ||
408  gsKDTreeTraits<KeyType>::fabs(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size()) < bpq.worst()) {
409  // Recursively search the other half of the tree if necessary
410  if (isLeftTree) nearestNeighborRecurse(currNode->right, key, bpq);
411  else nearestNeighborRecurse(currNode->left, key, bpq);
412  }
413 }
414 
415 template <class KeyType, class ValueType>
416 void gsKDTree<KeyType, ValueType>::nearestNeighborRecurse(const typename gsKDTree<KeyType, ValueType>::Node* currNode,
417  const KeyType& key,
418  gsBoundedPriorityQueue<ValueType*>& bpq) const
419 {
420  if (currNode == NULL) return;
421  const KeyType& currPoint = currNode->point;
422 
423  // Add the current point to the BPQ if it is closer to 'key' that some point in the BPQ
424  bpq.enqueue(const_cast<ValueType*>(&(currNode->value)), gsKDTreeTraits<KeyType>::distance(key, currPoint));
425 
426  // Recursively search the half of the tree that contains Point 'key'
427  int currLevel = currNode->level;
428  bool isLeftTree;
429  if (gsKDTreeTraits<KeyType>::islhalf(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size())) {
430  nearestNeighborRecurse(currNode->left, key, bpq);
431  isLeftTree = true;
432  } else {
433  nearestNeighborRecurse(currNode->right, key, bpq);
434  isLeftTree = false;
435  }
436 
437  if (bpq.size() < bpq.maxSize() ||
438  gsKDTreeTraits<KeyType>::fabs(key, currPoint, currLevel%gsKDTreeTraits<KeyType>::size()) < bpq.worst()) {
439  // Recursively search the other half of the tree if necessary
440  if (isLeftTree) nearestNeighborRecurse(currNode->right, key, bpq);
441  else nearestNeighborRecurse(currNode->left, key, bpq);
442  }
443 }
444 
445 template <class KeyType, class ValueType>
446 ValueType gsKDTree<KeyType, ValueType>::kNNValue(const KeyType& key, std::size_t k) const
447 {
448  // BPQ with maximum size k
449  gsBoundedPriorityQueue<ValueType> bpq(k);
450  if (empty()) throw std::out_of_range("gsKDTree is empty");
451 
452  // Recursively search the kd-tree with pruning
453  nearestNeighborRecurse(root_, key, bpq);
454 
455  // Ensure finite values; non-standard 'distance' functions can be
456  // used to exclude data elements that are close to the given key but
457  // on the 'wrong' side of the hyperplane. This allows to exclude
458  // nearest neighbours that are, e.g., smaller than the given key.
459  if (!math::isfinite(bpq.best()))
460  throw std::out_of_range("gsKDTree does not contain finite value");
461 
462  // Count occurrences of all ValueType in the kNN set
463  std::unordered_map<ValueType, int> counter;
464  while (!bpq.empty()) {
465  ++counter[bpq.dequeueMin()];
466  }
467 
468  // Return the most frequent element in the kNN set
469  ValueType result;
470  int cnt = -1;
471  for (const auto &p : counter) {
472  if (p.second > cnt) {
473  result = p.first;
474  cnt = p.second;
475  }
476  }
477  return result;
478 }
479 
480 template <class KeyType, class ValueType>
481 ValueType& gsKDTree<KeyType, ValueType>::kNNValue(const KeyType& key, std::size_t k)
482 {
483  // BPQ with maximum size k
484  gsBoundedPriorityQueue<ValueType*> bpq(k);
485  if (empty())
486  throw std::out_of_range("gsKDTree is empty");
487 
488  // Recursively search the kd-tree with pruning
489  nearestNeighborRecurse(root_, key, bpq);
490 
491  // Ensure finite values; non-standard 'distance' functions can be
492  // used to exclude data elements that are close to the given key but
493  // on the 'wrong' side of the hyperplane. This allows to exclude
494  // nearest neighbours that are, e.g., smaller than the given key.
495  if (!math::isfinite(bpq.best()))
496  throw std::out_of_range("gsKDTree does not contain finite value");
497 
498  // Count occurrences of all ValueType in the kNN set
499  std::unordered_map<ValueType*, int> counter;
500  while (!bpq.empty()) {
501  ++counter[bpq.dequeueMin()];
502  }
503 
504  // Return the most frequent element in the kNN set
505  ValueType* result = nullptr;
506  int cnt = -1;
507  for (const auto &p : counter) {
508  if (p.second > cnt) {
509  result = p.first;
510  cnt = p.second;
511  }
512  }
513  return *result;
514 }
515 
516 template <class KeyType, class ValueType>
517 void gsKDTree<KeyType, ValueType>::print(std::ostream& os) const
518 {
519  os << "KD-tree: size= " << size() << ", dimension= " << dimension() << ".\n";
520 }
521 
522 } //namespace gismo
void print(std::ostream &os) const
Prints the object as a string.
Definition: gsKDTree.h:517
friend std::ostream & operator<<(std::ostream &os, const gsKDTree &obj)
Print (as string) operator.
Definition: gsKDTree.h:137
T distance(gsMatrix< T > const &A, gsMatrix< T > const &B, index_t i=0, index_t j=0, bool cols=false)
compute a distance between the point number in the set and the point number &lt;j&gt; in the set ; by def...
An interface representing a kd-tree in some number of dimensions.
Definition: gsKDTree.h:58
bool isfinite(const gsEigen::MatrixBase< Derived > &x)
Check if all the entires if the matrix x are not INF (infinite)
Definition: gsLinearAlgebra.h:109
EIGEN_STRONG_INLINE abs_expr< E > abs(const E &u)
Absolute value.
Definition: gsExpressions.h:4488
Provides declaration of bounded priority queue.