Skip to content

Instantly share code, notes, and snippets.

@spdskatr
Last active May 14, 2019 07:29

Revisions

  1. spdskatr revised this gist May 14, 2019. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions treap.cpp
    Original file line number Diff line number Diff line change
    @@ -4,6 +4,7 @@

    using namespace std;

    int root, al = 1;
    struct tree { int val, data, l, r, cnt, rmq; } tr[1000005];

    void upd(int n) {
  2. spdskatr created this gist May 5, 2019.
    120 changes: 120 additions & 0 deletions treap.cpp
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,120 @@
    #include <cstdio>
    #include <cstdlib>
    #include <algorithm>

    using namespace std;

    struct tree { int val, data, l, r, cnt, rmq; } tr[1000005];

    void upd(int n) {
    if (n == 0) return;
    tr[n].cnt = tr[tr[n].l].cnt + tr[tr[n].r].cnt + 1;
    tr[n].rmq = max(tr[n].data, max(tr[tr[n].l].rmq, tr[tr[n].r].rmq));
    }

    void merge(int &n, int a, int b) {
    if (!a || !b) n = a ? a : b;
    else if (tr[a].val > tr[b].val) {
    merge(tr[a].r, tr[a].r, b); upd(a);
    n = a;
    } else {
    merge(tr[b].l, a, tr[b].l); upd(b);
    n = b;
    }
    }

    // X is number of nodes to split into the left subtree
    void split(int n, int x, int &l, int &r) {
    if (n == 0) l = r = 0;
    else if (x == tr[tr[n].l].cnt) {
    l = tr[n].l;
    tr[n].l = 0;
    r = n;
    } else if (x < tr[tr[n].l].cnt) {
    split(tr[n].l, x, l, tr[n].l);
    r = n;
    } else {
    split(tr[n].r, x - tr[tr[n].l].cnt - 1, tr[n].r, r);
    l = n;
    }
    upd(n);
    }

    void insert(int &n, int i, int pos) {
    if (n == 0) n = i;
    else if (tr[i].val > tr[n].val) {
    split(n, pos, tr[i].l, tr[i].r);
    n = i;
    } else {
    int c = tr[tr[n].l].cnt;
    if (pos <= c) {
    insert(tr[n].l, i, pos);
    } else {
    insert(tr[n].r, i, pos - c - 1);
    }
    }
    upd(n);
    }

    int erase(int &n, int idx) {
    int c = tr[tr[n].l].cnt, res = -1;
    if (n == 0) return -1;
    if (idx == c) res = n, merge(n, tr[n].l, tr[n].r);
    else if (idx < c) {
    res = erase(tr[n].l, idx);
    } else {
    res = erase(tr[n].r, idx-c-1);
    }
    upd(n);
    return res;
    }

    int range_max(int n, int l, int r) {
    if (n == 0 || r <= 0 || l >= tr[n].cnt) return 0;
    if (l <= 0 && r >= tr[n].cnt) return tr[n].rmq;
    int a = 0, c = tr[tr[n].l].cnt;
    if (l <= c && r > c) a = tr[n].data;
    return max(a, max(range_max(tr[n].l, l, r), range_max(tr[n].r, l - tr[tr[n].l].cnt - 1, r - tr[tr[n].l].cnt - 1)));
    }

    void disp(int n) {
    if (n == 0) return;
    disp(tr[n].l);
    printf("%d ", tr[n].data);
    disp(tr[n].r);
    }

    void disp_tr(int n) {
    if (n == 0) return;
    printf("(");
    disp_tr(tr[n].l);
    printf(")-%d-(", tr[n].data);
    disp_tr(tr[n].r);
    printf(")");
    }

    int main() {
    srand(6969);
    printf("Doing treap test\n");
    while (true) {
    int x, pos;
    scanf("%d %d", &x, &pos);
    if (pos == -1) return 0;
    if (x == -1) {
    // Remove
    int res = erase(root, pos);
    if (res == -1) printf("-1\n");
    else printf("%d\n", tr[res].data);
    if (res == -1) printf("Element not found\n");
    else printf("Erased element (%d): %d\n", res, tr[res].data);
    } else {
    tr[al].data = x;
    tr[al].val = rand();
    upd(al);
    insert(root, al, pos);
    al++;
    }
    printf("Order statistic at root (%d): %d\n", root, tr[root].cnt);
    disp(root);printf("\n");
    }
    }