Skip to content

Instantly share code, notes, and snippets.

@cppio
Last active September 18, 2023 23:11
Show Gist options
  • Select an option

  • Save cppio/0dd653c4c568e2d8d320c4bca97051a6 to your computer and use it in GitHub Desktop.

Select an option

Save cppio/0dd653c4c568e2d8d320c4bca97051a6 to your computer and use it in GitHub Desktop.
Many ways to write binary search
def bsearch0(l, x):
assert all(i < j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l)
while lo < hi:
# l[:lo] < x < l[hi:]
mid = hi - (hi - lo) // 2 - 1
assert lo <= mid < hi
if x > l[mid]:
lo = mid + 1
elif x < l[mid]:
hi = mid
else:
return mid
assert lo == hi
return lo - 0.5
def bsearch1(l, x):
assert all(i < j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l)
while lo < hi:
# l[:lo] < x < l[hi:]
mid = lo + (hi - lo) // 2
assert lo <= mid < hi
if x > l[mid]:
lo = mid + 1
elif x < l[mid]:
hi = mid
else:
return mid
assert lo == hi
return lo - 0.5
def bsearch2(l, x):
assert all(i < j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l) - 1
while lo <= hi:
# l[:lo] < x < l[hi+1:]
mid = lo + (hi - lo) // 2
assert lo <= mid <= hi
if x > l[mid]:
lo = mid + 1
elif x < l[mid]:
hi = mid - 1
else:
return mid
assert lo == hi + 1
return hi + 0.5
def bsearch3(l, x):
assert all(i < j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l) - 1
while lo <= hi:
# l[:lo] < x < l[hi+1:]
mid = hi - (hi - lo) // 2
assert lo <= mid <= hi
if x > l[mid]:
lo = mid + 1
elif x < l[mid]:
hi = mid - 1
else:
return mid
assert lo == hi + 1
return hi + 0.5
def bsearch4(l, x):
assert all(i < j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l)
while lo + 1 < hi:
# l[:lo+1] < x < l[hi:]
mid = lo + (hi - lo) // 2
assert lo < mid < hi
if x > l[mid]:
lo = mid
elif x < l[mid]:
hi = mid
else:
return mid
assert lo + 1 == hi
return lo + 0.5
def bsearch5(l, x):
assert all(i < j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l)
while lo + 1 < hi:
# l[:lo+1] < x < l[hi:]
mid = hi - (hi - lo) // 2
assert lo < mid < hi
if x > l[mid]:
lo = mid
elif x < l[mid]:
hi = mid
else:
return mid
assert lo + 1 == hi
return lo + 0.5
def bsearch6(l, x):
assert all(i < j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l) - 1
while lo < hi:
# l[:lo+1] < x < l[hi+1:]
mid = hi - (hi - lo) // 2
assert lo < mid <= hi
if x > l[mid]:
lo = mid
elif x < l[mid]:
hi = mid - 1
else:
return mid
assert lo == hi
return lo + 0.5
def bsearch7(l, x):
assert all(i < j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l) - 1
while lo < hi:
# l[:lo+1] < x < l[hi+1:]
mid = lo + (hi - lo) // 2 + 1
assert lo < mid <= hi
if x > l[mid]:
lo = mid
elif x < l[mid]:
hi = mid - 1
else:
return mid
assert lo == hi
return lo + 0.5
def bsearch0L(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l)
while lo < hi:
# l[:lo] < x <= l[hi:]
mid = hi - (hi - lo) // 2 - 1
assert lo <= mid < hi
if x > l[mid]:
lo = mid + 1
else:
hi = mid
assert lo == hi
if hi == len(l) or x < l[hi]:
return hi - 0.5
return hi
def bsearch1L(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l)
while lo < hi:
# l[:lo] < x <= l[hi:]
mid = lo + (hi - lo) // 2
assert lo <= mid < hi
if x > l[mid]:
lo = mid + 1
else:
hi = mid
assert lo == hi
if hi == len(l) or x < l[hi]:
return hi - 0.5
return hi
def bsearch2L(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l) - 1
while lo <= hi:
# l[:lo] < x <= l[hi+1:]
mid = lo + (hi - lo) // 2
assert lo <= mid <= hi
if x > l[mid]:
lo = mid + 1
else:
hi = mid - 1
assert lo == hi + 1
if hi + 1 == len(l) or x < l[hi + 1]:
return hi + 0.5
return hi + 1
def bsearch3L(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l) - 1
while lo <= hi:
# l[:lo] < x <= l[hi+1:]
mid = hi - (hi - lo) // 2
assert lo <= mid <= hi
if x > l[mid]:
lo = mid + 1
else:
hi = mid - 1
assert lo == hi + 1
if hi + 1 == len(l) or x < l[hi + 1]:
return hi + 0.5
return hi + 1
def bsearch4L(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l)
while lo + 1 < hi:
# l[:lo+1] < x <= l[hi:]
mid = lo + (hi - lo) // 2
assert lo < mid < hi
if x > l[mid]:
lo = mid
else:
hi = mid
assert lo + 1 == hi
if hi == len(l) or x < l[hi]:
return hi - 0.5
return hi
def bsearch5L(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l)
while lo + 1 < hi:
# l[:lo+1] < x <= l[hi:]
mid = hi - (hi - lo) // 2
assert lo < mid < hi
if x > l[mid]:
lo = mid
else:
hi = mid
assert lo + 1 == hi
if hi == len(l) or x < l[hi]:
return hi - 0.5
return hi
def bsearch6L(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l) - 1
while lo < hi:
# l[:lo+1] < x <= l[hi+1:]
mid = hi - (hi - lo) // 2
assert lo < mid <= hi
if x > l[mid]:
lo = mid
else:
hi = mid - 1
assert lo == hi
if hi + 1 == len(l) or x < l[hi + 1]:
return hi + 0.5
return hi + 1
def bsearch7L(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l) - 1
while lo < hi:
# l[:lo+1] < x <= l[hi+1:]
mid = lo + (hi - lo) // 2 + 1
assert lo < mid <= hi
if x > l[mid]:
lo = mid
else:
hi = mid - 1
assert lo == hi
if hi + 1 == len(l) or x < l[hi + 1]:
return hi + 0.5
return hi + 1
def bsearch0R(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l)
while lo < hi:
# l[:lo] <= x < l[hi:]
mid = hi - (hi - lo) // 2 - 1
assert lo <= mid < hi
if x < l[mid]:
hi = mid
else:
lo = mid + 1
assert lo == hi
if lo == 0 or l[lo - 1] < x:
return lo - 0.5
return lo - 1
def bsearch1R(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l)
while lo < hi:
# l[:lo] <= x < l[hi:]
mid = lo + (hi - lo) // 2
assert lo <= mid < hi
if x < l[mid]:
hi = mid
else:
lo = mid + 1
assert lo == hi
if lo == 0 or l[lo - 1] < x:
return lo - 0.5
return lo - 1
def bsearch2R(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l) - 1
while lo <= hi:
# l[:lo] <= x < l[hi+1:]
mid = lo + (hi - lo) // 2
assert lo <= mid <= hi
if x < l[mid]:
hi = mid - 1
else:
lo = mid + 1
assert lo == hi + 1
if lo == 0 or l[lo - 1] < x:
return lo - 0.5
return lo - 1
def bsearch3R(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = 0, len(l) - 1
while lo <= hi:
# l[:lo] <= x < l[hi+1:]
mid = hi - (hi - lo) // 2
assert lo <= mid <= hi
if x < l[mid]:
hi = mid - 1
else:
lo = mid + 1
assert lo == hi + 1
if lo == 0 or l[lo - 1] < x:
return lo - 0.5
return lo - 1
def bsearch4R(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l)
while lo + 1 < hi:
# l[:lo+1] <= x < l[hi:]
mid = lo + (hi - lo) // 2
assert lo < mid < hi
if x < l[mid]:
hi = mid
else:
lo = mid
assert lo + 1 == hi
if lo + 1 == 0 or l[lo] < x:
return lo + 0.5
return lo
def bsearch5R(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l)
while lo + 1 < hi:
# l[:lo+1] <= x < l[hi:]
mid = hi - (hi - lo) // 2
assert lo < mid < hi
if x < l[mid]:
hi = mid
else:
lo = mid
assert lo + 1 == hi
if lo + 1 == 0 or l[lo] < x:
return lo + 0.5
return lo
def bsearch6R(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l) - 1
while lo < hi:
# l[:lo+1] <= x < l[hi+1:]
mid = hi - (hi - lo) // 2
assert lo < mid <= hi
if x < l[mid]:
hi = mid - 1
else:
lo = mid
assert lo == hi
if lo + 1 == 0 or l[lo] < x:
return lo + 0.5
return lo
def bsearch7R(l, x):
assert all(i <= j for i, j in zip(l, l[1:]))
lo, hi = -1, len(l) - 1
while lo < hi:
# l[:lo+1] <= x < l[hi+1:]
mid = lo + (hi - lo) // 2 + 1
assert lo < mid <= hi
if x < l[mid]:
hi = mid - 1
else:
lo = mid
assert lo == hi
if lo + 1 == 0 or l[lo] < x:
return lo + 0.5
return lo
bsearch = [
bsearch0,
bsearch1,
bsearch2,
bsearch3,
bsearch4,
bsearch5,
bsearch6,
bsearch7,
]
bsearchL = [
bsearch0L,
bsearch1L,
bsearch2L,
bsearch3L,
bsearch4L,
bsearch5L,
bsearch6L,
bsearch7L,
]
bsearchR = [
bsearch0R,
bsearch1R,
bsearch2R,
bsearch3R,
bsearch4R,
bsearch5R,
bsearch6R,
bsearch7R,
]
for f in bsearch + bsearchL + bsearchR:
for i in range(3):
assert f([0, 1, 2], i) == i
for i in range(4):
assert f([0, 1, 2, 3], i) == i
for i in range(4):
assert f([1, 3, 5], i * 2) == i - 0.5
for i in range(5):
assert f([1, 3, 5, 7], i * 2) == i - 0.5
for f in bsearchL:
assert f([0] * 16, 0) == 0
assert f([1] * 16, 0) == -0.5
assert f([1] * 16, 2) == 15.5
assert f([1] * 4 + [3] * 16 + [5] * 8, 0) == -0.5
assert f([1] * 4 + [3] * 16 + [5] * 8, 1) == 0
assert f([1] * 4 + [3] * 16 + [5] * 8, 2) == 3.5
assert f([1] * 4 + [3] * 16 + [5] * 8, 3) == 4
assert f([1] * 4 + [3] * 16 + [5] * 8, 4) == 19.5
assert f([1] * 4 + [3] * 16 + [5] * 8, 5) == 20
assert f([1] * 4 + [3] * 16 + [5] * 8, 6) == 27.5
for f in bsearchR:
assert f([0] * 16, 0) == 15
assert f([1] * 16, 0) == -0.5
assert f([1] * 16, 2) == 15.5
assert f([1] * 4 + [3] * 16 + [5] * 8, 0) == -0.5
assert f([1] * 4 + [3] * 16 + [5] * 8, 1) == 3
assert f([1] * 4 + [3] * 16 + [5] * 8, 2) == 3.5
assert f([1] * 4 + [3] * 16 + [5] * 8, 3) == 19
assert f([1] * 4 + [3] * 16 + [5] * 8, 4) == 19.5
assert f([1] * 4 + [3] * 16 + [5] * 8, 5) == 27
assert f([1] * 4 + [3] * 16 + [5] * 8, 6) == 27.5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment