Last active
June 30, 2019 06:10
-
-
Save m00nlight/cf89e14d93ed69c204f8 to your computer and use it in GitHub Desktop.
python binary index tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
class BinaryIndexTree: | |
def __init__(self, n): | |
self.sz = n | |
self.vals = [0] * (n + 1) | |
def update(self, idx, delta): | |
"add c to the value at index idx" | |
while idx <= self.sz: | |
self.vals[idx] += delta | |
idx += idx & (-idx) | |
def accumulate(self, idx): | |
"get sum from the start to the index of idx" | |
ret = 0 | |
while idx > 0: | |
ret += self.vals[idx] | |
idx -= idx & (-idx) | |
return ret | |
def range_sum(self, start, end): | |
"Calculate a[start], a[start+1], ... a[end]" | |
return self.accumulate(end) - self.accumulate(start - 1) | |
def test(): | |
bit = BinaryIndexTree(10) | |
assert bit.range_sum(1, 5) == 0 | |
bit.update(1, 3) | |
bit.update(4, 6) | |
assert bit.range_sum(1, 4) == 9 | |
bit.update(3, 3) | |
assert bit.range_sum(1, 4) == 12 | |
assert bit.range_sum(3, 4) == 9 | |
bit.update(3, -2) | |
assert bit.range_sum(3, 3) == 1 | |
assert bit.range_sum(1, 4) == 10 | |
return 'test pass' | |
if __name__ == '__main__': | |
print test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@Haozun Thanks for pointing out the problem. Fix the mistake. But I think that's because my index start from 1 not 0.