Skip to content

Instantly share code, notes, and snippets.

@rohithpeddi
Created December 16, 2016 12:48
Show Gist options
  • Save rohithpeddi/0b7919c2c55dc013ae0fda5627f8876f to your computer and use it in GitHub Desktop.
Save rohithpeddi/0b7919c2c55dc013ae0fda5627f8876f to your computer and use it in GitHub Desktop.
Interval Tree implementation
class Interval1D implements Comparable<Interval1D> {
public final int min,max;
public Interval1D(int min, int max){
if(min<=max){
this.min = min;
this.max = max;
} else {
throw new RuntimeException("min is larger than max!");
}
}
public boolean contains(int x){
return (min<=x)&&(max>=x);
}
public boolean intersects(Interval1D that){
if(this.max<that.min) return false;
if(that.max<this.min) return false;
return true;
}
public int compareTo(Interval1D that){
if(this.max<that.max) return -1;
else if(this.max>that.max) return 1;
else if(this.min<that.min) return -1;
else if(this.min>that.min) return 1;
else return 0;
}
}
public class IntervalTree<Value> {
private Node root;
private class Node{
Interval1D interval;
Value val;
int N,max;
Node left,right;
public Node(Interval1D interval,Value val){
this.interval = interval;
this.val=val; this.N=1;
this.max = interval.max;
}
}
/**************************************
* SIZE IMPLEMENTATION
**************************************/
public int size(){
return size(root);
}
public int size(Node x){
if(x==null) return 0;
else return x.N;
}
/**************************************
* HEIGHT IMPLEMENTATION
**************************************/
public int height(){
return height(root);
}
public int height(Node x){
if(x==null) return 0;
return 1+Math.max(height(x.left), height(x.right));
}
/**************************************
* BST SEARCH IMPLEMENTATION
**************************************/
public Value get(Interval1D interval){
return get(root,interval);
}
public Value get(Node x, Interval1D interval){
if(x==null) return null;
int cmp = x.interval.compareTo(interval);
if(cmp<0) return get(x.left,interval);
else if(cmp>0) return get(x.right,interval);
else return x.val;
}
/**************************************
* CONTAINS FUNCTION
**************************************/
public boolean contains(Interval1D interval){
return get(interval)!=null;
}
/**************************************
* HELPER FUNCTIONS
**************************************/
public void update(Node x){
if(x==null) return;
x.N = 1+size(x.left)+size(x.right);
x.max = Maximum(x.interval.max,max(x.left),max(x.right));
}
public int max(Node x){
if(x==null) return Integer.MIN_VALUE;
return x.max;
}
public int Maximum(int a, int b, int c){
return Math.max(a, Math.max(b, c));
}
public Node rotateRight(Node h){
Node x = h.left;
h.left = x.right;
x.right=h;
update(h);
update(x);
return x;
}
public Node rotateLeft(Node h){
Node x = h.right;
h.right = x.left;
x.left=h;
update(h);
update(x);
return x;
}
/**************************************
* PUT IMPLEMENTATION
**************************************/
public void put(Interval1D interval,Value val){
if(contains(interval)) return;
root = randomPut(root, interval, val);
}
public Node randomPut(Node x, Interval1D interval, Value val){
if(x==null) return new Node(interval,val);
if(Math.random()*size(x)<1.0) return rootPut(x,interval,val);
int cmp = interval.compareTo(x.interval);
if (cmp < 0) x.left = randomPut(x.left, interval, val);
else x.right = randomPut(x.right, interval, val);
update(x);
return x;
}
public Node rootPut(Node x, Interval1D interval, Value val){
if (x == null) return new Node(interval, val);
int cmp = interval.compareTo(x.interval);
if (cmp < 0) { x.left = rootPut(x.left, interval, val); x = rotateRight(x); }
else { x.right = rootPut(x.right, interval, val); x = rotateLeft(x); }
return x;
}
/**************************************
* DELETE IMPLEMENTATION
**************************************/
private Node joinLR(Node a, Node b) {
if (a == null) return b;
if (b == null) return a;
if (Math.random() * (size(a) + size(b)) < size(a)) {
a.right = joinLR(a.right, b);
update(a);
return a;
}
else {
b.left = joinLR(a, b.left);
update(b);
return b;
}
}
public Value remove(Interval1D interval) {
Value value = get(interval);
root = remove(root, interval);
return value;
}
private Node remove(Node h, Interval1D interval) {
if (h == null) return null;
int cmp = interval.compareTo(h.interval);
if (cmp < 0) h.left = remove(h.left, interval);
else if (cmp > 0) h.right = remove(h.right, interval);
else h = joinLR(h.left, h.right);
update(h);
return h;
}
/**************************************
* INTERVAL SEARCH IMPLEMENTATION
**************************************/
public Interval1D search(Interval1D interval) {
return search(root, interval);
}
public Interval1D search(Node x, Interval1D interval) {
while (x != null) {
if (interval.intersects(x.interval)) return x.interval;
else if (x.left == null) x = x.right;
else if (x.left.max < interval.min) x = x.right;
else x = x.left;
}
return null;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment