Skip to content

Instantly share code, notes, and snippets.

@SuryaPratapK
Created February 13, 2020 15:25
Show Gist options
  • Save SuryaPratapK/f5cc6e770c4e5fdc201bff1076640910 to your computer and use it in GitHub Desktop.
Save SuryaPratapK/f5cc6e770c4e5fdc201bff1076640910 to your computer and use it in GitHub Desktop.
// C program to show segment tree operations like construction,
// query and update
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
// A utility function to get the middle index from corner indexes.
int getMid(int s, int e) { return s + (e -s)/2; }
/* A recursive function to get the xor of values in given range
of the array. The following are parameters for this function.
st --> Pointer to segment tree
si --> Index of current node in the segment tree. Initially
0 is passed as root is always at index 0
ss & se --> Starting and ending indexes of the segment
represented by current node, i.e., st[si]
qs & qe --> Starting and ending indexes of query range */
int getXorUtil(int *st, int ss, int se, int qs, int qe, int si)
{
// If segment of this node is a part of given range, then return
// the xor of the segment
if (qs <= ss && qe >= se)
return st[si];
// If segment of this node is outside the given range
if (se < qs || ss > qe)
return 0;
// If a part of this segment overlaps with the given range
int mid = getMid(ss, se);
return getXorUtil(st, ss, mid, qs, qe, 2*si+1) ^
getXorUtil(st, mid+1, se, qs, qe, 2*si+2);
}
/* A recursive function to update the nodes which have the given
index in their range. The following are parameters
st, si, ss and se are same as getXorUtil()
i --> index of the element to be updated. This index is
in input array.
diff --> Value to be added to all nodes which have i in range */
void updateValueUtil(int *st, int ss, int se, int i, int diff, int si)
{
// Base Case: If the input index lies outside the range of
// this segment
if (i < ss || i > se)
return;
// If the input index is in range of this node, then update
// the value of the node and its children
st[si] = st[si] + diff;
if (se != ss)
{
int mid = getMid(ss, se);
updateValueUtil(st, ss, mid, i, diff, 2*si + 1);
updateValueUtil(st, mid+1, se, i, diff, 2*si + 2);
}
}
// The function to update a value in input array and segment tree.
// It uses updateValueUtil() to update the value in segment tree
void updateValue(int arr[], int *st, int n, int i, int new_val)
{
// Check for erroneous input index
if (i < 0 || i > n-1)
{
printf("Invalid Input");
return;
}
// Get the difference between new value and old value
int diff = new_val - arr[i];
// Update the value in array
arr[i] = new_val;
// Update the values of nodes in segment tree
updateValueUtil(st, 0, n-1, i, diff, 0);
}
// Return xor of elements in range from index qs (quey start)
// to qe (query end). It mainly uses getXorUtil()
int getXor(int *st, int n, int qs, int qe)
{
// Check for erroneous input values
if (qs < 0 || qe > n-1 || qs > qe)
{
printf("Invalid Input");
return -1;
}
return getXorUtil(st, 0, n-1, qs, qe, 0);
}
// A recursive function that constructs Segment Tree for array[ss..se].
// si is index of current node in segment tree st
int constructSTUtil(int arr[], int ss, int se, int *st, int si)
{
// If there is one element in array, store it in current node of
// segment tree and return
if (ss == se)
{
st[si] = arr[ss];
return arr[ss];
}
// If there are more than one elements, then recur for left and
// right subtrees and store the xor of values in this node
int mid = getMid(ss, se);
st[si] = constructSTUtil(arr, ss, mid, st, si*2+1) ^
constructSTUtil(arr, mid+1, se, st, si*2+2);
return st[si];
}
/* Function to construct segment tree from given array. This function
allocates memory for segment tree and calls constructSTUtil() to
fill the allocated memory */
int *constructST(int arr[], int n)
{
// Allocate memory for segment tree
//Height of segment tree
int x = (int)(ceil(log2(n)));
//Maximum size of segment tree
int max_size = 2*(int)pow(2, x) - 1;
// Allocate memory
int *st = (int *)malloc(sizeof(int)*max_size);
// Fill the allocated memory st
constructSTUtil(arr, 0, n-1, st, 0);
// Return the constructed segment tree
return st;
}
// Driver program to test above functions
int main()
{
int arr[] = {1, 3, 5, 7, 9, 11};
int n = sizeof(arr)/sizeof(arr[0]);
// Build segment tree from given array
int *st = constructST(arr, n);
// Print xor of values in array from index 1 to 3
printf("Xor of values in given range = %d\n",
getXor(st, n, 1, 3));
// Update: set arr[1] = 10 and update corresponding
// segment tree nodes
updateValue(arr, st, n, 1, 10);
// Find xor after the value is updated
printf("Updated xor of values in given range = %d\n",
getXor(st, n, 1, 3));
return 0;
}
@ganeshkamath89
Copy link

Updating code with array shared in the tutorial video on Youtube.

#include <iostream>
#include <vector>
using namespace std;
int getMid(int s, int e)
{
	return s + (e - s) / 2;
}

void updateValueUtil(vector<int> &st, int si, int L, int R, int i, int prev_val, int new_val)
{
	if (i < L || i > R)
		return;
	st[si] = (st[si] ^ prev_val) ^ new_val; // xor the previous value to nullify with node before xoring new value
	if (R != L)
	{
		int mid = getMid(L, R);
		updateValueUtil(st, si * 2 + 1, L, mid, i, prev_val, new_val);
		updateValueUtil(st, si * 2 + 2, mid + 1, R, i, prev_val, new_val);
	}
}

void updateValue(vector<int> &arr, vector<int> &st, int n, int i, int new_val)
{
	if (i < 0 || i > n - 1)
	{
		cout << "Invalid input";
		return;
	}
	int temp = arr[i];
	arr[i] = new_val;
	updateValueUtil(st, 0, 0, n - 1, i, temp, new_val);
}

int getXorUtil(vector<int> &st, int L, int R, int sL, int sR, int si)
{
	if (sL <= L && sR >= R)
		return st[si];
	if (R < sL || L > sR)
		return 0;
	int mid = getMid(L, R);
	return getXorUtil(st, L, mid, sL, sR, 2 * si + 1) ^
		getXorUtil(st, mid + 1, R, sL, sR, 2 * si + 2);
}

int getXor(vector<int> &st, int n, int L, int R)
{
	if (L < 0 || R > n - 1 || L > R)
	{
		printf("Invalid Input");
		return -1;
	}
	return getXorUtil(st, 0, n - 1, L, R, 0);
}

int constructSTUtil(vector<int> &st, int si, vector<int> &arr, int L, int R)
{
	if (L == R)
	{
		st[si] = arr[L];
		return arr[L];
	}
	int mid = getMid(L, R);
	st[si] = constructSTUtil(st, si * 2 + 1, arr, L, mid) ^
		constructSTUtil(st, si * 2 + 2, arr, mid + 1, R);
	return st[si];
}

int getMaxSize(int n)
{
	int x = (int)(ceil(log2(n)));
	return 2 * (int)pow(2, x) - 1;
}

vector<int> constructST(vector<int> &arr)
{
	int max_size = getMaxSize(arr.size());
	vector<int> st(max_size, 0);
	constructSTUtil(st, 0, arr, 0, arr.size() - 1);
	return st;
}

void printSegmentTree(vector<int> st, int max_size)
{
	for (int i = 0; i < max_size; i++)
	{
		cout << st[i] << " ";
	}
	cout << endl;
}

int main()
{
	vector<int> arr{ 8,5,3,7,6 };
	vector<int> st = constructST(arr);
	printSegmentTree(st, getMaxSize(arr.size()));
	cout << "Xor of values in given range = " << getXor(st, arr.size(), 2, 4) << endl;

	updateValue(arr, st, arr.size(), 3, 11);
	printSegmentTree(st, getMaxSize(arr.size()));
	cout << "Updated xor of given range = " << getXor(st, arr.size(), 2, 4) << endl;

	updateValue(arr, st, arr.size(), 3, 7);
	printSegmentTree(st, getMaxSize(arr.size()));
	cout << "Updated xor of given range = " << getXor(st, arr.size(), 2, 4) << endl;
	return 0;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment