diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c index 00c0d937e..18ee070cf 100644 --- a/lib/decompress/zstd_decompress.c +++ b/lib/decompress/zstd_decompress.c @@ -354,8 +354,7 @@ unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize) skippableSize = MEM_readLE32((const BYTE *)src + 4) + ZSTD_skippableHeaderSize; if (srcSize < skippableSize) { - /* srcSize_wrong */ - return 0; + return ZSTD_CONTENTSIZE_ERROR; } src = (const BYTE *)src + skippableSize; @@ -384,7 +383,7 @@ unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize) frameSrcSize = ZSTD_frameSrcSize(src, srcSize); } if (ZSTD_isError(frameSrcSize)) { - return 0; + return ZSTD_CONTENTSIZE_ERROR; } src = (const BYTE *)src + frameSrcSize; @@ -393,8 +392,7 @@ unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize) } if (srcSize) { - /* srcSize_wrong */ - return 0; + return ZSTD_CONTENTSIZE_ERROR; } return totalDstSize; @@ -1498,22 +1496,22 @@ static size_t ZSTD_frameSrcSize(const void *src, size_t srcSize) * `dctx` must be properly initialized */ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, - const void* src, size_t srcSize) + const void** srcPtr, size_t *srcSizePtr) { - const BYTE* ip = (const BYTE*)src; + const BYTE* ip = (const BYTE*)(*srcPtr); BYTE* const ostart = (BYTE* const)dst; BYTE* const oend = ostart + dstCapacity; BYTE* op = ostart; - size_t remainingSize = srcSize; + size_t remainingSize = *srcSizePtr; /* check */ - if (srcSize < ZSTD_frameHeaderSize_min+ZSTD_blockHeaderSize) return ERROR(srcSize_wrong); + if (remainingSize < ZSTD_frameHeaderSize_min+ZSTD_blockHeaderSize) return ERROR(srcSize_wrong); /* Frame Header */ - { size_t const frameHeaderSize = ZSTD_frameHeaderSize(src, ZSTD_frameHeaderSize_prefix); + { size_t const frameHeaderSize = ZSTD_frameHeaderSize(ip, ZSTD_frameHeaderSize_prefix); if (ZSTD_isError(frameHeaderSize)) return frameHeaderSize; - if (srcSize < frameHeaderSize+ZSTD_blockHeaderSize) return ERROR(srcSize_wrong); - CHECK_F(ZSTD_decodeFrameHeader(dctx, src, frameHeaderSize)); + if (remainingSize < frameHeaderSize+ZSTD_blockHeaderSize) return ERROR(srcSize_wrong); + CHECK_F(ZSTD_decodeFrameHeader(dctx, ip, frameHeaderSize)); ip += frameHeaderSize; remainingSize -= frameHeaderSize; } @@ -1558,25 +1556,98 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx, if (remainingSize<4) return ERROR(checksum_wrong); checkRead = MEM_readLE32(ip); if (checkRead != checkCalc) return ERROR(checksum_wrong); + ip += 4; remainingSize -= 4; } - if (remainingSize) return ERROR(srcSize_wrong); + // Allow caller to get size read + *srcPtr = ip; + *srcSizePtr = remainingSize; return op-ostart; } +static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx, + void* dst, size_t dstCapacity, + const void* src, size_t srcSize, + const void *dict, size_t dictSize, + const ZSTD_DCtx* refContext) +{ + void* const dststart = dst; + while (srcSize >= ZSTD_frameHeaderSize_prefix) { + U32 magicNumber; + +#if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT >= 1) + if (ZSTD_isLegacy(src, srcSize)) { + size_t const frameSize = ZSTD_frameSrcSizeLegacy(src, srcSize); + size_t decodedSize; + if (ZSTD_isError(frameSize)) return frameSize; + + decodedSize = ZSTD_decompressLegacy(dst, dstCapacity, src, frameSize, dict, dictSize); + + dst = (BYTE*)dst + decodedSize; + dstCapacity -= decodedSize; + + src = (const BYTE*)src + frameSize; + srcSize -= frameSize; + + continue; + } +#endif + + magicNumber = MEM_readLE32(src); + if (magicNumber != ZSTD_MAGICNUMBER) { + if ((magicNumber & 0xFFFFFFF0U) == ZSTD_MAGIC_SKIPPABLE_START) { + size_t skippableSize; + if (srcSize < ZSTD_skippableHeaderSize) + return ERROR(srcSize_wrong); + skippableSize = MEM_readLE32((const BYTE *)src + 4) + + ZSTD_skippableHeaderSize; + if (srcSize < skippableSize) { + return ERROR(srcSize_wrong); + } + + src = (const BYTE *)src + skippableSize; + srcSize -= skippableSize; + continue; + } else { + return ERROR(prefix_unknown); + } + } + + if (refContext) { + /* we were called from ZSTD_decompress_usingDDict */ + ZSTD_refDCtx(dctx, refContext); + } else { + /* this will initialize correctly with no dict if dict == NULL, so + * use this in all cases but ddict */ + CHECK_F(ZSTD_decompressBegin_usingDict(dctx, dict, dictSize)); + } + ZSTD_checkContinuity(dctx, dst); + + { + const size_t res = ZSTD_decompressFrame(dctx, dst, dstCapacity, + &src, &srcSize); + if (ZSTD_isError(res)) return res; + /* don't need to bounds check this, ZSTD_decompressFrame will have + * already */ + dst = (BYTE*)dst + res; + dstCapacity -= res; + } + } + + if (srcSize) { + return ERROR(srcSize_wrong); + } + + return (BYTE*)dst - (BYTE*)dststart; +} size_t ZSTD_decompress_usingDict(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize, const void* dict, size_t dictSize) { -#if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT==1) - if (ZSTD_isLegacy(src, srcSize)) return ZSTD_decompressLegacy(dst, dstCapacity, src, srcSize, dict, dictSize); -#endif - CHECK_F(ZSTD_decompressBegin_usingDict(dctx, dict, dictSize)); - ZSTD_checkContinuity(dctx, dst); - return ZSTD_decompressFrame(dctx, dst, dstCapacity, src, srcSize); + return ZSTD_decompressMultiFrame(dctx, dst, dstCapacity, src, srcSize, dict, dictSize, NULL); } @@ -1973,12 +2044,10 @@ size_t ZSTD_decompress_usingDDict(ZSTD_DCtx* dctx, const void* src, size_t srcSize, const ZSTD_DDict* ddict) { -#if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT==1) - if (ZSTD_isLegacy(src, srcSize)) return ZSTD_decompressLegacy(dst, dstCapacity, src, srcSize, ddict->dictContent, ddict->dictSize); -#endif - ZSTD_refDCtx(dctx, ddict->refContext); - ZSTD_checkContinuity(dctx, dst); - return ZSTD_decompressFrame(dctx, dst, dstCapacity, src, srcSize); + /* pass content and size in case legacy frames are encountered */ + return ZSTD_decompressMultiFrame(dctx, dst, dstCapacity, src, srcSize, + ddict->dictContent, ddict->dictSize, + ddict->refContext); } diff --git a/lib/legacy/zstd_v07.c b/lib/legacy/zstd_v07.c index 4ee055ddd..6228ee870 100644 --- a/lib/legacy/zstd_v07.c +++ b/lib/legacy/zstd_v07.c @@ -3992,6 +3992,9 @@ size_t ZSTDv07_frameSrcSize(const void* src, size_t srcSize) ip += ZSTDv07_blockHeaderSize; remainingSize -= ZSTDv07_blockHeaderSize; + + if (blockProperties.blockType == bt_end) break; + if (cBlockSize > remainingSize) return ERROR(srcSize_wrong); ip += cBlockSize; diff --git a/tests/fuzzer.c b/tests/fuzzer.c index e7f25aae5..4239af455 100644 --- a/tests/fuzzer.c +++ b/tests/fuzzer.c @@ -167,6 +167,46 @@ static int basicUnitTests(U32 seed, double compressibility) if (ZSTD_getErrorCode(r) != ZSTD_error_srcSize_wrong) goto _output_error; } DISPLAYLEVEL(4, "OK \n"); + /* Simple API multiframe test */ + DISPLAYLEVEL(4, "test%3i : compress multiple frames : ", testNb++); + { size_t off = 0; + int i; + int const segs = 4; + /* only use the first half so we don't push against size limit of compressedBuffer */ + size_t const segSize = (CNBuffSize / 2) / segs; + for (i = 0; i < segs; i++) { + CHECK_V(r, + ZSTD_compress( + (BYTE *)compressedBuffer + off, CNBuffSize - off, + (BYTE *)CNBuffer + segSize * i, + segSize, 5)); + off += r; + if (i == segs/2) { + /* insert skippable frame */ + const U32 skipLen = 128 KB; + MEM_writeLE32((BYTE*)compressedBuffer + off, ZSTD_MAGIC_SKIPPABLE_START); + MEM_writeLE32((BYTE*)compressedBuffer + off + 4, skipLen); + off += skipLen + ZSTD_skippableHeaderSize; + } + } + cSize = off; + } + DISPLAYLEVEL(4, "OK \n"); + + DISPLAYLEVEL(4, "test%3i : get decompressed size of multiple frames : ", testNb++); + { unsigned long long const r = ZSTD_findDecompressedSize(compressedBuffer, cSize); + if (r != CNBuffSize / 2) goto _output_error; } + DISPLAYLEVEL(4, "OK \n"); + + DISPLAYLEVEL(4, "test%3i : decompress multiple frames : ", testNb++); + { CHECK_V(r, ZSTD_decompress(decodedBuffer, CNBuffSize, compressedBuffer, cSize)); + if (r != CNBuffSize / 2) goto _output_error; } + DISPLAYLEVEL(4, "OK \n"); + + DISPLAYLEVEL(4, "test%3i : check decompressed result : ", testNb++); + if (memcmp(decodedBuffer, CNBuffer, CNBuffSize / 2) != 0) goto _output_error; + DISPLAYLEVEL(4, "OK \n"); + /* Dictionary and CCtx Duplication tests */ { ZSTD_CCtx* const ctxOrig = ZSTD_createCCtx(); ZSTD_CCtx* const ctxDuplicated = ZSTD_createCCtx();