#pragma once #include "../Defines.h" #include "../BitHelper.h" #include "../CollectionHelper.h" #include "../Serializer.h" #include "../../inc/tbb/parallel_for_each.h" #include "../../inc/tbb/parallel_for.h" #include #include #include #include template class BinaryTree { private: template struct BinaryTreeNode { private: BinaryTreeNode* mChildren[2]; BinaryTree* mTree; T mMaterial; inline unsigned8 GetChildIndex(size_t coordinate, unsigned8 levelsLeft) const { return BitHelper::GetLS(coordinate, levelsLeft) ? 1 : 0; } inline BinaryTreeNode* GetNodeAt(size_t coordinate, unsigned8 levelsLeft) const { return mChildren[GetChildIndex(coordinate, levelsLeft)]; } public: BinaryTreeNode(BinaryTree* tree) : mTree(tree) { mChildren[0] = mChildren[1] = NULL; } ~BinaryTreeNode() {} BinaryTreeNode* GetChild(unsigned8 index) const { assert(index < 2); return mChildren[index]; } void SetChild(BinaryTreeNode* child, unsigned8 index) { mChildren[index] = child; } size_t GetChildCount() const { return (mChildren[0] == NULL ? 0 : 1) + (mChildren[1] == NULL ? 0 : 1); } bool IsLeaf() const { return mChildren[0] == NULL && mChildren[1] == NULL; } BinaryTreeNode* AddNode(size_t coordinate, unsigned8 levelsLeft) { auto node = GetNodeAt(coordinate, levelsLeft); if (node == NULL) { node = mTree->Create(); SetChild(node, GetChildIndex(coordinate, levelsLeft)); } if (levelsLeft == 0) return node; else return node->AddNode(coordinate, levelsLeft - 1); } U GetValueAt(size_t coordinate, unsigned8 levelsLeft) const { auto node = GetNodeAt(coordinate, levelsLeft); if (node == NULL) return GetValue(); if (levelsLeft == 0) return node->GetValue(); return node->GetValueAt(coordinate, levelsLeft - 1); } U GetValue() const { return mMaterial; } void SetValue(const T& material) { mMaterial = material; } void Traverse(const std::function*)>& f) const { f.operator()(this); if (mChildren[0] != NULL) mChildren[0]->Traverse(f); if (mChildren[1] != NULL) mChildren[1]->Traverse(f); } static bool Compare(BinaryTreeNode* a, BinaryTreeNode* b) { if (a->GetChild(0) != b->GetChild(0)) return (size_t)a->GetChild(0) < (size_t)b->GetChild(0); if (a->GetChild(1) != b->GetChild(1)) return (size_t)a->GetChild(1) < (size_t)b->GetChild(1); return a->GetValue() < b->GetValue(); } static bool Equals(BinaryTreeNode* a, BinaryTreeNode* b) { if (a == b) return true; if (a == NULL || b == NULL) return false; return a->GetChild(0) == b->GetChild(0) && a->GetChild(1) == b->GetChild(1) && a->GetValue() == b->GetValue(); } }; unsigned8 mDepth; BinaryTreeNode* mRoot; std::vector*> mNodePool; std::vector*> GetNodes() { std::vector*> res; std::function*)> nodeFinder = [&](const BinaryTreeNode*) { res.push_back(this); }; mRoot->Traverse(nodeFinder); CollectionHelper::Unique(res); return res; } void ShiftDown(unsigned8 levels) { // Shifts the root down the given number of levels if (levels == 0) return; for (unsigned8 i = 0; i < levels; i++) { BinaryTreeNode* newRoot = Create(); newRoot->SetChild(mRoot, 0); BinaryTreeNode* oldRoot = mRoot; mRoot = newRoot; mNodePool[0] = newRoot; mNodePool[mNodePool.size() - 1] = oldRoot; } } static void CalculateNodeLevelsRecursive(BinaryTreeNode* node, unsigned8 level, std::vector& nodeLevels, const std::unordered_map*, size_t>& nodeIndices) { assert(nodeIndices.find(node) != nodeIndices.end()); auto nodeIndex = nodeIndices.find(node); nodeLevels[nodeIndex->second] = level; for (unsigned8 child = 0; child < 2; child++) { auto childNode = node->GetChild(child); if (childNode != NULL) CalculateNodeLevelsRecursive(childNode, level + 1, nodeLevels, nodeIndices); } } // Calculates the levels of all nodes. Levels are stored in the same order as the node pool. std::vector CalculateNodeLevels() const { auto nodeIndices = CollectionHelper::GetIndexMap(mNodePool); std::vector nodeLevels(mNodePool.size()); CalculateNodeLevelsRecursive(mRoot, 0, nodeLevels, nodeIndices); return nodeLevels; } inline static size_t GetNodePointer(const size_t& index, const size_t& pointerSize, const size_t& valueSize, const size_t& firstLeafIndex, const bool& onlyLeafsContainValues) { if (!onlyLeafsContainValues) return 2 + index * (pointerSize * 2 + valueSize); else { size_t pointer = 2 + std::min(index, firstLeafIndex) * (pointerSize * 2); if (index > firstLeafIndex) pointer += (index - firstLeafIndex - 1) * (pointerSize * 2 + valueSize); return pointer; } } public: BinaryTree() { mDepth = 0; mRoot = Create(); } ~BinaryTree() { tbb::parallel_for_each(mNodePool.begin(), mNodePool.end(), [](BinaryTreeNode* node) { delete node; }); mNodePool.clear(); } void AddLeafNode(size_t coordinate) { AddNode(coordinate, mDepth); } void AddNode(size_t coordinate, unsigned8 level) { assert(level <= mDepth); mRoot->AddNode(coordinate, level); } T GetValueAtLeaf(size_t coordinate) const { return GetValueAtNode(coordinate, mDepth); } T GetValueAtNode(size_t coordinate, unsigned8 level) const { assert(level <= mDepth); return mRoot->GetValueAt(coordinate, level); } void SetValueAtLeaf(size_t coordinate, T value) { SetValueAtNode(coordinate, mDepth, value); } void SetValueAtNode(size_t coordinate, unsigned8 level, T value) { assert(level <= mDepth); auto node = mRoot->AddNode(coordinate, level); node->SetValue(value); } BinaryTreeNode* Create() { BinaryTreeNode* newNode = new BinaryTreeNode(this); mNodePool.push_back(newNode); return newNode; } void SetDepth(unsigned8 wantedDepth, bool shiftExisting) { if (mDepth != wantedDepth) { if (wantedDepth < mDepth) ShaveUntil(wantedDepth); else ShiftDown(wantedDepth - mDepth); } mDepth = wantedDepth; } // Deletes all nodes of which the level is greater than the given level void ShaveUntil(unsigned8 level) { // Calculate the level for all nodes std::vector nodeLevels = CalculateNodeLevels(); for (size_t i = 0; i < mNodePool.size(); i++) if (nodeLevels[i] > level) { delete mNodePool[i]; mNodePool[i] = NULL; } else if (nodeLevels[i] == level) { // Set all pointers to NULL: for (unsigned8 child = 0; child < 2; child++) mNodePool[i]->SetChild(NULL, child); } mNodePool.erase(std::remove_if(mNodePool.begin(), mNodePool.end(), [](const BinaryTreeNode* node) { return node == NULL; })); } void ReplaceValues(const std::unordered_map& replacers) { tbb::parallel_for(size_t(0), mNodePool.size(), [&](const size_t& i) { T curValue = mNodePool[i]->GetValue(); auto replacer = replacers.find(curValue); if (replacer != replacers.end()) mNodePool[i]->SetValue(replacer->second); }); } void Serialize(std::ostream& file) const { Serializer::Serialize(mDepth, file); // Write the number of nodes in the binary tree: Serializer::Serialize((unsigned64)mNodePool.size(), file); // Write the node materials for (size_t i = 0; i < mNodePool.size(); i++) Serializer::Serialize(mNodePool[i]->GetValue(), file); // Write (consequtively) the child pointers std::unordered_map*, size_t> nodeIndices = CollectionHelper::GetIndexMap(mNodePool); for (size_t i = 0; i < mNodePool.size(); i++) for (unsigned8 child = 0; child < 2; child++) { // Use 0 as NULL pointer, add 1 to all actual pointers BinaryTreeNode* childNode = mNodePool[i]->GetChild(child); unsigned64 pointer = 0; if (childNode != NULL) { assert(nodeIndices.find(childNode) != nodeIndices.end()); pointer = nodeIndices[childNode] + 1; } Serializer::Serialize(pointer, file); } } void Deserialize(std::istream& file) { if (mNodePool.size() > 1) { for (auto node = mNodePool.begin(); node != mNodePool.end(); node++) delete *node; mNodePool.resize(1); } Serializer::Deserialize(mDepth, file); unsigned64 nodeCount; Serializer::Deserialize(nodeCount, file); // The root is always already created, so create nodeCount - 1 nodes: for (size_t i = 0; i < nodeCount - 1; i++) Create(); // Deserialize the materials for each node T dummy; for (size_t i = 0; i < nodeCount; i++) { Serializer::Deserialize(dummy, file); mNodePool[i]->SetValue(dummy); } // Create all pointers for (size_t i = 0; i < mNodePool.size(); i++) for (unsigned8 child = 0; child < 2; child++) { unsigned64 pointer; Serializer::Deserialize(pointer, file); if (pointer != 0) mNodePool[i]->SetChild(mNodePool[pointer - 1], child); } } // Converts the current binary tree to a DAG, meaning that all duplicate nodes are removed. void ToDAG() { // Fill the current layer with all leaf nodes std::vector*> dagNodePool(1, mRoot); std::vector*> currentLayer; std::unordered_set*> nodesLeft; std::unordered_map*, std::vector*>> parentsMap; for (auto node : mNodePool) { if (node->IsLeaf()) currentLayer.push_back(node); else { nodesLeft.insert(node); for (unsigned8 child = 0; child < 2; child++) { auto childNode = node->GetChild(child); if (childNode != NULL) { auto parent = parentsMap.find(childNode); if (parent == parentsMap.end()) parentsMap.insert(std::make_pair(childNode, std::vector*>(1, node))); else parent->second.push_back(node); } } } } while (!currentLayer.empty() && currentLayer[0] != mRoot) { // Find unique nodes and replace them tbb::parallel_sort(currentLayer.begin(), currentLayer.end(), [](BinaryTreeNode* a, BinaryTreeNode* b) { return BinaryTreeNode::Compare(a, b); }); BinaryTreeNode* cur = NULL; std::vector*, BinaryTreeNode*>> replacements; std::vector*> nextLayer; size_t uniqueNodes = 0; for (auto node : currentLayer) { if (BinaryTreeNode::Equals(node, cur)) // Make sure that all nodes are replaced by their equals replacements.push_back(std::make_pair(node, cur)); else { uniqueNodes++; cur = node; dagNodePool.push_back(cur); } auto parents = parentsMap.find(node); if (parents != parentsMap.end()) nextLayer.insert(nextLayer.end(), parents->second.begin(), parents->second.end()); } CollectionHelper::Unique(nextLayer, [](BinaryTreeNode* a, BinaryTreeNode* b) { return BinaryTreeNode::Compare(a, b); }); if (uniqueNodes != currentLayer.size()) { for (auto replacement : replacements) { auto toReplace = replacement.first; auto replacer = replacement.second; auto parentsIt = parentsMap.find(toReplace); if (parentsIt == parentsMap.end()) continue; std::vector*> parents = parentsIt->second; for (auto parent : parents) { for (unsigned8 child = 0; child < 2; child++) { if (parent->GetChild(child) == toReplace) parent->SetChild(replacer, child); } } delete toReplace; } } currentLayer = nextLayer; } mNodePool = dagNodePool; } std::vector Serialize(bool onlyLeafsContainValues) const { // The first byte contains the number of bytes per pointer unsigned8 pointerSize = GetSerializedPointerByteSize(onlyLeafsContainValues); unsigned8 valueSize = GetMaterialByteSize(); std::vector res(GetSerializedByteCount(onlyLeafsContainValues), 0); res[0] = mDepth; res[1] = pointerSize; std::vector*> nodePoolCopy(mNodePool.size()); std::copy(mNodePool.begin(), mNodePool.end(), nodePoolCopy.begin()); size_t firstLeafIndex = ~size_t(0); if (onlyLeafsContainValues) { tbb::parallel_sort(nodePoolCopy.begin() + 1, nodePoolCopy.end(), [](BinaryTreeNode* a, BinaryTreeNode* b) { return !(a->IsLeaf()) && (b->IsLeaf()); }); for (size_t i = 0; i < nodePoolCopy.size(); i++) { BinaryTreeNode* node = nodePoolCopy[i]; if (node->IsLeaf() && firstLeafIndex > i) firstLeafIndex = i; } } std::unordered_map*, size_t> nodeIndices = CollectionHelper::GetIndexMap(nodePoolCopy); for (size_t i = 0; i < nodePoolCopy.size(); i++) { BinaryTreeNode* node = nodePoolCopy[i]; size_t nodePointer = GetNodePointer(i, pointerSize, valueSize, firstLeafIndex, onlyLeafsContainValues); // Write the node pointers for (unsigned8 child = 0; child < 2; child++) { BinaryTreeNode* childNode = node->GetChild(child); if (childNode != NULL) { assert(nodeIndices.find(childNode) != nodeIndices.end()); size_t childIndex = nodeIndices[childNode]; size_t pointer = GetNodePointer(childIndex, pointerSize, valueSize, firstLeafIndex, onlyLeafsContainValues); BitHelper::SplitInBytesAndMove(pointer, res, nodePointer + pointerSize * child, pointerSize); } } // Then write the node content/value if (!onlyLeafsContainValues || node->IsLeaf()) { std::vector serializedValue = node->GetValue().Serialize(); std::move(serializedValue.begin(), serializedValue.end(), res.begin() + nodePointer + pointerSize * 2); } } return res; } size_t GetNodeCount() const { return mNodePool.size(); } size_t GetLeafNodeCount() const { size_t leafNodeCount = 0; for (auto node : mNodePool) if (node->IsLeaf()) leafNodeCount++; return leafNodeCount; } static size_t GetSerializedByteCount(const size_t& nodeCount, const size_t& leafCount, const size_t& pointerSize, const size_t& valueSize, const bool& onlyLeafsContainValues) { return (onlyLeafsContainValues ? leafCount : nodeCount) * valueSize + (2 * pointerSize) * nodeCount; } size_t GetSerializedByteCount(bool onlyLeafsContainValues) const { return GetSerializedByteCount(GetNodeCount(), GetLeafNodeCount(), GetSerializedPointerByteSize(onlyLeafsContainValues), GetMaterialByteSize(), onlyLeafsContainValues); } unsigned8 GetSerializedNodeByteSize(bool onlyLeafsContainValues) const { return 2 * GetSerializedPointerByteSize(onlyLeafsContainValues) + GetMaterialByteSize(); } unsigned8 GetMaterialByteSize() const { return sizeof(T); } unsigned8 GetSerializedPointerByteSize(bool onlyLeafsContainValues) const { // Count the number of leaf nodes: size_t leafNodeCount = 0; if (onlyLeafsContainValues) leafNodeCount = GetLeafNodeCount(); bool fits = false; unsigned8 pointerSize = 0; while (!fits) { ++pointerSize; size_t requiredBytes = GetSerializedByteCount(GetNodeCount(), leafNodeCount, pointerSize, GetMaterialByteSize(), onlyLeafsContainValues); fits = BitHelper::Exp2(8 * pointerSize) > requiredBytes; } return pointerSize; } };