Skip to content

Instantly share code, notes, and snippets.

@dermesser
Created November 11, 2022 18:32
Show Gist options
  • Save dermesser/eb6caeffcb678153449913316239863f to your computer and use it in GitHub Desktop.
Save dermesser/eb6caeffcb678153449913316239863f to your computer and use it in GitHub Desktop.
Find an integer in a sorted matrix of integers (but generalizable to any orderable type)
# > X = generate_sorted_matrix(10, 10)
# 10×10 Matrix{Int64}:
# 2 139 194 248 337 428 544 625 736 873
# 11 141 195 251 340 454 564 644 743 892
# 39 153 200 272 350 454 582 649 786 900
# 58 163 203 284 357 462 586 652 799 912
# 66 164 219 295 359 473 599 654 806 930
# 73 170 220 302 361 482 604 677 813 947
# 74 177 243 303 364 488 609 681 815 957
# 78 187 245 323 409 504 611 696 845 961
# 96 189 245 332 414 538 615 703 853 973
# 125 190 245 332 414 539 615 727 864 998
#
# > find_in_sorted_matrix(X, 703)
# Some((8, 9))
function generate_sorted_matrix(m, n)::Matrix{Int}
v = abs.(rand(Int, m*n)) .% 1000
sort!(v)
reshape(v, m, n)
end
const Option{T} = Union{Some{T}, Nothing}
function find_in_sorted_array(V::AbstractVector{Int}, n::Int)::Option{Int}
lo, mid, hi = 1, round(Int, length(V)/2), length(V)
while lo < hi
if n < V[mid]
hi = mid-1
mid = round(Int, (lo+hi)/2)
continue
elseif n > V[mid]
lo = mid+1
mid = round(Int, (lo+hi)/2)
continue
elseif n == V[mid]
return Some(mid)
end
@assert false "if exhausted"
end
V[mid] == n ? Some(mid) : nothing
end
function find_in_sorted_matrix(M::Matrix, n::Int)::Option{Tuple{Int, Int}}
r, c = size(M)
# 1. Find column
lo, mid, hi = 1, round(Int, c/2), c
while lo < hi
if n < M[begin, mid]
hi = mid-1
mid = round(Int, (lo+hi)/2)
continue
elseif n > M[end, mid]
lo = mid+1
mid = round(Int, (lo+hi)/2)
continue
else
break
end
end
# 2. Find element
in_array = find_in_sorted_array((@view M[:,mid]), n)
if isnothing(in_array)
nothing
else
Some((mid, something(in_array)))
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment