1
0
mirror of https://github.com/facebook/zstd.git synced 2025-03-06 08:49:28 +02:00

Add support for in-place decompression

* Add a function and macro ZSTD_decompressionMargin() that computes the
  decompression margin for in-place decompression. The function computes
  a tight margin that works in all cases, and the macro computes an upper
  bound that will only work if flush isn't used.
* When doing in-place decompression, make sure that our output buffer
  doesn't overlap with the input buffer. This ensures that we don't
  decide to use the portion of the output buffer that overlaps the input
  buffer for temporary memory, like for literals.
* Add a simple unit test.
* Add in-place decompression to the simple_round_trip and
  stream_round_trip fuzzers. This should help verify that our margin stays
  correct.
This commit is contained in:
Nick Terrell 2023-01-11 18:14:40 -08:00 committed by Nick Terrell
parent 423500d1ae
commit 5b266196a4
7 changed files with 227 additions and 10 deletions

View File

@ -341,6 +341,7 @@ MEM_STATIC ZSTD_sequenceLength ZSTD_getSequenceLength(seqStore_t const* seqStore
* `decompressedBound != ZSTD_CONTENTSIZE_ERROR`
*/
typedef struct {
size_t nbBlocks;
size_t compressedSize;
unsigned long long decompressedBound;
} ZSTD_frameSizeInfo; /* decompress & legacy */

View File

@ -782,6 +782,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize
ip += 4;
}
frameSizeInfo.nbBlocks = nbBlocks;
frameSizeInfo.compressedSize = (size_t)(ip - ipstart);
frameSizeInfo.decompressedBound = (zfh.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN)
? zfh.frameContentSize
@ -825,6 +826,48 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize)
return bound;
}
size_t ZSTD_decompressionMargin(void const* src, size_t srcSize)
{
size_t margin = 0;
unsigned maxBlockSize = 0;
/* Iterate over each frame */
while (srcSize > 0) {
ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
size_t const compressedSize = frameSizeInfo.compressedSize;
unsigned long long const decompressedBound = frameSizeInfo.decompressedBound;
ZSTD_frameHeader zfh;
FORWARD_IF_ERROR(ZSTD_getFrameHeader(&zfh, src, srcSize), "");
if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR)
return ERROR(corruption_detected);
if (zfh.frameType == ZSTD_frame) {
/* Add the frame header to our margin */
margin += zfh.headerSize;
/* Add the checksum to our margin */
margin += zfh.checksumFlag ? 4 : 0;
/* Add 3 bytes per block */
margin += 3 * frameSizeInfo.nbBlocks;
/* Compute the max block size */
maxBlockSize = MAX(maxBlockSize, zfh.blockSizeMax);
} else {
assert(zfh.frameType == ZSTD_skippableFrame);
/* Add the entire skippable frame size to our margin. */
margin += compressedSize;
}
assert(srcSize >= compressedSize);
src = (const BYTE*)src + compressedSize;
srcSize -= compressedSize;
}
/* Add the max block size back to the margin. */
margin += maxBlockSize;
return margin;
}
/*-*************************************************************
* Frame decoding
@ -850,7 +893,7 @@ static size_t ZSTD_copyRawBlock(void* dst, size_t dstCapacity,
if (srcSize == 0) return 0;
RETURN_ERROR(dstBuffer_null, "");
}
ZSTD_memcpy(dst, src, srcSize);
ZSTD_memmove(dst, src, srcSize);
return srcSize;
}
@ -928,6 +971,7 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
/* Loop on each block */
while (1) {
BYTE* oBlockEnd = oend;
size_t decodedSize;
blockProperties_t blockProperties;
size_t const cBlockSize = ZSTD_getcBlockSize(ip, remainingSrcSize, &blockProperties);
@ -937,16 +981,34 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
remainingSrcSize -= ZSTD_blockHeaderSize;
RETURN_ERROR_IF(cBlockSize > remainingSrcSize, srcSize_wrong, "");
if (ip >= op && ip < oBlockEnd) {
/* We are decompressing in-place. Limit the output pointer so that we
* don't overwrite the block that we are currently reading. This will
* fail decompression if the input & output pointers aren't spaced
* far enough apart.
*
* This is important to set, even when the pointers are far enough
* apart, because ZSTD_decompressBlock_internal() can decide to store
* literals in the output buffer, after the block it is decompressing.
* Since we don't want anything to overwrite our input, we have to tell
* ZSTD_decompressBlock_internal to never write past ip.
*
* See ZSTD_allocateLiteralsBuffer() for reference.
*/
oBlockEnd = op + (ip - op);
}
switch(blockProperties.blockType)
{
case bt_compressed:
decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oend-op), ip, cBlockSize, /* frame */ 1, not_streaming);
decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, /* frame */ 1, not_streaming);
break;
case bt_raw :
/* Use oend instead of oBlockEnd because this function is safe to overlap. It uses memmove. */
decodedSize = ZSTD_copyRawBlock(op, (size_t)(oend-op), ip, cBlockSize);
break;
case bt_rle :
decodedSize = ZSTD_setRleBlock(op, (size_t)(oend-op), *ip, blockProperties.origSize);
decodedSize = ZSTD_setRleBlock(op, (size_t)(oBlockEnd-op), *ip, blockProperties.origSize);
break;
case bt_reserved :
default:

View File

@ -242,6 +242,13 @@ MEM_STATIC ZSTD_frameSizeInfo ZSTD_findFrameSizeInfoLegacy(const void *src, size
frameSizeInfo.compressedSize = ERROR(srcSize_wrong);
frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR;
}
/* In all cases, decompressedBound == nbBlocks * ZSTD_BLOCKSIZE_MAX.
* So we can compute nbBlocks without having to change every function.
*/
if (frameSizeInfo.decompressedBound != ZSTD_CONTENTSIZE_ERROR) {
assert((frameSizeInfo.decompressedBound & (ZSTD_BLOCKSIZE_MAX - 1)) == 0);
frameSizeInfo.nbBlocks = (size_t)(frameSizeInfo.decompressedBound / ZSTD_BLOCKSIZE_MAX);
}
return frameSizeInfo;
}

View File

@ -1427,6 +1427,51 @@ ZSTDLIB_STATIC_API unsigned long long ZSTD_decompressBound(const void* src, size
* or an error code (if srcSize is too small) */
ZSTDLIB_STATIC_API size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize);
/*! ZSTD_decompressionMargin() :
* Zstd supports in-place decompression, where the input and output buffers overlap.
* In this case, the output buffer must be at least (Margin + Output_Size) bytes large,
* and the input buffer must be at the end of the output buffer.
*
* _______________________ Output Buffer ________________________
* | |
* | ____ Input Buffer ____|
* | | |
* v v v
* |---------------------------------------|-----------|----------|
* ^ ^ ^
* |___________________ Output_Size ___________________|_ Margin _|
*
* NOTE: See also ZSTD_DECOMPRESSION_MARGIN().
* NOTE: This applies only to single-pass decompression through ZSTD_decompress() or
* ZSTD_decompressDCtx().
* NOTE: This function supports multi-frame input.
*
* @param src The compressed frame(s)
* @param srcSize The size of the compressed frame(s)
* @returns The decompression margin or an error that can be checked with ZSTD_isError().
*/
ZSTDLIB_STATIC_API size_t ZSTD_decompressionMargin(const void* src, size_t srcSize);
/*! ZSTD_DECOMPRESS_MARGIN() :
* Similar to ZSTD_decompressionMargin(), but instead of computing the margin from
* the compressed frame, compute it from the original size and the blockSizeLog.
* See ZSTD_decompressionMargin() for details.
*
* WARNING: This macro does not support multi-frame input, the input must be a single
* zstd frame. If you need that support use the function, or implement it yourself.
*
* @param originalSize The original uncompressed size of the data.
* @param blockSize The block size == MIN(windowSize, ZSTD_BLOCKSIZE_MAX).
* Unless you explicitly set the windowLog smaller than
* ZSTD_BLOCKSIZELOG_MAX you can just use ZSTD_BLOCKSIZE_MAX.
*/
#define ZSTD_DECOMPRESSION_MARGIN(originalSize, blockSize) ((size_t)( \
ZSTD_FRAMEHEADERSIZE_MAX /* Frame header */ + \
4 /* checksum */ + \
((originalSize) == 0 ? 0 : 3 * (((originalSize) + (blockSize) - 1) / blockSize)) /* 3 bytes per block */ + \
(blockSize) /* One block of margin */ \
))
typedef enum {
ZSTD_sf_noBlockDelimiters = 0, /* Representation of ZSTD_Sequence has no block delimiters, sequences only */
ZSTD_sf_explicitBlockDelimiters = 1 /* Representation of ZSTD_Sequence contains explicit block delimiters */

View File

@ -26,6 +26,23 @@
static ZSTD_CCtx *cctx = NULL;
static ZSTD_DCtx *dctx = NULL;
static size_t getDecompressionMargin(void const* compressed, size_t cSize, size_t srcSize, int hasSmallBlocks)
{
size_t margin = ZSTD_decompressionMargin(compressed, cSize);
if (!hasSmallBlocks) {
/* The macro should be correct in this case, but it may be smaller
* because of e.g. block splitting, so take the smaller of the two.
*/
ZSTD_frameHeader zfh;
size_t marginM;
FUZZ_ZASSERT(ZSTD_getFrameHeader(&zfh, compressed, cSize));
marginM = ZSTD_DECOMPRESSION_MARGIN(srcSize, zfh.blockSizeMax);
if (marginM < margin)
margin = marginM;
}
return margin;
}
static size_t roundTripTest(void *result, size_t resultCapacity,
void *compressed, size_t compressedCapacity,
const void *src, size_t srcSize,
@ -67,6 +84,25 @@ static size_t roundTripTest(void *result, size_t resultCapacity,
}
dSize = ZSTD_decompressDCtx(dctx, result, resultCapacity, compressed, cSize);
FUZZ_ZASSERT(dSize);
FUZZ_ASSERT_MSG(dSize == srcSize, "Incorrect regenerated size");
FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, result, dSize), "Corruption!");
{
size_t margin = getDecompressionMargin(compressed, cSize, srcSize, targetCBlockSize);
size_t const outputSize = srcSize + margin;
char* const output = (char*)FUZZ_malloc(outputSize);
char* const input = output + outputSize - cSize;
FUZZ_ASSERT(outputSize >= cSize);
memcpy(input, compressed, cSize);
dSize = ZSTD_decompressDCtx(dctx, output, outputSize, input, cSize);
FUZZ_ZASSERT(dSize);
FUZZ_ASSERT_MSG(dSize == srcSize, "Incorrect regenerated size");
FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, output, srcSize), "Corruption!");
free(output);
}
/* When superblock is enabled make sure we don't expand the block more than expected.
* NOTE: This test is currently disabled because superblock mode can arbitrarily
* expand the block in the worst case. Once superblock mode has been improved we can
@ -120,13 +156,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size)
FUZZ_ASSERT(dctx);
}
{
size_t const result =
roundTripTest(rBuf, rBufSize, cBuf, cBufSize, src, size, producer);
FUZZ_ZASSERT(result);
FUZZ_ASSERT_MSG(result == size, "Incorrect regenerated size");
FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, rBuf, size), "Corruption!");
}
free(rBuf);
free(cBuf);
FUZZ_dataProducer_free(producer);

View File

@ -166,6 +166,24 @@ int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size)
FUZZ_ZASSERT(rSize);
FUZZ_ASSERT_MSG(rSize == size, "Incorrect regenerated size");
FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, rBuf, size), "Corruption!");
/* Test in-place decompression (note the macro doesn't work in this case) */
{
size_t const margin = ZSTD_decompressionMargin(cBuf, cSize);
size_t const outputSize = size + margin;
char* const output = (char*)FUZZ_malloc(outputSize);
char* const input = output + outputSize - cSize;
size_t dSize;
FUZZ_ASSERT(outputSize >= cSize);
memcpy(input, cBuf, cSize);
dSize = ZSTD_decompressDCtx(dctx, output, outputSize, input, cSize);
FUZZ_ZASSERT(dSize);
FUZZ_ASSERT_MSG(dSize == size, "Incorrect regenerated size");
FUZZ_ASSERT_MSG(!FUZZ_memcmp(src, output, size), "Corruption!");
free(output);
}
}
FUZZ_dataProducer_free(producer);

View File

@ -1220,6 +1220,60 @@ static int basicUnitTests(U32 const seed, double compressibility)
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : in-place decompression : ", testNb++);
cSize = ZSTD_compress(compressedBuffer, compressedBufferSize, CNBuffer, CNBuffSize, -ZSTD_BLOCKSIZE_MAX);
CHECK_Z(cSize);
CHECK_LT(CNBuffSize, cSize);
{
size_t const margin = ZSTD_decompressionMargin(compressedBuffer, cSize);
size_t const outputSize = (CNBuffSize + margin);
char* output = malloc(outputSize);
char* input = output + outputSize - cSize;
CHECK_LT(cSize, CNBuffSize + margin);
CHECK(output != NULL);
CHECK_Z(margin);
CHECK(margin <= ZSTD_DECOMPRESSION_MARGIN(CNBuffSize, ZSTD_BLOCKSIZE_MAX));
memcpy(input, compressedBuffer, cSize);
{
size_t const dSize = ZSTD_decompress(output, outputSize, input, cSize);
CHECK_Z(dSize);
CHECK_EQ(dSize, CNBuffSize);
}
CHECK(!memcmp(output, CNBuffer, CNBuffSize));
free(output);
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3i : in-place decompression with 2 frames : ", testNb++);
cSize = ZSTD_compress(compressedBuffer, compressedBufferSize, CNBuffer, CNBuffSize / 3, -ZSTD_BLOCKSIZE_MAX);
CHECK_Z(cSize);
{
size_t const cSize2 = ZSTD_compress((char*)compressedBuffer + cSize, compressedBufferSize - cSize, (char const*)CNBuffer + (CNBuffSize / 3), CNBuffSize / 3, -ZSTD_BLOCKSIZE_MAX);
CHECK_Z(cSize2);
cSize += cSize2;
}
{
size_t const srcSize = (CNBuffSize / 3) * 2;
size_t const margin = ZSTD_decompressionMargin(compressedBuffer, cSize);
size_t const outputSize = (CNBuffSize + margin);
char* output = malloc(outputSize);
char* input = output + outputSize - cSize;
CHECK_LT(cSize, CNBuffSize + margin);
CHECK(output != NULL);
CHECK_Z(margin);
memcpy(input, compressedBuffer, cSize);
{
size_t const dSize = ZSTD_decompress(output, outputSize, input, cSize);
CHECK_Z(dSize);
CHECK_EQ(dSize, srcSize);
}
CHECK(!memcmp(output, CNBuffer, srcSize));
free(output);
}
DISPLAYLEVEL(3, "OK \n");
DISPLAYLEVEL(3, "test%3d: superblock uncompressible data, too many nocompress superblocks : ", testNb++);
{
ZSTD_CCtx* const cctx = ZSTD_createCCtx();