Initial commit: Final state of the master project
This commit is contained in:
490
Research/core/Util/BinaryTree.h
Normal file
490
Research/core/Util/BinaryTree.h
Normal file
@@ -0,0 +1,490 @@
|
||||
#pragma once
|
||||
#include "../Defines.h"
|
||||
#include "../BitHelper.h"
|
||||
#include "../CollectionHelper.h"
|
||||
#include "../Serializer.h"
|
||||
#include "../../inc/tbb/parallel_for_each.h"
|
||||
#include "../../inc/tbb/parallel_for.h"
|
||||
#include <unordered_map>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <cassert>
|
||||
|
||||
template<typename T>
|
||||
class BinaryTree
|
||||
{
|
||||
private:
|
||||
template<typename U>
|
||||
struct BinaryTreeNode
|
||||
{
|
||||
private:
|
||||
BinaryTreeNode<U>* mChildren[2];
|
||||
BinaryTree<U>* mTree;
|
||||
T mMaterial;
|
||||
|
||||
inline unsigned8 GetChildIndex(size_t coordinate, unsigned8 levelsLeft) const {
|
||||
return BitHelper::GetLS(coordinate, levelsLeft) ? 1 : 0;
|
||||
}
|
||||
|
||||
inline BinaryTreeNode<U>* GetNodeAt(size_t coordinate, unsigned8 levelsLeft) const
|
||||
{
|
||||
return mChildren[GetChildIndex(coordinate, levelsLeft)];
|
||||
}
|
||||
public:
|
||||
BinaryTreeNode(BinaryTree<U>* tree) :
|
||||
mTree(tree)
|
||||
{
|
||||
mChildren[0] = mChildren[1] = NULL;
|
||||
}
|
||||
|
||||
~BinaryTreeNode() {}
|
||||
|
||||
BinaryTreeNode* GetChild(unsigned8 index) const
|
||||
{
|
||||
assert(index < 2);
|
||||
return mChildren[index];
|
||||
}
|
||||
|
||||
void SetChild(BinaryTreeNode<U>* child, unsigned8 index) { mChildren[index] = child; }
|
||||
|
||||
size_t GetChildCount() const
|
||||
{
|
||||
return
|
||||
(mChildren[0] == NULL ? 0 : 1) +
|
||||
(mChildren[1] == NULL ? 0 : 1);
|
||||
}
|
||||
bool IsLeaf() const { return mChildren[0] == NULL && mChildren[1] == NULL; }
|
||||
|
||||
BinaryTreeNode<U>* AddNode(size_t coordinate, unsigned8 levelsLeft)
|
||||
{
|
||||
auto node = GetNodeAt(coordinate, levelsLeft);
|
||||
if (node == NULL)
|
||||
{
|
||||
node = mTree->Create();
|
||||
SetChild(node, GetChildIndex(coordinate, levelsLeft));
|
||||
}
|
||||
if (levelsLeft == 0)
|
||||
return node;
|
||||
else
|
||||
return node->AddNode(coordinate, levelsLeft - 1);
|
||||
}
|
||||
|
||||
U GetValueAt(size_t coordinate, unsigned8 levelsLeft) const
|
||||
{
|
||||
auto node = GetNodeAt(coordinate, levelsLeft);
|
||||
if (node == NULL) return GetValue();
|
||||
if (levelsLeft == 0) return node->GetValue();
|
||||
return node->GetValueAt(coordinate, levelsLeft - 1);
|
||||
}
|
||||
|
||||
U GetValue() const { return mMaterial; }
|
||||
|
||||
void SetValue(const T& material) { mMaterial = material; }
|
||||
|
||||
void Traverse(const std::function<void(const BinaryTreeNode<U>*)>& f) const
|
||||
{
|
||||
f.operator()(this);
|
||||
if (mChildren[0] != NULL) mChildren[0]->Traverse(f);
|
||||
if (mChildren[1] != NULL) mChildren[1]->Traverse(f);
|
||||
}
|
||||
|
||||
static bool Compare(BinaryTreeNode<U>* a, BinaryTreeNode<U>* b)
|
||||
{
|
||||
if (a->GetChild(0) != b->GetChild(0)) return (size_t)a->GetChild(0) < (size_t)b->GetChild(0);
|
||||
if (a->GetChild(1) != b->GetChild(1)) return (size_t)a->GetChild(1) < (size_t)b->GetChild(1);
|
||||
return a->GetValue() < b->GetValue();
|
||||
}
|
||||
|
||||
static bool Equals(BinaryTreeNode<U>* a, BinaryTreeNode<U>* b)
|
||||
{
|
||||
if (a == b) return true;
|
||||
if (a == NULL || b == NULL) return false;
|
||||
return a->GetChild(0) == b->GetChild(0) && a->GetChild(1) == b->GetChild(1) && a->GetValue() == b->GetValue();
|
||||
}
|
||||
};
|
||||
|
||||
unsigned8 mDepth;
|
||||
BinaryTreeNode<T>* mRoot;
|
||||
std::vector<BinaryTreeNode<T>*> mNodePool;
|
||||
|
||||
std::vector<BinaryTreeNode<T>*> GetNodes()
|
||||
{
|
||||
std::vector<BinaryTreeNode<T>*> res;
|
||||
std::function<void(const BinaryTreeNode<T>*)> nodeFinder = [&](const BinaryTreeNode<T>*) { res.push_back(this); };
|
||||
mRoot->Traverse(nodeFinder);
|
||||
CollectionHelper::Unique(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
void ShiftDown(unsigned8 levels)
|
||||
{
|
||||
// Shifts the root down the given number of levels
|
||||
if (levels == 0) return;
|
||||
|
||||
for (unsigned8 i = 0; i < levels; i++)
|
||||
{
|
||||
BinaryTreeNode<T>* newRoot = Create();
|
||||
newRoot->SetChild(mRoot, 0);
|
||||
BinaryTreeNode<T>* oldRoot = mRoot;
|
||||
mRoot = newRoot;
|
||||
mNodePool[0] = newRoot;
|
||||
mNodePool[mNodePool.size() - 1] = oldRoot;
|
||||
}
|
||||
}
|
||||
|
||||
static void CalculateNodeLevelsRecursive(BinaryTreeNode<T>* node, unsigned8 level, std::vector<unsigned8>& nodeLevels, const std::unordered_map<BinaryTreeNode<T>*, size_t>& nodeIndices)
|
||||
{
|
||||
assert(nodeIndices.find(node) != nodeIndices.end());
|
||||
auto nodeIndex = nodeIndices.find(node);
|
||||
nodeLevels[nodeIndex->second] = level;
|
||||
for (unsigned8 child = 0; child < 2; child++)
|
||||
{
|
||||
auto childNode = node->GetChild(child);
|
||||
if (childNode != NULL)
|
||||
CalculateNodeLevelsRecursive(childNode, level + 1, nodeLevels, nodeIndices);
|
||||
}
|
||||
}
|
||||
|
||||
// Calculates the levels of all nodes. Levels are stored in the same order as the node pool.
|
||||
std::vector<unsigned8> CalculateNodeLevels() const
|
||||
{
|
||||
auto nodeIndices = CollectionHelper::GetIndexMap(mNodePool);
|
||||
std::vector<unsigned8> nodeLevels(mNodePool.size());
|
||||
CalculateNodeLevelsRecursive(mRoot, 0, nodeLevels, nodeIndices);
|
||||
return nodeLevels;
|
||||
}
|
||||
|
||||
inline static size_t GetNodePointer(const size_t& index, const size_t& pointerSize, const size_t& valueSize, const size_t& firstLeafIndex, const bool& onlyLeafsContainValues)
|
||||
{
|
||||
if (!onlyLeafsContainValues)
|
||||
return 2 + index * (pointerSize * 2 + valueSize);
|
||||
else
|
||||
{
|
||||
size_t pointer = 2 + std::min(index, firstLeafIndex) * (pointerSize * 2);
|
||||
if (index > firstLeafIndex)
|
||||
pointer += (index - firstLeafIndex - 1) * (pointerSize * 2 + valueSize);
|
||||
return pointer;
|
||||
}
|
||||
}
|
||||
public:
|
||||
BinaryTree()
|
||||
{
|
||||
mDepth = 0;
|
||||
mRoot = Create();
|
||||
}
|
||||
~BinaryTree()
|
||||
{
|
||||
tbb::parallel_for_each(mNodePool.begin(), mNodePool.end(), [](BinaryTreeNode<T>* node) { delete node; });
|
||||
mNodePool.clear();
|
||||
}
|
||||
|
||||
void AddLeafNode(size_t coordinate) { AddNode(coordinate, mDepth); }
|
||||
void AddNode(size_t coordinate, unsigned8 level)
|
||||
{
|
||||
assert(level <= mDepth);
|
||||
mRoot->AddNode(coordinate, level);
|
||||
}
|
||||
|
||||
T GetValueAtLeaf(size_t coordinate) const { return GetValueAtNode(coordinate, mDepth); }
|
||||
T GetValueAtNode(size_t coordinate, unsigned8 level) const
|
||||
{
|
||||
assert(level <= mDepth);
|
||||
return mRoot->GetValueAt(coordinate, level);
|
||||
}
|
||||
void SetValueAtLeaf(size_t coordinate, T value) { SetValueAtNode(coordinate, mDepth, value); }
|
||||
void SetValueAtNode(size_t coordinate, unsigned8 level, T value)
|
||||
{
|
||||
assert(level <= mDepth);
|
||||
auto node = mRoot->AddNode(coordinate, level);
|
||||
node->SetValue(value);
|
||||
}
|
||||
|
||||
BinaryTreeNode<T>* Create()
|
||||
{
|
||||
BinaryTreeNode<T>* newNode = new BinaryTreeNode<T>(this);
|
||||
mNodePool.push_back(newNode);
|
||||
return newNode;
|
||||
}
|
||||
|
||||
void SetDepth(unsigned8 wantedDepth, bool shiftExisting)
|
||||
{
|
||||
if (mDepth != wantedDepth)
|
||||
{
|
||||
if (wantedDepth < mDepth)
|
||||
ShaveUntil(wantedDepth);
|
||||
else
|
||||
ShiftDown(wantedDepth - mDepth);
|
||||
}
|
||||
mDepth = wantedDepth;
|
||||
}
|
||||
|
||||
// Deletes all nodes of which the level is greater than the given level
|
||||
void ShaveUntil(unsigned8 level)
|
||||
{
|
||||
// Calculate the level for all nodes
|
||||
std::vector<unsigned8> nodeLevels = CalculateNodeLevels();
|
||||
for (size_t i = 0; i < mNodePool.size(); i++)
|
||||
if (nodeLevels[i] > level)
|
||||
{
|
||||
delete mNodePool[i];
|
||||
mNodePool[i] = NULL;
|
||||
}
|
||||
else if (nodeLevels[i] == level)
|
||||
{
|
||||
// Set all pointers to NULL:
|
||||
for (unsigned8 child = 0; child < 2; child++) mNodePool[i]->SetChild(NULL, child);
|
||||
}
|
||||
mNodePool.erase(std::remove_if(mNodePool.begin(), mNodePool.end(), [](const BinaryTreeNode<T>* node) { return node == NULL; }));
|
||||
}
|
||||
|
||||
void ReplaceValues(const std::unordered_map<T, T>& replacers)
|
||||
{
|
||||
tbb::parallel_for(size_t(0), mNodePool.size(), [&](const size_t& i)
|
||||
{
|
||||
T curValue = mNodePool[i]->GetValue();
|
||||
auto replacer = replacers.find(curValue);
|
||||
if (replacer != replacers.end())
|
||||
mNodePool[i]->SetValue(replacer->second);
|
||||
});
|
||||
}
|
||||
|
||||
void Serialize(std::ostream& file) const
|
||||
{
|
||||
Serializer<unsigned8>::Serialize(mDepth, file);
|
||||
|
||||
// Write the number of nodes in the binary tree:
|
||||
Serializer<unsigned64>::Serialize((unsigned64)mNodePool.size(), file);
|
||||
|
||||
// Write the node materials
|
||||
for (size_t i = 0; i < mNodePool.size(); i++)
|
||||
Serializer<T>::Serialize(mNodePool[i]->GetValue(), file);
|
||||
|
||||
// Write (consequtively) the child pointers
|
||||
std::unordered_map<BinaryTreeNode<T>*, size_t> nodeIndices = CollectionHelper::GetIndexMap(mNodePool);
|
||||
for (size_t i = 0; i < mNodePool.size(); i++)
|
||||
for (unsigned8 child = 0; child < 2; child++)
|
||||
{
|
||||
// Use 0 as NULL pointer, add 1 to all actual pointers
|
||||
BinaryTreeNode<T>* childNode = mNodePool[i]->GetChild(child);
|
||||
unsigned64 pointer = 0;
|
||||
if (childNode != NULL)
|
||||
{
|
||||
assert(nodeIndices.find(childNode) != nodeIndices.end());
|
||||
pointer = nodeIndices[childNode] + 1;
|
||||
}
|
||||
Serializer<unsigned64>::Serialize(pointer, file);
|
||||
}
|
||||
}
|
||||
|
||||
void Deserialize(std::istream& file)
|
||||
{
|
||||
if (mNodePool.size() > 1)
|
||||
{
|
||||
for (auto node = mNodePool.begin(); node != mNodePool.end(); node++)
|
||||
delete *node;
|
||||
mNodePool.resize(1);
|
||||
}
|
||||
|
||||
Serializer<unsigned8>::Deserialize(mDepth, file);
|
||||
|
||||
unsigned64 nodeCount;
|
||||
Serializer<unsigned64>::Deserialize(nodeCount, file);
|
||||
// The root is always already created, so create nodeCount - 1 nodes:
|
||||
for (size_t i = 0; i < nodeCount - 1; i++) Create();
|
||||
|
||||
// Deserialize the materials for each node
|
||||
T dummy;
|
||||
for (size_t i = 0; i < nodeCount; i++)
|
||||
{
|
||||
Serializer<T>::Deserialize(dummy, file);
|
||||
mNodePool[i]->SetValue(dummy);
|
||||
}
|
||||
|
||||
// Create all pointers
|
||||
for (size_t i = 0; i < mNodePool.size(); i++)
|
||||
for (unsigned8 child = 0; child < 2; child++)
|
||||
{
|
||||
unsigned64 pointer;
|
||||
Serializer<unsigned64>::Deserialize(pointer, file);
|
||||
if (pointer != 0)
|
||||
mNodePool[i]->SetChild(mNodePool[pointer - 1], child);
|
||||
}
|
||||
}
|
||||
|
||||
// Converts the current binary tree to a DAG, meaning that all duplicate nodes are removed.
|
||||
void ToDAG()
|
||||
{
|
||||
// Fill the current layer with all leaf nodes
|
||||
std::vector<BinaryTreeNode<T>*> dagNodePool(1, mRoot);
|
||||
std::vector<BinaryTreeNode<T>*> currentLayer;
|
||||
std::unordered_set<BinaryTreeNode<T>*> nodesLeft;
|
||||
std::unordered_map<BinaryTreeNode<T>*, std::vector<BinaryTreeNode<T>*>> parentsMap;
|
||||
for (auto node : mNodePool)
|
||||
{
|
||||
if (node->IsLeaf()) currentLayer.push_back(node);
|
||||
else
|
||||
{
|
||||
nodesLeft.insert(node);
|
||||
for (unsigned8 child = 0; child < 2; child++)
|
||||
{
|
||||
auto childNode = node->GetChild(child);
|
||||
if (childNode != NULL)
|
||||
{
|
||||
auto parent = parentsMap.find(childNode);
|
||||
if (parent == parentsMap.end())
|
||||
parentsMap.insert(std::make_pair(childNode, std::vector<BinaryTreeNode<T>*>(1, node)));
|
||||
else
|
||||
parent->second.push_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
while (!currentLayer.empty() && currentLayer[0] != mRoot)
|
||||
{
|
||||
// Find unique nodes and replace them
|
||||
tbb::parallel_sort(currentLayer.begin(), currentLayer.end(), [](BinaryTreeNode<T>* a, BinaryTreeNode<T>* b) { return BinaryTreeNode<T>::Compare(a, b); });
|
||||
BinaryTreeNode<T>* cur = NULL;
|
||||
std::vector<std::pair<BinaryTreeNode<T>*, BinaryTreeNode<T>*>> replacements;
|
||||
std::vector<BinaryTreeNode<T>*> nextLayer;
|
||||
size_t uniqueNodes = 0;
|
||||
for (auto node : currentLayer)
|
||||
{
|
||||
if (BinaryTreeNode<T>::Equals(node, cur))
|
||||
// Make sure that all nodes are replaced by their equals
|
||||
replacements.push_back(std::make_pair(node, cur));
|
||||
else
|
||||
{
|
||||
uniqueNodes++;
|
||||
cur = node;
|
||||
dagNodePool.push_back(cur);
|
||||
}
|
||||
auto parents = parentsMap.find(node);
|
||||
if (parents != parentsMap.end())
|
||||
nextLayer.insert(nextLayer.end(), parents->second.begin(), parents->second.end());
|
||||
}
|
||||
CollectionHelper::Unique(nextLayer, [](BinaryTreeNode<T>* a, BinaryTreeNode<T>* b) { return BinaryTreeNode<T>::Compare(a, b); });
|
||||
if (uniqueNodes != currentLayer.size())
|
||||
{
|
||||
for (auto replacement : replacements)
|
||||
{
|
||||
auto toReplace = replacement.first;
|
||||
auto replacer = replacement.second;
|
||||
auto parentsIt = parentsMap.find(toReplace);
|
||||
if (parentsIt == parentsMap.end())
|
||||
continue;
|
||||
std::vector<BinaryTreeNode<T>*> parents = parentsIt->second;
|
||||
for (auto parent : parents)
|
||||
{
|
||||
for (unsigned8 child = 0; child < 2; child++)
|
||||
{
|
||||
if (parent->GetChild(child) == toReplace)
|
||||
parent->SetChild(replacer, child);
|
||||
}
|
||||
}
|
||||
delete toReplace;
|
||||
}
|
||||
}
|
||||
currentLayer = nextLayer;
|
||||
}
|
||||
mNodePool = dagNodePool;
|
||||
}
|
||||
|
||||
std::vector<unsigned8> Serialize(bool onlyLeafsContainValues) const
|
||||
{
|
||||
// The first byte contains the number of bytes per pointer
|
||||
unsigned8 pointerSize = GetSerializedPointerByteSize(onlyLeafsContainValues);
|
||||
unsigned8 valueSize = GetMaterialByteSize();
|
||||
|
||||
std::vector<unsigned8> res(GetSerializedByteCount(onlyLeafsContainValues), 0);
|
||||
res[0] = mDepth;
|
||||
res[1] = pointerSize;
|
||||
|
||||
std::vector<BinaryTreeNode<T>*> nodePoolCopy(mNodePool.size());
|
||||
std::copy(mNodePool.begin(), mNodePool.end(), nodePoolCopy.begin());
|
||||
size_t firstLeafIndex = ~size_t(0);
|
||||
if (onlyLeafsContainValues)
|
||||
{
|
||||
tbb::parallel_sort(nodePoolCopy.begin() + 1, nodePoolCopy.end(), [](BinaryTreeNode<T>* a, BinaryTreeNode<T>* b)
|
||||
{
|
||||
return !(a->IsLeaf()) && (b->IsLeaf());
|
||||
});
|
||||
|
||||
for (size_t i = 0; i < nodePoolCopy.size(); i++)
|
||||
{
|
||||
BinaryTreeNode<T>* node = nodePoolCopy[i];
|
||||
if (node->IsLeaf() && firstLeafIndex > i)
|
||||
firstLeafIndex = i;
|
||||
}
|
||||
}
|
||||
std::unordered_map<BinaryTreeNode<T>*, size_t> nodeIndices = CollectionHelper::GetIndexMap(nodePoolCopy);
|
||||
|
||||
for (size_t i = 0; i < nodePoolCopy.size(); i++)
|
||||
{
|
||||
BinaryTreeNode<T>* node = nodePoolCopy[i];
|
||||
size_t nodePointer = GetNodePointer(i, pointerSize, valueSize, firstLeafIndex, onlyLeafsContainValues);
|
||||
// Write the node pointers
|
||||
for (unsigned8 child = 0; child < 2; child++)
|
||||
{
|
||||
BinaryTreeNode<T>* childNode = node->GetChild(child);
|
||||
if (childNode != NULL)
|
||||
{
|
||||
assert(nodeIndices.find(childNode) != nodeIndices.end());
|
||||
size_t childIndex = nodeIndices[childNode];
|
||||
size_t pointer = GetNodePointer(childIndex, pointerSize, valueSize, firstLeafIndex, onlyLeafsContainValues);
|
||||
BitHelper::SplitInBytesAndMove(pointer, res, nodePointer + pointerSize * child, pointerSize);
|
||||
}
|
||||
}
|
||||
// Then write the node content/value
|
||||
if (!onlyLeafsContainValues || node->IsLeaf())
|
||||
{
|
||||
std::vector<unsigned8> serializedValue = node->GetValue().Serialize();
|
||||
std::move(serializedValue.begin(), serializedValue.end(), res.begin() + nodePointer + pointerSize * 2);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t GetNodeCount() const { return mNodePool.size(); }
|
||||
size_t GetLeafNodeCount() const {
|
||||
size_t leafNodeCount = 0;
|
||||
for (auto node : mNodePool) if (node->IsLeaf()) leafNodeCount++;
|
||||
return leafNodeCount;
|
||||
}
|
||||
|
||||
static size_t GetSerializedByteCount(const size_t& nodeCount, const size_t& leafCount, const size_t& pointerSize, const size_t& valueSize, const bool& onlyLeafsContainValues)
|
||||
{
|
||||
return (onlyLeafsContainValues ? leafCount : nodeCount) * valueSize + (2 * pointerSize) * nodeCount;
|
||||
}
|
||||
|
||||
size_t GetSerializedByteCount(bool onlyLeafsContainValues) const
|
||||
{
|
||||
return GetSerializedByteCount(GetNodeCount(), GetLeafNodeCount(), GetSerializedPointerByteSize(onlyLeafsContainValues), GetMaterialByteSize(), onlyLeafsContainValues);
|
||||
}
|
||||
|
||||
unsigned8 GetSerializedNodeByteSize(bool onlyLeafsContainValues) const
|
||||
{
|
||||
return 2 * GetSerializedPointerByteSize(onlyLeafsContainValues) + GetMaterialByteSize();
|
||||
}
|
||||
|
||||
unsigned8 GetMaterialByteSize() const
|
||||
{
|
||||
return sizeof(T);
|
||||
}
|
||||
|
||||
unsigned8 GetSerializedPointerByteSize(bool onlyLeafsContainValues) const
|
||||
{
|
||||
// Count the number of leaf nodes:
|
||||
size_t leafNodeCount = 0;
|
||||
if (onlyLeafsContainValues)
|
||||
leafNodeCount = GetLeafNodeCount();
|
||||
bool fits = false;
|
||||
unsigned8 pointerSize = 0;
|
||||
while (!fits)
|
||||
{
|
||||
++pointerSize;
|
||||
size_t requiredBytes = GetSerializedByteCount(GetNodeCount(), leafNodeCount, pointerSize, GetMaterialByteSize(), onlyLeafsContainValues);
|
||||
fits = BitHelper::Exp2(8 * pointerSize) > requiredBytes;
|
||||
}
|
||||
return pointerSize;
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user