Skip to content

Instantly share code, notes, and snippets.

@KirillLykov
Created March 12, 2018 20:50
Show Gist options
  • Save KirillLykov/9965b07c0e098f57722509cb7088dbac to your computer and use it in GitHub Desktop.
Save KirillLykov/9965b07c0e098f57722509cb7088dbac to your computer and use it in GitHub Desktop.
solution for educational round of codeforces 36, problem F
/*
Solution for educational round 36, F
http://codeforces.com/problemset/problem/915/F
I used modified DisjointSet (additionally stores size of set)
*/
#include <bits/stdc++.h>
using namespace std;
typedef long long int lint;
class disjoint_sets
{
vector<int> parent;
vector<int> rank;
vector<int> size;
public:
disjoint_sets(size_t sz) : parent(sz, -1), rank(sz, -1), size(sz, 1) {}
void make_set(int v)
{
parent[v] = v;
rank[v] = 0;
}
int find_set(int x)
{
if (parent[x] != x)
return parent[x] = find_set(parent[x]);
return x;
}
long long set_size(int x) {
return size[find_set(x)];
}
void union_sets(int l, int r)
{
int pl = find_set(l);
int pr = find_set(r);
if (pl != pr) {
if (rank[pl] < rank[pr])
swap(pl, pr);
parent[pr] = pl;
if (rank[pr] == rank[pl])
rank[pl] += 1;
size[pl] += size[pr];
}
}
};
lint solve(disjoint_sets ds, vector< tuple<int,int,int> >& edges) {
lint res = 0;
for (auto e : edges) {
int szu = ds.set_size(get<1>(e));
int szv = ds.set_size(get<2>(e));
int w = get<0>(e);
res += szv *1LL* szu *1LL* w;
ds.union_sets(get<1>(e), get<2>(e));
}
return res;
}
int main(int, char**) {
std::ios::sync_with_stdio(false);
int n;
cin >> n;
vector<int> a(n);
for (int i = 0; i < n; ++i)
cin >> a[i];
vector< tuple<int,int,int> > maxEdges, minEdges;
for (int i = 0; i < n-1; ++i) {
int u, v;
cin >> u >> v;
--u; -- v;
maxEdges.push_back( make_tuple(max(a[u], a[v]), u, v) );
minEdges.push_back( make_tuple(-min(a[u], a[v]), u, v) ); // these edges have negative weight
// for sort and also contribute negatively to sum
}
disjoint_sets ds(n);
for (int i = 0; i < n; ++i)
ds.make_set(i);
sort(maxEdges.begin(), maxEdges.end());
lint sumMax = solve(ds, maxEdges);
sort(minEdges.begin(), minEdges.end());
lint sumMin = solve(ds, minEdges);
cout << sumMax + sumMin << endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment