Skip to content

Instantly share code, notes, and snippets.

@morris821028
Created December 2, 2017 01:07
Show Gist options
  • Save morris821028/63f6d4d807d58a07e52ae03c8f07141c to your computer and use it in GitHub Desktop.
Save morris821028/63f6d4d807d58a07e52ae03c8f07141c to your computer and use it in GitHub Desktop.
/*
给出一个tree, 每一个node都有一个value, 问tree里面相同value的node连接成最长路径的大小(edge的数量).
tree是要自己建的, 输入是两个array, 第一个array表示每个node的value, 第二个array表示所有的边
例如[1,1,1], [1,2,1,3] -> 2, 说明这个tree有3个node, value都是1, 然后node1和node2有边, node1和node3也有边, 最长的路径是2 -> 1 -> 3或者反过来, 总共两条边, 输出2
*/
#include <stdio.h>
#include <stdlib.h>
#include <map>
#include <vector>
#include <algorithm>
using namespace std;
const int MAXN = 32768;
int32_t val[MAXN];
vector<vector<int>> g;
int dfs(int u, int p, int &ret) {
int mx1 = 0, mx2 = 0; // mx1 > mx2
for (auto v : g[u]) {
if (v == p)
continue;
int t = dfs(v, u, ret);
if (val[v] != val[u])
continue;
if (t > mx2) {
swap(t, mx2);
if (mx2 > mx1)
swap(mx2, mx1);
}
}
ret = max(ret, mx1+1);
ret = max(ret, mx2+1);
ret = max(ret, mx1+mx2+1);
return max(mx1, mx2)+1;
}
int main() {
int n;
while (scanf("%d", &n) == 1) {
for (int i = 1; i <= n; i++)
scanf("%d", &val[i]);
g = vector<vector<int>>(n+1, vector<int>());
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
int ret = 0;
dfs(1, -1, ret);
printf("%d\n", ret-1);
}
return 0;
}
/*
3
1 1 1
1 2
1 3
[1,1,1],
[1,2,1,3]
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment