Skip to content

Instantly share code, notes, and snippets.

@luhenry
Created November 23, 2020 20:31
Show Gist options
  • Save luhenry/52c0aec9e2ea04d4e3f08040c3da10d4 to your computer and use it in GitHub Desktop.
Save luhenry/52c0aec9e2ea04d4e3f08040c3da10d4 to your computer and use it in GitHub Desktop.
/*
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package benchmark.utf8;
import java.util.HashMap;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CoderResult;
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.infra.Blackhole;
import jdk.incubator.vector.*;
@State(Scope.Thread)
@BenchmarkMode(Mode.Throughput)
@Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
@Warmup(iterations = 3, time = 3)
@Measurement(iterations = 8, time = 2)
public class DecodeBench {
@Param({"16384", "65536"})
private int dataSize;
private ByteBuffer src;
private CharBuffer dst;
private static final VectorSpecies<Byte> B128 = ByteVector.SPECIES_128;
private static final VectorSpecies<Short> S128 = ShortVector.SPECIES_128;
private static final HashMap<Long, DecoderLutEntry> lutTable = new HashMap<Long, DecoderLutEntry>();
private static class DecoderLutEntry {
public final byte[] shufAB; // shuffling mask to get lower two bytes of symbols
public final byte[] shufC; // shuffling mask to get third bytes of symbols
public final byte srcStep; // number of bytes processed in input buffer
public final byte dstStep; // number of symbols produced in output buffer (doubled)
public final byte[] headerMask; // mask of "111..10" bits required in each byte
public final short[] zeroBits;
public DecoderLutEntry(byte[] _shufAB, byte[] _shufC,
byte _srcStep, byte _dstStep,
byte[] _headerMask, short[] _zeroBits) {
shufAB = _shufAB;
shufC = _shufC;
srcStep = _srcStep;
dstStep = _dstStep;
headerMask = _headerMask;
zeroBits = _zeroBits;
}
// @Override
// public String toString() {
// return String.format("shufAB = %s, shufC = %s, srcStep = %d, dstStep = %d, headerMask = %s, zeroBits = %s",
// arrayToString(shufAB), arrayToString(shufC), srcStep, dstStep, arrayToString(headerMask), arrayToString(zeroBits));
// }
}
@Setup(Level.Trial)
public void setupVectorLut() {
int[] sizes = new int[32];
computeLutRecursive(sizes, 0, 0); //10609 entries total
// for (var entry : lutTable.entrySet()) {
// System.out.println("" + entry.getKey() + " -> " + entry.getValue());
// }
}
static void computeLutRecursive(int[] sizes, int num, int total) {
if (total >= 16) {
computeLutEntry(sizes, num);
return;
}
for (int size = 1; size <= 3; size++) {
sizes[num] = size;
computeLutRecursive(sizes, num + 1, total + size);
}
}
static void computeLutEntry(int[] sizes, int num) {
//find maximal number of chars to decode
int cnt = num - 1;
int preSum = 0;
for (int i = 0; i < cnt; i++)
preSum += sizes[i];
assert preSum < 16;
// Note: generally, we can process a char only if the next byte is within XMM register
// However, if the last char takes 3 bytes and fits the register tightly, we can take it too
if (preSum == 13 && preSum + sizes[cnt] == 16)
preSum += sizes[cnt++];
//still cannot process more that 8 chars per register
while (cnt > 8)
preSum -= sizes[--cnt];
//generate bitmask
long mask = 0;
for (int i = 0, pos = 0; i < num; i++) {
for (int j = 0; j < sizes[i]; j++, pos++) {
// The first byte is not represented in the mask
if (j > 0) {
mask |= 1 << pos;
}
}
}
assert mask <= 0xFFFF;
//generate shuffle masks
byte[] shufAB = new byte[16];
byte[] shufC = new byte[16];
for (int i = 0; i < 16; i++)
shufAB[i] = shufC[i] = (byte)0xFF;
for (int i = 0, pos = 0; i < cnt; i++) {
int sz = sizes[i];
for (int j = sz-1; j >= 0; j--, pos++) {
if (j < 2)
shufAB[2 * i + j] = (byte)pos;
else
shufC[2 * i] = (byte)pos;
}
}
//generate header masks for validation
byte[] headerMask = new byte[16];
for (int i = 0, pos = 0; i < cnt; i++) {
int sz = sizes[i];
for (int j = 0; j < sz; j++, pos++) {
int bits;
if (j > 0) bits = 2;
else if (sz == 1) bits = 1;
else if (sz == 2) bits = 3;
else /*sz == 3*/ bits = 4;
headerMask[pos] = (byte)-(1 << (8 - bits));
}
}
//generate min symbols values for validation
short[] zeroBits = new short[8];
for (int i = 0; i < 8; i++) {
int sz = i < cnt ? sizes[i] : 1;
if (sz == 1) zeroBits[i] = (short)(0xFF80);
else if (sz == 2) zeroBits[i] = (short)(0xF800);
else /*sz == 3*/ zeroBits[i] = (short)(0x0000);
}
//store info into the lookup table
lutTable.put(mask, new DecoderLutEntry(shufAB, shufC, (byte)preSum, (byte)cnt, headerMask, zeroBits));
}
@Setup
public void setup() {
src = randomBytesForString(dataSize);
dst = CharBuffer.allocate(dataSize);
}
private static final Random RANDOM = new Random(0);
private static ByteBuffer randomBytesForString(int dataSize) {
ByteBuffer out = ByteBuffer.allocate(dataSize);
for (int i = 0, size = (RANDOM.nextInt() % 4) + 1; i + size <= dataSize; i += size, size = (RANDOM.nextInt() % 4) + 1) {
switch (size) {
case 1:
out.put((byte)((0b0 << 7) | (RANDOM.nextInt() & 0b01111111)));
break;
case 2:
out.put((byte)((0b110 << 5) | (RANDOM.nextInt() & 0b00011111)));
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111)));
break;
case 3:
out.put((byte)((0b1110 << 4) | (RANDOM.nextInt() & 0b00001111)));
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111)));
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111)));
break;
case 4:
out.put((byte)((0b11110 << 3) | (RANDOM.nextInt() & 0b00000111)));
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111)));
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111)));
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111)));
break;
}
}
return out;
}
@Benchmark
public void decode(Blackhole bh) {
CoderResult cr = decodeArrayLoop(src, dst);
bh.consume(cr);
bh.consume(dst);
}
@Benchmark
public void decodeVector(Blackhole bh) {
decodeArrayVectorized(src, dst);
CoderResult cr = decodeArrayLoop(src, dst);
bh.consume(cr);
bh.consume(dst);
}
private static void decodeArrayVectorized(ByteBuffer src, CharBuffer dst) {
byte[] sa = src.array();
int sp = src.arrayOffset() + src.position();
int sl = src.arrayOffset() + src.limit();
char[] da = dst.array();
int dp = dst.arrayOffset() + dst.position();
int dl = dst.arrayOffset() + dst.limit();
// Vectorized loop
while (sp + B128.length() < sl && dp + S128.length() < dl) {
var bytes = ByteVector.fromArray(B128, sa, sp); // System.out.println("bytes = " + arrayToString((byte[])bytes.toArray()));
/* Decode */
var continuationByteMask = bytes.lanewise(VectorOperators.AND, (byte)0xC0).compare(VectorOperators.EQ, (byte)0x80); // System.out.println("continuationByteMask = " + arrayToString(continuationByteMask.toArray()));
final DecoderLutEntry lookup = lutTable.get(continuationByteMask.toLong()); // System.out.println("" + continuationByteMask.toLong() + " -> " + lookup);
if (lookup == null) { // System.out.println("back off (1)");
break;
}
// Shuffle the 1st and 2nd bytes
var Rab = bytes.rearrange(ByteVector.fromArray(B128, lookup.shufAB, 0).toShuffle(), ByteVector.fromArray(B128, lookup.shufAB, 0).compare(VectorOperators.NE, -1)).reinterpretShape(S128, 0); // System.out.println("Rab = " + arrayToString((short[])Rab.toArray()));
// Shuffle the 3rd byte
var Rc = bytes.rearrange(ByteVector.fromArray(B128, lookup.shufC, 0).toShuffle(), ByteVector.fromArray(B128, lookup.shufC, 0).compare(VectorOperators.NE, -1)).reinterpretShape(S128, 0); // System.out.println("Rc = " + arrayToString((short[])Rc.toArray()));
// Extract the bits from each byte
var sum = Rab.lanewise(VectorOperators.AND, (short)0x007F)
.add(Rab.lanewise(VectorOperators.AND, (short)0x3F00).lanewise(VectorOperators.LSHR, 2))
.add(Rc.lanewise(VectorOperators.LSHL, 12)); // System.out.println("sum = " + arrayToString((short[])sum.toArray()));
/* Validate */
if (sum.lanewise(VectorOperators.AND, ShortVector.fromArray(S128, lookup.zeroBits, 0)).compare(VectorOperators.NE, 0).anyTrue()) { // System.out.println("back off (2)");
break;
}
// Check for surrogate code point
if (sum.lanewise(VectorOperators.SUB, (short)0x6000).compare(VectorOperators.GT, 0x77FF).anyTrue()) { // System.out.println("back off (3)");
break;
}
var headerMask = ByteVector.fromArray(B128, lookup.headerMask, 0);
if (bytes.lanewise(VectorOperators.AND, headerMask).compare(VectorOperators.NE, headerMask.lanewise(VectorOperators.LSHL, 1)).anyTrue()) { // System.out.println("back off (4)");
break;
}
/* Advance */
((ShortVector)sum).intoCharArray(da, dp);
sp += lookup.srcStep;
dp += lookup.dstStep;
}
updatePositions(src, sp, dst, dp);
}
private static CoderResult decodeArrayLoop(ByteBuffer src, CharBuffer dst) {
// This method is optimized for ASCII input.
byte[] sa = src.array();
int sp = src.arrayOffset() + src.position();
int sl = src.arrayOffset() + src.limit();
char[] da = dst.array();
int dp = dst.arrayOffset() + dst.position();
int dl = dst.arrayOffset() + dst.limit();
int dlASCII = dp + Math.min(sl - sp, dl - dp);
// ASCII only loop
while (dp < dlASCII && sa[sp] >= 0)
da[dp++] = (char) sa[sp++];
while (sp < sl) {
int b1 = sa[sp];
if (b1 >= 0) {
// 1 byte, 7 bits: 0xxxxxxx
if (dp >= dl)
return xflow(src, sp, sl, dst, dp, 1);
da[dp++] = (char) b1;
sp++;
} else if ((b1 >> 5) == -2 && (b1 & 0x1e) != 0) {
// 2 bytes, 11 bits: 110xxxxx 10xxxxxx
// [C2..DF] [80..BF]
if (sl - sp < 2 || dp >= dl)
return xflow(src, sp, sl, dst, dp, 2);
int b2 = sa[sp + 1];
// Now we check the first byte of 2-byte sequence as
// if ((b1 >> 5) == -2 && (b1 & 0x1e) != 0)
// no longer need to check b1 against c1 & c0 for
// malformed as we did in previous version
// (b1 & 0x1e) == 0x0 || (b2 & 0xc0) != 0x80;
// only need to check the second byte b2.
if (isNotContinuation(b2))
return malformedForLength(src, sp, dst, dp, 1);
da[dp++] = (char) (((b1 << 6) ^ b2)
^
(((byte) 0xC0 << 6) ^
((byte) 0x80 << 0)));
sp += 2;
} else if ((b1 >> 4) == -2) {
// 3 bytes, 16 bits: 1110xxxx 10xxxxxx 10xxxxxx
int srcRemaining = sl - sp;
if (srcRemaining < 3 || dp >= dl) {
if (srcRemaining > 1 && isMalformed3_2(b1, sa[sp + 1]))
return malformedForLength(src, sp, dst, dp, 1);
return xflow(src, sp, sl, dst, dp, 3);
}
int b2 = sa[sp + 1];
int b3 = sa[sp + 2];
if (isMalformed3(b1, b2, b3))
return malformed(src, sp, dst, dp, 3);
char c = (char)
((b1 << 12) ^
(b2 << 6) ^
(b3 ^
(((byte) 0xE0 << 12) ^
((byte) 0x80 << 6) ^
((byte) 0x80 << 0))));
if (Character.isSurrogate(c))
return malformedForLength(src, sp, dst, dp, 3);
da[dp++] = c;
sp += 3;
} else if ((b1 >> 3) == -2) {
// 4 bytes, 21 bits: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
int srcRemaining = sl - sp;
if (srcRemaining < 4 || dl - dp < 2) {
b1 &= 0xff;
if (b1 > 0xf4 ||
srcRemaining > 1 && isMalformed4_2(b1, sa[sp + 1] & 0xff))
return malformedForLength(src, sp, dst, dp, 1);
if (srcRemaining > 2 && isMalformed4_3(sa[sp + 2]))
return malformedForLength(src, sp, dst, dp, 2);
return xflow(src, sp, sl, dst, dp, 4);
}
int b2 = sa[sp + 1];
int b3 = sa[sp + 2];
int b4 = sa[sp + 3];
int uc = ((b1 << 18) ^
(b2 << 12) ^
(b3 << 6) ^
(b4 ^
(((byte) 0xF0 << 18) ^
((byte) 0x80 << 12) ^
((byte) 0x80 << 6) ^
((byte) 0x80 << 0))));
if (isMalformed4(b2, b3, b4) ||
// shortest form check
!Character.isSupplementaryCodePoint(uc)) {
return malformed(src, sp, dst, dp, 4);
}
da[dp++] = Character.highSurrogate(uc);
da[dp++] = Character.lowSurrogate(uc);
sp += 4;
} else
return malformed(src, sp, dst, dp, 1);
}
return xflow(src, sp, sl, dst, dp, 0);
}
private static CoderResult xflow(Buffer src, int sp, int sl,
Buffer dst, int dp, int nb) {
updatePositions(src, sp, dst, dp);
return (nb == 0 || sl - sp < nb)
? CoderResult.UNDERFLOW : CoderResult.OVERFLOW;
}
private static CoderResult malformedForLength(ByteBuffer src,
int sp,
CharBuffer dst,
int dp,
int malformedNB)
{
updatePositions(src, sp, dst, dp);
return CoderResult.malformedForLength(malformedNB);
}
private static CoderResult malformed(ByteBuffer src, int sp,
CharBuffer dst, int dp,
int nb)
{
src.position(sp - src.arrayOffset());
CoderResult cr = malformedN(src, nb);
updatePositions(src, sp, dst, dp);
return cr;
}
private static CoderResult malformedN(ByteBuffer src, int nb) {
switch (nb) {
case 1:
case 2: // always 1
return CoderResult.malformedForLength(1);
case 3:
int b1 = src.get();
int b2 = src.get(); // no need to lookup b3
return CoderResult.malformedForLength(
((b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) ||
isNotContinuation(b2)) ? 1 : 2);
case 4: // we don't care the speed here
b1 = src.get() & 0xff;
b2 = src.get() & 0xff;
if (b1 > 0xf4 ||
(b1 == 0xf0 && (b2 < 0x90 || b2 > 0xbf)) ||
(b1 == 0xf4 && (b2 & 0xf0) != 0x80) ||
isNotContinuation(b2))
return CoderResult.malformedForLength(1);
if (isNotContinuation(src.get()))
return CoderResult.malformedForLength(2);
return CoderResult.malformedForLength(3);
default:
assert false;
return null;
}
}
private static boolean isNotContinuation(int b) {
return (b & 0xc0) != 0x80;
}
// [E0] [A0..BF] [80..BF]
// [E1..EF] [80..BF] [80..BF]
private static boolean isMalformed3(int b1, int b2, int b3) {
return (b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) ||
(b2 & 0xc0) != 0x80 || (b3 & 0xc0) != 0x80;
}
// only used when there is only one byte left in src buffer
private static boolean isMalformed3_2(int b1, int b2) {
return (b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) ||
(b2 & 0xc0) != 0x80;
}
// [F0] [90..BF] [80..BF] [80..BF]
// [F1..F3] [80..BF] [80..BF] [80..BF]
// [F4] [80..8F] [80..BF] [80..BF]
// only check 80-be range here, the [0xf0,0x80...] and [0xf4,0x90-...]
// will be checked by Character.isSupplementaryCodePoint(uc)
private static boolean isMalformed4(int b2, int b3, int b4) {
return (b2 & 0xc0) != 0x80 || (b3 & 0xc0) != 0x80 ||
(b4 & 0xc0) != 0x80;
}
// only used when there is less than 4 bytes left in src buffer.
// both b1 and b2 should be "& 0xff" before passed in.
private static boolean isMalformed4_2(int b1, int b2) {
return (b1 == 0xf0 && (b2 < 0x90 || b2 > 0xbf)) ||
(b1 == 0xf4 && (b2 & 0xf0) != 0x80) ||
(b2 & 0xc0) != 0x80;
}
// tests if b1 and b2 are malformed as the first 2 bytes of a
// legal`4-byte utf-8 byte sequence.
// only used when there is less than 4 bytes left in src buffer,
// after isMalformed4_2 has been invoked.
private static boolean isMalformed4_3(int b3) {
return (b3 & 0xc0) != 0x80;
}
private static void updatePositions(Buffer src, int sp,
Buffer dst, int dp) {
src.position(sp - src.arrayOffset());
dst.position(dp - dst.arrayOffset());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment