diff --git a/src/main/java/org/simdjson/StructuralIndexer.java b/src/main/java/org/simdjson/StructuralIndexer.java index c0eb4b0..053c927 100644 --- a/src/main/java/org/simdjson/StructuralIndexer.java +++ b/src/main/java/org/simdjson/StructuralIndexer.java @@ -2,7 +2,6 @@ import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.VectorSpecies; -import java.lang.invoke.MethodType; import static jdk.incubator.vector.VectorOperators.UNSIGNED_LE; diff --git a/src/main/java/org/simdjson/Utf8Validator.java b/src/main/java/org/simdjson/Utf8Validator.java index 2838d76..00f2f69 100644 --- a/src/main/java/org/simdjson/Utf8Validator.java +++ b/src/main/java/org/simdjson/Utf8Validator.java @@ -4,12 +4,11 @@ import java.util.Arrays; -public class Utf8Validator { - private static final VectorSpecies VECTOR_SPECIES = ByteVector.SPECIES_256; +class Utf8Validator { + + private static final VectorSpecies VECTOR_SPECIES = StructuralIndexer.SPECIES; private static final ByteVector INCOMPLETE_CHECK = getIncompleteCheck(); - private static final VectorShuffle SHIFT_FOUR_BYTES_FORWARD = VectorShuffle.iota(IntVector.SPECIES_256, - IntVector.SPECIES_256.elementSize() - 1, 1, true); - private static final ByteVector LOW_NIBBLE_MASK = ByteVector.broadcast(VECTOR_SPECIES, 0b0000_1111); + private static final byte LOW_NIBBLE_MASK = 0x0f; private static final ByteVector ALL_ASCII_MASK = ByteVector.broadcast(VECTOR_SPECIES, (byte) 0b1000_0000); /** @@ -19,9 +18,9 @@ public class Utf8Validator { * @throws JsonParsingException if the input is not valid UTF8 */ static void validate(byte[] inputBytes) { - long previousIncomplete = 0; - long errors = 0; - int previousFourUtf8Bytes = 0; + boolean previousIncomplete = false; + boolean errors = false; + ByteVector prevChunk = ByteVector.zero(VECTOR_SPECIES); int idx = 0; for (; idx < VECTOR_SPECIES.loopBound(inputBytes.length); idx += VECTOR_SPECIES.vectorByteSize()) { @@ -32,14 +31,12 @@ static void validate(byte[] inputBytes) { } else { previousIncomplete = isIncomplete(utf8Vector); - var fourBytesPrevious = fourBytesPreviousSlice(utf8Vector, previousFourUtf8Bytes); - - ByteVector firstCheck = firstTwoByteSequenceCheck(utf8Vector.reinterpretAsInts(), fourBytesPrevious); - ByteVector secondCheck = lastTwoByteSequenceCheck(utf8Vector.reinterpretAsInts(), fourBytesPrevious, firstCheck); + ByteVector firstCheck = firstTwoByteSequenceCheck(utf8Vector, prevChunk); + ByteVector secondCheck = lastTwoByteSequenceCheck(utf8Vector, prevChunk, firstCheck); - errors |= secondCheck.compare(VectorOperators.NE, 0).toLong(); + errors |= secondCheck.compare(VectorOperators.NE, 0).anyTrue(); } - previousFourUtf8Bytes = utf8Vector.reinterpretAsInts().lane(IntVector.SPECIES_256.length() - 1); + prevChunk = utf8Vector; } // if the input file doesn't align with the vector width, pad the missing bytes with zero @@ -48,66 +45,45 @@ static void validate(byte[] inputBytes) { if (!isAscii(lastVectorChunk)) { previousIncomplete = isIncomplete(lastVectorChunk); - var fourBytesPrevious = fourBytesPreviousSlice(lastVectorChunk, previousFourUtf8Bytes); + ByteVector firstCheck = firstTwoByteSequenceCheck(lastVectorChunk, prevChunk); + ByteVector secondCheck = lastTwoByteSequenceCheck(lastVectorChunk, prevChunk, firstCheck); - ByteVector firstCheck = firstTwoByteSequenceCheck(lastVectorChunk.reinterpretAsInts(), fourBytesPrevious); - ByteVector secondCheck = lastTwoByteSequenceCheck(lastVectorChunk.reinterpretAsInts(), fourBytesPrevious, firstCheck); - - errors |= secondCheck.compare(VectorOperators.NE, 0).toLong(); + errors |= secondCheck.compare(VectorOperators.NE, 0).anyTrue(); } - if ((errors | previousIncomplete) != 0) { + if (errors | previousIncomplete) { throw new JsonParsingException("Invalid UTF8"); } } - /* Shuffles the input forward by four bytes to make space for the previous four bytes. - The previous three bytes are required for validation, pulling in the last integer will give the previous four bytes. - The switch to integer vectors is to allow for integer shifting instead of the more expensive shuffle / slice operations */ - private static IntVector fourBytesPreviousSlice(ByteVector vectorChunk, int previousFourUtf8Bytes) { - return vectorChunk.reinterpretAsInts() - .rearrange(SHIFT_FOUR_BYTES_FORWARD) - .withLane(0, previousFourUtf8Bytes); - } - - // works similar to previousUtf8Vector.slice(VECTOR_SPECIES.length() - numOfBytesToInclude, utf8Vector) but without the performance cost - private static ByteVector previousVectorSlice(IntVector utf8Vector, IntVector fourBytesPrevious, int numOfPreviousBytes) { - return utf8Vector - .lanewise(VectorOperators.LSHL, Byte.SIZE * numOfPreviousBytes) - .or(fourBytesPrevious.lanewise(VectorOperators.LSHR, Byte.SIZE * (4 - numOfPreviousBytes))) - .reinterpretAsBytes(); - } - - private static ByteVector firstTwoByteSequenceCheck(IntVector utf8Vector, IntVector fourBytesPrevious) { + private static ByteVector firstTwoByteSequenceCheck(ByteVector utf8Vector, ByteVector prevChunk) { // shift the current input forward by 1 byte to include 1 byte from the previous input - var oneBytePrevious = previousVectorSlice(utf8Vector, fourBytesPrevious, 1); + var oneBytePrevious = concatenate(utf8Vector, prevChunk, 1); // high nibbles of the current input (e.g. 0xC3 >> 4 = 0xC) - ByteVector byte2HighNibbles = utf8Vector.lanewise(VectorOperators.LSHR, 4) - .reinterpretAsBytes().and(LOW_NIBBLE_MASK); + ByteVector byte2HighNibbles = utf8Vector.lanewise(VectorOperators.LSHR, 4); // high nibbles of the shifted input - ByteVector byte1HighNibbles = oneBytePrevious.reinterpretAsInts().lanewise(VectorOperators.LSHR, 4) - .reinterpretAsBytes().and(LOW_NIBBLE_MASK); + ByteVector byte1HighNibbles = oneBytePrevious.lanewise(VectorOperators.LSHR, 4); // low nibbles of the shifted input (e.g. 0xC3 & 0xF = 0x3) ByteVector byte1LowNibbles = oneBytePrevious.and(LOW_NIBBLE_MASK); - - ByteVector byte1HighState = byte1HighNibbles.selectFrom(LookupTable.byte1High); - ByteVector byte1LowState = byte1LowNibbles.selectFrom(LookupTable.byte1Low); - ByteVector byte2HighState = byte2HighNibbles.selectFrom(LookupTable.byte2High); - + ByteVector byte1HighState = byte2HighNibbles.selectFrom(LookupTable.byte2High); + ByteVector byte1LowState = byte1HighNibbles.selectFrom(LookupTable.byte1High); + ByteVector byte2HighState = byte1LowNibbles.selectFrom(LookupTable.byte1Low); return byte1HighState.and(byte1LowState).and(byte2HighState); } // All remaining checks are invalid 3–4 byte sequences, which either have too many continuations bytes or not enough - private static ByteVector lastTwoByteSequenceCheck(IntVector utf8Vector, IntVector fourBytesPrevious, ByteVector firstCheck) { + private static ByteVector lastTwoByteSequenceCheck(ByteVector utf8Vector, ByteVector prevChunk, ByteVector firstCheck) { // the minimum 3byte lead - 1110_0000 is always greater than the max 2byte lead - 110_11111 - ByteVector twoBytesPrevious = previousVectorSlice(utf8Vector, fourBytesPrevious, 2); + ByteVector twoBytesPrevious = concatenate(utf8Vector, prevChunk, 2); + VectorMask is3ByteLead = twoBytesPrevious.compare(VectorOperators.UNSIGNED_GT, (byte) 0b110_11111); // the minimum 4byte lead - 1111_0000 is always greater than the max 3byte lead - 1110_1111 - ByteVector threeBytesPrevious = previousVectorSlice(utf8Vector, fourBytesPrevious, 3); + ByteVector threeBytesPrevious = concatenate(utf8Vector, prevChunk, 3); + VectorMask is4ByteLead = threeBytesPrevious.compare(VectorOperators.UNSIGNED_GT, (byte) 0b1110_1111); // the firstCheck vector contains 0x80 values on continuation byte indexes @@ -115,6 +91,10 @@ private static ByteVector lastTwoByteSequenceCheck(IntVector utf8Vector, IntVect return firstCheck.add((byte) 0x80, is3ByteLead.or(is4ByteLead)); } + private static ByteVector concatenate(ByteVector curr, ByteVector prev, int byteCountFromPrev) { + return prev.slice(VECTOR_SPECIES.length() - byteCountFromPrev, curr); + } + /* checks that the previous vector isn't in an incomplete state. Previous vector is in an incomplete state if the last byte is smaller than 0xC0, or the second last byte is smaller than 0xE0, or the third last byte is smaller than 0xF0.*/ @@ -128,8 +108,8 @@ private static ByteVector getIncompleteCheck() { return ByteVector.fromArray(VECTOR_SPECIES, eofArray, 0); } - private static long isIncomplete(ByteVector utf8Vector) { - return utf8Vector.compare(VectorOperators.UNSIGNED_GE, INCOMPLETE_CHECK).toLong(); + private static boolean isIncomplete(ByteVector utf8Vector) { + return utf8Vector.compare(VectorOperators.UNSIGNED_GE, INCOMPLETE_CHECK).anyTrue(); } // ASCII will never exceed 01111_1111 diff --git a/src/test/java/org/simdjson/Utf8ValidatorTest.java b/src/test/java/org/simdjson/Utf8ValidatorTest.java index b129e86..45da0c1 100644 --- a/src/test/java/org/simdjson/Utf8ValidatorTest.java +++ b/src/test/java/org/simdjson/Utf8ValidatorTest.java @@ -1,6 +1,5 @@ package org.simdjson; -import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.VectorSpecies; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -8,11 +7,12 @@ import java.io.IOException; import java.util.Arrays; -import java.util.Objects; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; class Utf8ValidatorTest { + private static final VectorSpecies VECTOR_SPECIES = StructuralIndexer.SPECIES;