Skip to content

Instantly share code, notes, and snippets.

@buybackoff
Last active August 31, 2024 20:44
Show Gist options
  • Save buybackoff/f403be01486220baba8a9d4fe22c3cf6 to your computer and use it in GitHub Desktop.
Save buybackoff/f403be01486220baba8a9d4fe22c3cf6 to your computer and use it in GitHub Desktop.
Avx2/branch-optimized binary search in .NET

Binary search is theoretically optimal, but it's possible to speed it up substantially using AVX2 and branchless code even in .NET Core.

Memory access is the limiting factor for binary search. When we access each element for comparison a cache line is loaded, so we could load a 32-byte vector almost free, check if it contains the target value, and if not - reduce the search space by 32/sizeof(T) elements instead of 1 element. This gives quite good performance improvement (code in BinarySearch1.cs and results in the table 1 below).

However, for larger N the search space reduction is quite minimal and the most gains come from switching to linear search. After an interesting discussion in Twitter (especially with @trav_downs), and trying to widen the pivot area to use 2 AVX2 vectors it became clear that just switching to linear search sooner is more important than using AVX2 vectors as pivots.

The linear search was not using AVX2, and for linear AVX2 should definitely work, shouldn't it!? With vectorized linear search and some additional branching optimization the performance is improved by additional 30-50% for the most relevant N (code in BinarySearch2.cs and results in the table 2 below).

The final results vs classic binary search are faster by 65% on average, with near 2x improvement for N=[512,1024]:

N Classic Avx Avx+ Avx/Classic Avx+/Avx Avx+/Classic
1 569.0 390.3 630.0 -31% 61% 11%
2 537.5 287.0 616.1 -47% 115% 15%
4 286.0 209.5 629.5 -27% 201% 120%
8 185.2 290.4 247.5 57% -15% 34%
16 120.4 215.3 199.4 79% -7% 66%
32 99.5 144.2 153.8 45% 7% 55%
64 76.4 119.2 129.5 56% 9% 69%
128 61.2 101.5 111.1 66% 10% 82%
256 50.2 81.0 83.3 61% 3% 66%
512 29.1 43.8 74.8 50% 71% 157%
1024 22.5 31.0 43.7 38% 41% 94%
4096 19.0 23.3 30.7 23% 32% 62%
16384 17.7 20.5 28.3 16% 38% 60%
65536 16.7 19.6 24.6 17% 26% 47%
131072 15.9 19.1 23.1 20% 21% 45%
      Avg 28% 41% 65%
      Min -47% -15% 11%
      Max 79% 201% 157%

AVX512 with _mm256_cmpge_epi64_mask instruction should improve it even more, but it is not available on .NET yet.

Benchmarks

  • Avx - using AVX2 vector as pivot
  • Avx+ - same as Avx + using AVX2 for linear search

Results for long. Searching [0,1,2,3,...] in [0,2,4,6,...] array (50% hit rate)

Table 1: use AVX2 for pivot

Case MOPS Elapsed GC0 GC1 GC2 Memory
BS_Classic_2 570.38 15 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_1 569.90 15 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_1 505.60 17 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_2 504.03 17 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_4 503.77 17 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_4 288.84 29 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_8 285.34 29 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_16 186.61 45 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_8 183.58 46 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_32 151.42 55 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_64 117.92 71 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_16 111.73 75 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_128 95.59 88 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_32 90.19 93 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_64 71.43 117 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_256 63.05 133 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_128 60.40 139 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_256 47.07 178 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_512 40.02 210 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_1024 31.25 268 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_512 28.62 293 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_4096 22.84 367 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_1024 22.35 375 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_16384 19.77 424 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_4096 19.03 441 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_16384 17.71 474 ms 0.0 0.0 0.0 0.000 MB

Table 2: use AVX2 for linear search

Case MOPS Elapsed GC0 GC1 GC2 Memory
BS_Avx+_4 519.18 40 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_1 507.19 41 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_2 498.41 42 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_1 372.98 56 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_2 283.01 74 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_8 260.14 81 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_8 244.84 86 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_4 241.00 87 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_16 201.39 104 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_16 197.56 106 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_32 157.66 133 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_32 145.97 144 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_64 127.95 164 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_64 122.45 171 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_128 108.82 193 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_128 98.70 212 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_256 89.20 235 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_256 81.84 256 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_512 72.26 290 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_1024 44.33 473 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_512 36.27 578 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_4096 30.81 681 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_1024 28.90 726 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_16384 28.55 735 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_65536 24.31 863 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_131072 22.53 931 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_4096 22.53 931 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_16384 19.56 1,072 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_65536 18.17 1,154 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_131072 17.37 1,207 ms 0.0 0.0 0.0 0.000 MB

Table 3: combined results (different run from table 2)

Case MOPS Elapsed GC0 GC1 GC2 Memory
BS_Avx+_1 629.99 33 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_4 629.49 33 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_2 616.09 34 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_1 569.00 37 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_2 537.47 39 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_1 390.32 54 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_8 290.42 72 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_2 286.97 73 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_4 286.03 73 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_8 247.45 85 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_16 215.32 97 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_4 209.48 100 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_16 199.38 105 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_8 185.15 113 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_32 153.79 136 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_32 144.24 145 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_64 129.52 162 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_16 120.41 174 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_64 119.15 176 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_128 111.14 189 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_128 101.45 207 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_32 99.50 211 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_256 83.33 252 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_256 81.02 259 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_64 76.43 274 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_512 74.81 280 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_128 61.20 343 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_256 50.18 418 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_512 43.84 478 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_1024 43.72 480 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_1024 31.00 676 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_4096 30.71 683 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_512 29.13 720 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_16384 28.26 742 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_65536 24.60 853 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_4096 23.28 901 ms 0.0 0.0 0.0 0.000 MB
BS_Avx+_131072 23.08 909 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_1024 22.48 933 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_16384 20.52 1,022 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_65536 19.57 1,072 ms 0.0 0.0 0.0 0.000 MB
BS_Avx_131072 19.06 1,100 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_4096 18.95 1,106 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_16384 17.68 1,186 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_65536 16.71 1,255 ms 0.0 0.0 0.0 0.000 MB
BS_Classic_131072 15.87 1,321 ms 0.0 0.0 0.0 0.000 MB
/// <summary>
/// Performs classic binary search and returns index of the value or its negative binary complement.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining
#if HAS_AGGR_OPT
| MethodImplOptions.AggressiveOptimization
#endif
)]
[SuppressMessage("ReSharper", "HeapView.BoxingAllocation")] // false warnings for (type)(object)value pattern
public static int BinarySearch<T>(ref T vecStart, int length, T value, KeyComparer<T> comparer = default)
{
#if HAS_INTRINSICS
if (Avx2.IsSupported)
{
if (typeof(T) == typeof(sbyte))
return BinarySearchAvx(ref Unsafe.As<T, sbyte>(ref vecStart), length, (sbyte) (object) value);
if (typeof(T) == typeof(short))
return BinarySearchAvx(ref Unsafe.As<T, short>(ref vecStart), length, (short) (object) value);
if (typeof(T) == typeof(int))
return BinarySearchAvx(ref Unsafe.As<T, int>(ref vecStart), length, (int) (object) value);
if (typeof(T) == typeof(long)
|| typeof(T) == typeof(Timestamp)
|| typeof(T) == typeof(DateTime)
)
return BinarySearchAvx(ref Unsafe.As<T, long>(ref vecStart), length, (long) (object) value);
}
#endif
// This one is actually very hard to beat in general case
// because of memory access (cache miss) costs. In the classic
// algorithm every memory access is useful, i.e. it halves the
// search space. K-ary search has K-2 useless memory accesses.
// E.g. for SIMD-ized search with K = 4 we do 4 memory accesses
// but reduce the search space to the same size as 2 accesses
// in the classic algorithm. SIMD doesn't speedup memory access,
// which is the main cost for high number of items.
return BinarySearchClassic(ref vecStart, length, value, comparer);
}
#if HAS_INTRINSICS
[MethodImpl(MethodImplOptions.AggressiveInlining
#if HAS_AGGR_OPT
| MethodImplOptions.AggressiveOptimization
#endif
)]
internal static int BinarySearchAvx(ref long vecStart, int length, long value)
{
unchecked
{
int i;
int c;
int lo = 0;
int hi = length - 1;
var valVec = Vector256.Create(value);
while (hi - lo > Vector256<long>.Count - 1)
{
i = (int) (((uint) hi + (uint) lo) >> 1) - (Vector256<long>.Count >> 1);
var vec = Unsafe.ReadUnaligned<Vector256<long>>(ref Unsafe.As<long, byte>(ref Unsafe.Add(ref vecStart, i)));
// AVX512 has _mm256_cmpge_epi64_mask that should allow to combine the two operations
// and avoid edge-case check in `mask == 0` case below
var gt = Avx2.CompareGreaterThan(valVec, vec); // _mm256_cmpgt_epi64
var mask = Avx2.MoveMask(gt.AsByte());
if (mask == 0) // val is smaller than all in vec
{
// but could be equal to the first element
c = value.CompareTo(UnsafeEx.ReadUnaligned(ref Unsafe.Add(ref vecStart, i)));
if (c == 0)
{
lo = i;
goto RETURN;
}
hi = i - 1;
}
else if (mask == -1) // val is larger than all in vec
{
lo = i + Vector256<long>.Count;
}
else
{
var clz = BitUtil.NumberOfLeadingZeros(mask);
var index = (32 - clz) / Unsafe.SizeOf<long>();
lo = i + index;
c = value.CompareTo(UnsafeEx.ReadUnaligned(ref Unsafe.Add(ref vecStart, lo)));
goto RETURN;
}
}
while ((c = value.CompareTo(UnsafeEx.ReadUnaligned(ref Unsafe.Add(ref vecStart, lo)))) > 0
& ++lo <= hi) // if using branchless & then need to correct lo below
{
}
// correct back non-short-circuit & evaluation
lo -= UnsafeEx.Clt(c, 1); // (int)(c < 1)
RETURN:
var ceq1 = -UnsafeEx.Ceq(c, 0); // (int)(c == 0)
return (ceq1 & lo) | (~ceq1 & ~lo);
}
}
#endif
/// <summary>
/// Performs classic binary search and returns index of the value or its negative binary complement.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining
#if HAS_AGGR_OPT
| MethodImplOptions.AggressiveOptimization
#endif
)]
internal static int BinarySearchClassic<T>(ref T vecStart, int length, T value, KeyComparer<T> comparer = default)
{
unchecked
{
int lo = 0;
int hi = length - 1;
// If length == 0, hi == -1, and loop will not be entered
while (lo <= hi)
{
// PERF: `lo` or `hi` will never be negative inside the loop,
// so computing median using uints is safe since we know
// `length <= int.MaxValue`, and indices are >= 0
// and thus cannot overflow an uint.
// Saves one subtraction per loop compared to
// `int i = lo + ((hi - lo) >> 1);`
int i = (int) (((uint) hi + (uint) lo) >> 1);
int c = comparer.Compare(value, UnsafeEx.ReadUnaligned(ref Unsafe.Add(ref vecStart, i)));
if (c == 0)
{
return i;
}
if (c > 0)
{
lo = i + 1;
}
else
{
hi = i - 1;
}
}
// If none found, then a negative number that is the bitwise complement
// of the index of the next element that is larger than or, if there is
// no larger element, the bitwise complement of `length`, which
// is `lo` at this point.
return ~lo;
}
}
#if HAS_INTRINSICS
[MethodImpl(MethodImplOptions.AggressiveInlining
#if HAS_AGGR_OPT
| MethodImplOptions.AggressiveOptimization
#endif
)]
internal static int BinarySearchAvx2(ref long vecStart, int length, long value)
{
unchecked
{
int c;
int lo = 0;
int hi = length - 1;
Vector256<long> vec;
Vector256<long> gt;
int mask;
if (hi - lo < Vector256<long>.Count)
goto LINEAR;
var valVec = Vector256.Create(value);
while (hi - lo >= Vector256<long>.Count * 2)
{
var i = (int) (((uint) hi + (uint) lo - Vector256<long>.Count) >> 1);
vec = Unsafe.ReadUnaligned<Vector256<long>>(ref Unsafe.As<long, byte>(ref Unsafe.Add(ref vecStart, i)));
gt = Avx2.CompareGreaterThan(valVec, vec);
mask = Avx2.MoveMask(gt.AsByte());
if (mask != -1)
{
if (mask != 0)
{
int clz = (int) Lzcnt.LeadingZeroCount((uint) mask);
int index = (32 - clz) / Unsafe.SizeOf<long>();
lo = i + index;
c = value.CompareTo(UnsafeEx.ReadUnaligned<long>(ref Unsafe.Add<long>(ref vecStart, lo)));
goto RETURN;
}
// val is not greater than all in vec
// not i-1, i could equal;
hi = i;
}
else
{
// val is larger than all in vec
lo = i + Vector256<long>.Count;
}
}
{
vec = Unsafe.ReadUnaligned<Vector256<long>>(ref Unsafe.As<long, byte>(ref Unsafe.Add(ref vecStart, lo)));
gt = Avx2.CompareGreaterThan(valVec, vec); // _mm256_cmpgt_epi64
mask = Avx2.MoveMask(gt.AsByte());
var clz = (int) Lzcnt.LeadingZeroCount((uint) mask);
var index = (32 - clz) / Unsafe.SizeOf<long>();
lo += index;
}
while (mask == -1 & hi - lo >= Vector256<long>.Count) ;
LINEAR:
while ((c = value.CompareTo(UnsafeEx.ReadUnaligned(ref Unsafe.Add(ref vecStart, lo)))) > 0
&& ++lo <= hi)
{
}
RETURN:
var ceq1 = -UnsafeEx.Ceq(c, 0);
return (ceq1 & lo) | (~ceq1 & ~lo);
}
}
#endif

In .NET it's not trivial to cast a boolean to int, e.g.:

int lo = 0;
...
var c = a.CompareTo(b);
var cInt = *(int*)&c;

or similar manipulations.

It looks simple but is actually quite slow (when we count cycles, not milliseconds) probably due to bad codegen, local variable allocation and the likes.

In some cases we could write branchless code that is faster than code that uses if statement. E.g. in binary search we return current index lo if values are equal or los binary complement otherwise.

int lo = 0;
...
var c = a.CompareTo(b);
return c == 0 ? lo : ~lo;

Some modern compilers convert ternary expressions to cmov, but it's tricky. Even @lemire writes this:

The condition move instructions are pretty much standard and old at this point. Sadly, I only know how to convince one compiler (GNU GCC) to reliably produce conditional move instructions. And I have no clue how to control how Java, Swift, Rust or any other language deals with branches.

So instead of entertaining any hope that .NET JIT could/would support it, we could avoid branches using bitwise operations.

int lo = 0;
...
var ceq = a.CompareTo(b) == 0; // boolean
var ceqInt = -(*(int*)&ceq); // -1 if equal, 0 otherwise
return (ceq1 & lo) | (~ceq1 & ~lo);

However, this cast from bool to int is slow no matter how I tried to do it (pointers/unsafe).

Interestingly the IL instruction ceq already returns exactly 1 or 0, it is C#-the-language that exposes the result as a boolean.

To get access to the int result we could write our own unsafe method:

  .method public hidebysig static int32 Ceq(int32 first, int32 second) cil managed aggressiveinlining
  {
    ldarg.0
    ldarg.1
    ceq
    ret
  }

and similar methods for greater/less than operations.

Using this method as UnsafeEx.Ceq(first: int, second:int):int we could rewrite the code as:

int lo = 0;
...
var c = a.CompareTo(b);
var ceqInt = -UnsafeEx.Ceq(c, 0); // -1 if equal, 0 otherwise
return (ceqInt & lo) | (~ceqInt & ~lo);

and it finally performs very fast and faster than the branchy version when the branch is not predictable.

There is another example when branchless eager & comparison with subsequent correction operation via UnsafeEx.Clt allows to eliminate one branch. But benchmarks are inconclusive for this usage (branchless is definitely not slower, even though it does more work).

while ((c = value.CompareTo(UnsafeEx.ReadUnaligned(ref Unsafe.Add(ref vecStart, lo)))) > 0
       & ++lo <= hi) // if using branchless & then need to correct lo below
{
}
        
// correct back non-short-circuit & evaluation
lo -= UnsafeEx.Clt(c, 1); // (int)(c < 1)

Such methods that return int from comparison would be a good addition to System.Runtime.CompilerServices.Unsafe, but are also available in Spreads.Native repo and NuGet package.

The code for Avx2- and branch-optimized binary search for sbyte, short, int and long is generated from a T4 template here: https://github.com/Spreads/Spreads/blob/master/src/Spreads.Core/Algorithms/VectorSearch.Avx.tt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment