Skip to content

Instantly share code, notes, and snippets.

@rishi93
Last active July 23, 2022 13:04
Show Gist options
  • Select an option

  • Save rishi93/99896589c4f202932b9656fbc2d7d72a to your computer and use it in GitHub Desktop.

Select an option

Save rishi93/99896589c4f202932b9656fbc2d7d72a to your computer and use it in GitHub Desktop.
Segment Tree (Implementation in Java)
class SegmentTree{
int[] tree;
SegmentTree(int n){
tree = new int[n];
}
void build(int[] arr, int node, int start, int end){
if(start == end){
tree[node] = arr[start];
}
else{
int mid = (start + end)/2;
build(arr, 2*node + 1, start, mid);
build(arr, 2*node + 2, mid + 1, end);
tree[node] = tree[2*node + 1] + tree[2*node + 2];
}
}
void update(int[] arr, int node, int index, int val, int start, int end){
if(start == end){
tree[node] += val;
arr[start] += val;
}
else{
int mid = (start + end)/2;
if(start <= index && index <= mid){
update(arr, 2*node + 1, index, val, start, mid);
}
else{
update(arr, 2*node + 2, index, val, mid + 1, end);
}
tree[node] = tree[2*node + 1] + tree[2*node + 2];
}
}
int query(int node, int start, int end, int left, int right){
if(right < start || end < left){
return 0;
}
if(left <= start && end <= right){
return tree[node];
}
int mid = (start + end)/2;
int p1 = query(2*node + 1, start, mid, left, right);
int p2 = query(2*node + 2, mid + 1, end, left, right);
return p1 + p2;
}
}
public class Test{
public static void main(String[] args){
int[] arr = {1, 2, 3, 4, 5, 6, 7, 8};
int n = arr.length;
int height = (int)(Math.log(n)/Math.log(2)) + 1;
int tree_nodes = (int) Math.pow(2, height + 1);
SegmentTree ob = new SegmentTree(tree_nodes);
ob.build(arr, 0, 0, n - 1);
for(int i = 0; i < tree_nodes; i++){
System.out.print(ob.tree[i] + " ");
}
System.out.println();
System.out.println(ob.query(0, 0, n - 1, 0, 5));
}
}
@vorobii-vitalii
Copy link

Great implementation, there is the one optimization that can be done in this code
int height = (int) (Math.log(n) / Math.log(2)) + 1;
int treeNodes = (int) Math.pow(2, height + 1);
treeNodes = n * 4
Here is why:
(Math.log(n) / Math.log(2)) + 1 = log2(n) + 1
2 ^ (log2(n) + 2) = 2 ^ (log2(n)) * 2 ^ 2 = n * 4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment