Skip to content

Instantly share code, notes, and snippets.

@fseasy
Created August 20, 2016 08:05
Show Gist options
  • Save fseasy/0a14e10253854f8a88f7b09c29d9232a to your computer and use it in GitHub Desktop.
Save fseasy/0a14e10253854f8a88f7b09c29d9232a to your computer and use it in GitHub Desktop.
segment tree 递归构建版
#include <iostream>
#include <vector>
#include <string>
using namespace std;
class SumSegmentTree
{
public:
SumSegmentTree(const vector<int> &rawData)
: rawData(rawData)
{
buildTree();
}
SumSegmentTree(vector<int> &&rawDataRref)
: rawData(std::move(rawDataRref))
{
buildTree();
}
void updateNodeVal(int rawDataIndex, int newNodeVal)
{
if(!checkRawDataIndex(rawDataIndex))
{
string errMsg = string("raw data index ");
errMsg += to_string(rawDataIndex);
errMsg += " is out of range";
throw out_of_range(errMsg);
}
int increment = newNodeVal - rawData[rawDataIndex];
rawData[rawDataIndex] = newNodeVal;
updateTreeNodeVal(make_pair(0, rawData.size() - 1), 0, rawDataIndex, increment);
}
int queryRangeSum(int rangeStart, int rangeEnd) // [rangeStart, rangeEnd] -> inclusive
{
return queryRangeSumRecursively(make_pair(rangeStart, rangeEnd),
make_pair(0, rawData.size() - 1), 0);
}
private:
using Range = pair<int, int>;
void buildTree()
{
int nrTreeNode = 2 * pow(2, ceil( log2(rawData.size()) ) ) - 1;
if(nrTreeNode < 0){ return; } // -> rawData.size() == 0
treeData.resize(nrTreeNode);
buildTreeNodeRecursively(Range(0, rawData.size() - 1), 0);
}
int buildTreeNodeRecursively(const Range &nodeRange, int treeNodeIndex)
{
if(nodeRange.first == nodeRange.second)
{
// leaf node
treeData.at(treeNodeIndex) = rawData.at(nodeRange.first);
}
else
{
int treeLeftChildIndex = getTreeLeftChildIndex(treeNodeIndex),
treeRightChildIndex = getTreeRightChildIndex(treeNodeIndex);
int rangeMidVal = getRangeMidVal(nodeRange);
Range leftRange(nodeRange.first, rangeMidVal),
rightRange(rangeMidVal + 1, nodeRange.second);
treeData.at(treeNodeIndex) = buildTreeNodeRecursively(leftRange, treeLeftChildIndex) + // left
buildTreeNodeRecursively(rightRange, treeRightChildIndex);// right
}
return treeData[treeNodeIndex];
}
int getTreeLeftChildIndex(int parentIndex)
{
return parentIndex * 2 + 1;
}
int getTreeRightChildIndex(int parentIndex)
{
return parentIndex * 2 + 2;
}
int getRangeMidVal(const Range& range)
{
return range.first + (range.second - range.first) / 2;
}
bool checkRawDataIndex(int index)
{
return index >= 0 && static_cast<size_t>(index) < rawData.size();
}
void updateTreeNodeVal(const Range &nodeRange, int treeNodeIndex, int rawDataIndex, int increment)
{
if(rawDataIndex < nodeRange.first || rawDataIndex > nodeRange.second){ return; } // irrelevant
treeData[treeNodeIndex] += increment;
if( nodeRange.first < nodeRange.second )
{
int rangeMidVal = getRangeMidVal(nodeRange);
Range leftRange(nodeRange.first, rangeMidVal),
rightRange(rangeMidVal + 1, nodeRange.second);
int treeLeftChildIndex = getTreeLeftChildIndex(treeNodeIndex),
treeRightChildIndex = getTreeRightChildIndex(treeNodeIndex);
updateTreeNodeVal(leftRange, treeLeftChildIndex, rawDataIndex, increment);
updateTreeNodeVal(rightRange, treeRightChildIndex, rawDataIndex, increment);
}
}
int queryRangeSumRecursively(const Range &queryRange, const Range &nodeRange, int treeNodeIndex)
{
if(nodeRange.first >= queryRange.first && nodeRange.second <= queryRange.second)
{
// node range is all in query range
return treeData[treeNodeIndex];
}
if(nodeRange.first > queryRange.second || nodeRange.second < queryRange.first)
{
// node range is totally out of query range
return 0;
}
// node range partially overlap query range , split node range
int rangeMidVal = getRangeMidVal(nodeRange);
Range leftRange(nodeRange.first, rangeMidVal),
rightRange(rangeMidVal + 1, nodeRange.second);
int treeLeftChildIndex = getTreeLeftChildIndex(treeNodeIndex),
treeRightChildIndex = getTreeRightChildIndex(treeNodeIndex);
return queryRangeSumRecursively(queryRange, leftRange, treeLeftChildIndex) +
queryRangeSumRecursively(queryRange, rightRange, treeRightChildIndex);
}
private:
vector<int> rawData;
vector<int> treeData;
};
int main(int argc, char *argv[])
{
SumSegmentTree segTree(vector<int>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
cout << segTree.queryRangeSum(5, 6) << endl;
segTree.updateNodeVal(9, 27);
cout << segTree.queryRangeSum(0, 3) << endl;
system("pause");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment