Skip to content

Instantly share code, notes, and snippets.

@johnchen902
Last active January 9, 2019 11:05
Show Gist options
  • Save johnchen902/fc85ee4dd3b77b7efeca7c7b48742939 to your computer and use it in GitHub Desktop.
Save johnchen902/fc85ee4dd3b77b7efeca7c7b48742939 to your computer and use it in GitHub Desktop.
NTU PK 2018 -- A. f(Graph)
Display the source blob
Display the rendered blob
Raw
Complexity Time (s)
O(n^4/64) 1.736
O(n^3) 1.164
O(n^3 log n) 2.680
// The following is copied from my in-contest code.
#include <cstdio>
#include <utility>
#include <algorithm>
#include <bitset>
using namespace std;
namespace {
constexpr int mod = 998244353;
int w[500][500];
pair<int, int> p[500 * 499 / 2];
bitset<500> used[500];
}
int main() {
int n;
scanf("%d", &n);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
scanf("%d", w[i] + j);
long ans = 0;
long sumw = 0;
for (int i = 0; i < n; i++)
for (int j = i + 1; j < n; j++)
sumw += w[i][j];
ans = sumw % mod * ((long) n * (n - 1) / 2 - 1) % mod;
// fprintf(stderr, "ans=%ld\n", ans);
for (int k = 0; k < n; k++)
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
w[i][j] = min(w[i][j], w[i][k] + w[k][j]);
for (int i = 0, k = 0; i < n; i++)
for (int j = i + 1; j < n; j++)
p[k++] = make_pair(i, j);
sort(p, p + n * (n - 1) / 2,
[](pair<int, int> l, pair<int, int> r) {
return w[l.first][l.second] < w[r.first][r.second];
});
for (int i = 0; i < n; i++)
used[i][i] = true;
for (int i = 0; i < n * (n - 1) / 2; i++) {
int u = p[i].first, v = p[i].second;
used[u][v] = used[v][u] = true;
int cnt = 0;
for (int a = 0; a < n; a++)
if (!used[u][a])
cnt += n - (used[v] | used[a]).count();
ans += 2L * w[u][v] * cnt % mod;
}
printf("%d\n", (int) (ans % mod));
}
#include <cstdio>
#include <utility>
#include <algorithm>
using namespace std;
namespace {
constexpr int mod = 998244353;
int w[500][500];
pair<int, int> p[500 * 499 / 2];
bool used[500][500];
int cnt[500][500]; // (used[a] | used[b]).count()
}
int main() {
int n;
scanf("%d", &n);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
scanf("%d", w[i] + j);
long sumw = 0;
for (int i = 0; i < n; i++)
for (int j = i + 1; j < n; j++)
sumw += w[i][j];
long ans = sumw % mod * ((long) n * (n - 1) / 2 - 1) % mod;
// fprintf(stderr, "ans=%ld\n", ans);
for (int k = 0; k < n; k++)
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
w[i][j] = min(w[i][j], w[i][k] + w[k][j]);
for (int i = 0, k = 0; i < n; i++)
for (int j = i + 1; j < n; j++)
p[k++] = make_pair(i, j);
sort(p, p + n * (n - 1) / 2,
[](pair<int, int> l, pair<int, int> r) {
return w[l.first][l.second] < w[r.first][r.second];
});
auto set_used = [n](int u, int v) {
if (used[u][v])
return;
used[u][v] = true;
cnt[u][u]++;
for (int i = 0; i < n; i++)
if (!used[i][v]) {
cnt[i][u]++;
cnt[u][i]++;
}
};
for (int i = 0; i < n; i++)
set_used(i, i);
for (int i = 0; i < n * (n - 1) / 2; i++) {
int u = p[i].first, v = p[i].second;
set_used(u, v);
set_used(v, u);
int cnt1 = 0;
for (int a = 0; a < n; a++)
if (!used[u][a])
cnt1 += n - cnt[v][a];
ans += 2L * w[u][v] * cnt1 % mod;
}
printf("%d\n", (int) (ans % mod));
}
#include <cstdio>
#include <algorithm>
using namespace std;
namespace {
constexpr int mod = 998244353;
int w[500][500];
int ww[500];
}
int main() {
int n;
scanf("%d", &n);
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
scanf("%d", w[i] + j);
long sumw = 0;
for (int i = 0; i < n; i++)
for (int j = i + 1; j < n; j++)
sumw += w[i][j];
long ans = sumw % mod * ((long) n * (n - 1) / 2 - 1) % mod;
// fprintf(stderr, "ans=%ld\n", ans);
for (int k = 0; k < n; k++)
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
w[i][j] = min(w[i][j], w[i][k] + w[k][j]);
for (int i = 0; i < n; i++)
for (int j = i + 1; j < n; j++) {
for (int k = 0; k < n; k++)
ww[k] = min(w[i][k], w[j][k]);
sort(ww, ww + n);
long sum = 0;
for (int k = 0; k < n; k++)
sum += (long) (n - 1 - k) * ww[k];
ans += sum % mod;
}
printf("%d\n", (int) (ans % mod));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment