Skip to content

Instantly share code, notes, and snippets.

@ochaton
Last active November 21, 2019 08:53
Show Gist options
  • Select an option

  • Save ochaton/6bc0a98a2cf0c40bb79924958aec2b76 to your computer and use it in GitHub Desktop.

Select an option

Save ochaton/6bc0a98a2cf0c40bb79924958aec2b76 to your computer and use it in GitHub Desktop.
First attempt to understand how table.sort works in Lua 5.1
-- I suggest to execute this script using lua5.3+ (for integer division)
-- Algorithm was rewritten from https://www.lua.org/source/5.1/ltablib.c.html#auxsort
local function sort(a, comp, l, u)
while(l < u) do
if comp(a[u], a[l]) then -- a[u] < a[l]
a[u], a[l] = a[l], a[u] -- swap
end
if u-l == 1 then -- only 2 elements
break
end
local i = (l + u) // 2
if comp(a[i], a[l]) then -- pivot < left
a[i], a[l] = a[l], a[i] -- swap
elseif comp(a[u], a[i]) then -- right < pivot
a[i], a[u] = a[u], a[i] -- swap
end
if u-l == 2 then -- only 3 elements
break
end
local pivot = a[i]
a[u-1], a[i] = a[i], a[u-1]
-- a[l] <= pivot = a[u-1] <= a[u]
i = l
local j = u-1
while true do
repeat
i = i + 1
if i > u then error("invalid order function for sorting") end
until comp(a[i], pivot) == false
-- a[i] contains first item from a[l..i] which is not less than pivot
repeat
j = j - 1
if j < l then error("invalid order function for sorting") end
until comp(pivot, a[j]) == false
-- a[j] contains first item from [j..u-2] which is not greater than pivot
if j < i then
--[[
1) a[i-1] < pivot <= a[i], l < i <= u
2) a[j] <= pivot < a[j+1], l <= j < u - 1
if comparator is strict [ comp(a, b) != comp(b, a) for any a,b ]:
Assume that j < i - 1:
=> j + 1 <= i - 1
=> j+1 belongs to [l..i-1]
=> comp(a[j+1], pivot) = true (from first loop)
=> comp(pivot, a[j+1]) = true (from second loop, since comp(pivot, a[j]) == false)
But comparator is strict => so, j >= i - 1
Assume that j > i - 1:
=> j >= i:
=> if j < i fails
That means: j == i - 1
Otherwise:
I have to sleep some time to think about it.
]]
break
end
-- pivot <= a[j+1..u-2]
a[i], a[j] = a[j], a[i]
end
-- since a[i] not less than pivot (pivot == a[u-1])
-- and a[l..i-1] less than pivot:
-- we swap a[i], a[u-1]
a[u-1], a[i] = a[i], a[u-1] -- Note: a[u-1] === pivot
-- a[l..i-1] <= pivot == a[i] == a[j+1]
-- pivot <= a[j+2] == a[i]
-- So: a[l..i-1] <= pivot == a[i] <= a[i+1..u]
if i-l < u-i then -- size of left part is less than size of right part
j = l
i = i - 1
l = i + 2
-- let's sort {l, i-1} recursively
-- and then sort {i+2, u} in this call
else
j = i + 1
i = u
u = j - 2
--[[
You're looking at `u = j - 2` and think WTF?
We've already proved that j = i - 1 after while true do ... end loop
That means a[l..i-1] <= pivot but still unsorted.
So we must execute recursive call with {l, i - 1}.
But calling recursion with borders {l, j - 2} seems like a great bug.
But now, read code again:
_j = j -- copy j to temporary variable
j = i + 1 -- j = _j + 2 (since _j == i - 1)
i = u
u = j - 2 -- _j + 2 - 2 == _j == i - 1
That means that code is correct.
]]
-- let's sort {i+1, u} recursively
-- and then sort {l, i-1} in this call
end
sort(a, comp, j, i)
end
end
local t = { 5, 2, 3, 3, 1, 2, 0 }
sort(t, function(a, b) return a <= b end, 1, #t)
print(table.concat(t, " ")) --> 0 1 2 2 5 3 3
sort(t, function(a, b) return a <= b end, 1, #t)
print(table.concat(t, " ")) --> 0 1 2 2 3 3 5
t = { 5, 2, 3, 3, 1, 2, 0 }
sort(t, function(a, b) return a < b end, 1, #t)
print(table.concat(t, " ")) --> 0 1 2 2 3 3 5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment