Skip to content

Instantly share code, notes, and snippets.

@babhishek21
Last active May 23, 2023 16:21
Show Gist options
  • Save babhishek21/7198fcaabc9d61d171d8f418139c2372 to your computer and use it in GitHub Desktop.
Save babhishek21/7198fcaabc9d61d171d8f418139c2372 to your computer and use it in GitHub Desktop.
Range Sum Query for Mutable Arrays using Segment Trees
#include <bits/stdc++.h> // using GCC/G++11
using namespace std;
/**
* Range Sum Query for Mutable Arrays using Segment Trees (LeetCode)
* https://leetcode.com/problems/range-sum-query-mutable/
*
* Build a tree whose nodes represent the entire range. Its two children represent the two halves
* of this range. This continues down the tree with height log(n) until we reach the n individual
* leaves of the tree (each representing a single element).
*
* O(logn) per query
* O(logn) per update
*/
class SegTree {
vector<int> tree; // 0-based indexing, tree[0] is kept null
int n;
public:
SegTree(vector<int> &nums) {
n = nums.size();
tree.resize(4*n, 0);
buildSegTree(nums, 1, 0, n-1);
}
int queryRangeSum(int i, int j) {
return querySegTree(1, 0, n-1, i, j);
}
void updateVal(int i, int val) {
updateValSegTree(1, 0, n-1, i, val);
}
void buildSegTree(vector<int> &nums,int treeIndex, int lo, int hi) {
if(lo == hi) {
tree[treeIndex] = nums[lo];
return;
}
int mid = lo + (hi-lo)/2;
buildSegTree(nums, 2*treeIndex, lo, mid);
buildSegTree(nums, 2*treeIndex+1, mid+1, hi);
// merge build results
tree[treeIndex] = tree[2*treeIndex] + tree[2*treeIndex+1];
}
int querySegTree(int treeIndex, int lo, int hi, int i, int j) { // query for arr[i..j]
if(i == lo && j == hi)
return tree[treeIndex];
int mid = lo + (hi-lo)/2;
if(i > mid)
return querySegTree(2*treeIndex+1, mid+1, hi, i, j);
else if(j <= mid)
return querySegTree(2*treeIndex, lo, mid, i, j);
int leftQuery = querySegTree(2*treeIndex, lo, mid, i, mid);
int rightQuery = querySegTree(2*treeIndex+1, mid+1, hi, mid+1, j);
// merge queries
return leftQuery + rightQuery;
}
void updateValSegTree(int treeIndex, int lo, int hi, int arrIndex, int val) {
if(lo == hi) {
tree[treeIndex] = val;
return;
}
int mid = lo + (hi-lo)/2;
if(arrIndex > mid)
updateValSegTree(2*treeIndex+1, mid+1, hi, arrIndex, val);
else if(arrIndex <= mid)
updateValSegTree(2*treeIndex, lo, mid, arrIndex, val);
// merge updates
tree[treeIndex] = tree[2*treeIndex] + tree[2*treeIndex+1];
}
};
int main() {
vector<int> nums = {1,2,3,4,5,6,7,8,9};
cout << "Original input array: ";
for(auto &val: nums)
cout << val << " ";
cout << endl;
cout << "\nTesting Segment Trees:" << endl;
SegTree testst(nums);
cout << "GET (2, 7): " << testst.queryRangeSum(2, 7) << endl; // assert 33
cout << "UPDATE (4, -5)" << endl;
testst.updateVal(4, -5);
cout << "GET (1, 8): " << testst.queryRangeSum(1, 8) << endl; // assert 34
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment