Created
November 23, 2020 20:31
-
-
Save luhenry/52c0aec9e2ea04d4e3f08040c3da10d4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright (c) 2019, 2020, Oracle and/or its affiliates. All rights reserved. | |
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. | |
* | |
* This code is free software; you can redistribute it and/or modify it | |
* under the terms of the GNU General Public License version 2 only, as | |
* published by the Free Software Foundation. | |
* | |
* This code is distributed in the hope that it will be useful, but WITHOUT | |
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or | |
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License | |
* version 2 for more details (a copy is included in the LICENSE file that | |
* accompanied this code). | |
* | |
* You should have received a copy of the GNU General Public License version | |
* 2 along with this work; if not, write to the Free Software Foundation, | |
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. | |
* | |
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA | |
* or visit www.oracle.com if you need additional information or have any | |
* questions. | |
*/ | |
package benchmark.utf8; | |
import java.util.HashMap; | |
import java.nio.ByteBuffer; | |
import java.nio.CharBuffer; | |
import java.nio.charset.CoderResult; | |
import org.openjdk.jmh.annotations.*; | |
import org.openjdk.jmh.infra.Blackhole; | |
import jdk.incubator.vector.*; | |
@State(Scope.Thread) | |
@BenchmarkMode(Mode.Throughput) | |
@Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) | |
@Warmup(iterations = 3, time = 3) | |
@Measurement(iterations = 8, time = 2) | |
public class DecodeBench { | |
@Param({"16384", "65536"}) | |
private int dataSize; | |
private ByteBuffer src; | |
private CharBuffer dst; | |
private static final VectorSpecies<Byte> B128 = ByteVector.SPECIES_128; | |
private static final VectorSpecies<Short> S128 = ShortVector.SPECIES_128; | |
private static final HashMap<Long, DecoderLutEntry> lutTable = new HashMap<Long, DecoderLutEntry>(); | |
private static class DecoderLutEntry { | |
public final byte[] shufAB; // shuffling mask to get lower two bytes of symbols | |
public final byte[] shufC; // shuffling mask to get third bytes of symbols | |
public final byte srcStep; // number of bytes processed in input buffer | |
public final byte dstStep; // number of symbols produced in output buffer (doubled) | |
public final byte[] headerMask; // mask of "111..10" bits required in each byte | |
public final short[] zeroBits; | |
public DecoderLutEntry(byte[] _shufAB, byte[] _shufC, | |
byte _srcStep, byte _dstStep, | |
byte[] _headerMask, short[] _zeroBits) { | |
shufAB = _shufAB; | |
shufC = _shufC; | |
srcStep = _srcStep; | |
dstStep = _dstStep; | |
headerMask = _headerMask; | |
zeroBits = _zeroBits; | |
} | |
// @Override | |
// public String toString() { | |
// return String.format("shufAB = %s, shufC = %s, srcStep = %d, dstStep = %d, headerMask = %s, zeroBits = %s", | |
// arrayToString(shufAB), arrayToString(shufC), srcStep, dstStep, arrayToString(headerMask), arrayToString(zeroBits)); | |
// } | |
} | |
@Setup(Level.Trial) | |
public void setupVectorLut() { | |
int[] sizes = new int[32]; | |
computeLutRecursive(sizes, 0, 0); //10609 entries total | |
// for (var entry : lutTable.entrySet()) { | |
// System.out.println("" + entry.getKey() + " -> " + entry.getValue()); | |
// } | |
} | |
static void computeLutRecursive(int[] sizes, int num, int total) { | |
if (total >= 16) { | |
computeLutEntry(sizes, num); | |
return; | |
} | |
for (int size = 1; size <= 3; size++) { | |
sizes[num] = size; | |
computeLutRecursive(sizes, num + 1, total + size); | |
} | |
} | |
static void computeLutEntry(int[] sizes, int num) { | |
//find maximal number of chars to decode | |
int cnt = num - 1; | |
int preSum = 0; | |
for (int i = 0; i < cnt; i++) | |
preSum += sizes[i]; | |
assert preSum < 16; | |
// Note: generally, we can process a char only if the next byte is within XMM register | |
// However, if the last char takes 3 bytes and fits the register tightly, we can take it too | |
if (preSum == 13 && preSum + sizes[cnt] == 16) | |
preSum += sizes[cnt++]; | |
//still cannot process more that 8 chars per register | |
while (cnt > 8) | |
preSum -= sizes[--cnt]; | |
//generate bitmask | |
long mask = 0; | |
for (int i = 0, pos = 0; i < num; i++) { | |
for (int j = 0; j < sizes[i]; j++, pos++) { | |
// The first byte is not represented in the mask | |
if (j > 0) { | |
mask |= 1 << pos; | |
} | |
} | |
} | |
assert mask <= 0xFFFF; | |
//generate shuffle masks | |
byte[] shufAB = new byte[16]; | |
byte[] shufC = new byte[16]; | |
for (int i = 0; i < 16; i++) | |
shufAB[i] = shufC[i] = (byte)0xFF; | |
for (int i = 0, pos = 0; i < cnt; i++) { | |
int sz = sizes[i]; | |
for (int j = sz-1; j >= 0; j--, pos++) { | |
if (j < 2) | |
shufAB[2 * i + j] = (byte)pos; | |
else | |
shufC[2 * i] = (byte)pos; | |
} | |
} | |
//generate header masks for validation | |
byte[] headerMask = new byte[16]; | |
for (int i = 0, pos = 0; i < cnt; i++) { | |
int sz = sizes[i]; | |
for (int j = 0; j < sz; j++, pos++) { | |
int bits; | |
if (j > 0) bits = 2; | |
else if (sz == 1) bits = 1; | |
else if (sz == 2) bits = 3; | |
else /*sz == 3*/ bits = 4; | |
headerMask[pos] = (byte)-(1 << (8 - bits)); | |
} | |
} | |
//generate min symbols values for validation | |
short[] zeroBits = new short[8]; | |
for (int i = 0; i < 8; i++) { | |
int sz = i < cnt ? sizes[i] : 1; | |
if (sz == 1) zeroBits[i] = (short)(0xFF80); | |
else if (sz == 2) zeroBits[i] = (short)(0xF800); | |
else /*sz == 3*/ zeroBits[i] = (short)(0x0000); | |
} | |
//store info into the lookup table | |
lutTable.put(mask, new DecoderLutEntry(shufAB, shufC, (byte)preSum, (byte)cnt, headerMask, zeroBits)); | |
} | |
@Setup | |
public void setup() { | |
src = randomBytesForString(dataSize); | |
dst = CharBuffer.allocate(dataSize); | |
} | |
private static final Random RANDOM = new Random(0); | |
private static ByteBuffer randomBytesForString(int dataSize) { | |
ByteBuffer out = ByteBuffer.allocate(dataSize); | |
for (int i = 0, size = (RANDOM.nextInt() % 4) + 1; i + size <= dataSize; i += size, size = (RANDOM.nextInt() % 4) + 1) { | |
switch (size) { | |
case 1: | |
out.put((byte)((0b0 << 7) | (RANDOM.nextInt() & 0b01111111))); | |
break; | |
case 2: | |
out.put((byte)((0b110 << 5) | (RANDOM.nextInt() & 0b00011111))); | |
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111))); | |
break; | |
case 3: | |
out.put((byte)((0b1110 << 4) | (RANDOM.nextInt() & 0b00001111))); | |
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111))); | |
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111))); | |
break; | |
case 4: | |
out.put((byte)((0b11110 << 3) | (RANDOM.nextInt() & 0b00000111))); | |
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111))); | |
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111))); | |
out.put((byte)((0b10 << 6) | (RANDOM.nextInt() & 0b00111111))); | |
break; | |
} | |
} | |
return out; | |
} | |
@Benchmark | |
public void decode(Blackhole bh) { | |
CoderResult cr = decodeArrayLoop(src, dst); | |
bh.consume(cr); | |
bh.consume(dst); | |
} | |
@Benchmark | |
public void decodeVector(Blackhole bh) { | |
decodeArrayVectorized(src, dst); | |
CoderResult cr = decodeArrayLoop(src, dst); | |
bh.consume(cr); | |
bh.consume(dst); | |
} | |
private static void decodeArrayVectorized(ByteBuffer src, CharBuffer dst) { | |
byte[] sa = src.array(); | |
int sp = src.arrayOffset() + src.position(); | |
int sl = src.arrayOffset() + src.limit(); | |
char[] da = dst.array(); | |
int dp = dst.arrayOffset() + dst.position(); | |
int dl = dst.arrayOffset() + dst.limit(); | |
// Vectorized loop | |
while (sp + B128.length() < sl && dp + S128.length() < dl) { | |
var bytes = ByteVector.fromArray(B128, sa, sp); // System.out.println("bytes = " + arrayToString((byte[])bytes.toArray())); | |
/* Decode */ | |
var continuationByteMask = bytes.lanewise(VectorOperators.AND, (byte)0xC0).compare(VectorOperators.EQ, (byte)0x80); // System.out.println("continuationByteMask = " + arrayToString(continuationByteMask.toArray())); | |
final DecoderLutEntry lookup = lutTable.get(continuationByteMask.toLong()); // System.out.println("" + continuationByteMask.toLong() + " -> " + lookup); | |
if (lookup == null) { // System.out.println("back off (1)"); | |
break; | |
} | |
// Shuffle the 1st and 2nd bytes | |
var Rab = bytes.rearrange(ByteVector.fromArray(B128, lookup.shufAB, 0).toShuffle(), ByteVector.fromArray(B128, lookup.shufAB, 0).compare(VectorOperators.NE, -1)).reinterpretShape(S128, 0); // System.out.println("Rab = " + arrayToString((short[])Rab.toArray())); | |
// Shuffle the 3rd byte | |
var Rc = bytes.rearrange(ByteVector.fromArray(B128, lookup.shufC, 0).toShuffle(), ByteVector.fromArray(B128, lookup.shufC, 0).compare(VectorOperators.NE, -1)).reinterpretShape(S128, 0); // System.out.println("Rc = " + arrayToString((short[])Rc.toArray())); | |
// Extract the bits from each byte | |
var sum = Rab.lanewise(VectorOperators.AND, (short)0x007F) | |
.add(Rab.lanewise(VectorOperators.AND, (short)0x3F00).lanewise(VectorOperators.LSHR, 2)) | |
.add(Rc.lanewise(VectorOperators.LSHL, 12)); // System.out.println("sum = " + arrayToString((short[])sum.toArray())); | |
/* Validate */ | |
if (sum.lanewise(VectorOperators.AND, ShortVector.fromArray(S128, lookup.zeroBits, 0)).compare(VectorOperators.NE, 0).anyTrue()) { // System.out.println("back off (2)"); | |
break; | |
} | |
// Check for surrogate code point | |
if (sum.lanewise(VectorOperators.SUB, (short)0x6000).compare(VectorOperators.GT, 0x77FF).anyTrue()) { // System.out.println("back off (3)"); | |
break; | |
} | |
var headerMask = ByteVector.fromArray(B128, lookup.headerMask, 0); | |
if (bytes.lanewise(VectorOperators.AND, headerMask).compare(VectorOperators.NE, headerMask.lanewise(VectorOperators.LSHL, 1)).anyTrue()) { // System.out.println("back off (4)"); | |
break; | |
} | |
/* Advance */ | |
((ShortVector)sum).intoCharArray(da, dp); | |
sp += lookup.srcStep; | |
dp += lookup.dstStep; | |
} | |
updatePositions(src, sp, dst, dp); | |
} | |
private static CoderResult decodeArrayLoop(ByteBuffer src, CharBuffer dst) { | |
// This method is optimized for ASCII input. | |
byte[] sa = src.array(); | |
int sp = src.arrayOffset() + src.position(); | |
int sl = src.arrayOffset() + src.limit(); | |
char[] da = dst.array(); | |
int dp = dst.arrayOffset() + dst.position(); | |
int dl = dst.arrayOffset() + dst.limit(); | |
int dlASCII = dp + Math.min(sl - sp, dl - dp); | |
// ASCII only loop | |
while (dp < dlASCII && sa[sp] >= 0) | |
da[dp++] = (char) sa[sp++]; | |
while (sp < sl) { | |
int b1 = sa[sp]; | |
if (b1 >= 0) { | |
// 1 byte, 7 bits: 0xxxxxxx | |
if (dp >= dl) | |
return xflow(src, sp, sl, dst, dp, 1); | |
da[dp++] = (char) b1; | |
sp++; | |
} else if ((b1 >> 5) == -2 && (b1 & 0x1e) != 0) { | |
// 2 bytes, 11 bits: 110xxxxx 10xxxxxx | |
// [C2..DF] [80..BF] | |
if (sl - sp < 2 || dp >= dl) | |
return xflow(src, sp, sl, dst, dp, 2); | |
int b2 = sa[sp + 1]; | |
// Now we check the first byte of 2-byte sequence as | |
// if ((b1 >> 5) == -2 && (b1 & 0x1e) != 0) | |
// no longer need to check b1 against c1 & c0 for | |
// malformed as we did in previous version | |
// (b1 & 0x1e) == 0x0 || (b2 & 0xc0) != 0x80; | |
// only need to check the second byte b2. | |
if (isNotContinuation(b2)) | |
return malformedForLength(src, sp, dst, dp, 1); | |
da[dp++] = (char) (((b1 << 6) ^ b2) | |
^ | |
(((byte) 0xC0 << 6) ^ | |
((byte) 0x80 << 0))); | |
sp += 2; | |
} else if ((b1 >> 4) == -2) { | |
// 3 bytes, 16 bits: 1110xxxx 10xxxxxx 10xxxxxx | |
int srcRemaining = sl - sp; | |
if (srcRemaining < 3 || dp >= dl) { | |
if (srcRemaining > 1 && isMalformed3_2(b1, sa[sp + 1])) | |
return malformedForLength(src, sp, dst, dp, 1); | |
return xflow(src, sp, sl, dst, dp, 3); | |
} | |
int b2 = sa[sp + 1]; | |
int b3 = sa[sp + 2]; | |
if (isMalformed3(b1, b2, b3)) | |
return malformed(src, sp, dst, dp, 3); | |
char c = (char) | |
((b1 << 12) ^ | |
(b2 << 6) ^ | |
(b3 ^ | |
(((byte) 0xE0 << 12) ^ | |
((byte) 0x80 << 6) ^ | |
((byte) 0x80 << 0)))); | |
if (Character.isSurrogate(c)) | |
return malformedForLength(src, sp, dst, dp, 3); | |
da[dp++] = c; | |
sp += 3; | |
} else if ((b1 >> 3) == -2) { | |
// 4 bytes, 21 bits: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx | |
int srcRemaining = sl - sp; | |
if (srcRemaining < 4 || dl - dp < 2) { | |
b1 &= 0xff; | |
if (b1 > 0xf4 || | |
srcRemaining > 1 && isMalformed4_2(b1, sa[sp + 1] & 0xff)) | |
return malformedForLength(src, sp, dst, dp, 1); | |
if (srcRemaining > 2 && isMalformed4_3(sa[sp + 2])) | |
return malformedForLength(src, sp, dst, dp, 2); | |
return xflow(src, sp, sl, dst, dp, 4); | |
} | |
int b2 = sa[sp + 1]; | |
int b3 = sa[sp + 2]; | |
int b4 = sa[sp + 3]; | |
int uc = ((b1 << 18) ^ | |
(b2 << 12) ^ | |
(b3 << 6) ^ | |
(b4 ^ | |
(((byte) 0xF0 << 18) ^ | |
((byte) 0x80 << 12) ^ | |
((byte) 0x80 << 6) ^ | |
((byte) 0x80 << 0)))); | |
if (isMalformed4(b2, b3, b4) || | |
// shortest form check | |
!Character.isSupplementaryCodePoint(uc)) { | |
return malformed(src, sp, dst, dp, 4); | |
} | |
da[dp++] = Character.highSurrogate(uc); | |
da[dp++] = Character.lowSurrogate(uc); | |
sp += 4; | |
} else | |
return malformed(src, sp, dst, dp, 1); | |
} | |
return xflow(src, sp, sl, dst, dp, 0); | |
} | |
private static CoderResult xflow(Buffer src, int sp, int sl, | |
Buffer dst, int dp, int nb) { | |
updatePositions(src, sp, dst, dp); | |
return (nb == 0 || sl - sp < nb) | |
? CoderResult.UNDERFLOW : CoderResult.OVERFLOW; | |
} | |
private static CoderResult malformedForLength(ByteBuffer src, | |
int sp, | |
CharBuffer dst, | |
int dp, | |
int malformedNB) | |
{ | |
updatePositions(src, sp, dst, dp); | |
return CoderResult.malformedForLength(malformedNB); | |
} | |
private static CoderResult malformed(ByteBuffer src, int sp, | |
CharBuffer dst, int dp, | |
int nb) | |
{ | |
src.position(sp - src.arrayOffset()); | |
CoderResult cr = malformedN(src, nb); | |
updatePositions(src, sp, dst, dp); | |
return cr; | |
} | |
private static CoderResult malformedN(ByteBuffer src, int nb) { | |
switch (nb) { | |
case 1: | |
case 2: // always 1 | |
return CoderResult.malformedForLength(1); | |
case 3: | |
int b1 = src.get(); | |
int b2 = src.get(); // no need to lookup b3 | |
return CoderResult.malformedForLength( | |
((b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) || | |
isNotContinuation(b2)) ? 1 : 2); | |
case 4: // we don't care the speed here | |
b1 = src.get() & 0xff; | |
b2 = src.get() & 0xff; | |
if (b1 > 0xf4 || | |
(b1 == 0xf0 && (b2 < 0x90 || b2 > 0xbf)) || | |
(b1 == 0xf4 && (b2 & 0xf0) != 0x80) || | |
isNotContinuation(b2)) | |
return CoderResult.malformedForLength(1); | |
if (isNotContinuation(src.get())) | |
return CoderResult.malformedForLength(2); | |
return CoderResult.malformedForLength(3); | |
default: | |
assert false; | |
return null; | |
} | |
} | |
private static boolean isNotContinuation(int b) { | |
return (b & 0xc0) != 0x80; | |
} | |
// [E0] [A0..BF] [80..BF] | |
// [E1..EF] [80..BF] [80..BF] | |
private static boolean isMalformed3(int b1, int b2, int b3) { | |
return (b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) || | |
(b2 & 0xc0) != 0x80 || (b3 & 0xc0) != 0x80; | |
} | |
// only used when there is only one byte left in src buffer | |
private static boolean isMalformed3_2(int b1, int b2) { | |
return (b1 == (byte)0xe0 && (b2 & 0xe0) == 0x80) || | |
(b2 & 0xc0) != 0x80; | |
} | |
// [F0] [90..BF] [80..BF] [80..BF] | |
// [F1..F3] [80..BF] [80..BF] [80..BF] | |
// [F4] [80..8F] [80..BF] [80..BF] | |
// only check 80-be range here, the [0xf0,0x80...] and [0xf4,0x90-...] | |
// will be checked by Character.isSupplementaryCodePoint(uc) | |
private static boolean isMalformed4(int b2, int b3, int b4) { | |
return (b2 & 0xc0) != 0x80 || (b3 & 0xc0) != 0x80 || | |
(b4 & 0xc0) != 0x80; | |
} | |
// only used when there is less than 4 bytes left in src buffer. | |
// both b1 and b2 should be "& 0xff" before passed in. | |
private static boolean isMalformed4_2(int b1, int b2) { | |
return (b1 == 0xf0 && (b2 < 0x90 || b2 > 0xbf)) || | |
(b1 == 0xf4 && (b2 & 0xf0) != 0x80) || | |
(b2 & 0xc0) != 0x80; | |
} | |
// tests if b1 and b2 are malformed as the first 2 bytes of a | |
// legal`4-byte utf-8 byte sequence. | |
// only used when there is less than 4 bytes left in src buffer, | |
// after isMalformed4_2 has been invoked. | |
private static boolean isMalformed4_3(int b3) { | |
return (b3 & 0xc0) != 0x80; | |
} | |
private static void updatePositions(Buffer src, int sp, | |
Buffer dst, int dp) { | |
src.position(sp - src.arrayOffset()); | |
dst.position(dp - dst.arrayOffset()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment