diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c index 856335e52..298278c99 100644 --- a/lib/compress/zstd_compress.c +++ b/lib/compress/zstd_compress.c @@ -322,13 +322,12 @@ static size_t ZSTD_resetCCtx_advanced (ZSTD_CCtx* zc, * Duplicate an existing context `srcCCtx` into another one `dstCCtx`. * Only works during stage ZSTDcs_init (i.e. after creation, but before first call to ZSTD_compressContinue()). * @return : 0, or an error code */ -size_t ZSTD_copyCCtx(ZSTD_CCtx* dstCCtx, const ZSTD_CCtx* srcCCtx) +size_t ZSTD_copyCCtx(ZSTD_CCtx* dstCCtx, const ZSTD_CCtx* srcCCtx, unsigned long long pledgedSrcSize) { if (srcCCtx->stage!=ZSTDcs_init) return ERROR(stage_wrong); memcpy(&dstCCtx->customMem, &srcCCtx->customMem, sizeof(ZSTD_customMem)); - ZSTD_resetCCtx_advanced(dstCCtx, srcCCtx->params, srcCCtx->frameContentSize, ZSTDcrp_noMemset); - dstCCtx->params.fParams.contentSizeFlag = 0; /* content size different from the one set during srcCCtx init */ + ZSTD_resetCCtx_advanced(dstCCtx, srcCCtx->params, pledgedSrcSize, ZSTDcrp_noMemset); /* copy tables */ { size_t const chainSize = (srcCCtx->params.cParams.strategy == ZSTD_fast) ? 0 : (1 << srcCCtx->params.cParams.chainLog); @@ -2740,7 +2739,7 @@ size_t ZSTD_freeCDict(ZSTD_CDict* cdict) size_t ZSTD_compressBegin_usingCDict(ZSTD_CCtx* cctx, const ZSTD_CDict* cdict, U64 pledgedSrcSize) { - if (cdict->dictContentSize) CHECK_F(ZSTD_copyCCtx(cctx, cdict->refContext)) + if (cdict->dictContentSize) CHECK_F(ZSTD_copyCCtx(cctx, cdict->refContext, pledgedSrcSize)) else CHECK_F(ZSTD_compressBegin_advanced(cctx, NULL, 0, cdict->refContext->params, pledgedSrcSize)); return 0; } diff --git a/lib/dictBuilder/zdict.c b/lib/dictBuilder/zdict.c index cfabb20ba..8a38aadeb 100644 --- a/lib/dictBuilder/zdict.c +++ b/lib/dictBuilder/zdict.c @@ -563,7 +563,7 @@ static void ZDICT_countEStats(EStats_ress_t esr, ZSTD_parameters params, size_t cSize; if (srcSize > blockSizeMax) srcSize = blockSizeMax; /* protection vs large samples */ - { size_t const errorCode = ZSTD_copyCCtx(esr.zc, esr.ref); + { size_t const errorCode = ZSTD_copyCCtx(esr.zc, esr.ref, 0); if (ZSTD_isError(errorCode)) { DISPLAYLEVEL(1, "warning : ZSTD_copyCCtx failed \n"); return; } } cSize = ZSTD_compressBlock(esr.zc, esr.workPlace, ZSTD_BLOCKSIZE_ABSOLUTEMAX, src, srcSize); diff --git a/lib/zstd.h b/lib/zstd.h index 31171d04d..d7eb9c01f 100644 --- a/lib/zstd.h +++ b/lib/zstd.h @@ -447,7 +447,7 @@ ZSTDLIB_API size_t ZSTD_sizeof_DStream(const ZSTD_DStream* zds); ZSTDLIB_API size_t ZSTD_compressBegin(ZSTD_CCtx* cctx, int compressionLevel); ZSTDLIB_API size_t ZSTD_compressBegin_usingDict(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, int compressionLevel); ZSTDLIB_API size_t ZSTD_compressBegin_advanced(ZSTD_CCtx* cctx, const void* dict, size_t dictSize, ZSTD_parameters params, unsigned long long pledgedSrcSize); -ZSTDLIB_API size_t ZSTD_copyCCtx(ZSTD_CCtx* cctx, const ZSTD_CCtx* preparedCCtx); +ZSTDLIB_API size_t ZSTD_copyCCtx(ZSTD_CCtx* cctx, const ZSTD_CCtx* preparedCCtx, unsigned long long pledgedSrcSize); ZSTDLIB_API size_t ZSTD_compressContinue(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); ZSTDLIB_API size_t ZSTD_compressEnd(ZSTD_CCtx* cctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize); diff --git a/tests/fuzzer.c b/tests/fuzzer.c index b8f102a9c..ae8450e40 100644 --- a/tests/fuzzer.c +++ b/tests/fuzzer.c @@ -173,13 +173,13 @@ static int basicUnitTests(U32 seed, double compressibility) static const size_t dictSize = 551; DISPLAYLEVEL(4, "test%3i : copy context too soon : ", testNb++); - { size_t const copyResult = ZSTD_copyCCtx(ctxDuplicated, ctxOrig); + { size_t const copyResult = ZSTD_copyCCtx(ctxDuplicated, ctxOrig, 0); if (!ZSTD_isError(copyResult)) goto _output_error; } /* error must be detected */ DISPLAYLEVEL(4, "OK \n"); DISPLAYLEVEL(4, "test%3i : load dictionary into context : ", testNb++); CHECK( ZSTD_compressBegin_usingDict(ctxOrig, CNBuffer, dictSize, 2) ); - CHECK( ZSTD_copyCCtx(ctxDuplicated, ctxOrig) ); + CHECK( ZSTD_copyCCtx(ctxDuplicated, ctxOrig, CNBuffSize - dictSize) ); DISPLAYLEVEL(4, "OK \n"); DISPLAYLEVEL(4, "test%3i : compress with flat dictionary : ", testNb++); @@ -221,10 +221,10 @@ static int basicUnitTests(U32 seed, double compressibility) p.fParams.contentSizeFlag = 1; CHECK( ZSTD_compressBegin_advanced(ctxOrig, CNBuffer, dictSize, p, testSize-1) ); } - CHECK( ZSTD_copyCCtx(ctxDuplicated, ctxOrig) ); + CHECK( ZSTD_copyCCtx(ctxDuplicated, ctxOrig, testSize) ); - CHECKPLUS(r, ZSTD_compressContinue(ctxDuplicated, compressedBuffer, ZSTD_compressBound(testSize), - (const char*)CNBuffer + dictSize, CNBuffSize - dictSize), + CHECKPLUS(r, ZSTD_compressEnd(ctxDuplicated, compressedBuffer, ZSTD_compressBound(testSize), + (const char*)CNBuffer + dictSize, testSize), cSize = r); { ZSTD_frameParams fp; if (ZSTD_getFrameParams(&fp, compressedBuffer, cSize)) goto _output_error; @@ -674,9 +674,9 @@ static int fuzzerTests(U32 seed, U32 nbTests, unsigned startTest, U32 const maxD errorCode = ZSTD_compressBegin_advanced(refCtx, dict, dictSize, p, 0); CHECK (ZSTD_isError(errorCode), "ZSTD_compressBegin_advanced error : %s", ZSTD_getErrorName(errorCode)); } - { size_t const errorCode = ZSTD_copyCCtx(ctx, refCtx); - CHECK (ZSTD_isError(errorCode), "ZSTD_copyCCtx error : %s", ZSTD_getErrorName(errorCode)); } - } + { size_t const errorCode = ZSTD_copyCCtx(ctx, refCtx, 0); + CHECK (ZSTD_isError(errorCode), "ZSTD_copyCCtx error : %s", ZSTD_getErrorName(errorCode)); + } } XXH64_reset(&xxhState, 0); { U32 const nbChunks = (FUZ_rand(&lseed) & 127) + 2; U32 n;