Last active
November 28, 2021 17:59
-
-
Save comp500/d496e7484379bd264bc6816835252b7b to your computer and use it in GitHub Desktop.
Sodium SIMD optimisations microbenchmarks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.example; | |
import sun.misc.Unsafe; | |
import org.openjdk.jmh.annotations.*; | |
import java.lang.invoke.MethodHandles; | |
import java.lang.invoke.VarHandle; | |
import java.lang.reflect.Field; | |
import java.lang.reflect.InvocationTargetException; | |
import java.lang.reflect.Method; | |
import java.nio.ByteBuffer; | |
import java.nio.ByteOrder; | |
import java.nio.LongBuffer; | |
import java.util.Arrays; | |
import java.util.Random; | |
@Fork(value = 1, jvmArgsAppend = { | |
"--add-exports", | |
"java.base/jdk.internal.misc=ALL-UNNAMED", | |
"-XX:+UseSuperWord", | |
"-XX:+UnlockDiagnosticVMOptions", | |
"-XX:CompileCommand=print,*SectionCopyBenchmark.copyShift2D"}) | |
@Warmup(iterations = 5) | |
@Measurement(iterations = 10) | |
public class SectionCopyBenchmark { | |
private static class BlockState {} | |
@State(Scope.Thread) | |
public static class Context { | |
public final BlockState[] dest = new BlockState[4096]; | |
public final BlockState[] palette = new BlockState[16]; | |
public final long[] src = new long[256]; | |
public final int[] destIndexes = new int[4096]; | |
public final long[] shifts = new long[4096]; | |
public final long[] destIndexesLong = new long[4096]; | |
public final int[] destIndexesInt = new int[4096]; | |
public final LongBuffer destIndexesLongOffHeap = ByteBuffer.allocateDirect(4096 * 8).asLongBuffer(); | |
public final LongBuffer destIndexesLongOnHeap = LongBuffer.allocate(4096); | |
public final byte[] destIndexesByte = new byte[4096 * 8]; | |
public long startAddr = 0; | |
public final long[][] destIndexesLong2D = new long[16][256]; | |
@Setup | |
public void setup() throws ClassNotFoundException, InvocationTargetException, IllegalAccessException, NoSuchMethodException { | |
Arrays.fill(dest, null); | |
for (int i = 0; i < palette.length; i++) { | |
palette[i] = new BlockState(); | |
} | |
Random random = new Random(); | |
long prev = random.nextLong(); | |
for (int i = 0; i < src.length; i++) { | |
if (random.nextBoolean()) { | |
src[i] = random.nextLong(); | |
prev = src[i]; | |
} else { | |
src[i] = prev; | |
} | |
} | |
for (int i = 0; i < shifts.length; i += 16) { | |
for (int j = 0; j < 64; j += 4, i++) { | |
shifts[i] = 1L << j; | |
} | |
} | |
// Method address = Class.forName("sun.nio.ch.DirectBuffer").getMethod("address"); | |
// startAddr = (long) address.invoke(destIndexesLongOffHeap); | |
} | |
} | |
/* | |
With random data: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.unrolled thrpt 10 184724.725 ± 1339.400 ops/s | |
With all zeros: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.unrolled thrpt 10 184770.122 ± 1621.971 ops/s | |
*/ | |
@Benchmark | |
public BlockState[] unrolled(Context context) { | |
long[] data = context.src; | |
BlockState[] palette = context.palette; | |
BlockState[] dst = context.dest; | |
for (int i = 0, blockIdx = 0; i < data.length; i++, blockIdx += 16) { | |
long v = data[i]; | |
dst[blockIdx] = palette[(int) (v & 15)]; | |
dst[blockIdx + 1] = palette[(int) ((v >>> 4 ) & 15)]; | |
dst[blockIdx + 2] = palette[(int) ((v >>> 8 ) & 15)]; | |
dst[blockIdx + 3] = palette[(int) ((v >>> 12) & 15)]; | |
dst[blockIdx + 4] = palette[(int) ((v >>> 16) & 15)]; | |
dst[blockIdx + 5] = palette[(int) ((v >>> 20) & 15)]; | |
dst[blockIdx + 6] = palette[(int) ((v >>> 24) & 15)]; | |
dst[blockIdx + 7] = palette[(int) ((v >>> 28) & 15)]; | |
dst[blockIdx + 8] = palette[(int) ((v >>> 32) & 15)]; | |
dst[blockIdx + 9] = palette[(int) ((v >>> 36) & 15)]; | |
dst[blockIdx + 10] = palette[(int) ((v >>> 40) & 15)]; | |
dst[blockIdx + 11] = palette[(int) ((v >>> 44) & 15)]; | |
dst[blockIdx + 12] = palette[(int) ((v >>> 48) & 15)]; | |
dst[blockIdx + 13] = palette[(int) ((v >>> 52) & 15)]; | |
dst[blockIdx + 14] = palette[(int) ((v >>> 56) & 15)]; | |
dst[blockIdx + 15] = palette[(int) ((v >>> 60) & 15)]; | |
} | |
return context.dest; | |
} | |
/* | |
With random data: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.prevArraycopy thrpt 10 105549.357 ± 3117.402 ops/s | |
With 50% repeated data: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.prevArraycopy thrpt 10 107347.234 ± 4216.832 ops/s | |
With all zeros: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.prevArraycopy thrpt 10 127338.287 ± 2165.085 ops/s | |
*/ | |
@Benchmark | |
public BlockState[] prevArraycopy(Context context) { | |
long[] data = context.src; | |
BlockState[] palette = context.palette; | |
BlockState[] dst = context.dest; | |
int blockIdx = 0; | |
long v = data[0], prevV = data[0]; | |
for (int shift = 0; shift < 64; shift += 4, blockIdx++) { | |
dst[blockIdx] = palette[(int) ((v >>> shift) & 15)]; | |
} | |
for (int i = 1; i < data.length; i++) { | |
v = data[i]; | |
if (v == prevV) { | |
System.arraycopy(dst, blockIdx - 16, dst, blockIdx, 16); | |
blockIdx += 16; | |
} else { | |
prevV = v; | |
for (int shift = 0; shift < 64; shift += 4, blockIdx++) { | |
dst[blockIdx] = palette[(int) ((v >>> shift) & 15)]; | |
} | |
} | |
} | |
return context.dest; | |
} | |
/* | |
With all zeros: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.simpleLoop thrpt 10 130528.025 ± 1449.589 ops/s | |
With 50% repeated data: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.simpleLoop thrpt 10 128406.166 ± 7226.141 ops/s | |
*/ | |
@Benchmark | |
public BlockState[] simpleLoop(Context context) { | |
long[] data = context.src; | |
BlockState[] palette = context.palette; | |
BlockState[] dst = context.dest; | |
for (int i = 0, blockIdx = 0; i < data.length; i++) { | |
long v = data[i]; | |
for (int j = 0; j < 64; j += 4, blockIdx++) { | |
dst[blockIdx] = palette[(int) ((v >>> j) & 15)]; | |
} | |
} | |
return context.dest; | |
} | |
/* | |
With 50% repeated data: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.bytewise thrpt 10 130926.419 ± 3423.496 ops/s | |
*/ | |
@Benchmark | |
public BlockState[] bytewise(Context context) { | |
long[] data = context.src; | |
BlockState[] palette = context.palette; | |
BlockState[] dst = context.dest; | |
for (int i = 0, blockIdx = 0; i < data.length; i++) { | |
long v = data[i]; | |
for (int j = 0; j < 64; j += 8, blockIdx += 2) { | |
byte b = (byte) (v >>> j); | |
dst[blockIdx] = palette[b & 15]; | |
dst[blockIdx + 1] = palette[(b >>> 4) & 15]; | |
} | |
} | |
return context.dest; | |
} | |
/* | |
50% repeated data: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.bytewiseIndexes thrpt 10 377116.340 ± 3399.245 ops/s | |
*/ | |
@Benchmark | |
public int[] bytewiseIndexes(Context context) { | |
long[] data = context.src; | |
int[] dst = context.destIndexes; | |
for (int i = 0, blockIdx = 0; i < data.length; i++) { | |
long v = data[i]; | |
for (int j = 0; j < 64; j += 8, blockIdx += 2) { | |
byte b = (byte) (v >>> j); | |
dst[blockIdx] = b & 15; | |
dst[blockIdx + 1] = (b >>> 4) & 15; | |
} | |
} | |
return context.destIndexes; | |
} | |
/* | |
50% repeated data: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.simpleLoopIndexes thrpt 10 329252.592 ± 2189.486 ops/s | |
*/ | |
@Benchmark | |
public int[] simpleLoopIndexes(Context context) { | |
long[] data = context.src; | |
int[] dst = context.destIndexes; | |
for (int i = 0, blockIdx = 0; i < data.length; i++) { | |
long v = data[i]; | |
for (int j = 0; j < 64; j += 4, blockIdx++) { | |
dst[blockIdx] = (int) ((v >>> j) & 15); | |
} | |
} | |
return context.destIndexes; | |
} | |
/* | |
50% repeated data: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.unrolledIndexes thrpt 10 394397.562 ± 2150.307 ops/s | |
*/ | |
@Benchmark | |
public int[] unrolledIndexes(Context context) { | |
long[] data = context.src; | |
int[] dst = context.destIndexes; | |
for (int i = 0, blockIdx = 0; i < data.length; i++, blockIdx += 16) { | |
long v = data[i]; | |
dst[blockIdx] = (int) (v & 15); | |
dst[blockIdx + 1] = (int) ((v >>> 4 ) & 15); | |
dst[blockIdx + 2] = (int) ((v >>> 8 ) & 15); | |
dst[blockIdx + 3] = (int) ((v >>> 12) & 15); | |
dst[blockIdx + 4] = (int) ((v >>> 16) & 15); | |
dst[blockIdx + 5] = (int) ((v >>> 20) & 15); | |
dst[blockIdx + 6] = (int) ((v >>> 24) & 15); | |
dst[blockIdx + 7] = (int) ((v >>> 28) & 15); | |
dst[blockIdx + 8] = (int) ((v >>> 32) & 15); | |
dst[blockIdx + 9] = (int) ((v >>> 36) & 15); | |
dst[blockIdx + 10] = (int) ((v >>> 40) & 15); | |
dst[blockIdx + 11] = (int) ((v >>> 44) & 15); | |
dst[blockIdx + 12] = (int) ((v >>> 48) & 15); | |
dst[blockIdx + 13] = (int) ((v >>> 52) & 15); | |
dst[blockIdx + 14] = (int) ((v >>> 56) & 15); | |
dst[blockIdx + 15] = (int) ((v >>> 60) & 15); | |
} | |
return context.destIndexes; | |
} | |
// Note: doesn't do mask! | |
/* | |
With 50% repeated data: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.preCalculatedShifts thrpt 10 462235.488 ± 6357.533 ops/s | |
*/ | |
@Benchmark | |
public long[] preCalculatedShifts(Context context) { | |
long[] data = context.src; | |
long[] shifts = context.shifts; | |
long[] dst = context.destIndexesLong; | |
for (int blockIdx = 0; blockIdx < dst.length; blockIdx += 16) { | |
long v = data[blockIdx >>> 4]; | |
dst[blockIdx] = v >>> shifts[blockIdx]; | |
dst[blockIdx + 1] = v >>> shifts[blockIdx + 1]; | |
dst[blockIdx + 2] = v >>> shifts[blockIdx + 2]; | |
dst[blockIdx + 3] = v >>> shifts[blockIdx + 3]; | |
dst[blockIdx + 4] = v >>> shifts[blockIdx + 4]; | |
dst[blockIdx + 5] = v >>> shifts[blockIdx + 5]; | |
dst[blockIdx + 6] = v >>> shifts[blockIdx + 6]; | |
dst[blockIdx + 7] = v >>> shifts[blockIdx + 7]; | |
dst[blockIdx + 8] = v >>> shifts[blockIdx + 8]; | |
dst[blockIdx + 9] = v >>> shifts[blockIdx + 9]; | |
dst[blockIdx + 10] = v >>> shifts[blockIdx + 10]; | |
dst[blockIdx + 11] = v >>> shifts[blockIdx + 11]; | |
dst[blockIdx + 12] = v >>> shifts[blockIdx + 12]; | |
dst[blockIdx + 13] = v >>> shifts[blockIdx + 13]; | |
dst[blockIdx + 14] = v >>> shifts[blockIdx + 14]; | |
dst[blockIdx + 15] = v >>> shifts[blockIdx + 15]; | |
} | |
return dst; | |
} | |
// Note: changes layout of destIndexesLong... also doesn't work! It copies the data over the same location... | |
/* | |
50% repeated data, SIMD enabled: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.constantShiftLoops2D thrpt 10 1862705.797 ± 51227.373 ops/s | |
50% repeated data, SIMD disabled: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.constantShiftLoops2D thrpt 10 513844.888 ± 2668.240 ops/s | |
*/ | |
@Benchmark | |
public long[] constantShiftLoops(Context context) { | |
long[] data = context.src; | |
long[] dst = context.destIndexesLong; | |
for (int i = 0; i < 16; i++) { | |
copyShift(data, dst, i << 2, 15L << (i << 2)); | |
} | |
return dst; | |
} | |
private static void copyShift(long[] data, long[] dst, long shift, long mask) { | |
for (int blockIdx = 0; blockIdx < 256; blockIdx += 1) { | |
// This isn't the right offset, and adding an offset makes the JVM not autovectorise it | |
dst[blockIdx] = (data[blockIdx] & mask) >>> shift; | |
} | |
} | |
// Note: changes layout of destIndexesInt | |
/* | |
This one doesn't get autovectorised :( | |
50% repeated data: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.constantShiftLoopsInt thrpt 10 438808.426 ± 3617.427 ops/s | |
*/ | |
@Benchmark | |
public int[] constantShiftLoopsInt(Context context) { | |
long[] data = context.src; | |
int[] dst = context.destIndexesInt; | |
for (int i = 0; i < 16; i++) { | |
copyShiftInt(data, dst, i << 2, 15L << (i << 2)); | |
} | |
return dst; | |
} | |
private static void copyShiftInt(long[] data, int[] dst, long shift, long mask) { | |
for (int blockIdx = 0; blockIdx < 256; blockIdx += 1) { | |
dst[blockIdx] = (int) ((data[blockIdx] & mask) >>> shift); | |
} | |
} | |
private static final Unsafe UNSAFE; | |
static { | |
try { | |
Field field = Unsafe.class.getDeclaredField("theUnsafe"); | |
field.setAccessible(true); | |
UNSAFE = (Unsafe) field.get(null); | |
} catch (IllegalAccessException | NoSuchFieldException e) { | |
throw new RuntimeException(e); | |
} | |
} | |
/* | |
Doesn't get autovectorised! | |
*/ | |
// @Benchmark | |
// public long[] constantShiftLoopsUnsafe(Context context) { | |
// long[] data = context.src; | |
// long[] dst = context.destIndexesLong; | |
// | |
// long addr = context.startAddr; | |
// | |
// for (int i = 0; i < 16; i++) { | |
// copyShiftUnsafe(i << 2, 15L << (i << 2), addr + (i * 256 * 8)); | |
// } | |
// | |
// return dst; | |
// } | |
// | |
// private static void copyShiftUnsafe(long shift, long mask, long addr) { | |
// for (int blockIdx = 0; blockIdx < 256; blockIdx += 1) { | |
// //UNSAFE.putLong(addr + (blockIdx * 8), data[blockIdx] & mask >>> shift); | |
// UNSAFE.putLong(addr + (blockIdx * 8), (UNSAFE.getLong(addr + (blockIdx * 8)) & mask) >>> shift); | |
// } | |
// } | |
final static VarHandle VH_arr_view = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.nativeOrder()); | |
/* | |
Doesn't get autovectorised! | |
*/ | |
@Benchmark | |
public byte[] constantShiftLoopsVarHandle(Context context) { | |
long[] data = context.src; | |
byte[] dst = context.destIndexesByte; | |
for (int i = 0; i < 16; i++) { | |
copyShiftVarHandle(data, dst, i << 2, 15L << (i << 2)); | |
} | |
return dst; | |
} | |
private static void copyShiftVarHandle(long[] data, byte[] dst, long shift, long mask) { | |
for (int blockIdx = 0; blockIdx < 256; blockIdx += 1) { | |
VH_arr_view.set(dst, blockIdx, (data[blockIdx] & mask) >>> shift); | |
} | |
} | |
/* | |
Doesn't get autovectorised! | |
*/ | |
@Benchmark | |
public LongBuffer constantShiftLoopsBuf(Context context) { | |
long[] data = context.src; | |
LongBuffer dst = context.destIndexesLongOnHeap; | |
for (int i = 0; i < 16; i++) { | |
copyShiftBuf(data, i << 2, 15L << (i << 2), dst); | |
} | |
return dst; | |
} | |
private static void copyShiftBuf(long[] data, long shift, long mask, LongBuffer dst) { | |
for (int blockIdx = 0; blockIdx < 256; blockIdx += 1) { | |
dst.put(blockIdx, (data[blockIdx] & mask) >>> shift); | |
} | |
} | |
/* | |
50% repeated data, SIMD enabled: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.constantShiftLoops2D thrpt 10 1867600.924 ± 51025.923 ops/s | |
50% repeated data, SIMD disabled: | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.constantShiftLoops2D thrpt 10 513538.086 ± 2510.975 ops/s | |
*/ | |
@Benchmark | |
public long[][] constantShiftLoops2D(Context context) { | |
long[] data = context.src; | |
long[][] dst = context.destIndexesLong2D; | |
for (int i = 0; i < 16; i++) { | |
copyShift2D(data, dst[i], i << 2, 15L << (i << 2)); | |
} | |
return dst; | |
} | |
private static void copyShift2D(long[] data, long[] dst, long shift, long mask) { | |
for (int blockIdx = 0; blockIdx < 256; blockIdx += 1) { | |
dst[blockIdx] = (data[blockIdx] & mask) >>> shift; | |
} | |
} | |
/* | |
50% repeated data, SIMD enabled, [0, 256): | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.constantShiftLoops2DSection thrpt 10 1790481.403 ± 9363.477 ops/s | |
50% repeated data, SIMD enabled, [0, 128) then [128, 256): | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.constantShiftLoops2DSection thrpt 10 1687734.486 ± 24420.654 ops/s | |
50% repeated data, SIMD enabled, [0, 128): | |
Benchmark Mode Cnt Score Error Units | |
SectionCopyBenchmark.constantShiftLoops2DSection thrpt 10 3517598.227 ± 16300.049 ops/s | |
*/ | |
@Benchmark | |
public long[][] constantShiftLoops2DSection(Context context) { | |
long[] data = context.src; | |
long[][] dst = context.destIndexesLong2D; | |
for (int i = 0; i < 16; i++) { | |
copyShift2D(data, dst[i], i << 2, 15L << (i << 2), 0, 128); | |
} | |
return dst; | |
} | |
private static void copyShift2D(long[] data, long[] dst, long shift, long mask, int start, int end) { | |
for (int blockIdx = start; blockIdx < end; blockIdx += 1) { | |
dst[blockIdx] = (data[blockIdx] & mask) >>> shift; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment