Photon Engine 2.0.0-beta
A physically based renderer.
Loading...
Searching...
No Matches
TIndexedPointKdtree.h
Go to the documentation of this file.
1#pragma once
2
3#include "Math/TVector3.h"
6#include "Utility/TSpan.h"
7#include "Utility/utility.h"
8
9#include <Common/assertion.h>
10
11#include <vector>
12#include <utility>
13#include <cstddef>
14#include <algorithm>
15#include <array>
16#include <memory>
17#include <type_traits>
18#include <limits>
19#include <concepts>
20
21namespace ph::math
22{
23
24template<typename Storage, typename Item>
26 std::default_initializable<Storage> &&
27 std::copyable<Storage> &&
28 requires (Storage storage)
29 {
30 { storage[std::size_t{}] } -> std::convertible_to<Item>;
31 { storage.size() } -> std::convertible_to<std::size_t>;
32 };
33
34template
35<
36 typename Item,
37 typename Index,
38 typename PointCalculator,
39 CIndexedPointKdtreeItemStorage<Item> ItemStorage = std::vector<Item>
40>
42{
43 // TODO: static_assert for signature of PointCalculator
44
45public:
46 struct BuildCache final
47 {
48 std::vector<Index> itemIndices;
49 std::vector<math::Vector3R> itemPoints;
50 };
51
54 TIndexedPointKdtree(const std::size_t maxNodeItems, const PointCalculator& pointCalculator)
55 : m_nodeBuffer ()
56 , m_numNodes (0)
57 , m_items ()
58 , m_rootAABB ()
59 , m_maxNodeItems (maxNodeItems)
60 , m_indexBuffer ()
61 , m_pointCalculator(pointCalculator)
62 {
63 PH_ASSERT_GT(maxNodeItems, 0);
64 }
65
69 void build(ItemStorage items)
70 {
71 BuildCache buildCache;
72 build(std::move(items), buildCache);
73 }
74
80 void build(ItemStorage items, BuildCache& buildCache)
81 {
82 m_nodeBuffer.clear();
83 m_numNodes = 0;
84 m_items = std::move(items);
85 m_rootAABB = math::AABB3D();
86 m_indexBuffer.clear();
87 if(m_items.size() == 0)
88 {
89 return;
90 }
91
92 auto& itemPoints = buildCache.itemPoints;
93 itemPoints.resize(m_items.size());
94 for(std::size_t i = 0; i < m_items.size(); ++i)
95 {
96 const auto& item = m_items[i];
97
98 const math::Vector3R& center = m_pointCalculator(ref_access(item));
99 itemPoints[i] = center;
100 }
101
102 PH_ASSERT_LE(m_items.size() - 1, std::numeric_limits<Index>::max());
103 auto& itemIndices = buildCache.itemIndices;
104 itemIndices.resize(m_items.size());
105 for(std::size_t i = 0; i < m_items.size(); ++i)
106 {
107 itemIndices[i] = static_cast<Index>(i);
108 }
109
110 m_rootAABB = calcPointsAABB(itemIndices, itemPoints);
111
112 buildNodeRecursive(
113 0,
114 m_rootAABB,
115 itemIndices,
116 itemPoints,
117 0);
118 }
119
121 const math::Vector3R& location,
122 const real searchRadius,
123 std::vector<Item>& results) const
124 {
125 PH_ASSERT_GT(m_numNodes, 0);
126
127 const real searchRadius2 = searchRadius * searchRadius;
128
130 location,
131 searchRadius2,
132 [this, location, searchRadius2, &results](const Item& item)
133 {
134 const math::Vector3R itemPoint = m_pointCalculator(item);
135 const real dist2 = (itemPoint - location).lengthSquared();
136 if(dist2 < searchRadius2)
137 {
138 results.push_back(item);
139 }
140 });
141 }
142
144 const math::Vector3R& location,
145 const std::size_t maxItems,
146 std::vector<Item>& results) const
147 {
148 PH_ASSERT_GT(m_numNodes, 0);
149 PH_ASSERT_GT(maxItems, 0);
150
151 // OPT: distance calculation can be cached
152
153 auto isACloserThanB =
154 [this, location](const Item& itemA, const Item& itemB) -> bool
155 {
156 return (m_pointCalculator(itemA) - location).lengthSquared() <
157 (m_pointCalculator(itemB) - location).lengthSquared();
158 };
159
160 real searchRadius2 = std::numeric_limits<real>::max();
161 std::size_t numFoundItems = 0;
162 auto handler =
163 [this, location, maxItems,
164 &searchRadius2, &numFoundItems, &isACloserThanB, &results]
165 (const Item& item)
166 {
167 /*
168 If k nearest neighbors are required and n items are processed
169 by this handler, this handler (not including traversal) will
170 take O(k + (n-k)*log(k)) time in total.
171 */
172
173 // output buffer is not full, just insert the item
174 if(numFoundItems < maxItems)
175 {
176 results.push_back(item);
177 numFoundItems++;
178
179 // once output buffer is full, make it a max heap
180 if(numFoundItems == maxItems)
181 {
182 // this takes O(k) time
183 std::make_heap(
184 results.end() - maxItems,
185 results.end(),
186 isACloserThanB);
187
188 // the furthest one is at the max heap's root
189 const Item& furthestItem = results[results.size() - maxItems];
190
191 // search radius can be shrunk in this case
192 searchRadius2 = (m_pointCalculator(furthestItem) - location).lengthSquared();
193 }
194 }
195 // last <maxItems> items in output buffer forms a max heap now
196 else
197 {
198 // the furthest one is at the max heap's root
199 const Item& furthestItem = results[results.size() - maxItems];
200
201 if(isACloserThanB(item, furthestItem))
202 {
203 // remove furthest item, this takes O(log(k)) time
204 std::pop_heap(
205 results.end() - maxItems,
206 results.end(),
207 isACloserThanB);
208
209 results.back() = item;
210
211 // add new item, this takes O(log(k)) time
212 std::push_heap(
213 results.end() - maxItems,
214 results.end(),
215 isACloserThanB);
216
217 // search radius can be shrunk in this case
218 searchRadius2 = (m_pointCalculator(furthestItem) - location).lengthSquared();
219 }
220 }
221
222 return searchRadius2;
223 };
224
226 location,
227 searchRadius2,
228 handler);
229 }
230
231 template<typename ItemHandler>
233 const math::Vector3R& location,
234 const real squaredSearchRadius,
235 ItemHandler itemHandler) const
236 {
237 static_assert(std::is_invocable_v<ItemHandler, Item>,
238 "ItemHandler must accept an item as input.");
239
240 PH_ASSERT_GT(m_numNodes, 0);
241 PH_ASSERT_LE(squaredSearchRadius, std::numeric_limits<real>::max());
242
243 constexpr std::size_t MAX_STACK_HEIGHT = 64;
244 std::array<const Node*, MAX_STACK_HEIGHT> nodeStack;
245
246 const Node* currentNode = &(m_nodeBuffer[0]);
247 std::size_t stackHeight = 0;
248 while(true)
249 {
250 PH_ASSERT(currentNode);
251 if(!currentNode->isLeaf())
252 {
253 const auto splitAxis = currentNode->getSplitAxis();
254 const real splitPos = currentNode->getSplitPos();
255 const real splitPlaneDiff = location[splitAxis] - splitPos;
256
257 const Node* nearNode;
258 const Node* farNode;
259 if(splitPlaneDiff < 0)
260 {
261 nearNode = currentNode + 1;
262 farNode = &(m_nodeBuffer[currentNode->getPositiveChildIndex()]);
263 }
264 else
265 {
266 nearNode = &(m_nodeBuffer[currentNode->getPositiveChildIndex()]);
267 farNode = currentNode + 1;
268 }
269
270 currentNode = nearNode;
271 if(squaredSearchRadius >= splitPlaneDiff * splitPlaneDiff)
272 {
273 PH_ASSERT(stackHeight < MAX_STACK_HEIGHT);
274 nodeStack[stackHeight++] = farNode;
275 }
276 }
277 // current node is leaf
278 else
279 {
280 const std::size_t numItems = currentNode->numItems();
281 const std::size_t indexBufferOffset = currentNode->getIndexBufferOffset();
282 for(std::size_t i = 0; i < numItems; ++i)
283 {
284 const Index itemIndex = m_indexBuffer[indexBufferOffset + i];
285 const Item& item = m_items[itemIndex];
286
287 itemHandler(item);
288 }
289
290 if(stackHeight > 0)
291 {
292 currentNode = nodeStack[--stackHeight];
293 }
294 else
295 {
296 break;
297 }
298 }
299 }// end while stackHeight > 0
300 }
301
302 template<typename ItemHandler>
304 const math::Vector3R& location,
305 const real initialSquaredSearchRadius,
306 ItemHandler itemHandler) const
307 {
308 static_assert(std::is_invocable_v<ItemHandler, Item>,
309 "ItemHandler must accept an item as input.");
310
311 using Return = decltype(itemHandler(std::declval<Item>()));
312 static_assert(std::is_same_v<Return, real>,
313 "ItemHandler must return an potentially shrunk squared search radius.");
314
315 PH_ASSERT_GT(m_numNodes, 0);
316 PH_ASSERT_LE(initialSquaredSearchRadius, std::numeric_limits<real>::max());
317
318 struct NodeRecord
319 {
320 const Node* node;
321
322 // The value is zero for near nodes as they should not be skipped by distance test.
323 real parentSplitPlaneDiff2;
324 };
325
326 constexpr std::size_t MAX_STACK_HEIGHT = 64;
327 std::array<NodeRecord, MAX_STACK_HEIGHT> nodeStack;
328
329 NodeRecord currentNode = {&(m_nodeBuffer[0]), 0};
330 std::size_t stackHeight = 0;
331 real currentRadius2 = initialSquaredSearchRadius;
332 while(true)
333 {
334 PH_ASSERT(currentNode.node);
335 if(!currentNode.node->isLeaf())
336 {
337 const auto splitAxis = currentNode.node->getSplitAxis();
338 const real splitPos = currentNode.node->getSplitPos();
339 const real splitPlaneDiff = location[splitAxis] - splitPos;
340
341 const Node* nearNode;
342 const Node* farNode;
343 if(splitPlaneDiff < 0)
344 {
345 nearNode = currentNode.node + 1;
346 farNode = &(m_nodeBuffer[currentNode.node->getPositiveChildIndex()]);
347 }
348 else
349 {
350 nearNode = &(m_nodeBuffer[currentNode.node->getPositiveChildIndex()]);
351 farNode = currentNode.node + 1;
352 }
353
354 const real splitPlaneDiff2 = splitPlaneDiff * splitPlaneDiff;
355
356 currentNode = {nearNode, 0};
357 if(currentRadius2 >= splitPlaneDiff2)
358 {
359 PH_ASSERT(stackHeight < MAX_STACK_HEIGHT);
360 nodeStack[stackHeight++] = {farNode, splitPlaneDiff2};
361 }
362 }
363 // current node is leaf
364 else
365 {
366 // For far nodes, they can be culled if radius has shrunk.
367 // For near nodes, they have <parentSplitPlaneDiff2> == 0 hence cannot be skipped.
368 if(currentRadius2 >= currentNode.parentSplitPlaneDiff2)
369 {
370 const std::size_t numItems = currentNode.node->numItems();
371 const std::size_t indexBufferOffset = currentNode.node->getIndexBufferOffset();
372 for(std::size_t i = 0; i < numItems; ++i)
373 {
374 const Index itemIndex = m_indexBuffer[indexBufferOffset + i];
375 const Item& item = m_items[itemIndex];
376
377 // potentially reduce search radius
378 const real shrunkRadius2 = itemHandler(item);
379 PH_ASSERT_LE(shrunkRadius2, currentRadius2);
380 currentRadius2 = shrunkRadius2;
381 }
382 }
383
384 if(stackHeight > 0)
385 {
386 currentNode = nodeStack[--stackHeight];
387 }
388 else
389 {
390 break;
391 }
392 }
393 }// end while stackHeight > 0
394 }
395
396 std::size_t numItems() const
397 {
398 return m_items.size();
399 }
400
401private:
403
404 void buildNodeRecursive(
405 const std::size_t nodeIndex,
406 const math::AABB3D& nodeAABB,
407 const TSpan<Index> nodeItemIndices,
408 const TSpanView<math::Vector3R> itemPoints,
409 const std::size_t currentNodeDepth)
410 {
411 ++m_numNodes;
412 if(m_numNodes > m_nodeBuffer.size())
413 {
414 m_nodeBuffer.resize(m_numNodes * 2);
415 }
416 PH_ASSERT_LT(nodeIndex, m_nodeBuffer.size());
417
418 if(nodeItemIndices.size() <= m_maxNodeItems)
419 {
420 m_nodeBuffer[nodeIndex] = Node::makeLeaf(nodeItemIndices, m_indexBuffer);
421 return;
422 }
423
424 const math::Vector3R& nodeExtents = nodeAABB.getExtents();
425 const auto splitAxis = nodeExtents.maxDimension();
426
427 const std::size_t midIndicesIndex = nodeItemIndices.size() / 2;
428 std::nth_element(
429 nodeItemIndices.begin(),
430 nodeItemIndices.begin() + midIndicesIndex,
431 nodeItemIndices.end(),
432 [itemPoints, splitAxis](const Index& a, const Index& b) -> bool
433 {
434 return itemPoints[a][splitAxis] < itemPoints[b][splitAxis];
435 });
436
437 const real splitPos = itemPoints[nodeItemIndices[midIndicesIndex]][splitAxis];
438
439 math::Vector3R splitPosMinVertex = nodeAABB.getMinVertex();
440 math::Vector3R splitPosMaxVertex = nodeAABB.getMaxVertex();
441 splitPosMinVertex[splitAxis] = splitPos;
442 splitPosMaxVertex[splitAxis] = splitPos;
443 const math::AABB3D negativeNodeAABB(nodeAABB.getMinVertex(), splitPosMaxVertex);
444 const math::AABB3D positiveNodeAABB(splitPosMinVertex, nodeAABB.getMaxVertex());
445
446 buildNodeRecursive(
447 nodeIndex + 1,
448 negativeNodeAABB,
449 nodeItemIndices.subspan(0, midIndicesIndex),
450 itemPoints,
451 currentNodeDepth + 1);
452
453 const std::size_t positiveChildIndex = m_numNodes;
454 m_nodeBuffer[nodeIndex] = Node::makeInner(splitPos, splitAxis, positiveChildIndex);
455
456 buildNodeRecursive(
457 positiveChildIndex,
458 positiveNodeAABB,
459 nodeItemIndices.subspan(midIndicesIndex),
460 itemPoints,
461 currentNodeDepth + 1);
462 }
463
466 static math::AABB3D calcPointsAABB(
467 const TSpanView<Index> pointIndices,
468 const TSpanView<math::Vector3R> points)
469 {
470 PH_ASSERT_GT(pointIndices.size(), 0);
471
472 math::AABB3D pointsAABB(points[pointIndices[0]]);
473 for(std::size_t i = 1; i < pointIndices.size(); ++i)
474 {
475 pointsAABB.unionWith(points[pointIndices[i]]);
476 }
477 return pointsAABB;
478 }
479
480 std::vector<Node> m_nodeBuffer;
481 std::size_t m_numNodes;
482 ItemStorage m_items;
483 math::AABB3D m_rootAABB;
484 std::size_t m_maxNodeItems;
485 std::vector<Index> m_indexBuffer;
486 PointCalculator m_pointCalculator;
487};
488
489}// end namespace ph::math
const TVector3< T > & getMaxVertex() const
Get the corner vertex of the maximum (+++) octant.
Definition TAABB3D.ipp:152
TVector3< T > getExtents() const
Get the side lengths of the bound.
Definition TAABB3D.ipp:171
const TVector3< T > & getMinVertex() const
Get the corner vertex of the minimum (—) octant.
Definition TAABB3D.ipp:146
An indexed kD-tree node with compacted memory layout.
Definition TIndexedKdtreeNode.h:23
bool isLeaf() const
Definition TIndexedKdtreeNode.h:181
std::size_t getSplitAxis() const
Definition TIndexedKdtreeNode.h:219
static TIndexedKdtreeNode makeInner(real splitPos, std::size_t splitAxisIndex, std::size_t positiveChildIndex)
Definition TIndexedKdtreeNode.h:105
std::size_t numItems() const
Definition TIndexedKdtreeNode.h:199
static TIndexedKdtreeNode makeLeaf(Index indexBufferOffset, std::size_t numItems)
Definition TIndexedKdtreeNode.h:128
std::size_t getIndexBufferOffset() const
Definition TIndexedKdtreeNode.h:248
std::size_t getPositiveChildIndex() const
Definition TIndexedKdtreeNode.h:189
real getSplitPos() const
Definition TIndexedKdtreeNode.h:209
Definition TIndexedPointKdtree.h:42
TIndexedPointKdtree(const std::size_t maxNodeItems, const PointCalculator &pointCalculator)
Creates empty tree. Call build() to populate the tree.
Definition TIndexedPointKdtree.h:54
void findWithinRange(const math::Vector3R &location, const real searchRadius, std::vector< Item > &results) const
Definition TIndexedPointKdtree.h:120
void build(ItemStorage items, BuildCache &buildCache)
Populate the tree. Better for multiple builds.
Definition TIndexedPointKdtree.h:80
std::size_t numItems() const
Definition TIndexedPointKdtree.h:396
void rangeTraversal(const math::Vector3R &location, const real squaredSearchRadius, ItemHandler itemHandler) const
Definition TIndexedPointKdtree.h:232
void findNearest(const math::Vector3R &location, const std::size_t maxItems, std::vector< Item > &results) const
Definition TIndexedPointKdtree.h:143
void build(ItemStorage items)
Populate the tree. Better for build once then read only.
Definition TIndexedPointKdtree.h:69
void nearestTraversal(const math::Vector3R &location, const real initialSquaredSearchRadius, ItemHandler itemHandler) const
Definition TIndexedPointKdtree.h:303
std::size_t maxDimension() const
Definition TVectorNBase.ipp:81
Definition TIndexedPointKdtree.h:25
Math functions and utilities.
Definition TransformInfo.h:10
TAABB3D< real > AABB3D
Definition TAABB3D.h:21
TVector3< real > Vector3R
Definition math_fwd.h:52
std::span< const T, EXTENT > TSpanView
Same as TSpan, except that the objects are const-qualified. Note that for pointer types,...
Definition TSpan.h:19
T & ref_access(T &ref)
Definition utility.h:97
std::span< T, EXTENT > TSpan
A contiguous sequence of objects of type T. Effectively the same as std::span.
Definition TSpan.h:12
Definition TIndexedPointKdtree.h:47
std::vector< math::Vector3R > itemPoints
Definition TIndexedPointKdtree.h:49
std::vector< Index > itemIndices
Definition TIndexedPointKdtree.h:48