1
0
mirror of https://github.com/facebook/zstd.git synced 2025-03-06 16:56:49 +02:00

changed (partially) the decodeSequences flow logic

this allows detecting overflow events without a checksum.
This commit is contained in:
Yann Collet 2023-06-16 11:56:22 -07:00
parent 8b8b5f4d75
commit 02134fad12

View File

@ -1214,14 +1214,20 @@ ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, U16
typedef enum { ZSTD_lo_isRegularOffset, ZSTD_lo_isLongOffset=1 } ZSTD_longOffset_e;
/**
* ZSTD_decodeSequence_old():
* @p longOffsets : tells the decoder to reload more bit while decoding large offsets
* only used in 32-bit mode
* @return : Sequence (litL + matchL + offset)
*/
FORCE_INLINE_TEMPLATE seq_t
ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
ZSTD_decodeSequence_old(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
{
seq_t seq;
/*
* ZSTD_seqSymbol is a structure with a total of 64 bits wide. So it can be
* loaded in one operation and extracted its fields by simply shifting or
* bit-extracting on aarch64.
* ZSTD_seqSymbol is a 64 bits wide structure.
* It can be loaded in one operation
* and its fields extracted by simply shifting or bit-extracting on aarch64.
* GCC doesn't recognize this and generates more unnecessary ldr/ldrb/ldrh
* operations that cause performance drop. This can be avoided by using this
* ZSTD_memcpy hack.
@ -1330,6 +1336,132 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
return seq;
}
/**
* ZSTD_decodeSequence():
* @p longOffsets : tells the decoder to reload more bit while decoding large offsets
* only used in 32-bit mode
* @return : Sequence (litL + matchL + offset)
*/
FORCE_INLINE_TEMPLATE seq_t
ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets, const int isLastSeq)
{
seq_t seq;
/*
* ZSTD_seqSymbol is a 64 bits wide structure.
* It can be loaded in one operation
* and its fields extracted by simply shifting or bit-extracting on aarch64.
* GCC doesn't recognize this and generates more unnecessary ldr/ldrb/ldrh
* operations that cause performance drop. This can be avoided by using this
* ZSTD_memcpy hack.
*/
#if defined(__aarch64__) && (defined(__GNUC__) && !defined(__clang__))
ZSTD_seqSymbol llDInfoS, mlDInfoS, ofDInfoS;
ZSTD_seqSymbol* const llDInfo = &llDInfoS;
ZSTD_seqSymbol* const mlDInfo = &mlDInfoS;
ZSTD_seqSymbol* const ofDInfo = &ofDInfoS;
ZSTD_memcpy(llDInfo, seqState->stateLL.table + seqState->stateLL.state, sizeof(ZSTD_seqSymbol));
ZSTD_memcpy(mlDInfo, seqState->stateML.table + seqState->stateML.state, sizeof(ZSTD_seqSymbol));
ZSTD_memcpy(ofDInfo, seqState->stateOffb.table + seqState->stateOffb.state, sizeof(ZSTD_seqSymbol));
#else
const ZSTD_seqSymbol* const llDInfo = seqState->stateLL.table + seqState->stateLL.state;
const ZSTD_seqSymbol* const mlDInfo = seqState->stateML.table + seqState->stateML.state;
const ZSTD_seqSymbol* const ofDInfo = seqState->stateOffb.table + seqState->stateOffb.state;
#endif
seq.matchLength = mlDInfo->baseValue;
seq.litLength = llDInfo->baseValue;
{ U32 const ofBase = ofDInfo->baseValue;
BYTE const llBits = llDInfo->nbAdditionalBits;
BYTE const mlBits = mlDInfo->nbAdditionalBits;
BYTE const ofBits = ofDInfo->nbAdditionalBits;
BYTE const totalBits = llBits+mlBits+ofBits;
U16 const llNext = llDInfo->nextState;
U16 const mlNext = mlDInfo->nextState;
U16 const ofNext = ofDInfo->nextState;
U32 const llnbBits = llDInfo->nbBits;
U32 const mlnbBits = mlDInfo->nbBits;
U32 const ofnbBits = ofDInfo->nbBits;
assert(llBits <= MaxLLBits);
assert(mlBits <= MaxMLBits);
assert(ofBits <= MaxOff);
/*
* As gcc has better branch and block analyzers, sometimes it is only
* valuable to mark likeliness for clang, it gives around 3-4% of
* performance.
*/
/* sequence */
{ size_t offset;
if (ofBits > 1) {
ZSTD_STATIC_ASSERT(ZSTD_lo_isLongOffset == 1);
ZSTD_STATIC_ASSERT(LONG_OFFSETS_MAX_EXTRA_BITS_32 == 5);
ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 > LONG_OFFSETS_MAX_EXTRA_BITS_32);
ZSTD_STATIC_ASSERT(STREAM_ACCUMULATOR_MIN_32 - LONG_OFFSETS_MAX_EXTRA_BITS_32 >= MaxMLBits);
if (MEM_32bits() && longOffsets && (ofBits >= STREAM_ACCUMULATOR_MIN_32)) {
/* Always read extra bits, this keeps the logic simple,
* avoids branches, and avoids accidentally reading 0 bits.
*/
U32 const extraBits = LONG_OFFSETS_MAX_EXTRA_BITS_32;
offset = ofBase + (BIT_readBitsFast(&seqState->DStream, ofBits - extraBits) << extraBits);
BIT_reloadDStream(&seqState->DStream);
offset += BIT_readBitsFast(&seqState->DStream, extraBits);
} else {
offset = ofBase + BIT_readBitsFast(&seqState->DStream, ofBits/*>0*/); /* <= (ZSTD_WINDOWLOG_MAX-1) bits */
if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream);
}
seqState->prevOffset[2] = seqState->prevOffset[1];
seqState->prevOffset[1] = seqState->prevOffset[0];
seqState->prevOffset[0] = offset;
} else {
U32 const ll0 = (llDInfo->baseValue == 0);
if (LIKELY((ofBits == 0))) {
offset = seqState->prevOffset[ll0];
seqState->prevOffset[1] = seqState->prevOffset[!ll0];
seqState->prevOffset[0] = offset;
} else {
offset = ofBase + ll0 + BIT_readBitsFast(&seqState->DStream, 1);
{ size_t temp = (offset==3) ? seqState->prevOffset[0] - 1 : seqState->prevOffset[offset];
temp += !temp; /* 0 is not valid; input is corrupted; force offset to 1 */
if (offset != 1) seqState->prevOffset[2] = seqState->prevOffset[1];
seqState->prevOffset[1] = seqState->prevOffset[0];
seqState->prevOffset[0] = offset = temp;
} } }
seq.offset = offset;
}
if (mlBits > 0)
seq.matchLength += BIT_readBitsFast(&seqState->DStream, mlBits/*>0*/);
if (MEM_32bits() && (mlBits+llBits >= STREAM_ACCUMULATOR_MIN_32-LONG_OFFSETS_MAX_EXTRA_BITS_32))
BIT_reloadDStream(&seqState->DStream);
if (MEM_64bits() && UNLIKELY(totalBits >= STREAM_ACCUMULATOR_MIN_64-(LLFSELog+MLFSELog+OffFSELog)))
BIT_reloadDStream(&seqState->DStream);
/* Ensure there are enough bits to read the rest of data in 64-bit mode. */
ZSTD_STATIC_ASSERT(16+LLFSELog+MLFSELog+OffFSELog < STREAM_ACCUMULATOR_MIN_64);
if (llBits > 0)
seq.litLength += BIT_readBitsFast(&seqState->DStream, llBits/*>0*/);
if (MEM_32bits())
BIT_reloadDStream(&seqState->DStream);
DEBUGLOG(6, "seq: litL=%u, matchL=%u, offset=%u",
(U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset);
if (!isLastSeq) {
/* don't update FSE state for last Sequence */
ZSTD_updateFseStateWithDInfo(&seqState->stateLL, &seqState->DStream, llNext, llnbBits); /* <= 9 bits */
ZSTD_updateFseStateWithDInfo(&seqState->stateML, &seqState->DStream, mlNext, mlnbBits); /* <= 9 bits */
if (MEM_32bits()) BIT_reloadDStream(&seqState->DStream); /* <= 18 bits */
ZSTD_updateFseStateWithDInfo(&seqState->stateOffb, &seqState->DStream, ofNext, ofnbBits); /* <= 8 bits */
BIT_reloadDStream(&seqState->DStream);
}
}
return seq;
}
#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
MEM_STATIC int ZSTD_dictionaryIsActive(ZSTD_DCtx const* dctx, BYTE const* prefixStart, BYTE const* oLitEnd)
{
@ -1420,7 +1552,7 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx,
/* decompress without overrunning litPtr begins */
{
seq_t sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
seq_t sequence = ZSTD_decodeSequence_old(&seqState, isLongOffset);
/* Align the decompression loop to 32 + 16 bytes.
*
* zstd compiled with gcc-9 on an Intel i9-9900k shows 10% decompression
@ -1495,7 +1627,7 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx,
if (UNLIKELY(!--nbSeq))
break;
BIT_reloadDStream(&(seqState.DStream));
sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
sequence = ZSTD_decodeSequence_old(&seqState, isLongOffset);
}
/* If there are more sequences, they will need to read literals from litExtraBuffer; copy over the remainder from dst and update litPtr and litEnd */
@ -1548,7 +1680,7 @@ ZSTD_decompressSequences_bodySplitLitBuffer( ZSTD_DCtx* dctx,
#endif
for (; ; ) {
seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
seq_t const sequence = ZSTD_decodeSequence_old(&seqState, isLongOffset);
size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litBufferEnd, prefixStart, vBase, dictEnd);
#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE)
assert(!ZSTD_isError(oneSeqSize));
@ -1647,8 +1779,8 @@ ZSTD_decompressSequences_body(ZSTD_DCtx* dctx,
# endif
#endif
for ( ; ; ) {
seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
for ( ; nbSeq ; nbSeq--) {
seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, nbSeq==1);
size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litEnd, prefixStart, vBase, dictEnd);
#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE)
assert(!ZSTD_isError(oneSeqSize));
@ -1658,15 +1790,13 @@ ZSTD_decompressSequences_body(ZSTD_DCtx* dctx,
return oneSeqSize;
DEBUGLOG(6, "regenerated sequence size : %u", (U32)oneSeqSize);
op += oneSeqSize;
if (UNLIKELY(!--nbSeq))
break;
BIT_reloadDStream(&(seqState.DStream));
}
/* check if reached exact end */
DEBUGLOG(5, "ZSTD_decompressSequences_body: after decode loop, remaining nbSeq : %i", nbSeq);
RETURN_ERROR_IF(nbSeq, corruption_detected, "");
RETURN_ERROR_IF(BIT_reloadDStream(&seqState.DStream) < BIT_DStream_completed, corruption_detected, "");
DEBUGLOG(5, "bitStream : start=%p, ptr=%p, bitsConsumed=%u", seqState.DStream.start, seqState.DStream.ptr, seqState.DStream.bitsConsumed);
RETURN_ERROR_IF(!BIT_endOfDStream(&seqState.DStream), corruption_detected, "");
/* save reps for next block */
{ U32 i; for (i=0; i<ZSTD_REP_NUM; i++) dctx->entropy.rep[i] = (U32)(seqState.prevOffset[i]); }
}
@ -1763,7 +1893,7 @@ ZSTD_decompressSequencesLong_body(
/* prepare in advance */
for (seqNb=0; (BIT_reloadDStream(&seqState.DStream) <= BIT_DStream_completed) && (seqNb<seqAdvance); seqNb++) {
seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
seq_t const sequence = ZSTD_decodeSequence_old(&seqState, isLongOffset);
prefetchPos = ZSTD_prefetchMatch(prefetchPos, sequence, prefixStart, dictEnd);
sequences[seqNb] = sequence;
}
@ -1771,7 +1901,7 @@ ZSTD_decompressSequencesLong_body(
/* decompress without stomping litBuffer */
for (; (BIT_reloadDStream(&(seqState.DStream)) <= BIT_DStream_completed) && (seqNb < nbSeq); seqNb++) {
seq_t sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
seq_t sequence = ZSTD_decodeSequence_old(&seqState, isLongOffset);
size_t oneSeqSize;
if (dctx->litBufferLocation == ZSTD_split && litPtr + sequences[(seqNb - ADVANCED_SEQS) & STORED_SEQS_MASK].litLength > dctx->litBufferEnd)