1
0
mirror of https://github.com/FFmpeg/FFmpeg.git synced 2025-02-09 14:14:39 +02:00

avfilter/vf_bm3d: switch to TX from lavu

This commit is contained in:
Paul B Mahol 2022-12-04 17:32:04 +01:00
parent 9bfae83856
commit 6c814093d8
2 changed files with 166 additions and 169 deletions

3
configure vendored
View File

@ -3629,8 +3629,6 @@ avgblur_vulkan_filter_deps="vulkan spirv_compiler"
azmq_filter_deps="libzmq" azmq_filter_deps="libzmq"
blackframe_filter_deps="gpl" blackframe_filter_deps="gpl"
blend_vulkan_filter_deps="vulkan spirv_compiler" blend_vulkan_filter_deps="vulkan spirv_compiler"
bm3d_filter_deps="avcodec"
bm3d_filter_select="dct"
boxblur_filter_deps="gpl" boxblur_filter_deps="gpl"
boxblur_opencl_filter_deps="opencl gpl" boxblur_opencl_filter_deps="opencl gpl"
bs2b_filter_deps="libbs2b" bs2b_filter_deps="libbs2b"
@ -7444,7 +7442,6 @@ enabled zlib && add_cppflags -DZLIB_CONST
# conditional library dependencies, in any order # conditional library dependencies, in any order
enabled amovie_filter && prepend avfilter_deps "avformat avcodec" enabled amovie_filter && prepend avfilter_deps "avformat avcodec"
enabled aresample_filter && prepend avfilter_deps "swresample" enabled aresample_filter && prepend avfilter_deps "swresample"
enabled bm3d_filter && prepend avfilter_deps "avcodec"
enabled cover_rect_filter && prepend avfilter_deps "avformat avcodec" enabled cover_rect_filter && prepend avfilter_deps "avformat avcodec"
enabled ebur128_filter && enabled swresample && prepend avfilter_deps "swresample" enabled ebur128_filter && enabled swresample && prepend avfilter_deps "swresample"
enabled elbg_filter && prepend avfilter_deps "avcodec" enabled elbg_filter && prepend avfilter_deps "avcodec"

View File

@ -25,17 +25,17 @@
/** /**
* @todo * @todo
* - non-power of 2 DCT
* - opponent color space * - opponent color space
* - temporal support * - temporal support
*/ */
#include <float.h> #include <float.h>
#include "libavutil/cpu.h"
#include "libavutil/imgutils.h" #include "libavutil/imgutils.h"
#include "libavutil/opt.h" #include "libavutil/opt.h"
#include "libavutil/pixdesc.h" #include "libavutil/pixdesc.h"
#include "libavcodec/avfft.h" #include "libavutil/tx.h"
#include "avfilter.h" #include "avfilter.h"
#include "filters.h" #include "filters.h"
#include "formats.h" #include "formats.h"
@ -69,16 +69,19 @@ typedef struct PosPairCode {
} PosPairCode; } PosPairCode;
typedef struct SliceContext { typedef struct SliceContext {
DCTContext *gdctf, *gdcti; AVTXContext *gdctf, *gdcti;
DCTContext *dctf, *dcti; av_tx_fn tx_fn_g, itx_fn_g;
FFTSample *bufferh; AVTXContext *dctf, *dcti;
FFTSample *bufferv; av_tx_fn tx_fn, itx_fn;
FFTSample *bufferz; float *bufferh;
FFTSample *buffer; float *buffert;
FFTSample *rbufferh; float *bufferv;
FFTSample *rbufferv; float *bufferz;
FFTSample *rbufferz; float *buffer;
FFTSample *rbuffer; float *rbufferh;
float *rbufferv;
float *rbufferz;
float *rbuffer;
float *num, *den; float *num, *den;
PosPairCode match_blocks[256]; PosPairCode match_blocks[256];
int nb_match_blocks; int nb_match_blocks;
@ -105,7 +108,7 @@ typedef struct BM3DContext {
int nb_planes; int nb_planes;
int planewidth[4]; int planewidth[4];
int planeheight[4]; int planeheight[4];
int group_bits; int pblock_size;
int pgroup_size; int pgroup_size;
SliceContext slices[MAX_NB_THREADS]; SliceContext slices[MAX_NB_THREADS];
@ -128,11 +131,12 @@ typedef struct BM3DContext {
#define OFFSET(x) offsetof(BM3DContext, x) #define OFFSET(x) offsetof(BM3DContext, x)
#define FLAGS AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_VIDEO_PARAM #define FLAGS AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_VIDEO_PARAM
static const AVOption bm3d_options[] = { static const AVOption bm3d_options[] = {
{ "sigma", "set denoising strength", { "sigma", "set denoising strength",
OFFSET(sigma), AV_OPT_TYPE_FLOAT, {.dbl=1}, 0, 99999.9, FLAGS }, OFFSET(sigma), AV_OPT_TYPE_FLOAT, {.dbl=1}, 0, 99999.9, FLAGS },
{ "block", "set log2(size) of local patch", { "block", "set size of local patch",
OFFSET(block_size), AV_OPT_TYPE_INT, {.i64=4}, 4, 6, FLAGS }, OFFSET(block_size), AV_OPT_TYPE_INT, {.i64=16}, 8, 64, FLAGS },
{ "bstep", "set sliding step for processing blocks", { "bstep", "set sliding step for processing blocks",
OFFSET(block_step), AV_OPT_TYPE_INT, {.i64=4}, 1, 64, FLAGS }, OFFSET(block_step), AV_OPT_TYPE_INT, {.i64=4}, 1, 64, FLAGS },
{ "group", "set maximal number of similar blocks", { "group", "set maximal number of similar blocks",
@ -273,9 +277,9 @@ static void do_block_matching_multi(BM3DContext *s, const uint8_t *src, int src_
double MSE2SSE = s->group_size * s->block_size * s->block_size * src_range * src_range / (s->max * s->max); double MSE2SSE = s->group_size * s->block_size * s->block_size * src_range * src_range / (s->max * s->max);
double distMul = 1. / MSE2SSE; double distMul = 1. / MSE2SSE;
double th_sse = th_mse * MSE2SSE; double th_sse = th_mse * MSE2SSE;
int i, index = sc->nb_match_blocks; int index = sc->nb_match_blocks;
for (i = 0; i < search_size; i++) { for (int i = 0; i < search_size; i++) {
PosCode pos = search_pos[i]; PosCode pos = search_pos[i];
double dist; double dist;
@ -316,10 +320,10 @@ static void block_matching_multi(BM3DContext *s, const uint8_t *ref, int ref_lin
int r = search_boundary(width - block_size, range, step, 0, y, x); int r = search_boundary(width - block_size, range, step, 0, y, x);
int t = search_boundary(0, range, step, 1, y, x); int t = search_boundary(0, range, step, 1, y, x);
int b = search_boundary(height - block_size, range, step, 1, y, x); int b = search_boundary(height - block_size, range, step, 1, y, x);
int j, i, index = 0; int index = 0;
for (j = t; j <= b; j += step) { for (int j = t; j <= b; j += step) {
for (i = l; i <= r; i += step) { for (int i = l; i <= r; i += step) {
PosCode pos; PosCode pos;
if (exclude_cur_pos > 0 && j == y && i == x) { if (exclude_cur_pos > 0 && j == y && i == x) {
@ -364,22 +368,18 @@ static void get_block_row(const uint8_t *srcp, int src_linesize,
int y, int x, int block_size, float *dst) int y, int x, int block_size, float *dst)
{ {
const uint8_t *src = srcp + y * src_linesize + x; const uint8_t *src = srcp + y * src_linesize + x;
int j;
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++)
dst[j] = src[j]; dst[j] = src[j];
}
} }
static void get_block_row16(const uint8_t *srcp, int src_linesize, static void get_block_row16(const uint8_t *srcp, int src_linesize,
int y, int x, int block_size, float *dst) int y, int x, int block_size, float *dst)
{ {
const uint16_t *src = (uint16_t *)srcp + y * src_linesize / 2 + x; const uint16_t *src = (uint16_t *)srcp + y * src_linesize / 2 + x;
int j;
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++)
dst[j] = src[j]; dst[j] = src[j];
}
} }
static void basic_block_filtering(BM3DContext *s, const uint8_t *src, int src_linesize, static void basic_block_filtering(BM3DContext *s, const uint8_t *src, int src_linesize,
@ -387,7 +387,8 @@ static void basic_block_filtering(BM3DContext *s, const uint8_t *src, int src_li
int y, int x, int plane, int jobnr) int y, int x, int plane, int jobnr)
{ {
SliceContext *sc = &s->slices[jobnr]; SliceContext *sc = &s->slices[jobnr];
const int buffer_linesize = s->block_size * s->block_size; const int pblock_size = s->pblock_size;
const int buffer_linesize = s->pblock_size * s->pblock_size;
const int nb_match_blocks = sc->nb_match_blocks; const int nb_match_blocks = sc->nb_match_blocks;
const int block_size = s->block_size; const int block_size = s->block_size;
const int width = s->planewidth[plane]; const int width = s->planewidth[plane];
@ -395,54 +396,50 @@ static void basic_block_filtering(BM3DContext *s, const uint8_t *src, int src_li
const int group_size = s->group_size; const int group_size = s->group_size;
float *buffer = sc->buffer; float *buffer = sc->buffer;
float *bufferh = sc->bufferh; float *bufferh = sc->bufferh;
float *buffert = sc->buffert;
float *bufferv = sc->bufferv; float *bufferv = sc->bufferv;
float *bufferz = sc->bufferz; float *bufferz = sc->bufferz;
float threshold[4]; float threshold[4];
float den_weight, num_weight; float den_weight, num_weight;
int retained = 0; int retained = 0;
int i, j, k;
for (k = 0; k < nb_match_blocks; k++) { for (int k = 0; k < nb_match_blocks; k++) {
const int y = sc->match_blocks[k].y; const int y = sc->match_blocks[k].y;
const int x = sc->match_blocks[k].x; const int x = sc->match_blocks[k].x;
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
s->get_block_row(src, src_linesize, y + i, x, block_size, bufferh + block_size * i); s->get_block_row(src, src_linesize, y + i, x, block_size, bufferh + pblock_size * i);
av_dct_calc(sc->dctf, bufferh + block_size * i); sc->tx_fn(sc->dctf, buffert, bufferh + pblock_size * i, sizeof(float));
for (int j = 0; j < block_size; j++)
bufferv[j * pblock_size + i] = buffert[j];
} }
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
for (j = 0; j < block_size; j++) { sc->tx_fn(sc->dctf, buffert, bufferv + i * pblock_size, sizeof(float));
bufferv[i * block_size + j] = bufferh[j * block_size + i]; memcpy(buffer + k * buffer_linesize + i * pblock_size,
} buffert, block_size * sizeof(float));
av_dct_calc(sc->dctf, bufferv + i * block_size);
}
for (i = 0; i < block_size; i++) {
memcpy(buffer + k * buffer_linesize + i * block_size,
bufferv + i * block_size, block_size * 4);
} }
} }
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++) {
for (k = 0; k < nb_match_blocks; k++) for (int k = 0; k < nb_match_blocks; k++)
bufferz[k] = buffer[buffer_linesize * k + i * block_size + j]; bufferz[k] = buffer[buffer_linesize * k + i * pblock_size + j];
if (group_size > 1) if (group_size > 1)
av_dct_calc(sc->gdctf, bufferz); sc->tx_fn_g(sc->gdctf, bufferz, bufferz, sizeof(float));
bufferz += pgroup_size; bufferz += pgroup_size;
} }
} }
threshold[0] = s->hard_threshold * s->sigma * M_SQRT2 * block_size * block_size * (1 << (s->depth - 8)) / 255.f; threshold[0] = s->hard_threshold * s->sigma * M_SQRT2 * 4.f * block_size * block_size * (1 << (s->depth - 8)) / 255.f;
threshold[1] = threshold[0] * sqrtf(2.f); threshold[1] = threshold[0] * sqrtf(2.f);
threshold[2] = threshold[0] * 2.f; threshold[2] = threshold[0] * 2.f;
threshold[3] = threshold[0] * sqrtf(8.f); threshold[3] = threshold[0] * sqrtf(8.f);
bufferz = sc->bufferz; bufferz = sc->bufferz;
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++) {
for (k = 0; k < nb_match_blocks; k++) { for (int k = 0; k < nb_match_blocks; k++) {
const float thresh = threshold[(j == 0) + (i == 0) + (k == 0)]; const float thresh = threshold[(j == 0) + (i == 0) + (k == 0)];
if (bufferz[k] > thresh || bufferz[k] < -thresh) { if (bufferz[k] > thresh || bufferz[k] < -thresh) {
@ -457,13 +454,12 @@ static void basic_block_filtering(BM3DContext *s, const uint8_t *src, int src_li
bufferz = sc->bufferz; bufferz = sc->bufferz;
buffer = sc->buffer; buffer = sc->buffer;
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++) {
if (group_size > 1) if (group_size > 1)
av_dct_calc(sc->gdcti, bufferz); sc->itx_fn_g(sc->gdcti, bufferz, bufferz, sizeof(float));
for (k = 0; k < nb_match_blocks; k++) { for (int k = 0; k < nb_match_blocks; k++)
buffer[buffer_linesize * k + i * block_size + j] = bufferz[k]; buffer[buffer_linesize * k + i * pblock_size + j] = bufferz[k];
}
bufferz += pgroup_size; bufferz += pgroup_size;
} }
} }
@ -472,27 +468,26 @@ static void basic_block_filtering(BM3DContext *s, const uint8_t *src, int src_li
num_weight = den_weight; num_weight = den_weight;
buffer = sc->buffer; buffer = sc->buffer;
for (k = 0; k < nb_match_blocks; k++) { for (int k = 0; k < nb_match_blocks; k++) {
float *num = sc->num + y * width + x; float *num = sc->num + y * width + x;
float *den = sc->den + y * width + x; float *den = sc->den + y * width + x;
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
memcpy(bufferv + i * block_size, memcpy(bufferv + i * pblock_size,
buffer + k * buffer_linesize + i * block_size, buffer + k * buffer_linesize + i * pblock_size,
block_size * 4); block_size * sizeof(float));
} }
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
av_dct_calc(sc->dcti, bufferv + block_size * i); sc->itx_fn(sc->dcti, buffert, bufferv + i * pblock_size, sizeof(float));
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++)
bufferh[j * block_size + i] = bufferv[i * block_size + j]; bufferh[j * pblock_size + i] = buffert[j];
}
} }
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
av_dct_calc(sc->dcti, bufferh + block_size * i); sc->itx_fn(sc->dcti, buffert, bufferh + pblock_size * i, sizeof(float));
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++) {
num[j] += bufferh[i * block_size + j] * num_weight; num[j] += buffert[j] * num_weight;
den[j] += den_weight; den[j] += den_weight;
} }
num += width; num += width;
@ -506,7 +501,8 @@ static void final_block_filtering(BM3DContext *s, const uint8_t *src, int src_li
int y, int x, int plane, int jobnr) int y, int x, int plane, int jobnr)
{ {
SliceContext *sc = &s->slices[jobnr]; SliceContext *sc = &s->slices[jobnr];
const int buffer_linesize = s->block_size * s->block_size; const int pblock_size = s->pblock_size;
const int buffer_linesize = s->pblock_size * s->pblock_size;
const int nb_match_blocks = sc->nb_match_blocks; const int nb_match_blocks = sc->nb_match_blocks;
const int block_size = s->block_size; const int block_size = s->block_size;
const int width = s->planewidth[plane]; const int width = s->planewidth[plane];
@ -523,45 +519,44 @@ static void final_block_filtering(BM3DContext *s, const uint8_t *src, int src_li
float *rbufferz = sc->rbufferz; float *rbufferz = sc->rbufferz;
float den_weight, num_weight; float den_weight, num_weight;
float l2_wiener = 0; float l2_wiener = 0;
int i, j, k;
for (k = 0; k < nb_match_blocks; k++) { for (int k = 0; k < nb_match_blocks; k++) {
const int y = sc->match_blocks[k].y; const int y = sc->match_blocks[k].y;
const int x = sc->match_blocks[k].x; const int x = sc->match_blocks[k].x;
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
s->get_block_row(src, src_linesize, y + i, x, block_size, bufferh + block_size * i); s->get_block_row(src, src_linesize, y + i, x, block_size, bufferh + pblock_size * i);
s->get_block_row(ref, ref_linesize, y + i, x, block_size, rbufferh + block_size * i); s->get_block_row(ref, ref_linesize, y + i, x, block_size, rbufferh + pblock_size * i);
av_dct_calc(sc->dctf, bufferh + block_size * i); sc->tx_fn(sc->dctf, bufferh + pblock_size * i, bufferh + pblock_size * i, sizeof(float));
av_dct_calc(sc->dctf, rbufferh + block_size * i); sc->tx_fn(sc->dctf, rbufferh + pblock_size * i, rbufferh + pblock_size * i, sizeof(float));
} }
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++) {
bufferv[i * block_size + j] = bufferh[j * block_size + i]; bufferv[i * pblock_size + j] = bufferh[j * pblock_size + i];
rbufferv[i * block_size + j] = rbufferh[j * block_size + i]; rbufferv[i * pblock_size + j] = rbufferh[j * pblock_size + i];
} }
av_dct_calc(sc->dctf, bufferv + i * block_size); sc->tx_fn(sc->dctf, bufferv + i * pblock_size, bufferv + i * pblock_size, sizeof(float));
av_dct_calc(sc->dctf, rbufferv + i * block_size); sc->tx_fn(sc->dctf, rbufferv + i * pblock_size, rbufferv + i * pblock_size, sizeof(float));
} }
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
memcpy(buffer + k * buffer_linesize + i * block_size, memcpy(buffer + k * buffer_linesize + i * pblock_size,
bufferv + i * block_size, block_size * 4); bufferv + i * pblock_size, block_size * sizeof(float));
memcpy(rbuffer + k * buffer_linesize + i * block_size, memcpy(rbuffer + k * buffer_linesize + i * pblock_size,
rbufferv + i * block_size, block_size * 4); rbufferv + i * pblock_size, block_size * sizeof(float));
} }
} }
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++) {
for (k = 0; k < nb_match_blocks; k++) { for (int k = 0; k < nb_match_blocks; k++) {
bufferz[k] = buffer[buffer_linesize * k + i * block_size + j]; bufferz[k] = buffer[buffer_linesize * k + i * pblock_size + j];
rbufferz[k] = rbuffer[buffer_linesize * k + i * block_size + j]; rbufferz[k] = rbuffer[buffer_linesize * k + i * pblock_size + j];
} }
if (group_size > 1) { if (group_size > 1) {
av_dct_calc(sc->gdctf, bufferz); sc->tx_fn_g(sc->gdctf, bufferz, bufferz, sizeof(float));
av_dct_calc(sc->gdctf, rbufferz); sc->tx_fn_g(sc->gdctf, rbufferz, rbufferz, sizeof(float));
} }
bufferz += pgroup_size; bufferz += pgroup_size;
rbufferz += pgroup_size; rbufferz += pgroup_size;
@ -571,9 +566,9 @@ static void final_block_filtering(BM3DContext *s, const uint8_t *src, int src_li
bufferz = sc->bufferz; bufferz = sc->bufferz;
rbufferz = sc->rbufferz; rbufferz = sc->rbufferz;
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++) {
for (k = 0; k < nb_match_blocks; k++) { for (int k = 0; k < nb_match_blocks; k++) {
const float ref_sqr = rbufferz[k] * rbufferz[k]; const float ref_sqr = rbufferz[k] * rbufferz[k];
float wiener_coef = ref_sqr / (ref_sqr + sigma_sqr); float wiener_coef = ref_sqr / (ref_sqr + sigma_sqr);
@ -589,12 +584,12 @@ static void final_block_filtering(BM3DContext *s, const uint8_t *src, int src_li
bufferz = sc->bufferz; bufferz = sc->bufferz;
buffer = sc->buffer; buffer = sc->buffer;
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++) {
if (group_size > 1) if (group_size > 1)
av_dct_calc(sc->gdcti, bufferz); sc->itx_fn_g(sc->gdcti, bufferz, bufferz, sizeof(float));
for (k = 0; k < nb_match_blocks; k++) { for (int k = 0; k < nb_match_blocks; k++) {
buffer[buffer_linesize * k + i * block_size + j] = bufferz[k]; buffer[buffer_linesize * k + i * pblock_size + j] = bufferz[k];
} }
bufferz += pgroup_size; bufferz += pgroup_size;
} }
@ -604,27 +599,27 @@ static void final_block_filtering(BM3DContext *s, const uint8_t *src, int src_li
den_weight = 1.f / l2_wiener; den_weight = 1.f / l2_wiener;
num_weight = den_weight; num_weight = den_weight;
for (k = 0; k < nb_match_blocks; k++) { for (int k = 0; k < nb_match_blocks; k++) {
float *num = sc->num + y * width + x; float *num = sc->num + y * width + x;
float *den = sc->den + y * width + x; float *den = sc->den + y * width + x;
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
memcpy(bufferv + i * block_size, memcpy(bufferv + i * pblock_size,
buffer + k * buffer_linesize + i * block_size, buffer + k * buffer_linesize + i * pblock_size,
block_size * 4); block_size * sizeof(float));
} }
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
av_dct_calc(sc->dcti, bufferv + block_size * i); sc->itx_fn(sc->dcti, bufferv + pblock_size * i, bufferv + pblock_size * i, sizeof(float));
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++) {
bufferh[j * block_size + i] = bufferv[i * block_size + j]; bufferh[j * pblock_size + i] = bufferv[i * pblock_size + j];
} }
} }
for (i = 0; i < block_size; i++) { for (int i = 0; i < block_size; i++) {
av_dct_calc(sc->dcti, bufferh + block_size * i); sc->itx_fn(sc->dcti, bufferh + pblock_size * i, bufferh + pblock_size * i, sizeof(float));
for (j = 0; j < block_size; j++) { for (int j = 0; j < block_size; j++) {
num[j] += bufferh[i * block_size + j] * num_weight; num[j] += bufferh[i * pblock_size + j] * num_weight;
den[j] += den_weight; den[j] += den_weight;
} }
num += width; num += width;
@ -638,15 +633,14 @@ static void do_output(BM3DContext *s, uint8_t *dst, int dst_linesize,
{ {
const int height = s->planeheight[plane]; const int height = s->planeheight[plane];
const int width = s->planewidth[plane]; const int width = s->planewidth[plane];
int i, j, k;
for (i = 0; i < height; i++) { for (int i = 0; i < height; i++) {
for (j = 0; j < width; j++) { for (int j = 0; j < width; j++) {
uint8_t *dstp = dst + i * dst_linesize; uint8_t *dstp = dst + i * dst_linesize;
float sum_den = 0.f; float sum_den = 0.f;
float sum_num = 0.f; float sum_num = 0.f;
for (k = 0; k < nb_jobs; k++) { for (int k = 0; k < nb_jobs; k++) {
SliceContext *sc = &s->slices[k]; SliceContext *sc = &s->slices[k];
float num = sc->num[i * width + j]; float num = sc->num[i * width + j];
float den = sc->den[i * width + j]; float den = sc->den[i * width + j];
@ -666,15 +660,14 @@ static void do_output16(BM3DContext *s, uint8_t *dst, int dst_linesize,
const int height = s->planeheight[plane]; const int height = s->planeheight[plane];
const int width = s->planewidth[plane]; const int width = s->planewidth[plane];
const int depth = s->depth; const int depth = s->depth;
int i, j, k;
for (i = 0; i < height; i++) { for (int i = 0; i < height; i++) {
for (j = 0; j < width; j++) { for (int j = 0; j < width; j++) {
uint16_t *dstp = (uint16_t *)dst + i * dst_linesize / 2; uint16_t *dstp = (uint16_t *)dst + i * dst_linesize / 2;
float sum_den = 0.f; float sum_den = 0.f;
float sum_num = 0.f; float sum_num = 0.f;
for (k = 0; k < nb_jobs; k++) { for (int k = 0; k < nb_jobs; k++) {
SliceContext *sc = &s->slices[k]; SliceContext *sc = &s->slices[k];
float num = sc->num[i * width + j]; float num = sc->num[i * width + j];
float den = sc->den[i * width + j]; float den = sc->den[i * width + j];
@ -706,17 +699,16 @@ static int filter_slice(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
const int slice_start = (((height + block_step - 1) / block_step) * jobnr / nb_jobs) * block_step; const int slice_start = (((height + block_step - 1) / block_step) * jobnr / nb_jobs) * block_step;
const int slice_end = (jobnr == nb_jobs - 1) ? block_pos_bottom + block_step : const int slice_end = (jobnr == nb_jobs - 1) ? block_pos_bottom + block_step :
(((height + block_step - 1) / block_step) * (jobnr + 1) / nb_jobs) * block_step; (((height + block_step - 1) / block_step) * (jobnr + 1) / nb_jobs) * block_step;
int i, j;
memset(sc->num, 0, width * height * sizeof(FFTSample)); memset(sc->num, 0, width * height * sizeof(float));
memset(sc->den, 0, width * height * sizeof(FFTSample)); memset(sc->den, 0, width * height * sizeof(float));
for (j = slice_start; j < slice_end; j += block_step) { for (int j = slice_start; j < slice_end; j += block_step) {
if (j > block_pos_bottom) { if (j > block_pos_bottom) {
j = block_pos_bottom; j = block_pos_bottom;
} }
for (i = 0; i < block_pos_right + block_step; i += block_step) { for (int i = 0; i < block_pos_right + block_step; i += block_step) {
if (i > block_pos_right) { if (i > block_pos_right) {
i = block_pos_right; i = block_pos_right;
} }
@ -749,7 +741,7 @@ static int filter_frame(AVFilterContext *ctx, AVFrame **out, AVFrame *in, AVFram
if (!((1 << p) & s->planes) || ctx->is_disabled) { if (!((1 << p) & s->planes) || ctx->is_disabled) {
av_image_copy_plane((*out)->data[p], (*out)->linesize[p], av_image_copy_plane((*out)->data[p], (*out)->linesize[p],
in->data[p], in->linesize[p], in->data[p], in->linesize[p],
s->planewidth[p], s->planeheight[p]); s->planewidth[p] * (1 + (s->depth > 8)), s->planeheight[p]);
continue; continue;
} }
@ -773,7 +765,6 @@ static int config_input(AVFilterLink *inlink)
const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format); const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format);
AVFilterContext *ctx = inlink->dst; AVFilterContext *ctx = inlink->dst;
BM3DContext *s = ctx->priv; BM3DContext *s = ctx->priv;
int i, group_bits;
s->nb_threads = FFMIN(ff_filter_get_nb_threads(ctx), MAX_NB_THREADS); s->nb_threads = FFMIN(ff_filter_get_nb_threads(ctx), MAX_NB_THREADS);
s->nb_planes = av_pix_fmt_count_planes(inlink->format); s->nb_planes = av_pix_fmt_count_planes(inlink->format);
@ -783,43 +774,53 @@ static int config_input(AVFilterLink *inlink)
s->planeheight[0] = s->planeheight[3] = inlink->h; s->planeheight[0] = s->planeheight[3] = inlink->h;
s->planewidth[1] = s->planewidth[2] = AV_CEIL_RSHIFT(inlink->w, desc->log2_chroma_w); s->planewidth[1] = s->planewidth[2] = AV_CEIL_RSHIFT(inlink->w, desc->log2_chroma_w);
s->planewidth[0] = s->planewidth[3] = inlink->w; s->planewidth[0] = s->planewidth[3] = inlink->w;
s->pblock_size = FFALIGN(s->block_size * 2, av_cpu_max_align());
s->pgroup_size = FFALIGN(s->group_size * 2, av_cpu_max_align());
for (group_bits = 4; 1 << group_bits < s->group_size; group_bits++); for (int i = 0; i < s->nb_threads; i++) {
s->group_bits = group_bits;
s->pgroup_size = 1 << group_bits;
for (i = 0; i < s->nb_threads; i++) {
SliceContext *sc = &s->slices[i]; SliceContext *sc = &s->slices[i];
float iscale = 0.5f / s->block_size;
float scale = 1.f;
int ret;
sc->num = av_calloc(FFALIGN(s->planewidth[0], s->block_size) * FFALIGN(s->planeheight[0], s->block_size), sizeof(FFTSample)); sc->num = av_calloc(FFALIGN(s->planewidth[0], s->block_size) * FFALIGN(s->planeheight[0], s->block_size), sizeof(float));
sc->den = av_calloc(FFALIGN(s->planewidth[0], s->block_size) * FFALIGN(s->planeheight[0], s->block_size), sizeof(FFTSample)); sc->den = av_calloc(FFALIGN(s->planewidth[0], s->block_size) * FFALIGN(s->planeheight[0], s->block_size), sizeof(float));
if (!sc->num || !sc->den) if (!sc->num || !sc->den)
return AVERROR(ENOMEM); return AVERROR(ENOMEM);
sc->dctf = av_dct_init(av_log2(s->block_size), DCT_II); ret = av_tx_init(&sc->dctf, &sc->tx_fn, AV_TX_FLOAT_DCT, 0, s->block_size >> 0, &scale, 0);
sc->dcti = av_dct_init(av_log2(s->block_size), DCT_III); if (ret < 0)
if (!sc->dctf || !sc->dcti) return ret;
return AVERROR(ENOMEM);
if (s->group_bits > 1) { ret = av_tx_init(&sc->dcti, &sc->itx_fn, AV_TX_FLOAT_DCT, 1, s->block_size >> 1, &iscale, 0);
sc->gdctf = av_dct_init(s->group_bits, DCT_II); if (ret < 0)
sc->gdcti = av_dct_init(s->group_bits, DCT_III); return ret;
if (!sc->gdctf || !sc->gdcti)
return AVERROR(ENOMEM); if (s->group_size > 1) {
float iscale = 0.5f / s->group_size;
ret = av_tx_init(&sc->gdctf, &sc->tx_fn_g, AV_TX_FLOAT_DCT, 0, s->group_size >> 0, &scale, 0);
if (ret < 0)
return ret;
ret = av_tx_init(&sc->gdcti, &sc->itx_fn_g, AV_TX_FLOAT_DCT, 1, s->group_size >> 1, &iscale, 0);
if (ret < 0)
return ret;
} }
sc->buffer = av_calloc(s->block_size * s->block_size * s->pgroup_size, sizeof(*sc->buffer)); sc->buffer = av_calloc(s->pblock_size * s->pblock_size * s->pgroup_size, sizeof(*sc->buffer));
sc->bufferz = av_calloc(s->block_size * s->block_size * s->pgroup_size, sizeof(*sc->bufferz)); sc->bufferz = av_calloc(s->pblock_size * s->pblock_size * s->pgroup_size, sizeof(*sc->bufferz));
sc->bufferh = av_calloc(s->block_size * s->block_size, sizeof(*sc->bufferh)); sc->bufferh = av_calloc(s->pblock_size * s->pblock_size, sizeof(*sc->bufferh));
sc->bufferv = av_calloc(s->block_size * s->block_size, sizeof(*sc->bufferv)); sc->bufferv = av_calloc(s->pblock_size * s->pblock_size, sizeof(*sc->bufferv));
if (!sc->bufferh || !sc->bufferv || !sc->buffer || !sc->bufferz) sc->buffert = av_calloc(s->pblock_size, sizeof(*sc->buffert));
if (!sc->bufferh || !sc->bufferv || !sc->buffer || !sc->bufferz || !sc->buffert)
return AVERROR(ENOMEM); return AVERROR(ENOMEM);
if (s->mode == FINAL) { if (s->mode == FINAL) {
sc->rbuffer = av_calloc(s->block_size * s->block_size * s->pgroup_size, sizeof(*sc->rbuffer)); sc->rbuffer = av_calloc(s->pblock_size * s->pblock_size * s->pgroup_size, sizeof(*sc->rbuffer));
sc->rbufferz = av_calloc(s->block_size * s->block_size * s->pgroup_size, sizeof(*sc->rbufferz)); sc->rbufferz = av_calloc(s->pblock_size * s->pblock_size * s->pgroup_size, sizeof(*sc->rbufferz));
sc->rbufferh = av_calloc(s->block_size * s->block_size, sizeof(*sc->rbufferh)); sc->rbufferh = av_calloc(s->pblock_size * s->pblock_size, sizeof(*sc->rbufferh));
sc->rbufferv = av_calloc(s->block_size * s->block_size, sizeof(*sc->rbufferv)); sc->rbufferv = av_calloc(s->pblock_size * s->pblock_size, sizeof(*sc->rbufferv));
if (!sc->rbufferh || !sc->rbufferv || !sc->rbuffer || !sc->rbufferz) if (!sc->rbufferh || !sc->rbufferv || !sc->rbuffer || !sc->rbufferz)
return AVERROR(ENOMEM); return AVERROR(ENOMEM);
} }
@ -919,13 +920,12 @@ static av_cold int init(AVFilterContext *ctx)
return AVERROR_BUG; return AVERROR_BUG;
} }
s->block_size = 1 << s->block_size;
if (s->block_step > s->block_size) { if (s->block_step > s->block_size) {
av_log(ctx, AV_LOG_WARNING, "bstep: %d can't be bigger than block size. Changing to %d.\n", av_log(ctx, AV_LOG_WARNING, "bstep: %d can't be bigger than block size. Changing to %d.\n",
s->block_step, s->block_size); s->block_step, s->block_size);
s->block_step = s->block_size; s->block_step = s->block_size;
} }
if (s->bm_step > s->bm_range) { if (s->bm_step > s->bm_range) {
av_log(ctx, AV_LOG_WARNING, "mstep: %d can't be bigger than block matching range. Changing to %d.\n", av_log(ctx, AV_LOG_WARNING, "mstep: %d can't be bigger than block matching range. Changing to %d.\n",
s->bm_step, s->bm_range); s->bm_step, s->bm_range);
@ -1004,24 +1004,24 @@ static int config_output(AVFilterLink *outlink)
static av_cold void uninit(AVFilterContext *ctx) static av_cold void uninit(AVFilterContext *ctx)
{ {
BM3DContext *s = ctx->priv; BM3DContext *s = ctx->priv;
int i;
if (s->ref) if (s->ref)
ff_framesync_uninit(&s->fs); ff_framesync_uninit(&s->fs);
for (i = 0; i < s->nb_threads; i++) { for (int i = 0; i < s->nb_threads; i++) {
SliceContext *sc = &s->slices[i]; SliceContext *sc = &s->slices[i];
av_freep(&sc->num); av_freep(&sc->num);
av_freep(&sc->den); av_freep(&sc->den);
av_dct_end(sc->gdctf); av_tx_uninit(&sc->gdctf);
av_dct_end(sc->gdcti); av_tx_uninit(&sc->gdcti);
av_dct_end(sc->dctf); av_tx_uninit(&sc->dctf);
av_dct_end(sc->dcti); av_tx_uninit(&sc->dcti);
av_freep(&sc->buffer); av_freep(&sc->buffer);
av_freep(&sc->bufferh); av_freep(&sc->bufferh);
av_freep(&sc->buffert);
av_freep(&sc->bufferv); av_freep(&sc->bufferv);
av_freep(&sc->bufferz); av_freep(&sc->bufferz);
av_freep(&sc->rbuffer); av_freep(&sc->rbuffer);