Last active
July 1, 2018 01:09
-
-
Save sorawee/552cc671dcbad9b6255f033c21453ad1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
using namespace std; | |
// returns 2^x such that 2^x >= n | |
inline int smallest_pow2(int n) { | |
int ret = 1; | |
while (ret < n) ret <<= 1; | |
return ret; | |
} | |
// returns number of intersecting cells of [a, b] and [x, y] | |
inline int intersect(int a, int b, int x, int y) { | |
return max(0, min(b, y) - max(a, x) + 1); | |
} | |
struct Node { | |
int l, r, val, lazy; | |
// it's actually possible to compute l and r from an index, | |
// so we could save 8 bytes per node if we really need it | |
}; | |
struct SegmentTree { | |
vector<Node> nodes; | |
SegmentTree(int n) { | |
int offset = smallest_pow2(n); | |
nodes.resize(2*offset); | |
// could save space by resizing to only offset+n+1, but will need more checks in various places | |
// to prevent index out of bound | |
/* | |
E.g., if 5 <= n <= 8, we want to start at the index offset = smallest_pow2(n) = 8 | |
1 | |
2 3 | |
4 5 6 7 | |
here>> 8 9 10 11 12 13 14 15 | |
*/ | |
// set the leaves (the bottommost row) | |
for (int i = 0; i < offset; ++i) { | |
nodes[offset + i].l = i; | |
nodes[offset + i].r = i; | |
nodes[offset + i].val = 0; // if initial values are provided, can set them here (for i < n) | |
nodes[offset + i].lazy = 0; // in the context of summation, lazy = 0 means no lazy value | |
} | |
// set the inner nodes | |
for (int i = offset - 1; i >= 1; --i) { | |
nodes[i].l = nodes[i*2].l; | |
nodes[i].r = nodes[i*2 + 1].r; | |
nodes[i].val = nodes[i*2].val + nodes[i*2 + 1].val; // can just set to 0 if everything is initially 0 | |
nodes[i].lazy = 0; | |
} | |
} | |
void propagate(int v) { | |
if (nodes[v].lazy == 0) return; // this line is simply to shortcut. It's not really needed... | |
nodes[v].val += nodes[v].lazy * (nodes[v].r - nodes[v].l + 1); | |
if (v*2 < int(nodes.size())) nodes[v*2].lazy += nodes[v].lazy; | |
if (v*2 + 1 < int(nodes.size())) nodes[v*2 + 1].lazy += nodes[v].lazy; | |
nodes[v].lazy = 0; | |
} | |
int query(int l, int r) { | |
return query_iter(1, l, r); | |
} | |
int query_iter(int v, int l, int r) { | |
if (not intersect(nodes[v].l, nodes[v].r, l, r)) return 0; | |
propagate(v); | |
if (l <= nodes[v].l and nodes[v].r <= r) return nodes[v].val; | |
return query_iter(v*2, l, r) + query_iter(v*2 + 1, l, r); | |
} | |
void update(int l, int r, int val) { | |
update_iter(1, l, r, val); | |
} | |
void update_iter(int v, int l, int r, int val) { | |
int intersecting_cells = intersect(nodes[v].l, nodes[v].r, l, r); | |
if (not intersecting_cells) return; | |
if (l <= nodes[v].l and nodes[v].r <= r) { | |
nodes[v].lazy += val; // only set lazy. No need to bother with val since propagation will deal with that. | |
return; | |
} | |
// can't update the entire node's lazy, so we need to maintain val explicitly | |
nodes[v].val += val * intersecting_cells; | |
update_iter(v*2, l, r, val); | |
update_iter(v*2 + 1, l, r, val); | |
} | |
}; | |
int main() { | |
SegmentTree st(10); | |
// 0 0 0 0 0 0 0 0 0 0 | |
st.update(2, 5, 3); | |
// 0 0 3 3 3 3 0 0 0 0 | |
cout << st.query(0, 9) << endl; // expect 12 | |
cout << st.query(1, 2) << endl; // expect 3 | |
st.update(1, 3, 7); | |
// 0 7 10 10 3 3 0 0 0 0 | |
st.update(5, 8, 1); | |
// 0 7 10 10 3 4 1 1 1 0 | |
cout << st.query(2, 3) << endl; // expect 20 | |
cout << st.query(1, 6) << endl; // expect 35 | |
cout << st.query(0, 9) << endl; // expect 37 | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment