Skip to content

Instantly share code, notes, and snippets.

@amoshyc
Last active August 29, 2015 14:27
Show Gist options
  • Save amoshyc/69a5e92872e46b8a77af to your computer and use it in GitHub Desktop.
Save amoshyc/69a5e92872e46b8a77af to your computer and use it in GitHub Desktop.
Poj 3685: Matrix

Poj 3685: Matrix

分析

跟前一題很類似…poj 3579

我們對 M-th smallest element 二分搜。

正負相關性判定

bool C(m) = whether m < M-th smallest element

C(m) 表:

1 1 1 1 0 0 0

目標是尋找最後一個 1

注意,不是第一個 0,而是最後一個 1 ,最後一個 1 才會有 M-1 個比它小的數

上下界判定

觀察公式,可知

下界發生在 M = 1 時,即求最小值,此時答案為 -100000 * N

上界發生在 M = N * N,即求最大值,此時答案為 N * N + 100000 * N + N * N + N * N

根據 C(m) 表,所以使用 [lb, ub)

lb = -100000 * N;
ub = N * N + 100000 * N + N * N + N * N + 1;

判斷函式實作

改為計算小於 m 的數有幾個,是否 < M。至於如何計算有幾個呢? 觀察 Aij 可發現 Aiji 成正相關:

Ex. N = 5

     3   -99993  -199987  -299979  -399969
100007       12   -99981  -199972  -299961
200013   100019       27   -99963  -199951
300021   200028   100037       48   -99939
400031   300039   200049   100061       75

於是我們可以得到以下策略:

分別計算每個 column 中有幾個 < m 的數(對 row 用二分搜), 累加每個 column,即得到總共有多少個 < m 的數

對 row 用二分搜,即對 i 二分搜,如果讓判斷函式為該數是否 < m,則判斷函式表為:

1
1
1
1
0
0
0

我們的目標是找最後一個 1 在哪,該值即為該 column 有多少個 < m 的數。 而解的範圍為 [0, N](沒有任何 ~ 全部都是),因為使用 [lb, ub),所以

lb = 0;
ub = N + 1;

答案為 lb

※ 實作時,注意計算小於 m 的數總共有幾個時,要用 long long

AC Code

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>

using namespace std;

typedef long long ll;

ll N, M;

ll f(ll i, ll j) {
    return i * i + 100000 * i + j * j - 100000 * j + i * j;
}

bool C(ll m) {
    ll cnt = 0;
    for (int j = 1; j <= N; j++) {
        ll lb_i = 0;
        ll ub_i = N + 1;
        while (ub_i - lb_i > 1) {
            ll mid_i = (lb_i + ub_i) / 2;
            if (f(mid_i, j) < m) lb_i = mid_i;
            else ub_i = mid_i;
        }
        cnt += lb_i;
    }

    return cnt < M;
}

ll solve() {
    ll lb = -100000 * N;
    ll ub = N * N + 100000 * N + N * N + N * N + 1;

    while(ub - lb > 1) {
        ll mid = (ub + lb) / 2;
        if (C(mid)) lb = mid;
        else ub = mid;
    }

    return lb;
}

int main() {
    int T;
    scanf("%d", &T);

    while (T--) {
        scanf("%lld %lld", &N, &M);
        printf("%lld\n", solve());
    }

    return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment