Created
August 20, 2016 08:05
-
-
Save fseasy/0a14e10253854f8a88f7b09c29d9232a to your computer and use it in GitHub Desktop.
segment tree 递归构建版
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
#include <string> | |
using namespace std; | |
class SumSegmentTree | |
{ | |
public: | |
SumSegmentTree(const vector<int> &rawData) | |
: rawData(rawData) | |
{ | |
buildTree(); | |
} | |
SumSegmentTree(vector<int> &&rawDataRref) | |
: rawData(std::move(rawDataRref)) | |
{ | |
buildTree(); | |
} | |
void updateNodeVal(int rawDataIndex, int newNodeVal) | |
{ | |
if(!checkRawDataIndex(rawDataIndex)) | |
{ | |
string errMsg = string("raw data index "); | |
errMsg += to_string(rawDataIndex); | |
errMsg += " is out of range"; | |
throw out_of_range(errMsg); | |
} | |
int increment = newNodeVal - rawData[rawDataIndex]; | |
rawData[rawDataIndex] = newNodeVal; | |
updateTreeNodeVal(make_pair(0, rawData.size() - 1), 0, rawDataIndex, increment); | |
} | |
int queryRangeSum(int rangeStart, int rangeEnd) // [rangeStart, rangeEnd] -> inclusive | |
{ | |
return queryRangeSumRecursively(make_pair(rangeStart, rangeEnd), | |
make_pair(0, rawData.size() - 1), 0); | |
} | |
private: | |
using Range = pair<int, int>; | |
void buildTree() | |
{ | |
int nrTreeNode = 2 * pow(2, ceil( log2(rawData.size()) ) ) - 1; | |
if(nrTreeNode < 0){ return; } // -> rawData.size() == 0 | |
treeData.resize(nrTreeNode); | |
buildTreeNodeRecursively(Range(0, rawData.size() - 1), 0); | |
} | |
int buildTreeNodeRecursively(const Range &nodeRange, int treeNodeIndex) | |
{ | |
if(nodeRange.first == nodeRange.second) | |
{ | |
// leaf node | |
treeData.at(treeNodeIndex) = rawData.at(nodeRange.first); | |
} | |
else | |
{ | |
int treeLeftChildIndex = getTreeLeftChildIndex(treeNodeIndex), | |
treeRightChildIndex = getTreeRightChildIndex(treeNodeIndex); | |
int rangeMidVal = getRangeMidVal(nodeRange); | |
Range leftRange(nodeRange.first, rangeMidVal), | |
rightRange(rangeMidVal + 1, nodeRange.second); | |
treeData.at(treeNodeIndex) = buildTreeNodeRecursively(leftRange, treeLeftChildIndex) + // left | |
buildTreeNodeRecursively(rightRange, treeRightChildIndex);// right | |
} | |
return treeData[treeNodeIndex]; | |
} | |
int getTreeLeftChildIndex(int parentIndex) | |
{ | |
return parentIndex * 2 + 1; | |
} | |
int getTreeRightChildIndex(int parentIndex) | |
{ | |
return parentIndex * 2 + 2; | |
} | |
int getRangeMidVal(const Range& range) | |
{ | |
return range.first + (range.second - range.first) / 2; | |
} | |
bool checkRawDataIndex(int index) | |
{ | |
return index >= 0 && static_cast<size_t>(index) < rawData.size(); | |
} | |
void updateTreeNodeVal(const Range &nodeRange, int treeNodeIndex, int rawDataIndex, int increment) | |
{ | |
if(rawDataIndex < nodeRange.first || rawDataIndex > nodeRange.second){ return; } // irrelevant | |
treeData[treeNodeIndex] += increment; | |
if( nodeRange.first < nodeRange.second ) | |
{ | |
int rangeMidVal = getRangeMidVal(nodeRange); | |
Range leftRange(nodeRange.first, rangeMidVal), | |
rightRange(rangeMidVal + 1, nodeRange.second); | |
int treeLeftChildIndex = getTreeLeftChildIndex(treeNodeIndex), | |
treeRightChildIndex = getTreeRightChildIndex(treeNodeIndex); | |
updateTreeNodeVal(leftRange, treeLeftChildIndex, rawDataIndex, increment); | |
updateTreeNodeVal(rightRange, treeRightChildIndex, rawDataIndex, increment); | |
} | |
} | |
int queryRangeSumRecursively(const Range &queryRange, const Range &nodeRange, int treeNodeIndex) | |
{ | |
if(nodeRange.first >= queryRange.first && nodeRange.second <= queryRange.second) | |
{ | |
// node range is all in query range | |
return treeData[treeNodeIndex]; | |
} | |
if(nodeRange.first > queryRange.second || nodeRange.second < queryRange.first) | |
{ | |
// node range is totally out of query range | |
return 0; | |
} | |
// node range partially overlap query range , split node range | |
int rangeMidVal = getRangeMidVal(nodeRange); | |
Range leftRange(nodeRange.first, rangeMidVal), | |
rightRange(rangeMidVal + 1, nodeRange.second); | |
int treeLeftChildIndex = getTreeLeftChildIndex(treeNodeIndex), | |
treeRightChildIndex = getTreeRightChildIndex(treeNodeIndex); | |
return queryRangeSumRecursively(queryRange, leftRange, treeLeftChildIndex) + | |
queryRangeSumRecursively(queryRange, rightRange, treeRightChildIndex); | |
} | |
private: | |
vector<int> rawData; | |
vector<int> treeData; | |
}; | |
int main(int argc, char *argv[]) | |
{ | |
SumSegmentTree segTree(vector<int>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); | |
cout << segTree.queryRangeSum(5, 6) << endl; | |
segTree.updateNodeVal(9, 27); | |
cout << segTree.queryRangeSum(0, 3) << endl; | |
system("pause"); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment