Created
June 28, 2023 10:58
-
-
Save appgurueu/f2192d081cd6e1d7dba4ee5f8a3fe1b6 to your computer and use it in GitHub Desktop.
Count inversions summing up to `x`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Use a modified mergesort to count the inversions summing up to `x` | |
local function count_x_inversions( | |
list, -- of distinct nums | |
x -- target sum | |
) | |
local function merge(result, left, right) | |
local inversions = 0 | |
local i, j, k = 1, 1, 1 | |
local left_idx = {} | |
for idx, v in ipairs(left) do | |
assert(not left_idx[v], "nums aren't distinct") | |
left_idx[v] = idx | |
end | |
while i <= #left and j <= #right do | |
-- Compare "head" element, insert "winner" | |
if right[j] < left[i] then | |
-- right list came first, this is an inversion with all the larger elements left in the left list | |
-- This is how you would normally count inversions: `inversions = inversions + (#left - i + 1)` | |
-- This is the crucial part; note that there can only be *one* inversion with `right[j]` summing up to exactly `x` | |
-- Note that we could also do a binary search on `left[i:]` here; | |
-- this would give worst case O(log n) instead of expected O(1) and worst case O(n) for this step, | |
-- leading to hard O(n (log n)^2) total vs. expected O(n log n) but worst case O(n^2 log n). | |
if (left_idx[x - right[j]] or -math.huge) >= i then | |
inversions = inversions + 1 | |
end | |
result[k] = right[j] | |
j = j + 1 | |
else | |
result[k] = left[i] | |
i = i + 1 | |
end | |
k = k + 1 | |
end | |
-- Add remaining elements of either list | |
for offset = 0, #left - i do | |
result[k + offset] = left[i + offset] | |
end | |
for offset = 0, #right - j do | |
result[k + offset] = right[j + offset] | |
end | |
return inversions | |
end | |
local function mergesort(list_to_sort, from, to) | |
if from == to then | |
list_to_sort[1] = list[from] | |
end | |
if from >= to then | |
return 0 | |
end | |
local mid = math.floor((to + from) / 2) | |
local left = {} | |
local left_inversions = mergesort(left, from, mid) | |
local right = {} | |
local right_inversions = mergesort(right, mid + 1, to) | |
return merge(list_to_sort, left, right) + left_inversions + right_inversions | |
end | |
return mergesort(list, 1, #list) | |
end | |
local function count_x_inversions_naive(list, x) | |
local inversions = 0 | |
for i = 1, #list do | |
for j = i + 1, #list do | |
if list[i] > list[j] and list[i] + list[j] == x then | |
inversions = inversions + 1 | |
end | |
end | |
end | |
return inversions | |
end | |
-- Tests | |
do | |
assert(count_x_inversions({2, 0, 5, 1}, 3) == count_x_inversions_naive({2, 0, 5, 1}, 3)) | |
-- Fuzzing because I'm lazy | |
local function shuffle( | |
list -- list to be shuffled in-place | |
) | |
for index = 1, #list - 1 do | |
local index_2 = math.random(index, #list) | |
list[index], list[index_2] = list[index_2], list[index] | |
end | |
end | |
for _ = 1, 10 do | |
local t = {} | |
local n = math.random(10, 1e3) | |
for i = 1, n do | |
t[i] = i | |
end | |
shuffle(t) | |
local x = math.random() < 0.5 and (1 + n) or math.random(math.ceil(n/4), math.floor(3*n/4)) | |
local naive = count_x_inversions_naive(t, x) | |
--! count_x_inversion sorts `t`, reducing the inversions to 0! | |
assert(naive == count_x_inversions(t, x)) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment