Skip to content

Instantly share code, notes, and snippets.

@henrybear327
Created April 2, 2018 02:02
Show Gist options
  • Save henrybear327/2c7f21ef346ddd43e85688e326589763 to your computer and use it in GitHub Desktop.
Save henrybear327/2c7f21ef346ddd43e85688e326589763 to your computer and use it in GitHub Desktop.
total_distance.cpp
#include <bits/stdc++.h>
using namespace std;
void solve()
{
int n;
scanf("%d", &n);
int par[n + 1], w[n + 1], deg[n + 1], cnt[n + 1];
fill(cnt, cnt + n + 1, 1);
memset(deg, 0, sizeof(deg));
for (int i = 2; i <= n; i++) {
scanf("%d", &par[i]);
deg[par[i]]++;
}
for (int i = 2; i <= n; i++)
scanf("%d", &w[i]);
// start from leaf
queue<int> q;
for (int i = 2; i <= n; i++) {
if (deg[i] == 0)
q.push(i);
}
int total = 0;
while (q.size() > 0) {
int u = q.front();
q.pop();
int tmp = (n - cnt[u]) * w[u] * cnt[u];
total += tmp;
deg[par[u]]--;
cnt[par[u]] += cnt[u];
if (par[u] != 1 && deg[par[u]] == 0)
q.push(par[u]);
}
printf("%d\n", 2 * total);
}
int main()
{
// crucial observation
// the number of times that the edge will be counted is
// # of nodes on the left of the edge * # of nodes on the right of the edge
int ncase;
scanf("%d", &ncase);
while (ncase--)
solve();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment