diff --git a/lib/common/zstd_internal.h b/lib/common/zstd_internal.h index 48558873d..12e1106a1 100644 --- a/lib/common/zstd_internal.h +++ b/lib/common/zstd_internal.h @@ -341,6 +341,7 @@ MEM_STATIC ZSTD_sequenceLength ZSTD_getSequenceLength(seqStore_t const* seqStore * `decompressedBound != ZSTD_CONTENTSIZE_ERROR` */ typedef struct { + size_t nbBlocks; size_t compressedSize; unsigned long long decompressedBound; } ZSTD_frameSizeInfo; /* decompress & legacy */ diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c index f00ef3a67..4559451d0 100644 --- a/lib/decompress/zstd_decompress.c +++ b/lib/decompress/zstd_decompress.c @@ -782,6 +782,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize ip += 4; } + frameSizeInfo.nbBlocks = nbBlocks; frameSizeInfo.compressedSize = (size_t)(ip - ipstart); frameSizeInfo.decompressedBound = (zfh.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN) ? zfh.frameContentSize @@ -825,6 +826,48 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize) return bound; } +size_t ZSTD_decompressionMargin(void const* src, size_t srcSize) +{ + size_t margin = 0; + unsigned maxBlockSize = 0; + + /* Iterate over each frame */ + while (srcSize > 0) { + ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize); + size_t const compressedSize = frameSizeInfo.compressedSize; + unsigned long long const decompressedBound = frameSizeInfo.decompressedBound; + ZSTD_frameHeader zfh; + + FORWARD_IF_ERROR(ZSTD_getFrameHeader(&zfh, src, srcSize), ""); + if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR) + return ERROR(corruption_detected); + + if (zfh.frameType == ZSTD_frame) { + /* Add the frame header to our margin */ + margin += zfh.headerSize; + /* Add the checksum to our margin */ + margin += zfh.checksumFlag ? 4 : 0; + /* Add 3 bytes per block */ + margin += 3 * frameSizeInfo.nbBlocks; + + /* Compute the max block size */ + maxBlockSize = MAX(maxBlockSize, zfh.blockSizeMax); + } else { + assert(zfh.frameType == ZSTD_skippableFrame); + /* Add the entire skippable frame size to our margin. */ + margin += compressedSize; + } + + assert(srcSize >= compressedSize); + src = (const BYTE*)src + compressedSize; + srcSize -= compressedSize; + } + + /* Add the max block size back to the margin. */ + margin += maxBlockSize; + + return margin; +} /*-************************************************************* * Frame decoding @@ -850,7 +893,7 @@ static size_t ZSTD_copyRawBlock(void* dst, size_t dstCapacity, if (srcSize == 0) return 0; RETURN_ERROR(dstBuffer_null, ""); } - ZSTD_memcpy(dst, src, srcSize); + ZSTD_memmove(dst, src, srcSize); return srcSize; } @@ -928,6 +971,7 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, /* Loop on each block */ while (1) { + BYTE* oBlockEnd = oend; size_t decodedSize; blockProperties_t blockProperties; size_t const cBlockSize = ZSTD_getcBlockSize(ip, remainingSrcSize, &blockProperties); @@ -937,16 +981,34 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, remainingSrcSize -= ZSTD_blockHeaderSize; RETURN_ERROR_IF(cBlockSize > remainingSrcSize, srcSize_wrong, ""); + if (ip >= op && ip < oBlockEnd) { + /* We are decompressing in-place. Limit the output pointer so that we + * don't overwrite the block that we are currently reading. This will + * fail decompression if the input & output pointers aren't spaced + * far enough apart. + * + * This is important to set, even when the pointers are far enough + * apart, because ZSTD_decompressBlock_internal() can decide to store + * literals in the output buffer, after the block it is decompressing. + * Since we don't want anything to overwrite our input, we have to tell + * ZSTD_decompressBlock_internal to never write past ip. + * + * See ZSTD_allocateLiteralsBuffer() for reference. + */ + oBlockEnd = op + (ip - op); + } + switch(blockProperties.blockType) { case bt_compressed: - decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oend-op), ip, cBlockSize, /* frame */ 1, not_streaming); + decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, /* frame */ 1, not_streaming); break; case bt_raw : + /* Use oend instead of oBlockEnd because this function is safe to overlap. It uses memmove. */ decodedSize = ZSTD_copyRawBlock(op, (size_t)(oend-op), ip, cBlockSize); break; case bt_rle : - decodedSize = ZSTD_setRleBlock(op, (size_t)(oend-op), *ip, blockProperties.origSize); + decodedSize = ZSTD_setRleBlock(op, (size_t)(oBlockEnd-op), *ip, blockProperties.origSize); break; case bt_reserved : default: diff --git a/lib/legacy/zstd_legacy.h b/lib/legacy/zstd_legacy.h index 9f53d4cbd..dd173251d 100644 --- a/lib/legacy/zstd_legacy.h +++ b/lib/legacy/zstd_legacy.h @@ -242,6 +242,13 @@ MEM_STATIC ZSTD_frameSizeInfo ZSTD_findFrameSizeInfoLegacy(const void *src, size frameSizeInfo.compressedSize = ERROR(srcSize_wrong); frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR; } + /* In all cases, decompressedBound == nbBlocks * ZSTD_BLOCKSIZE_MAX. + * So we can compute nbBlocks without having to change every function. + */ + if (frameSizeInfo.decompressedBound != ZSTD_CONTENTSIZE_ERROR) { + assert((frameSizeInfo.decompressedBound & (ZSTD_BLOCKSIZE_MAX - 1)) == 0); + frameSizeInfo.nbBlocks = (size_t)(frameSizeInfo.decompressedBound / ZSTD_BLOCKSIZE_MAX); + } return frameSizeInfo; } diff --git a/lib/zstd.h b/lib/zstd.h index 480d65f67..22c8bba5b 100644 --- a/lib/zstd.h +++ b/lib/zstd.h @@ -1427,6 +1427,51 @@ ZSTDLIB_STATIC_API unsigned long long ZSTD_decompressBound(const void* src, size * or an error code (if srcSize is too small) */ ZSTDLIB_STATIC_API size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize); +/*! ZSTD_decompressionMargin() : + * Zstd supports in-place decompression, where the input and output buffers overlap. + * In this case, the output buffer must be at least (Margin + Output_Size) bytes large, + * and the input buffer must be at the end of the output buffer. + * + * _______________________ Output Buffer ________________________ + * | | + * | ____ Input Buffer ____| + * | | | + * v v v + * |---------------------------------------|-----------|----------| + * ^ ^ ^ + * |___________________ Output_Size ___________________|_ Margin _| + * + * NOTE: See also ZSTD_DECOMPRESSION_MARGIN(). + * NOTE: This applies only to single-pass decompression through ZSTD_decompress() or + * ZSTD_decompressDCtx(). + * NOTE: This function supports multi-frame input. + * + * @param src The compressed frame(s) + * @param srcSize The size of the compressed frame(s) + * @returns The decompression margin or an error that can be checked with ZSTD_isError(). + */ +ZSTDLIB_STATIC_API size_t ZSTD_decompressionMargin(const void* src, size_t srcSize); + +/*! ZSTD_DECOMPRESS_MARGIN() : + * Similar to ZSTD_decompressionMargin(), but instead of computing the margin from + * the compressed frame, compute it from the original size and the blockSizeLog. + * See ZSTD_decompressionMargin() for details. + * + * WARNING: This macro does not support multi-frame input, the input must be a single + * zstd frame. If you need that support use the function, or implement it yourself. + * + * @param originalSize The original uncompressed size of the data. + * @param blockSize The block size == MIN(windowSize, ZSTD_BLOCKSIZE_MAX). + * Unless you explicitly set the windowLog smaller than + * ZSTD_BLOCKSIZELOG_MAX you can just use ZSTD_BLOCKSIZE_MAX. + */ +#define ZSTD_DECOMPRESSION_MARGIN(originalSize, blockSize) ((size_t)( \ + ZSTD_FRAMEHEADERSIZE_MAX /* Frame header */ + \ + 4 /* checksum */ + \ + ((originalSize) == 0 ? 0 : 3 * (((originalSize) + (blockSize) - 1) / blockSize)) /* 3 bytes per block */ + \ + (blockSize) /* One block of margin */ \ + )) + typedef enum { ZSTD_sf_noBlockDelimiters = 0, /* Representation of ZSTD_Sequence has no block delimiters, sequences only */ ZSTD_sf_explicitBlockDelimiters = 1 /* Representation of ZSTD_Sequence contains explicit block delimiters */ diff --git a/tests/fuzz/simple_round_trip.c b/tests/fuzz/simple_round_trip.c index 23a805af2..c2c69d950 100644 --- a/tests/fuzz/simple_round_trip.c +++ b/tests/fuzz/simple_round_trip.c @@ -26,6 +26,23 @@ static ZSTD_CCtx *cctx = NULL; static ZSTD_DCtx *dctx = NULL; +static size_t getDecompressionMargin(void const* compressed, size_t cSize, size_t srcSize, int hasSmallBlocks) +{ + size_t margin = ZSTD_decompressionMargin(compressed, cSize); + if (!hasSmallBlocks) { + /* The macro should be correct in this case, but it may be smaller + * because of e.g. block splitting, so take the smaller of the two. + */ + ZSTD_frameHeader zfh; + size_t marginM; + FUZZ_ZASSERT(ZSTD_getFrameHeader(&zfh, compressed, cSize)); + marginM = ZSTD_DECOMPRESSION_MARGIN(srcSize, zfh.blockSizeMax); + if (marginM < margin) + margin = marginM; + } + return margin; +} + static size_t roundTripTest(void *result, size_t resultCapacity, void *compressed, size_t compressedCapacity, const void *src, size_t srcSize, @@ -67,6 +84,25 @@ static size_t roundTripTest(void *result, size_t resultCapacity, } dSize = ZSTD_decompressDCtx(dctx, result, resultCapacity, compressed, cSize); FUZZ_ZASSERT(dSize); + FUZZ_ASSERT_MSG(dSize == srcSize, "Incorrect regenerated size"); + FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, result, dSize), "Corruption!"); + + { + size_t margin = getDecompressionMargin(compressed, cSize, srcSize, targetCBlockSize); + size_t const outputSize = srcSize + margin; + char* const output = (char*)FUZZ_malloc(outputSize); + char* const input = output + outputSize - cSize; + FUZZ_ASSERT(outputSize >= cSize); + memcpy(input, compressed, cSize); + + dSize = ZSTD_decompressDCtx(dctx, output, outputSize, input, cSize); + FUZZ_ZASSERT(dSize); + FUZZ_ASSERT_MSG(dSize == srcSize, "Incorrect regenerated size"); + FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, output, srcSize), "Corruption!"); + + free(output); + } + /* When superblock is enabled make sure we don't expand the block more than expected. * NOTE: This test is currently disabled because superblock mode can arbitrarily * expand the block in the worst case. Once superblock mode has been improved we can @@ -120,13 +156,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size) FUZZ_ASSERT(dctx); } - { - size_t const result = - roundTripTest(rBuf, rBufSize, cBuf, cBufSize, src, size, producer); - FUZZ_ZASSERT(result); - FUZZ_ASSERT_MSG(result == size, "Incorrect regenerated size"); - FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, rBuf, size), "Corruption!"); - } + roundTripTest(rBuf, rBufSize, cBuf, cBufSize, src, size, producer); free(rBuf); free(cBuf); FUZZ_dataProducer_free(producer); diff --git a/tests/fuzz/stream_round_trip.c b/tests/fuzz/stream_round_trip.c index 8a28907b6..fae9ccbf4 100644 --- a/tests/fuzz/stream_round_trip.c +++ b/tests/fuzz/stream_round_trip.c @@ -166,6 +166,24 @@ int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size) FUZZ_ZASSERT(rSize); FUZZ_ASSERT_MSG(rSize == size, "Incorrect regenerated size"); FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, rBuf, size), "Corruption!"); + + /* Test in-place decompression (note the macro doesn't work in this case) */ + { + size_t const margin = ZSTD_decompressionMargin(cBuf, cSize); + size_t const outputSize = size + margin; + char* const output = (char*)FUZZ_malloc(outputSize); + char* const input = output + outputSize - cSize; + size_t dSize; + FUZZ_ASSERT(outputSize >= cSize); + memcpy(input, cBuf, cSize); + + dSize = ZSTD_decompressDCtx(dctx, output, outputSize, input, cSize); + FUZZ_ZASSERT(dSize); + FUZZ_ASSERT_MSG(dSize == size, "Incorrect regenerated size"); + FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, output, size), "Corruption!"); + + free(output); + } } FUZZ_dataProducer_free(producer); diff --git a/tests/fuzzer.c b/tests/fuzzer.c index 3ad8ced5e..ce0bea573 100644 --- a/tests/fuzzer.c +++ b/tests/fuzzer.c @@ -1220,6 +1220,60 @@ static int basicUnitTests(U32 const seed, double compressibility) } DISPLAYLEVEL(3, "OK \n"); + DISPLAYLEVEL(3, "test%3i : in-place decompression : ", testNb++); + cSize = ZSTD_compress(compressedBuffer, compressedBufferSize, CNBuffer, CNBuffSize, -ZSTD_BLOCKSIZE_MAX); + CHECK_Z(cSize); + CHECK_LT(CNBuffSize, cSize); + { + size_t const margin = ZSTD_decompressionMargin(compressedBuffer, cSize); + size_t const outputSize = (CNBuffSize + margin); + char* output = malloc(outputSize); + char* input = output + outputSize - cSize; + CHECK_LT(cSize, CNBuffSize + margin); + CHECK(output != NULL); + CHECK_Z(margin); + CHECK(margin <= ZSTD_DECOMPRESSION_MARGIN(CNBuffSize, ZSTD_BLOCKSIZE_MAX)); + memcpy(input, compressedBuffer, cSize); + + { + size_t const dSize = ZSTD_decompress(output, outputSize, input, cSize); + CHECK_Z(dSize); + CHECK_EQ(dSize, CNBuffSize); + } + CHECK(!memcmp(output, CNBuffer, CNBuffSize)); + free(output); + } + DISPLAYLEVEL(3, "OK \n"); + + DISPLAYLEVEL(3, "test%3i : in-place decompression with 2 frames : ", testNb++); + cSize = ZSTD_compress(compressedBuffer, compressedBufferSize, CNBuffer, CNBuffSize / 3, -ZSTD_BLOCKSIZE_MAX); + CHECK_Z(cSize); + { + size_t const cSize2 = ZSTD_compress((char*)compressedBuffer + cSize, compressedBufferSize - cSize, (char const*)CNBuffer + (CNBuffSize / 3), CNBuffSize / 3, -ZSTD_BLOCKSIZE_MAX); + CHECK_Z(cSize2); + cSize += cSize2; + } + { + size_t const srcSize = (CNBuffSize / 3) * 2; + size_t const margin = ZSTD_decompressionMargin(compressedBuffer, cSize); + size_t const outputSize = (CNBuffSize + margin); + char* output = malloc(outputSize); + char* input = output + outputSize - cSize; + CHECK_LT(cSize, CNBuffSize + margin); + CHECK(output != NULL); + CHECK_Z(margin); + memcpy(input, compressedBuffer, cSize); + + { + size_t const dSize = ZSTD_decompress(output, outputSize, input, cSize); + CHECK_Z(dSize); + CHECK_EQ(dSize, srcSize); + } + CHECK(!memcmp(output, CNBuffer, srcSize)); + free(output); + } + DISPLAYLEVEL(3, "OK \n"); + DISPLAYLEVEL(3, "test%3d: superblock uncompressible data, too many nocompress superblocks : ", testNb++); { ZSTD_CCtx* const cctx = ZSTD_createCCtx();