1
0
mirror of https://github.com/FFmpeg/FFmpeg.git synced 2024-12-28 20:53:54 +02:00
FFmpeg/libavcodec/jpegxl_parser.c
Michael Niedermayer 850ab8f6da
avcodec/jpegxl_parser: Check get_vlc2()
Fixes: shift exponent -1 is negative
Fixes: 63889/clusterfuzz-testcase-minimized-ffmpeg_DEMUXER_fuzzer-6009343056936960

Found-by: continuous fuzzing process https://github.com/google/oss-fuzz/tree/master/projects/ffmpeg
Signed-off-by: Michael Niedermayer <michael@niedermayer.cc>
2023-12-29 19:21:26 +01:00

1539 lines
48 KiB
C

/**
* JPEG XL parser
* Copyright (c) 2023 Leo Izen <leo.izen@gmail.com>
*
* This file is part of FFmpeg.
*
* FFmpeg is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* FFmpeg is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with FFmpeg; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#include <errno.h>
#include <stdint.h>
#include <string.h>
#include "libavutil/attributes.h"
#include "libavutil/error.h"
#include "libavutil/intmath.h"
#include "libavutil/macros.h"
#include "libavutil/mem.h"
#include "libavutil/pixfmt.h"
#include "bytestream.h"
#include "codec_id.h"
#define UNCHECKED_BITSTREAM_READER 0
#define BITSTREAM_READER_LE
#include "get_bits.h"
#include "jpegxl.h"
#include "jpegxl_parse.h"
#include "parser.h"
#include "vlc.h"
#define JXL_FLAG_NOISE 1
#define JXL_FLAG_PATCHES 2
#define JXL_FLAG_SPLINES 16
#define JXL_FLAG_USE_LF_FRAME 32
#define JXL_FLAG_SKIP_ADAPTIVE_LF_SMOOTH 128
#define MAX_PREFIX_ALPHABET_SIZE (1u << 15)
#define clog1p(x) (ff_log2(x) + !!(x))
#define unpack_signed(x) (((x) & 1 ? -(x)-1 : (x))/2)
#define div_ceil(x, y) (((x) - 1) / (y) + 1)
#define vlm(a,b) {.sym = (a), .len = (b)}
typedef struct JXLHybridUintConf {
int split_exponent;
uint32_t msb_in_token;
uint32_t lsb_in_token;
} JXLHybridUintConf;
typedef struct JXLSymbolDistribution {
JXLHybridUintConf config;
int log_bucket_size;
/* this is the actual size of the alphabet */
int alphabet_size;
/* ceil(log(alphabet_size)) */
int log_alphabet_size;
/* for prefix code distributions */
VLC vlc;
/* in case bits == 0 */
uint32_t default_symbol;
/*
* each (1 << log_alphabet_size) length
* with log_alphabet_size <= 8
*/
/* frequencies associated with this Distribution */
uint32_t freq[258];
/* cutoffs for using the symbol table */
uint16_t cutoffs[258];
/* the symbol table for this distribution */
uint16_t symbols[258];
/* the offset for symbols */
uint16_t offsets[258];
/* if this distribution contains only one symbol this is its index */
int uniq_pos;
} JXLSymbolDistribution;
typedef struct JXLDistributionBundle {
/* lz77 flags */
int lz77_enabled;
uint32_t lz77_min_symbol;
uint32_t lz77_min_length;
JXLHybridUintConf lz_len_conf;
/* one entry for each distribution */
uint8_t *cluster_map;
/* length of cluster_map */
int num_dist;
/* one for each cluster */
JXLSymbolDistribution *dists;
int num_clusters;
/* whether to use brotli prefixes or ans */
int use_prefix_code;
/* bundle log alphabet size, dist ones may be smaller */
int log_alphabet_size;
} JXLDistributionBundle;
typedef struct JXLEntropyDecoder {
/* state is a positive 32-bit integer, or -1 if unset */
int64_t state;
/* lz77 values */
uint32_t num_to_copy;
uint32_t copy_pos;
uint32_t num_decoded;
/* length is (1 << 20) */
/* if lz77 is enabled for this bundle */
/* if lz77 is disabled it's NULL */
uint32_t *window;
/* primary bundle associated with this distribution */
JXLDistributionBundle bundle;
/* for av_log */
void *logctx;
} JXLEntropyDecoder;
typedef struct JXLFrame {
FFJXLFrameType type;
FFJXLFrameEncoding encoding;
int is_last;
int full_frame;
uint32_t total_length;
uint32_t body_length;
} JXLFrame;
typedef struct JXLCodestream {
FFJXLMetadata meta;
JXLFrame frame;
} JXLCodestream;
typedef struct JXLParseContext {
ParseContext pc;
JXLCodestream codestream;
/* using ISOBMFF-based container */
int container;
int skip;
int copied;
int collected_size;
int codestream_length;
int skipped_icc;
int next;
uint8_t cs_buffer[4096 + AV_INPUT_BUFFER_PADDING_SIZE];
} JXLParseContext;
/* used for reading brotli prefixes */
static const VLCElem level0_table[16] = {
vlm(0, 2), vlm(4, 2), vlm(3, 2), vlm(2, 3), vlm(0, 2), vlm(4, 2), vlm(3, 2), vlm(1, 4),
vlm(0, 2), vlm(4, 2), vlm(3, 2), vlm(2, 3), vlm(0, 2), vlm(4, 2), vlm(3, 2), vlm(5, 4),
};
/* prefix table for populating ANS distribution */
static const VLCElem dist_prefix_table[128] = {
vlm(10, 3), vlm(12, 7), vlm(7, 3), vlm(3, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(5, 4),
vlm(10, 3), vlm(4, 4), vlm(7, 3), vlm(1, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(2, 4),
vlm(10, 3), vlm(0, 5), vlm(7, 3), vlm(3, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(5, 4),
vlm(10, 3), vlm(4, 4), vlm(7, 3), vlm(1, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(2, 4),
vlm(10, 3), vlm(11, 6), vlm(7, 3), vlm(3, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(5, 4),
vlm(10, 3), vlm(4, 4), vlm(7, 3), vlm(1, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(2, 4),
vlm(10, 3), vlm(0, 5), vlm(7, 3), vlm(3, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(5, 4),
vlm(10, 3), vlm(4, 4), vlm(7, 3), vlm(1, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(2, 4),
vlm(10, 3), vlm(13, 7), vlm(7, 3), vlm(3, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(5, 4),
vlm(10, 3), vlm(4, 4), vlm(7, 3), vlm(1, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(2, 4),
vlm(10, 3), vlm(0, 5), vlm(7, 3), vlm(3, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(5, 4),
vlm(10, 3), vlm(4, 4), vlm(7, 3), vlm(1, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(2, 4),
vlm(10, 3), vlm(11, 6), vlm(7, 3), vlm(3, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(5, 4),
vlm(10, 3), vlm(4, 4), vlm(7, 3), vlm(1, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(2, 4),
vlm(10, 3), vlm(0, 5), vlm(7, 3), vlm(3, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(5, 4),
vlm(10, 3), vlm(4, 4), vlm(7, 3), vlm(1, 4), vlm(6, 3), vlm(8, 3), vlm(9, 3), vlm(2, 4),
};
static const uint8_t prefix_codelen_map[18] = {
1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15,
};
/**
* Read a variable-length 8-bit integer.
* Used when populating the ANS frequency tables.
*/
static av_always_inline uint8_t jxl_u8(GetBitContext *gb)
{
int n;
if (!get_bits1(gb))
return 0;
n = get_bits(gb, 3);
return get_bitsz(gb, n) | (1 << n);
}
/* read a U32(c_i + u(u_i)) */
static av_always_inline uint32_t jxl_u32(GetBitContext *gb,
uint32_t c0, uint32_t c1, uint32_t c2, uint32_t c3,
uint32_t u0, uint32_t u1, uint32_t u2, uint32_t u3)
{
const uint32_t constants[4] = {c0, c1, c2, c3};
const uint32_t ubits [4] = {u0, u1, u2, u3};
uint32_t ret, choice = get_bits(gb, 2);
ret = constants[choice];
if (ubits[choice])
ret += get_bits_long(gb, ubits[choice]);
return ret;
}
/* read a U64() */
static uint64_t jxl_u64(GetBitContext *gb)
{
uint64_t shift = 12, ret;
switch (get_bits(gb, 2)) {
case 1:
ret = 1 + get_bits(gb, 4);
break;
case 2:
ret = 17 + get_bits(gb, 8);
break;
case 3:
ret = get_bits(gb, 12);
while (get_bits1(gb)) {
if (shift < 60) {
ret |= (uint64_t)get_bits(gb, 8) << shift;
shift += 8;
} else {
ret |= (uint64_t)get_bits(gb, 4) << shift;
break;
}
}
break;
default:
ret = 0;
}
return ret;
}
static int read_hybrid_uint_conf(GetBitContext *gb, JXLHybridUintConf *conf, int log_alphabet_size)
{
conf->split_exponent = get_bitsz(gb, clog1p(log_alphabet_size));
if (conf->split_exponent == log_alphabet_size) {
conf->msb_in_token = conf->lsb_in_token = 0;
return 0;
}
conf->msb_in_token = get_bitsz(gb, clog1p(conf->split_exponent));
if (conf->msb_in_token > conf->split_exponent)
return AVERROR_INVALIDDATA;
conf->lsb_in_token = get_bitsz(gb, clog1p(conf->split_exponent - conf->msb_in_token));
if (conf->msb_in_token + conf->lsb_in_token > conf->split_exponent)
return AVERROR_INVALIDDATA;
return 0;
}
static int read_hybrid_uint(GetBitContext *gb, const JXLHybridUintConf *conf, uint32_t token, uint32_t *hybrid_uint)
{
uint32_t n, low, split = 1 << conf->split_exponent;
if (token < split) {
*hybrid_uint = token;
return 0;
}
n = conf->split_exponent - conf->lsb_in_token - conf->msb_in_token +
((token - split) >> (conf->msb_in_token + conf->lsb_in_token));
if (n >= 32)
return AVERROR_INVALIDDATA;
low = token & ((1 << conf->lsb_in_token) - 1);
token >>= conf->lsb_in_token;
token &= (1 << conf->msb_in_token) - 1;
token |= 1 << conf->msb_in_token;
*hybrid_uint = (((token << n) | get_bits_long(gb, n)) << conf->lsb_in_token ) | low;
return 0;
}
static inline uint32_t read_prefix_symbol(GetBitContext *gb, const JXLSymbolDistribution *dist)
{
if (!dist->vlc.bits)
return dist->default_symbol;
return get_vlc2(gb, dist->vlc.table, dist->vlc.bits, 1);
}
static uint32_t read_ans_symbol(GetBitContext *gb, JXLEntropyDecoder *dec, const JXLSymbolDistribution *dist)
{
uint32_t index, i, pos, symbol, offset;
if (dec->state < 0)
dec->state = get_bits_long(gb, 32);
index = dec->state & 0xFFF;
i = index >> dist->log_bucket_size;
pos = index & ((1 << dist->log_bucket_size) - 1);
symbol = pos >= dist->cutoffs[i] ? dist->symbols[i] : i;
offset = pos >= dist->cutoffs[i] ? dist->offsets[i] + pos : pos;
dec->state = dist->freq[symbol] * (dec->state >> 12) + offset;
if (dec->state < (1 << 16))
dec->state = (dec->state << 16) | get_bits(gb, 16);
dec->state &= 0xFFFFFFFF;
return symbol;
}
static int decode_hybrid_varlen_uint(GetBitContext *gb, JXLEntropyDecoder *dec,
const JXLDistributionBundle *bundle,
uint32_t context, uint32_t *hybrid_uint)
{
int ret;
uint32_t token, distance;
const JXLSymbolDistribution *dist;
if (dec->num_to_copy > 0) {
*hybrid_uint = dec->window[dec->copy_pos++ & 0xFFFFF];
dec->num_to_copy--;
dec->window[dec->num_decoded++ & 0xFFFFF] = *hybrid_uint;
return 0;
}
if (context >= bundle->num_dist)
return AVERROR(EINVAL);
if (bundle->cluster_map[context] >= bundle->num_clusters)
return AVERROR_INVALIDDATA;
dist = &bundle->dists[bundle->cluster_map[context]];
if (bundle->use_prefix_code)
token = read_prefix_symbol(gb, dist);
else
token = read_ans_symbol(gb, dec, dist);
if (bundle->lz77_enabled && token >= bundle->lz77_min_symbol) {
const JXLSymbolDistribution *lz77dist = &bundle->dists[bundle->cluster_map[bundle->num_dist - 1]];
ret = read_hybrid_uint(gb, &bundle->lz_len_conf, token - bundle->lz77_min_symbol, &dec->num_to_copy);
if (ret < 0)
return ret;
dec->num_to_copy += bundle->lz77_min_length;
if (bundle->use_prefix_code)
token = read_prefix_symbol(gb, lz77dist);
else
token = read_ans_symbol(gb, dec, lz77dist);
ret = read_hybrid_uint(gb, &lz77dist->config, token, &distance);
if (ret < 0)
return ret;
distance++;
distance = FFMIN3(distance, dec->num_decoded, 1 << 20);
dec->copy_pos = dec->num_decoded - distance;
return decode_hybrid_varlen_uint(gb, dec, bundle, context, hybrid_uint);
}
ret = read_hybrid_uint(gb, &dist->config, token, hybrid_uint);
if (ret < 0)
return ret;
if (bundle->lz77_enabled)
dec->window[dec->num_decoded++ & 0xFFFFF] = *hybrid_uint;
return 0;
}
static int populate_distribution(GetBitContext *gb, JXLSymbolDistribution *dist, int log_alphabet_size)
{
int len = 0, shift, omit_log = -1, omit_pos = -1;
int prev = 0, num_same = 0;
uint32_t total_count = 0;
uint8_t logcounts[258] = { 0 };
uint8_t same[258] = { 0 };
const int table_size = 1 << log_alphabet_size;
dist->uniq_pos = -1;
if (get_bits1(gb)) {
/* simple code */
if (get_bits1(gb)) {
uint8_t v1 = jxl_u8(gb);
uint8_t v2 = jxl_u8(gb);
if (v1 == v2)
return AVERROR_INVALIDDATA;
dist->freq[v1] = get_bits(gb, 12);
dist->freq[v2] = (1 << 12) - dist->freq[v1];
if (!dist->freq[v1])
dist->uniq_pos = v2;
dist->alphabet_size = 1 + FFMAX(v1, v2);
} else {
uint8_t x = jxl_u8(gb);
dist->freq[x] = 1 << 12;
dist->uniq_pos = x;
dist->alphabet_size = 1 + x;
}
if (dist->alphabet_size > table_size)
return AVERROR_INVALIDDATA;
return 0;
}
if (get_bits1(gb)) {
/* flat code */
dist->alphabet_size = jxl_u8(gb) + 1;
if (dist->alphabet_size > table_size)
return AVERROR_INVALIDDATA;
for (int i = 0; i < dist->alphabet_size; i++)
dist->freq[i] = (1 << 12) / dist->alphabet_size;
for (int i = 0; i < (1 << 12) % dist->alphabet_size; i++)
dist->freq[i]++;
return 0;
}
do {
if (!get_bits1(gb))
break;
} while (++len < 3);
shift = (get_bitsz(gb, len) | (1 << len)) - 1;
if (shift > 13)
return AVERROR_INVALIDDATA;
dist->alphabet_size = jxl_u8(gb) + 3;
if (dist->alphabet_size > table_size)
return AVERROR_INVALIDDATA;
for (int i = 0; i < dist->alphabet_size; i++) {
logcounts[i] = get_vlc2(gb, dist_prefix_table, 7, 1);
if (logcounts[i] == 13) {
int rle = jxl_u8(gb);
same[i] = rle + 5;
i += rle + 3;
continue;
}
if (logcounts[i] > omit_log) {
omit_log = logcounts[i];
omit_pos = i;
}
}
if (omit_pos < 0 || omit_pos + 1 < dist->alphabet_size && logcounts[omit_pos + 1] == 13)
return AVERROR_INVALIDDATA;
for (int i = 0; i < dist->alphabet_size; i++) {
if (same[i]) {
num_same = same[i] - 1;
prev = i > 0 ? dist->freq[i - 1] : 0;
}
if (num_same) {
dist->freq[i] = prev;
num_same--;
} else {
if (i == omit_pos || !logcounts[i])
continue;
if (logcounts[i] == 1) {
dist->freq[i] = 1;
} else {
int bitcount = FFMIN(FFMAX(0, shift - ((12 - logcounts[i] + 1) >> 1)), logcounts[i] - 1);
dist->freq[i] = (1 << (logcounts[i] - 1)) + (get_bitsz(gb, bitcount) << (logcounts[i] - 1 - bitcount));
}
}
total_count += dist->freq[i];
}
dist->freq[omit_pos] = (1 << 12) - total_count;
return 0;
}
static void dist_bundle_close(JXLDistributionBundle *bundle)
{
if (bundle->use_prefix_code && bundle->dists)
for (int i = 0; i < bundle->num_clusters; i++)
ff_vlc_free(&bundle->dists[i].vlc);
av_freep(&bundle->dists);
av_freep(&bundle->cluster_map);
}
static int read_distribution_bundle(GetBitContext *gb, JXLEntropyDecoder *dec,
JXLDistributionBundle *bundle, int num_dist, int disallow_lz77);
static int read_dist_clustering(GetBitContext *gb, JXLEntropyDecoder *dec, JXLDistributionBundle *bundle)
{
int ret;
bundle->cluster_map = av_malloc(bundle->num_dist);
if (!bundle->cluster_map)
return AVERROR(ENOMEM);
if (bundle->num_dist == 1) {
bundle->cluster_map[0] = 0;
bundle->num_clusters = 1;
return 0;
}
if (get_bits1(gb)) {
/* simple clustering */
uint32_t nbits = get_bits(gb, 2);
for (int i = 0; i < bundle->num_dist; i++)
bundle->cluster_map[i] = get_bitsz(gb, nbits);
} else {
/* complex clustering */
int use_mtf = get_bits1(gb);
JXLDistributionBundle nested = { 0 };
/* num_dist == 1 prevents this from recursing again */
ret = read_distribution_bundle(gb, dec, &nested, 1, bundle->num_dist <= 2);
if (ret < 0) {
dist_bundle_close(&nested);
return ret;
}
for (int i = 0; i < bundle->num_dist; i++) {
uint32_t clust;
ret = decode_hybrid_varlen_uint(gb, dec, &nested, 0, &clust);
if (ret < 0) {
dist_bundle_close(&nested);
return ret;
}
bundle->cluster_map[i] = clust;
}
dec->state = -1;
/* it's not going to necessarily be zero after reading */
dec->num_to_copy = 0;
dist_bundle_close(&nested);
if (use_mtf) {
uint8_t mtf[256];
for (int i = 0; i < 256; i++)
mtf[i] = i;
for (int i = 0; i < bundle->num_dist; i++) {
int index = bundle->cluster_map[i];
bundle->cluster_map[i] = mtf[index];
if (index) {
int value = mtf[index];
for (int j = index; j > 0; j--)
mtf[j] = mtf[j - 1];
mtf[0] = value;
}
}
}
}
for (int i = 0; i < bundle->num_dist; i++) {
if (bundle->cluster_map[i] >= bundle->num_clusters)
bundle->num_clusters = bundle->cluster_map[i] + 1;
}
if (bundle->num_clusters > bundle->num_dist)
return AVERROR_INVALIDDATA;
return 0;
}
static int gen_alias_map(JXLEntropyDecoder *dec, JXLSymbolDistribution *dist, int log_alphabet_size)
{
uint32_t bucket_size, table_size;
uint8_t overfull[256], underfull[256];
int overfull_pos = 0, underfull_pos = 0;
dist->log_bucket_size = 12 - log_alphabet_size;
bucket_size = 1 << dist->log_bucket_size;
table_size = 1 << log_alphabet_size;
if (dist->uniq_pos >= 0) {
for (int i = 0; i < table_size; i++) {
dist->symbols[i] = dist->uniq_pos;
dist->offsets[i] = bucket_size * i;
dist->cutoffs[i] = 0;
}
return 0;
}
for (int i = 0; i < dist->alphabet_size; i++) {
dist->cutoffs[i] = dist->freq[i];
dist->symbols[i] = i;
if (dist->cutoffs[i] > bucket_size)
overfull[overfull_pos++] = i;
else if (dist->cutoffs[i] < bucket_size)
underfull[underfull_pos++] = i;
}
for (int i = dist->alphabet_size; i < table_size; i++) {
dist->cutoffs[i] = 0;
underfull[underfull_pos++] = i;
}
while (overfull_pos) {
int o, u, by;
/* this should be impossible */
if (!underfull_pos)
return AVERROR_INVALIDDATA;
u = underfull[--underfull_pos];
o = overfull[--overfull_pos];
by = bucket_size - dist->cutoffs[u];
dist->cutoffs[o] -= by;
dist->symbols[u] = o;
dist->offsets[u] = dist->cutoffs[o];
if (dist->cutoffs[o] < bucket_size)
underfull[underfull_pos++] = o;
else if (dist->cutoffs[o] > bucket_size)
overfull[overfull_pos++] = o;
}
for (int i = 0; i < table_size; i++) {
if (dist->cutoffs[i] == bucket_size) {
dist->symbols[i] = i;
dist->offsets[i] = 0;
dist->cutoffs[i] = 0;
} else {
dist->offsets[i] -= dist->cutoffs[i];
}
}
return 0;
}
static int read_simple_vlc_prefix(GetBitContext *gb, JXLEntropyDecoder *dec, JXLSymbolDistribution *dist)
{
int nsym, tree_select, bits;
int8_t lens[4];
int16_t symbols[4];
nsym = 1 + get_bits(gb, 2);
for (int i = 0; i < nsym; i++)
symbols[i] = get_bitsz(gb, dist->log_alphabet_size);
if (nsym == 4)
tree_select = get_bits1(gb);
switch (nsym) {
case 1:
dist->vlc.bits = 0;
dist->default_symbol = symbols[0];
return 0;
case 2:
bits = 1;
lens[0] = 1, lens[1] = 1, lens[2] = 0, lens[3] = 0;
if (symbols[1] < symbols[0])
FFSWAP(int16_t, symbols[0], symbols[1]);
break;
case 3:
bits = 2;
lens[0] = 1, lens[1] = 2, lens[2] = 2, lens[3] = 0;
if (symbols[2] < symbols[1])
FFSWAP(int16_t, symbols[1], symbols[2]);
break;
case 4:
if (tree_select) {
bits = 3;
lens[0] = 1, lens[1] = 2, lens[2] = 3, lens[3] = 3;
if (symbols[3] < symbols[2])
FFSWAP(int16_t, symbols[2], symbols[3]);
} else {
bits = 2;
lens[0] = 2, lens[1] = 2, lens[2] = 2, lens[3] = 2;
while (1) {
if (symbols[1] < symbols[0])
FFSWAP(int16_t, symbols[0], symbols[1]);
if (symbols[3] < symbols[2])
FFSWAP(int16_t, symbols[2], symbols[3]);
if (symbols[1] <= symbols[2])
break;
FFSWAP(int16_t, symbols[1], symbols[2]);
}
}
break;
default:
// Challenge Complete! How did we get here?
return AVERROR_BUG;
}
return ff_vlc_init_from_lengths(&dist->vlc, bits, nsym, lens, 1, symbols,
2, 2, 0, VLC_INIT_LE, dec->logctx);
}
static int read_vlc_prefix(GetBitContext *gb, JXLEntropyDecoder *dec, JXLSymbolDistribution *dist)
{
int8_t level1_lens[18] = { 0 };
int8_t level1_lens_s[18] = { 0 };
int16_t level1_syms[18] = { 0 };
uint32_t level1_codecounts[19] = { 0 };
uint8_t *buf = NULL;
int8_t *level2_lens, *level2_lens_s;
int16_t *level2_syms;
uint32_t *level2_codecounts;
int repeat_count_prev = 0, repeat_count_zero = 0, prev = 8;
int total_code = 0, len, hskip, num_codes = 0, ret;
VLC level1_vlc = { 0 };
if (dist->alphabet_size == 1) {
dist->vlc.bits = 0;
dist->default_symbol = 0;
return 0;
}
hskip = get_bits(gb, 2);
if (hskip == 1)
return read_simple_vlc_prefix(gb, dec, dist);
level1_codecounts[0] = hskip;
for (int i = hskip; i < 18; i++) {
len = level1_lens[prefix_codelen_map[i]] = get_vlc2(gb, level0_table, 4, 1);
if (len < 0) {
ret = AVERROR_INVALIDDATA;
goto end;
}
level1_codecounts[len]++;
if (len) {
total_code += (32 >> len);
num_codes++;
}
if (total_code >= 32) {
level1_codecounts[0] += 18 - i - 1;
break;
}
}
if (total_code != 32 && num_codes >= 2 || num_codes < 1) {
ret = AVERROR_INVALIDDATA;
goto end;
}
for (int i = 1; i < 19; i++)
level1_codecounts[i] += level1_codecounts[i - 1];
for (int i = 17; i >= 0; i--) {
int idx = --level1_codecounts[level1_lens[i]];
level1_lens_s[idx] = level1_lens[i];
level1_syms[idx] = i;
}
ret = ff_vlc_init_from_lengths(&level1_vlc, 5, 18, level1_lens_s, 1, level1_syms, 2, 2,
0, VLC_INIT_LE, dec->logctx);
if (ret < 0)
goto end;
buf = av_mallocz(MAX_PREFIX_ALPHABET_SIZE * (2 * sizeof(int8_t) + sizeof(int16_t) + sizeof(uint32_t))
+ sizeof(uint32_t));
if (!buf) {
ret = AVERROR(ENOMEM);
goto end;
}
level2_lens = (int8_t *)buf;
level2_lens_s = (int8_t *)(buf + MAX_PREFIX_ALPHABET_SIZE * sizeof(int8_t));
level2_syms = (int16_t *)(buf + MAX_PREFIX_ALPHABET_SIZE * (2 * sizeof(int8_t)));
level2_codecounts = (uint32_t *)(buf + MAX_PREFIX_ALPHABET_SIZE * (2 * sizeof(int8_t) + sizeof(int16_t)));
total_code = 0;
for (int i = 0; i < dist->alphabet_size; i++) {
len = get_vlc2(gb, level1_vlc.table, 5, 1);
if (len < 0) {
ret = AVERROR_INVALIDDATA;
goto end;
}
if (get_bits_left(gb) < 0) {
ret = AVERROR_BUFFER_TOO_SMALL;
goto end;
}
if (len == 16) {
int extra = 3 + get_bits(gb, 2);
if (repeat_count_prev)
extra += 4 * (repeat_count_prev - 2) - repeat_count_prev;
extra = FFMIN(extra, dist->alphabet_size - i);
for (int j = 0; j < extra; j++)
level2_lens[i + j] = prev;
total_code += (32768 >> prev) * extra;
i += extra - 1;
repeat_count_prev += extra;
repeat_count_zero = 0;
level2_codecounts[prev] += extra;
} else if (len == 17) {
int extra = 3 + get_bits(gb, 3);
if (repeat_count_zero > 0)
extra += 8 * (repeat_count_zero - 2) - repeat_count_zero;
extra = FFMIN(extra, dist->alphabet_size - i);
i += extra - 1;
repeat_count_prev = 0;
repeat_count_zero += extra;
level2_codecounts[0] += extra;
} else {
level2_lens[i] = len;
repeat_count_prev = repeat_count_zero = 0;
if (len) {
total_code += (32768 >> len);
prev = len;
}
level2_codecounts[len]++;
}
if (total_code >= 32768) {
level2_codecounts[0] += dist->alphabet_size - i - 1;
break;
}
}
if (total_code != 32768 && level2_codecounts[0] < dist->alphabet_size - 1) {
ret = AVERROR_INVALIDDATA;
goto end;
}
for (int i = 1; i < dist->alphabet_size + 1; i++)
level2_codecounts[i] += level2_codecounts[i - 1];
for (int i = dist->alphabet_size - 1; i >= 0; i--) {
int idx = --level2_codecounts[level2_lens[i]];
level2_lens_s[idx] = level2_lens[i];
level2_syms[idx] = i;
}
ret = ff_vlc_init_from_lengths(&dist->vlc, 15, dist->alphabet_size, level2_lens_s,
1, level2_syms, 2, 2, 0, VLC_INIT_LE, dec->logctx);
end:
av_freep(&buf);
ff_vlc_free(&level1_vlc);
return ret;
}
static int read_distribution_bundle(GetBitContext *gb, JXLEntropyDecoder *dec,
JXLDistributionBundle *bundle, int num_dist, int disallow_lz77)
{
int ret;
if (num_dist <= 0)
return AVERROR(EINVAL);
bundle->num_dist = num_dist;
bundle->lz77_enabled = get_bits1(gb);
if (bundle->lz77_enabled) {
if (disallow_lz77)
return AVERROR_INVALIDDATA;
bundle->lz77_min_symbol = jxl_u32(gb, 224, 512, 4096, 8, 0, 0, 0, 15);
bundle->lz77_min_length = jxl_u32(gb, 3, 4, 5, 9, 0, 0, 2, 8);
bundle->num_dist++;
ret = read_hybrid_uint_conf(gb, &bundle->lz_len_conf, 8);
if (ret < 0)
return ret;
}
if (bundle->lz77_enabled && !dec->window) {
dec->window = av_malloc_array(1 << 20, sizeof(uint32_t));
if (!dec->window)
return AVERROR(ENOMEM);
}
ret = read_dist_clustering(gb, dec, bundle);
if (ret < 0)
return ret;
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
bundle->dists = av_calloc(bundle->num_clusters, sizeof(JXLSymbolDistribution));
if (!bundle->dists)
return AVERROR(ENOMEM);
bundle->use_prefix_code = get_bits1(gb);
bundle->log_alphabet_size = bundle->use_prefix_code ? 15 : 5 + get_bits(gb, 2);
for (int i = 0; i < bundle->num_clusters; i++) {
ret = read_hybrid_uint_conf(gb, &bundle->dists[i].config, bundle->log_alphabet_size);
if (ret < 0)
return ret;
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
}
if (bundle->use_prefix_code) {
for (int i = 0; i < bundle->num_clusters; i++) {
JXLSymbolDistribution *dist = &bundle->dists[i];
if (get_bits1(gb)) {
int n = get_bits(gb, 4);
dist->alphabet_size = 1 + (1 << n) + get_bitsz(gb, n);
if (dist->alphabet_size > MAX_PREFIX_ALPHABET_SIZE)
return AVERROR_INVALIDDATA;
} else {
dist->alphabet_size = 1;
}
dist->log_alphabet_size = clog1p(dist->alphabet_size - 1);
}
for (int i = 0; i < bundle->num_clusters; i++) {
ret = read_vlc_prefix(gb, dec, &bundle->dists[i]);
if (ret < 0)
return ret;
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
}
} else {
for (int i = 0; i < bundle->num_clusters; i++) {
ret = populate_distribution(gb, &bundle->dists[i], bundle->log_alphabet_size);
if (ret < 0)
return ret;
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
}
for (int i = 0; i < bundle->num_clusters; i++) {
ret = gen_alias_map(dec, &bundle->dists[i], bundle->log_alphabet_size);
if (ret < 0)
return ret;
}
}
return 0;
}
static void entropy_decoder_close(JXLEntropyDecoder *dec)
{
if (!dec)
return;
av_freep(&dec->window);
dist_bundle_close(&dec->bundle);
}
static int entropy_decoder_init(void *avctx, GetBitContext *gb, JXLEntropyDecoder *dec, int num_dist)
{
int ret;
memset(dec, 0, sizeof(*dec));
dec->logctx = avctx;
dec->state = -1;
ret = read_distribution_bundle(gb, dec, &dec->bundle, num_dist, 0);
if (ret < 0) {
entropy_decoder_close(dec);
return ret;
}
return 0;
}
static int64_t entropy_decoder_read_symbol(GetBitContext *gb, JXLEntropyDecoder *dec, uint32_t context)
{
int ret;
uint32_t hybrid_uint;
ret = decode_hybrid_varlen_uint(gb, dec, &dec->bundle, context, &hybrid_uint);
if (ret < 0)
return ret;
return hybrid_uint;
}
static inline uint32_t icc_context(uint64_t i, uint32_t b1, uint32_t b2)
{
uint32_t p1, p2;
if (i <= 128)
return 0;
if (b1 >= 'a' && b1 <= 'z' || b1 >= 'A' && b1 <= 'Z')
p1 = 0;
else if (b1 >= '0' && b1 <= '9' || b1 == '.' || b1 == ',')
p1 = 1;
else if (b1 <= 1)
p1 = b1 + 2;
else if (b1 > 1 && b1 < 16)
p1 = 4;
else if (b1 > 240 && b1 < 255)
p1 = 5;
else if (b1 == 255)
p1 = 6;
else
p1 = 7;
if (b2 >= 'a' && b2 <= 'z' || b2 >= 'A' && b2 <= 'Z')
p2 = 0;
else if (b2 >= '0' && b2 <= '9' || b2 == '.' || b2 == ',')
p2 = 1;
else if (b2 < 16)
p2 = 2;
else if (b2 > 240)
p2 = 3;
else
p2 = 4;
return 1 + p1 + p2 * 8;
}
static inline uint32_t toc_context(uint32_t x)
{
return FFMIN(7, clog1p(x));
}
static void populate_fields(AVCodecParserContext *s, AVCodecContext *avctx, const FFJXLMetadata *meta)
{
s->width = meta->width;
s->height = meta->height;
switch (meta->csp) {
case JPEGXL_CS_RGB:
case JPEGXL_CS_XYB:
avctx->colorspace = AVCOL_SPC_RGB;
break;
default:
avctx->colorspace = AVCOL_SPC_UNSPECIFIED;
}
if (meta->wp == JPEGXL_WP_D65) {
switch (meta->primaries) {
case JPEGXL_PR_SRGB:
avctx->color_primaries = AVCOL_PRI_BT709;
break;
case JPEGXL_PR_P3:
avctx->color_primaries = AVCOL_PRI_SMPTE432;
break;
case JPEGXL_PR_2100:
avctx->color_primaries = AVCOL_PRI_BT2020;
break;
default:
avctx->color_primaries = AVCOL_PRI_UNSPECIFIED;
}
} else if (meta->wp == JPEGXL_WP_DCI && meta->primaries == JPEGXL_PR_P3) {
avctx->color_primaries = AVCOL_PRI_SMPTE431;
} else {
avctx->color_primaries = AVCOL_PRI_UNSPECIFIED;
}
if (meta->trc > JPEGXL_TR_GAMMA) {
FFJXLTransferCharacteristic trc = meta->trc - JPEGXL_TR_GAMMA;
switch (trc) {
case JPEGXL_TR_BT709:
avctx->color_trc = AVCOL_TRC_BT709;
break;
case JPEGXL_TR_LINEAR:
avctx->color_trc = AVCOL_TRC_LINEAR;
break;
case JPEGXL_TR_SRGB:
avctx->color_trc = AVCOL_TRC_IEC61966_2_1;
break;
case JPEGXL_TR_PQ:
avctx->color_trc = AVCOL_TRC_SMPTEST2084;
break;
case JPEGXL_TR_DCI:
avctx->color_trc = AVCOL_TRC_SMPTE428;
break;
case JPEGXL_TR_HLG:
avctx->color_trc = AVCOL_TRC_ARIB_STD_B67;
break;
default:
avctx->color_trc = AVCOL_TRC_UNSPECIFIED;
}
} else if (meta->trc > 0) {
if (meta->trc > 45355 && meta->trc < 45555)
avctx->color_trc = AVCOL_TRC_GAMMA22;
else if (meta->trc > 35614 && meta->trc < 35814)
avctx->color_trc = AVCOL_TRC_GAMMA28;
else
avctx->color_trc = AVCOL_TRC_UNSPECIFIED;
} else {
avctx->color_trc = AVCOL_TRC_UNSPECIFIED;
}
if (meta->csp == JPEGXL_CS_GRAY) {
if (meta->bit_depth <= 8)
s->format = meta->have_alpha ? AV_PIX_FMT_YA8 : AV_PIX_FMT_GRAY8;
else if (meta->bit_depth <= 16)
s->format = meta->have_alpha ? AV_PIX_FMT_YA16 : AV_PIX_FMT_GRAY16;
else
s->format = meta->have_alpha ? AV_PIX_FMT_NONE : AV_PIX_FMT_GRAYF32;
} else {
if (meta->bit_depth <= 8)
s->format = meta->have_alpha ? AV_PIX_FMT_RGBA : AV_PIX_FMT_RGB24;
else if (meta->bit_depth <= 16)
s->format = meta->have_alpha ? AV_PIX_FMT_RGBA64 : AV_PIX_FMT_RGB48;
else
s->format = meta->have_alpha ? AV_PIX_FMT_RGBAF32 : AV_PIX_FMT_RGBF32;
}
}
static int skip_icc_profile(void *avctx, JXLParseContext *ctx, GetBitContext *gb)
{
int64_t ret;
uint32_t last = 0, last2 = 0;
JXLEntropyDecoder dec = { 0 };
uint64_t enc_size = jxl_u64(gb);
uint64_t output_size = 0;
int out_size_shift = 0;
if (!enc_size || enc_size > (1 << 22))
return AVERROR_INVALIDDATA;
ret = entropy_decoder_init(avctx, gb, &dec, 41);
if (ret < 0)
goto end;
if (get_bits_left(gb) < 0) {
ret = AVERROR_BUFFER_TOO_SMALL;
goto end;
}
for (uint64_t read = 0; read < enc_size; read++) {
ret = entropy_decoder_read_symbol(gb, &dec, icc_context(read, last, last2));
if (ret < 0)
goto end;
if (ret > 255) {
ret = AVERROR_INVALIDDATA;
goto end;
}
if (get_bits_left(gb) < 0) {
ret = AVERROR_BUFFER_TOO_SMALL;
goto end;
}
last2 = last;
last = ret;
if (out_size_shift < 63) {
output_size += (ret & UINT64_C(0x7F)) << out_size_shift;
if (!(ret & 0x80)) {
out_size_shift = 63;
} else {
out_size_shift += 7;
if (out_size_shift > 56) {
ret = AVERROR_INVALIDDATA;
goto end;
}
}
} else if (output_size < 132) {
ret = AVERROR_INVALIDDATA;
goto end;
}
}
ret = 0;
end:
entropy_decoder_close(&dec);
return ret;
}
static int skip_extensions(GetBitContext *gb)
{
uint64_t extensions = jxl_u64(gb), extensions_len = 0;
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
if (!extensions)
return 0;
for (int i = 0; i < 64; i++) {
if (extensions & (UINT64_C(1) << i))
extensions_len += jxl_u64(gb);
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
}
if (extensions_len > INT_MAX || get_bits_left(gb) < extensions_len)
return AVERROR_BUFFER_TOO_SMALL;
skip_bits_long(gb, extensions_len);
return 0;
}
static int parse_frame_header(void *avctx, JXLParseContext *ctx, GetBitContext *gb)
{
int all_default, do_yCbCr = 0, num_passes = 1, ret;
int group_size_shift = 1, lf_level = 0, save_as_ref = 0;
int have_crop = 0, full_frame = 1, resets_canvas = 1, upsampling = 1;
JXLFrame *frame = &ctx->codestream.frame;
const FFJXLMetadata *meta = &ctx->codestream.meta;
int32_t x0 = 0, y0 = 0;
uint32_t duration = 0, width = meta->coded_width, height = meta->coded_height;
uint32_t name_len, num_groups, num_lf_groups, group_dim, lf_group_dim, toc_count;
uint64_t flags = 0;
int start_len = get_bits_count(gb);
memset(frame, 0, sizeof(*frame));
frame->is_last = 1;
all_default = get_bits1(gb);
if (!all_default) {
frame->type = get_bits(gb, 2);
frame->encoding = get_bits1(gb);
flags = jxl_u64(gb);
if (!meta->xyb_encoded)
do_yCbCr = get_bits1(gb);
if (!(flags & JXL_FLAG_USE_LF_FRAME)) {
if (do_yCbCr)
skip_bits(gb, 6); // jpeg upsampling
upsampling = jxl_u32(gb, 1, 2, 4, 8, 0, 0, 0, 0);
skip_bits_long(gb, 2 * meta->num_extra_channels);
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
}
if (frame->encoding == JPEGXL_ENC_MODULAR)
group_size_shift = get_bits(gb, 2);
else if (meta->xyb_encoded)
skip_bits(gb, 6); // xqm and bqm scales
if (frame->type != JPEGXL_FRAME_REFERENCE_ONLY) {
num_passes = jxl_u32(gb, 1, 2, 3, 4, 0, 0, 0, 3);
if (num_passes != 1) {
int num_ds = jxl_u32(gb, 0, 1, 2, 3, 0, 0, 0, 1);
skip_bits(gb, 2 * (num_passes - 1)); // shift
skip_bits(gb, 2 * num_ds); // downsample
for (int i = 0; i < num_ds; i++)
jxl_u32(gb, 0, 1, 2, 0, 0, 0, 0, 3);
}
}
if (frame->type == JPEGXL_FRAME_LF)
lf_level = 1 + get_bits(gb, 2);
else
have_crop = get_bits1(gb);
if (have_crop) {
if (frame->type != JPEGXL_FRAME_REFERENCE_ONLY) {
uint32_t ux0 = jxl_u32(gb, 0, 256, 2304, 18688, 8, 11, 14, 30);
uint32_t uy0 = jxl_u32(gb, 0, 256, 2304, 18688, 8, 11, 14, 30);
x0 = unpack_signed(ux0);
y0 = unpack_signed(uy0);
}
width = jxl_u32(gb, 0, 256, 2304, 18688, 8, 11, 14, 30);
height = jxl_u32(gb, 0, 256, 2304, 18688, 8, 11, 14, 30);
full_frame = x0 <= 0 && y0 <= 0 && width + x0 >= meta->coded_width
&& height + y0 >= meta->coded_height;
}
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
if (frame->type == JPEGXL_FRAME_REGULAR || frame->type == JPEGXL_FRAME_SKIP_PROGRESSIVE) {
for (int i = 0; i <= meta->num_extra_channels; i++) {
int mode = jxl_u32(gb, 0, 1, 2, 3, 0, 0, 0, 2);
if (meta->num_extra_channels && (mode == JPEGXL_BM_BLEND || mode == JPEGXL_BM_MULADD))
jxl_u32(gb, 0, 1, 2, 3, 0, 0, 0, 2);
if (meta->num_extra_channels && (mode == JPEGXL_BM_BLEND || mode == JPEGXL_BM_MULADD
|| mode == JPEGXL_BM_MUL))
skip_bits1(gb);
if (!i)
resets_canvas = mode == JPEGXL_BM_REPLACE && full_frame;
if (!resets_canvas)
skip_bits(gb, 2);
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
}
if (meta->animation_offset)
duration = jxl_u32(gb, 0, 1, 0, 0, 0, 0, 8, 32);
if (meta->have_timecodes)
skip_bits_long(gb, 32);
frame->is_last = get_bits1(gb);
} else {
frame->is_last = 0;
}
if (frame->type != JPEGXL_FRAME_LF && !frame->is_last)
save_as_ref = get_bits(gb, 2);
if (frame->type == JPEGXL_FRAME_REFERENCE_ONLY ||
(resets_canvas && !frame->is_last && (!duration || save_as_ref)
&& frame->type != JPEGXL_FRAME_LF))
skip_bits1(gb); // save before color transform
name_len = 8 * jxl_u32(gb, 0, 0, 16, 48, 0, 4, 5, 10);
if (get_bits_left(gb) < name_len)
return AVERROR_BUFFER_TOO_SMALL;
skip_bits_long(gb, name_len);
}
if (!all_default) {
int restd = get_bits1(gb), gab = 1;
if (!restd)
gab = get_bits1(gb);
if (gab && !restd && get_bits1(gb))
// gab custom
skip_bits_long(gb, 16 * 6);
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
if (!restd) {
int epf = get_bits(gb, 2);
if (epf) {
if (frame->encoding == JPEGXL_ENC_VARDCT && get_bits1(gb)) {
skip_bits_long(gb, 16 * 8); // custom epf sharpness
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
}
if (get_bits1(gb)) {
skip_bits_long(gb, 3 * 16 + 32); // custom epf weight
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
}
if (get_bits1(gb)) { // custom epf sigma
if (frame->encoding == JPEGXL_ENC_VARDCT)
skip_bits(gb, 16);
skip_bits_long(gb, 16 * 3);
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
}
if (frame->encoding == JPEGXL_ENC_MODULAR)
skip_bits(gb, 16);
}
ret = skip_extensions(gb);
if (ret < 0)
return ret;
}
ret = skip_extensions(gb);
if (ret < 0)
return ret;
}
width = div_ceil(div_ceil(width, upsampling), 1 << (3 * lf_level));
height = div_ceil(div_ceil(height, upsampling), 1 << (3 * lf_level));
group_dim = 128 << group_size_shift;
lf_group_dim = group_dim << 3;
num_groups = div_ceil(width, group_dim) * div_ceil(height, group_dim);
num_lf_groups = div_ceil(width, lf_group_dim) * div_ceil(height, lf_group_dim);
if (num_groups == 1 && num_passes == 1)
toc_count = 1;
else
toc_count = 2 + num_lf_groups + num_groups * num_passes;
// permuted toc
if (get_bits1(gb)) {
JXLEntropyDecoder dec;
uint32_t end, lehmer = 0;
ret = entropy_decoder_init(avctx, gb, &dec, 8);
if (ret < 0)
return ret;
if (get_bits_left(gb) < 0) {
entropy_decoder_close(&dec);
return AVERROR_BUFFER_TOO_SMALL;
}
end = entropy_decoder_read_symbol(gb, &dec, toc_context(toc_count));
if (end > toc_count) {
entropy_decoder_close(&dec);
return AVERROR_INVALIDDATA;
}
for (uint32_t i = 0; i < end; i++) {
lehmer = entropy_decoder_read_symbol(gb, &dec, toc_context(lehmer));
if (get_bits_left(gb) < 0) {
entropy_decoder_close(&dec);
return AVERROR_BUFFER_TOO_SMALL;
}
}
entropy_decoder_close(&dec);
}
align_get_bits(gb);
for (uint32_t i = 0; i < toc_count; i++) {
frame->body_length += 8 * jxl_u32(gb, 0, 1024, 17408, 4211712, 10, 14, 22, 30);
if (get_bits_left(gb) < 0)
return AVERROR_BUFFER_TOO_SMALL;
}
align_get_bits(gb);
frame->total_length = frame->body_length + get_bits_count(gb) - start_len;
return 0;
}
static int skip_boxes(JXLParseContext *ctx, const uint8_t *buf, int buf_size)
{
GetByteContext gb;
if (ctx->skip > buf_size)
return AVERROR_BUFFER_TOO_SMALL;
buf += ctx->skip;
buf_size -= ctx->skip;
bytestream2_init(&gb, buf, buf_size);
while (1) {
uint64_t size;
int head_size = 8;
if (bytestream2_peek_le16(&gb) == FF_JPEGXL_CODESTREAM_SIGNATURE_LE)
break;
if (bytestream2_peek_le64(&gb) == FF_JPEGXL_CONTAINER_SIGNATURE_LE)
break;
if (bytestream2_get_bytes_left(&gb) < 8)
return AVERROR_BUFFER_TOO_SMALL;
size = bytestream2_get_be32(&gb);
bytestream2_skip(&gb, 4); // tag
if (size == 1) {
if (bytestream2_get_bytes_left(&gb) < 8)
return AVERROR_BUFFER_TOO_SMALL;
size = bytestream2_get_be64(&gb);
head_size = 16;
}
if (!size)
return AVERROR_INVALIDDATA;
/* invalid ISOBMFF size */
if (size <= head_size || size > INT_MAX - ctx->skip)
return AVERROR_INVALIDDATA;
ctx->skip += size;
bytestream2_skip(&gb, size - head_size);
if (bytestream2_get_bytes_left(&gb) <= 0)
return AVERROR_BUFFER_TOO_SMALL;
}
return 0;
}
static int try_parse(AVCodecParserContext *s, AVCodecContext *avctx, JXLParseContext *ctx,
const uint8_t *buf, int buf_size)
{
int ret, cs_buflen, header_skip;
const uint8_t *cs_buffer;
GetBitContext gb;
if (ctx->skip > buf_size)
return AVERROR_BUFFER_TOO_SMALL;
buf += ctx->skip;
buf_size -= ctx->skip;
if (ctx->container || AV_RL64(buf) == FF_JPEGXL_CONTAINER_SIGNATURE_LE) {
ctx->container = 1;
ret = ff_jpegxl_collect_codestream_header(buf, buf_size, ctx->cs_buffer,
sizeof(ctx->cs_buffer) - AV_INPUT_BUFFER_PADDING_SIZE, &ctx->copied);
if (ret < 0)
return ret;
ctx->collected_size = ret;
if (!ctx->copied) {
ctx->skip += ret;
return AVERROR_BUFFER_TOO_SMALL;
}
cs_buffer = ctx->cs_buffer;
cs_buflen = FFMIN(sizeof(ctx->cs_buffer) - AV_INPUT_BUFFER_PADDING_SIZE, ctx->copied);
} else {
cs_buffer = buf;
cs_buflen = buf_size;
}
if (!ctx->codestream_length) {
header_skip = ff_jpegxl_parse_codestream_header(cs_buffer, cs_buflen, &ctx->codestream.meta, 0);
if (header_skip < 0)
return header_skip;
ctx->codestream_length = header_skip;
populate_fields(s, avctx, &ctx->codestream.meta);
}
if (ctx->container)
return ctx->collected_size;
ret = init_get_bits8(&gb, cs_buffer, cs_buflen);
if (ret < 0)
return ret;
skip_bits_long(&gb, ctx->codestream_length);
if (!ctx->skipped_icc && ctx->codestream.meta.have_icc_profile) {
ret = skip_icc_profile(avctx, ctx, &gb);
if (ret < 0)
return ret;
ctx->skipped_icc = 1;
align_get_bits(&gb);
ctx->codestream_length = get_bits_count(&gb);
}
if (get_bits_left(&gb) <= 0)
return AVERROR_BUFFER_TOO_SMALL;
while (1) {
ret = parse_frame_header(avctx, ctx, &gb);
if (ret < 0)
return ret;
ctx->codestream_length += ctx->codestream.frame.total_length;
if (ctx->codestream.frame.is_last)
return ctx->codestream_length / 8;
if (get_bits_left(&gb) <= ctx->codestream.frame.body_length)
return AVERROR_BUFFER_TOO_SMALL;
skip_bits_long(&gb, ctx->codestream.frame.body_length);
}
}
static int jpegxl_parse(AVCodecParserContext *s, AVCodecContext *avctx,
const uint8_t **poutbuf, int *poutbuf_size,
const uint8_t *buf, int buf_size)
{
JXLParseContext *ctx = s->priv_data;
int next = END_NOT_FOUND, ret;
const uint8_t *pbuf = ctx->pc.buffer;
int pindex = ctx->pc.index;
*poutbuf_size = 0;
*poutbuf = NULL;
if (!ctx->pc.index) {
if (ctx->pc.overread)
goto flush;
pbuf = buf;
pindex = buf_size;
}
if ((!ctx->container || !ctx->codestream_length) && !ctx->next) {
ret = try_parse(s, avctx, ctx, pbuf, pindex);
if (ret < 0)
goto flush;
ctx->next = ret;
if (ctx->container)
ctx->skip += ctx->next;
}
if (ctx->container && ctx->next >= 0) {
ret = skip_boxes(ctx, pbuf, pindex);
if (ret < 0) {
if (ret == AVERROR_INVALIDDATA)
ctx->next = -1;
goto flush;
}
ctx->next = ret + ctx->skip;
}
if (ctx->next >= 0)
next = ctx->next - ctx->pc.index;
flush:
if (next > buf_size)
next = END_NOT_FOUND;
ret = ff_combine_frame(&ctx->pc, next, &buf, &buf_size);
if (ret < 0)
return buf_size;
*poutbuf = buf;
*poutbuf_size = buf_size;
ctx->codestream_length = 0;
ctx->collected_size = 0;
ctx->container = 0;
ctx->copied = 0;
ctx->skip = 0;
ctx->skipped_icc = 0;
ctx->next = 0;
memset(&ctx->codestream, 0, sizeof(ctx->codestream));
return next;
}
const AVCodecParser ff_jpegxl_parser = {
.codec_ids = { AV_CODEC_ID_JPEGXL },
.priv_data_size = sizeof(JXLParseContext),
.parser_parse = jpegxl_parse,
.parser_close = ff_parse_close,
};