Skip to content

Instantly share code, notes, and snippets.

@kusano
Created May 25, 2014 14:16
Show Gist options
  • Save kusano/cbe0e89b600fb54e4509 to your computer and use it in GitHub Desktop.
Save kusano/cbe0e89b600fb54e4509 to your computer and use it in GitHub Desktop.
/*
O(n^2)ならば間に合う。グラフを辿る処理はO(n)なので、各ノードでO(n)の処理はでき
る。
tree1のノードiとtree2のノードjについて、ノードiをルートとする部分木のサイズと、
ノードjをルートとする部分木のサイズ、両方の部分木に共通して含まれるノードの個数
を求めておけば、ノードiとノードjからルートに向かうエッジをそれぞれ取り除いた場合
の、S(e1, e2)が計算できる。
両方の部分木に共通して含まれるノード数は、事前に部分木にノードが含まれるかどうか
を調べておけば良い。
*/
#include <string>
#include <vector>
#include <algorithm>
using namespace std;
int n;
// 隣接リスト
vector<vector<int> > T1, T2;
// S1[i]: ノードiをルートとする部分木のノード数
vector<int> S1, S2;
// S12[i][j]: tree1のノードiをルートとする部分木と、tree2のノードjをルートとす
// る部分木の両方に含まれるノード数
vector<vector<int> > S12;
// C[i][j]: T2のノードjはノードiの子孫か
vector<vector<bool> > C;
// S2とCを求める
void search2(int c, int p, vector<int> &path)
{
path.push_back(c);
S2[c] = 1;
for (int i=0; i<(int)T2[c].size(); i++)
if (T2[c][i] != p)
{
search2(T2[c][i], c, path);
S2[c] += S2[T2[c][i]];
}
for (int i=0; i<(int)path.size(); i++)
C[path[i]][c] = true;
path.pop_back();
}
// S1とS12を求める
void search1(int c, int p)
{
S1[c]++;
for (int i=0; i<n; i++)
if (C[i][c])
S12[c][i]++;
for (int i=0; i<(int)T1[c].size(); i++)
if (T1[c][i] != p)
{
search1(T1[c][i], c);
S1[c] += S1[T1[c][i]];
for (int j=0; j<n; j++)
S12[c][j] += S12[T1[c][i]][j];
}
}
class TreesAnalysis{public:
long long treeSimilarity( vector <int> tree1, vector <int> tree2 )
{
n = (int)tree1.size()+1;
T1 = T2 = vector<vector<int> >(n);
S1 = S2 = vector<int>(n);
S12 = vector<vector<int> >(n, vector<int>(n));
C = vector<vector<bool> >(n, vector<bool>(n));
for (int i=0; i<n-1; i++)
{
T1[i].push_back(tree1[i]);
T1[tree1[i]].push_back(i);
T2[i].push_back(tree2[i]);
T2[tree2[i]].push_back(i);
}
vector<int> path;
search2(0, -1, path);
search1(0, -1);
long long ans = 0;
for (int i=1; i<n; i++)
for (int j=1; j<n; j++)
{
int t = 0;
t = max(t, S12[i][j]);
t = max(t, S1[i] - S12[i][j]);
t = max(t, S2[j] - S12[i][j]);
t = max(t, n - S1[i] - S2[j] + S12[i][j]);
ans += (long long)t*t;
}
return ans;
}};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment