Skip to content

Instantly share code, notes, and snippets.

@kazuki0824
Created March 20, 2019 05:43
Show Gist options
  • Save kazuki0824/d181a91ca24e9db1d3f468112fd21649 to your computer and use it in GitHub Desktop.
Save kazuki0824/d181a91ca24e9db1d3f468112fd21649 to your computer and use it in GitHub Desktop.
Generic Union-Find implementation
/*
* union_find.cpp
*
* Created on: 2019/03/19
* Author: maleicacid
*/
#include <algorithm>
#include <ext/pb_ds/assoc_container.hpp>
//#include <ext/pb_ds/tag_and_trait.hpp>
using namespace std;
using namespace __gnu_pbds;
template <typename T>
class UnionFind
{
public:
gp_hash_table<T, long> sz;
gp_hash_table<T, long> rank;
UnionFind()
{
}
~UnionFind()
{
sz.clear();
rank.clear();
m.clear();
}
void makeSet(long value)
{
m[value] = value;
sz[value] = 1;
rank[value] = 0;
}
T findSet(T x)
{
auto tmp = m.find(x);
if (tmp != m.end())
{
if (x != (*tmp).second)
{
return (*tmp).second = findSet((*tmp).second);
}
return (*tmp).second;
}
}
bool isEqual(T x, T y)
{
return findSet(x) == findSet(y);
}
long unite(T x, T y)
{
auto p = pair<T,T>(findSet(x), findSet(y));
if (p.first != p.second)
{
auto t = link(p.first, p.second);
return t.first * t.second;
}
else
return 0;
}
private:
gp_hash_table<T, T> m;
pair<const long, const long> link(T x, T y)
{
pair<const long, const long> tmp(sz[x], sz[y]);
if (sz[x] > sz[y])
{
m[y] = x;
sz[x] += sz[y];
sz.erase(y);
}
else
{
m[x] = y;
if (sz[y] == sz[x])
rank[y]++;
sz[y] += sz[x];
sz.erase(x);
}
return tmp;
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment