1
0
mirror of https://github.com/FFmpeg/FFmpeg.git synced 2024-12-23 12:43:46 +02:00

dnn: add a new interface DNNModel.get_output

for some cases (for example, super resolution), the DNN model changes
the frame size which impacts the filter behavior, so the filter needs
to know the out frame size at very beginning.

Currently, the filter reuses DNNModule.execute_model to query the
out frame size, it is not clear from interface perspective, so add
a new explict interface DNNModel.get_output for such query.
This commit is contained in:
Guo, Yejun 2020-09-11 22:15:04 +08:00
parent fce3e3e137
commit e71d73b096
6 changed files with 185 additions and 58 deletions

View File

@ -44,6 +44,10 @@ const AVClass dnn_native_class = {
.category = AV_CLASS_CATEGORY_FILTER, .category = AV_CLASS_CATEGORY_FILTER,
}; };
static DNNReturnType execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame,
int do_ioproc);
static DNNReturnType get_input_native(void *model, DNNData *input, const char *input_name) static DNNReturnType get_input_native(void *model, DNNData *input, const char *input_name)
{ {
NativeModel *native_model = (NativeModel *)model; NativeModel *native_model = (NativeModel *)model;
@ -70,6 +74,25 @@ static DNNReturnType get_input_native(void *model, DNNData *input, const char *i
return DNN_ERROR; return DNN_ERROR;
} }
static DNNReturnType get_output_native(void *model, const char *input_name, int input_width, int input_height,
const char *output_name, int *output_width, int *output_height)
{
DNNReturnType ret;
NativeModel *native_model = (NativeModel *)model;
AVFrame *in_frame = av_frame_alloc();
AVFrame *out_frame = av_frame_alloc();
in_frame->width = input_width;
in_frame->height = input_height;
ret = execute_model_native(native_model->model, input_name, in_frame, &output_name, 1, out_frame, 0);
*output_width = out_frame->width;
*output_height = out_frame->height;
av_frame_free(&out_frame);
av_frame_free(&in_frame);
return ret;
}
// Loads model and its parameters that are stored in a binary file with following structure: // Loads model and its parameters that are stored in a binary file with following structure:
// layers_num,layer_type,layer_parameterss,layer_type,layer_parameters... // layers_num,layer_type,layer_parameterss,layer_type,layer_parameters...
// For CONV layer: activation_function, input_num, output_num, kernel_size, kernel, biases // For CONV layer: activation_function, input_num, output_num, kernel_size, kernel, biases
@ -216,6 +239,7 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *optio
} }
model->get_input = &get_input_native; model->get_input = &get_input_native;
model->get_output = &get_output_native;
model->userdata = userdata; model->userdata = userdata;
return model; return model;
@ -226,8 +250,9 @@ fail:
return NULL; return NULL;
} }
DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame, static DNNReturnType execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame) const char **output_names, uint32_t nb_output, AVFrame *out_frame,
int do_ioproc)
{ {
NativeModel *native_model = (NativeModel *)model->model; NativeModel *native_model = (NativeModel *)model->model;
NativeContext *ctx = &native_model->ctx; NativeContext *ctx = &native_model->ctx;
@ -276,10 +301,12 @@ DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *inp
input.channels = oprd->dims[3]; input.channels = oprd->dims[3];
input.data = oprd->data; input.data = oprd->data;
input.dt = oprd->data_type; input.dt = oprd->data_type;
if (native_model->model->pre_proc != NULL) { if (do_ioproc) {
native_model->model->pre_proc(in_frame, &input, native_model->model->userdata); if (native_model->model->pre_proc != NULL) {
} else { native_model->model->pre_proc(in_frame, &input, native_model->model->userdata);
proc_from_frame_to_dnn(in_frame, &input, ctx); } else {
proc_from_frame_to_dnn(in_frame, &input, ctx);
}
} }
if (nb_output != 1) { if (nb_output != 1) {
@ -322,21 +349,40 @@ DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *inp
output.channels = oprd->dims[3]; output.channels = oprd->dims[3];
output.dt = oprd->data_type; output.dt = oprd->data_type;
if (out_frame->width != output.width || out_frame->height != output.height) { if (do_ioproc) {
out_frame->width = output.width;
out_frame->height = output.height;
} else {
if (native_model->model->post_proc != NULL) { if (native_model->model->post_proc != NULL) {
native_model->model->post_proc(out_frame, &output, native_model->model->userdata); native_model->model->post_proc(out_frame, &output, native_model->model->userdata);
} else { } else {
proc_from_dnn_to_frame(out_frame, &output, ctx); proc_from_dnn_to_frame(out_frame, &output, ctx);
} }
} else {
out_frame->width = output.width;
out_frame->height = output.height;
} }
} }
return DNN_SUCCESS; return DNN_SUCCESS;
} }
DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame)
{
NativeModel *native_model = (NativeModel *)model->model;
NativeContext *ctx = &native_model->ctx;
if (!in_frame) {
av_log(ctx, AV_LOG_ERROR, "in frame is NULL when execute model.\n");
return DNN_ERROR;
}
if (!out_frame) {
av_log(ctx, AV_LOG_ERROR, "out frame is NULL when execute model.\n");
return DNN_ERROR;
}
return execute_model_native(model, input_name, in_frame, output_names, nb_output, out_frame, 1);
}
int32_t calculate_operand_dims_count(const DnnOperand *oprd) int32_t calculate_operand_dims_count(const DnnOperand *oprd)
{ {
int32_t result = 1; int32_t result = 1;

View File

@ -63,6 +63,10 @@ static const AVOption dnn_openvino_options[] = {
AVFILTER_DEFINE_CLASS(dnn_openvino); AVFILTER_DEFINE_CLASS(dnn_openvino);
static DNNReturnType execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame,
int do_ioproc);
static DNNDataType precision_to_datatype(precision_e precision) static DNNDataType precision_to_datatype(precision_e precision)
{ {
switch (precision) switch (precision)
@ -132,6 +136,25 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input
return DNN_ERROR; return DNN_ERROR;
} }
static DNNReturnType get_output_ov(void *model, const char *input_name, int input_width, int input_height,
const char *output_name, int *output_width, int *output_height)
{
DNNReturnType ret;
OVModel *ov_model = (OVModel *)model;
AVFrame *in_frame = av_frame_alloc();
AVFrame *out_frame = av_frame_alloc();
in_frame->width = input_width;
in_frame->height = input_height;
ret = execute_model_ov(ov_model->model, input_name, in_frame, &output_name, 1, out_frame, 0);
*output_width = out_frame->width;
*output_height = out_frame->height;
av_frame_free(&out_frame);
av_frame_free(&in_frame);
return ret;
}
DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata) DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata)
{ {
char *all_dev_names = NULL; char *all_dev_names = NULL;
@ -191,6 +214,7 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options,
model->model = (void *)ov_model; model->model = (void *)ov_model;
model->get_input = &get_input_ov; model->get_input = &get_input_ov;
model->get_output = &get_output_ov;
model->options = options; model->options = options;
model->userdata = userdata; model->userdata = userdata;
@ -213,8 +237,9 @@ err:
return NULL; return NULL;
} }
DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame, static DNNReturnType execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame) const char **output_names, uint32_t nb_output, AVFrame *out_frame,
int do_ioproc)
{ {
char *model_output_name = NULL; char *model_output_name = NULL;
char *all_output_names = NULL; char *all_output_names = NULL;
@ -252,10 +277,12 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_n
input.channels = dims.dims[1]; input.channels = dims.dims[1];
input.data = blob_buffer.buffer; input.data = blob_buffer.buffer;
input.dt = precision_to_datatype(precision); input.dt = precision_to_datatype(precision);
if (ov_model->model->pre_proc != NULL) { if (do_ioproc) {
ov_model->model->pre_proc(in_frame, &input, ov_model->model->userdata); if (ov_model->model->pre_proc != NULL) {
} else { ov_model->model->pre_proc(in_frame, &input, ov_model->model->userdata);
proc_from_frame_to_dnn(in_frame, &input, ctx); } else {
proc_from_frame_to_dnn(in_frame, &input, ctx);
}
} }
ie_blob_free(&input_blob); ie_blob_free(&input_blob);
@ -308,15 +335,15 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_n
output.width = dims.dims[3]; output.width = dims.dims[3];
output.dt = precision_to_datatype(precision); output.dt = precision_to_datatype(precision);
output.data = blob_buffer.buffer; output.data = blob_buffer.buffer;
if (out_frame->width != output.width || out_frame->height != output.height) { if (do_ioproc) {
out_frame->width = output.width;
out_frame->height = output.height;
} else {
if (ov_model->model->post_proc != NULL) { if (ov_model->model->post_proc != NULL) {
ov_model->model->post_proc(out_frame, &output, ov_model->model->userdata); ov_model->model->post_proc(out_frame, &output, ov_model->model->userdata);
} else { } else {
proc_from_dnn_to_frame(out_frame, &output, ctx); proc_from_dnn_to_frame(out_frame, &output, ctx);
} }
} else {
out_frame->width = output.width;
out_frame->height = output.height;
} }
ie_blob_free(&output_blob); ie_blob_free(&output_blob);
} }
@ -324,6 +351,25 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_n
return DNN_SUCCESS; return DNN_SUCCESS;
} }
DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame)
{
OVModel *ov_model = (OVModel *)model->model;
OVContext *ctx = &ov_model->ctx;
if (!in_frame) {
av_log(ctx, AV_LOG_ERROR, "in frame is NULL when execute model.\n");
return DNN_ERROR;
}
if (!out_frame) {
av_log(ctx, AV_LOG_ERROR, "out frame is NULL when execute model.\n");
return DNN_ERROR;
}
return execute_model_ov(model, input_name, in_frame, output_names, nb_output, out_frame, 1);
}
void ff_dnn_free_model_ov(DNNModel **model) void ff_dnn_free_model_ov(DNNModel **model)
{ {
if (*model){ if (*model){

View File

@ -55,6 +55,10 @@ static const AVClass dnn_tensorflow_class = {
.category = AV_CLASS_CATEGORY_FILTER, .category = AV_CLASS_CATEGORY_FILTER,
}; };
static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame,
int do_ioproc);
static void free_buffer(void *data, size_t length) static void free_buffer(void *data, size_t length)
{ {
av_freep(&data); av_freep(&data);
@ -150,6 +154,25 @@ static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input
return DNN_SUCCESS; return DNN_SUCCESS;
} }
static DNNReturnType get_output_tf(void *model, const char *input_name, int input_width, int input_height,
const char *output_name, int *output_width, int *output_height)
{
DNNReturnType ret;
TFModel *tf_model = (TFModel *)model;
AVFrame *in_frame = av_frame_alloc();
AVFrame *out_frame = av_frame_alloc();
in_frame->width = input_width;
in_frame->height = input_height;
ret = execute_model_tf(tf_model->model, input_name, in_frame, &output_name, 1, out_frame, 0);
*output_width = out_frame->width;
*output_height = out_frame->height;
av_frame_free(&out_frame);
av_frame_free(&in_frame);
return ret;
}
static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename) static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename)
{ {
TFContext *ctx = &tf_model->ctx; TFContext *ctx = &tf_model->ctx;
@ -583,14 +606,16 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
model->model = (void *)tf_model; model->model = (void *)tf_model;
model->get_input = &get_input_tf; model->get_input = &get_input_tf;
model->get_output = &get_output_tf;
model->options = options; model->options = options;
model->userdata = userdata; model->userdata = userdata;
return model; return model;
} }
DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame, static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame) const char **output_names, uint32_t nb_output, AVFrame *out_frame,
int do_ioproc)
{ {
TF_Output *tf_outputs; TF_Output *tf_outputs;
TFModel *tf_model = (TFModel *)model->model; TFModel *tf_model = (TFModel *)model->model;
@ -618,10 +643,12 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_n
} }
input.data = (float *)TF_TensorData(input_tensor); input.data = (float *)TF_TensorData(input_tensor);
if (tf_model->model->pre_proc != NULL) { if (do_ioproc) {
tf_model->model->pre_proc(in_frame, &input, tf_model->model->userdata); if (tf_model->model->pre_proc != NULL) {
} else { tf_model->model->pre_proc(in_frame, &input, tf_model->model->userdata);
proc_from_frame_to_dnn(in_frame, &input, ctx); } else {
proc_from_frame_to_dnn(in_frame, &input, ctx);
}
} }
if (nb_output != 1) { if (nb_output != 1) {
@ -673,15 +700,15 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_n
output.data = TF_TensorData(output_tensors[i]); output.data = TF_TensorData(output_tensors[i]);
output.dt = TF_TensorType(output_tensors[i]); output.dt = TF_TensorType(output_tensors[i]);
if (out_frame->width != output.width || out_frame->height != output.height) { if (do_ioproc) {
out_frame->width = output.width;
out_frame->height = output.height;
} else {
if (tf_model->model->post_proc != NULL) { if (tf_model->model->post_proc != NULL) {
tf_model->model->post_proc(out_frame, &output, tf_model->model->userdata); tf_model->model->post_proc(out_frame, &output, tf_model->model->userdata);
} else { } else {
proc_from_dnn_to_frame(out_frame, &output, ctx); proc_from_dnn_to_frame(out_frame, &output, ctx);
} }
} else {
out_frame->width = output.width;
out_frame->height = output.height;
} }
} }
@ -696,6 +723,25 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_n
return DNN_SUCCESS; return DNN_SUCCESS;
} }
DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame)
{
TFModel *tf_model = (TFModel *)model->model;
TFContext *ctx = &tf_model->ctx;
if (!in_frame) {
av_log(ctx, AV_LOG_ERROR, "in frame is NULL when execute model.\n");
return DNN_ERROR;
}
if (!out_frame) {
av_log(ctx, AV_LOG_ERROR, "out frame is NULL when execute model.\n");
return DNN_ERROR;
}
return execute_model_tf(model, input_name, in_frame, output_names, nb_output, out_frame, 1);
}
void ff_dnn_free_model_tf(DNNModel **model) void ff_dnn_free_model_tf(DNNModel **model)
{ {
TFModel *tf_model; TFModel *tf_model;

View File

@ -51,6 +51,9 @@ typedef struct DNNModel{
// Gets model input information // Gets model input information
// Just reuse struct DNNData here, actually the DNNData.data field is not needed. // Just reuse struct DNNData here, actually the DNNData.data field is not needed.
DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name); DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name);
// Gets model output width/height with given input w/h
DNNReturnType (*get_output)(void *model, const char *input_name, int input_width, int input_height,
const char *output_name, int *output_width, int *output_height);
// set the pre process to transfer data from AVFrame to DNNData // set the pre process to transfer data from AVFrame to DNNData
// the default implementation within DNN is used if it is not provided by the filter // the default implementation within DNN is used if it is not provided by the filter
int (*pre_proc)(AVFrame *frame_in, DNNData *model_input, void *user_data); int (*pre_proc)(AVFrame *frame_in, DNNData *model_input, void *user_data);

View File

@ -233,24 +233,15 @@ static int config_output(AVFilterLink *outlink)
DnnProcessingContext *ctx = context->priv; DnnProcessingContext *ctx = context->priv;
DNNReturnType result; DNNReturnType result;
AVFilterLink *inlink = context->inputs[0]; AVFilterLink *inlink = context->inputs[0];
AVFrame *out = NULL;
AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
// have a try run in case that the dnn model resize the frame // have a try run in case that the dnn model resize the frame
out = ff_get_video_buffer(inlink, inlink->w, inlink->h); result = ctx->model->get_output(ctx->model->model, ctx->model_inputname, inlink->w, inlink->h,
result = (ctx->dnn_module->execute_model)(ctx->model, ctx->model_inputname, fake_in, ctx->model_outputname, &outlink->w, &outlink->h);
(const char **)&ctx->model_outputname, 1, out); if (result != DNN_SUCCESS) {
if (result != DNN_SUCCESS){ av_log(ctx, AV_LOG_ERROR, "could not get output from the model\n");
av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
return AVERROR(EIO); return AVERROR(EIO);
} }
outlink->w = out->width;
outlink->h = out->height;
av_frame_free(&fake_in);
av_frame_free(&out);
prepare_uv_scale(outlink); prepare_uv_scale(outlink);
return 0; return 0;

View File

@ -111,23 +111,20 @@ static int config_output(AVFilterLink *outlink)
SRContext *ctx = context->priv; SRContext *ctx = context->priv;
DNNReturnType result; DNNReturnType result;
AVFilterLink *inlink = context->inputs[0]; AVFilterLink *inlink = context->inputs[0];
AVFrame *out = NULL; int out_width, out_height;
const char *model_output_name = "y";
// have a try run in case that the dnn model resize the frame // have a try run in case that the dnn model resize the frame
AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h); result = ctx->model->get_output(ctx->model->model, "x", inlink->w, inlink->h,
out = ff_get_video_buffer(inlink, inlink->w, inlink->h); "y", &out_width, &out_height);
result = (ctx->dnn_module->execute_model)(ctx->model, "x", fake_in, if (result != DNN_SUCCESS) {
(const char **)&model_output_name, 1, out); av_log(ctx, AV_LOG_ERROR, "could not get output from the model\n");
if (result != DNN_SUCCESS){
av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
return AVERROR(EIO); return AVERROR(EIO);
} }
if (fake_in->width != out->width || fake_in->height != out->height) { if (inlink->w != out_width || inlink->h != out_height) {
//espcn //espcn
outlink->w = out->width; outlink->w = out_width;
outlink->h = out->height; outlink->h = out_height;
if (inlink->format != AV_PIX_FMT_GRAY8){ if (inlink->format != AV_PIX_FMT_GRAY8){
const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format); const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format);
int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h); int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h);
@ -141,15 +138,13 @@ static int config_output(AVFilterLink *outlink)
} }
} else { } else {
//srcnn //srcnn
outlink->w = out->width * ctx->scale_factor; outlink->w = out_width * ctx->scale_factor;
outlink->h = out->height * ctx->scale_factor; outlink->h = out_height * ctx->scale_factor;
ctx->sws_pre_scale = sws_getContext(inlink->w, inlink->h, inlink->format, ctx->sws_pre_scale = sws_getContext(inlink->w, inlink->h, inlink->format,
outlink->w, outlink->h, outlink->format, outlink->w, outlink->h, outlink->format,
SWS_BICUBIC, NULL, NULL, NULL); SWS_BICUBIC, NULL, NULL, NULL);
} }
av_frame_free(&fake_in);
av_frame_free(&out);
return 0; return 0;
} }