Skip to content

Instantly share code, notes, and snippets.

@lxl66566
Last active November 23, 2024 17:51
Show Gist options
  • Select an option

  • Save lxl66566/1ca5ece7ff62ef70fd9cf7f699249e54 to your computer and use it in GitHub Desktop.

Select an option

Save lxl66566/1ca5ece7ff62ef70fd9cf7f699249e54 to your computer and use it in GitHub Desktop.
Kth Largest Elementt in an Array, benchmark

我一直不明白,第 k 大的数的正解为什么时间复杂度是 O(n)。看快排代码,二分递归代码一眼就是 O(nlog n),一看题解,全都说证明在算法导论,自己看书。 我看不懂书,因此想尝试做一个 benchmark,通过数据规模增长和 benchmark 用时来判断其时间复杂度。

  1. 生成足够数量的随机数。后续所有测试都在这样同一个相对随机的数据上进行。
  2. find_kth_largest 函数为 leetcode 题解(基于快速排序的选择方法) 改写为 rust 版本的结果。

注:find_kth_largest 函数本身没有 clone;我在 benchmark 之前就进行了数据 clone。

测试结果为(已排序,否则 cargo test 输出的排序为 10 100 1000 10000 50 500 5000):

test tests::bench_find_kth_largest_10    ... bench:          17.07 ns/iter (+/- 2.93)
test tests::bench_find_kth_largest_50    ... bench:         374.54 ns/iter (+/- 96.44)
test tests::bench_find_kth_largest_100   ... bench:       1,800.08 ns/iter (+/- 251.31)
test tests::bench_find_kth_largest_500   ... bench:      32,918.46 ns/iter (+/- 2,464.11)
test tests::bench_find_kth_largest_1000  ... bench:     125,544.36 ns/iter (+/- 14,766.95)
test tests::bench_find_kth_largest_5000  ... bench:   2,960,778.11 ns/iter (+/- 203,536.92)
test tests::bench_find_kth_largest_10000 ... bench:  11,616,331.11 ns/iter (+/- 382,382.89)

然后我对其进行画图。对数据规模(横轴)和用时(纵轴)同时取对数,再拟合:

from math import log

import matplotlib.pyplot as plt

a = [
    [10, 17.07],
    [100, 1800.08],
    [1000, 125544.36],
    [10000, 11616331.11],
    [50, 374.54],
    [500, 32918.46],
    [5000, 2960778.11],
]

a.sort(key=lambda x: x[0])
x = [log(i[0]) for i in a]
y = [log(i[1]) for i in a]

plt.plot(x, y)
plt.show()

结果为:

Figure_1

可以看出这是一条斜率大约为 1.9 的直线,也就是说 $y \approx x^{1.9}$,因此实测该算法复杂度为 $n^{1.9}$

然后我感觉这数据测出来好像不对啊。因此我想试一试 基于堆排序的选择方法,这个是明确的 O(nlog n) 复杂度。

我只修改了以下算法代码和测试组数,其他代码没有修改:

fn max_heapify(a: &mut [i32], i: usize, heap_size: usize) {
    let l = 2 * i + 1;
    let r = 2 * i + 2;
    let mut largest = i;

    if l < heap_size && a[l] > a[largest] {
        largest = l;
    }
    if r < heap_size && a[r] > a[largest] {
        largest = r;
    }
    if largest != i {
        a.swap(i, largest);
        max_heapify(a, largest, heap_size);
    }
}

fn build_max_heap(a: &mut [i32], heap_size: usize) {
    for i in (0..heap_size / 2).rev() {
        max_heapify(a, i, heap_size);
    }
}

fn find_kth_largest(nums: &mut [i32], k: usize) -> i32 {
    let mut heap_size = nums.len();
    build_max_heap(nums, heap_size);
    for i in (nums.len() - k + 1..nums.len()).rev() {
        nums.swap(0, i);
        heap_size -= 1;
        max_heapify(nums, 0, heap_size);
    }
    nums[0]
}

测试结果:

test tests::bench_find_kth_largest_10      ... bench:          34.38 ns/iter (+/- 2.59)
test tests::bench_find_kth_largest_100     ... bench:         108.42 ns/iter (+/- 19.10)
test tests::bench_find_kth_largest_1000    ... bench:         581.40 ns/iter (+/- 94.61)
test tests::bench_find_kth_largest_10000   ... bench:       4,709.71 ns/iter (+/- 265.75)
test tests::bench_find_kth_largest_100000  ... bench:      45,676.72 ns/iter (+/- 2,213.69)
test tests::bench_find_kth_largest_1000000 ... bench:     456,662.50 ns/iter (+/- 51,854.12)
test tests::bench_find_kth_largest_50      ... bench:          78.74 ns/iter (+/- 14.73)
test tests::bench_find_kth_largest_500     ... bench:         316.42 ns/iter (+/- 225.63)
test tests::bench_find_kth_largest_5000    ... bench:       2,438.85 ns/iter (+/- 104.89)
test tests::bench_find_kth_largest_50000   ... bench:      23,093.09 ns/iter (+/- 2,560.42)
test tests::bench_find_kth_largest_500000  ... bench:     228,612.50 ns/iter (+/- 18,253.94)

拟合:

from math import log

import matplotlib.pyplot as plt

a = [
    [34.38, 2.59],
    [108.42, 19.10],
    [581.40, 94.61],
    [4709.71, 265.75],
    [45676.72, 2213.69],
    [456662.50, 51854.12],
    [78.74, 14.73],
    [316.42, 225.63],
    [2438.85, 104.89],
    [23093.09, 2560.42],
    [228612.50, 18253.94],
]

a.sort(key=lambda x: x[0])
x = [i[0] for i in a]
y = [i[1] for i in a]

y2 = [i * log(i) / 130 for i in x]

plt.plot(x, y, label="test")
plt.plot(x, y2, label="y=1/130 xlog x")
plt.legend()
plt.show()

结果:

Figure_2

看起来确实和 O(nlog n) 比较接近。

因此我陷入了迷惑状态,为什么会发生这样的情况呢?有人可以解答一下吗?

在 AI 解答后我算是终于看懂了:

  • 最坏情况:每次分区导致一个元素被分出,剩下 $n-1$ 个元素需要递归处理,时间复杂度为 $O(n) + O(n-1) + \dots + O(1) = O(n^2)$
    • 最坏情况发生在每次选择的基准值都接近数组的最大或最小值。
  • 平均情况:基准值均匀分布时,每次分区大致将数组分成两部分,递归树的总复杂度是: $T(n) = T(n/2) + O(n)$ 使用主定理(Master Theorem),可得 $T(n) = O(n)$

所以 基于快速排序的选择方法 不是错了,而是实测接近最坏情况的复杂度。但是我的数据是使用 Rng 生成的伪随机,为什么会出现这种情况,这我就不清楚了。

#![feature(test)]
extern crate test;
use rand::Rng;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
const FILE_PATH: &str = "numbers.bin";
const N: usize = 10_000_000; // 1000w
fn find_kth_largest(nums: &mut [i32], k: usize) -> i32 {
let n = nums.len();
quickselect(nums, 0, n - 1, n - k)
}
fn quickselect(nums: &mut [i32], l: usize, r: usize, k: usize) -> i32 {
if l == r {
return nums[k];
}
let partition = nums[l];
let mut i = l as isize - 1;
let mut j = r as isize + 1;
while i < j {
i += 1;
while nums[i as usize] < partition {
i += 1;
}
j -= 1;
while nums[j as usize] > partition {
j -= 1;
}
if i < j {
nums.swap(i as usize, j as usize);
}
}
let j = j as usize; // Convert `j` back to usize
if k <= j {
quickselect(nums, l, j, k)
} else {
quickselect(nums, j + 1, r, k)
}
}
// 随机生成 n 个 i32 并持久化到文件
fn generate_and_save_random_numbers(n: usize) -> std::io::Result<()> {
if Path::new(FILE_PATH).exists() {
println!("File {} already exists, skipping", FILE_PATH);
return Ok(());
}
let mut rng = rand::thread_rng();
let file = File::create(FILE_PATH)?;
let mut writer = BufWriter::new(file);
for _ in 0..n {
let num: i32 = rng.gen();
writer.write_all(&num.to_le_bytes())?;
}
writer.flush()?;
println!("Generated and saved {} random numbers to {}", n, FILE_PATH);
Ok(())
}
// 从文件读取所有 i32 并存储到 Vec
fn load_numbers_from_file(path: &str) -> std::io::Result<Vec<i32>> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer)?;
// 每 4 字节对应一个 i32
let numbers: Vec<_> = buffer
.chunks_exact(4)
.map(|chunk| i32::from_le_bytes(chunk.try_into().unwrap()))
.collect();
println!(
"Loaded {} numbers from file, example: {:?}",
numbers.len(),
numbers.iter().take(10).collect::<Vec<_>>()
);
Ok(numbers)
}
fn main() -> std::io::Result<()> {
generate_and_save_random_numbers(N)?;
// let numbers = load_numbers_from_file(FILE_PATH)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use test::{black_box, Bencher};
use paste::paste;
#[test]
fn test_find_kth_largest() {
let mut numbers = vec![2, 1, 9, 5, 4, 3, 8, 7, 6, 10];
assert_eq!(find_kth_largest(&mut numbers, 5), 6);
}
macro_rules! generate_benchmarks {
($($n:expr,)*) => {
$(
paste! {
#[bench]
fn [<bench_find_kth_largest_ $n>](b: &mut Bencher) {
let numbers = load_numbers_from_file(FILE_PATH).unwrap();
let slice = &numbers[..$n];
let mut owned = slice.to_vec();
b.iter(|| black_box(find_kth_largest(&mut owned, 5)));
}
}
)*
};
}
generate_benchmarks!(10, 50, 100, 500, 1000, 5000, 10000,);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment