mirror of
https://github.com/FFmpeg/FFmpeg.git
synced 2024-12-23 12:43:46 +02:00
lavfi/dnn_backend_tensorflow: add multiple outputs support
Signed-off-by: Ting Fu <ting.fu@intel.com>
This commit is contained in:
parent
f02928eb5a
commit
1b1064054c
@ -155,7 +155,7 @@ static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// currently only NHWC is supported
|
||||
av_assert0(dims[0] == 1);
|
||||
av_assert0(dims[0] == 1 || dims[0] == -1);
|
||||
input->height = dims[1];
|
||||
input->width = dims[2];
|
||||
input->channels = dims[3];
|
||||
@ -707,7 +707,7 @@ static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
|
||||
TF_Output *tf_outputs;
|
||||
TFModel *tf_model = model->model;
|
||||
TFContext *ctx = &tf_model->ctx;
|
||||
DNNData input, output;
|
||||
DNNData input, *outputs;
|
||||
TF_Tensor **output_tensors;
|
||||
TF_Output tf_input;
|
||||
TF_Tensor *input_tensor;
|
||||
@ -738,14 +738,6 @@ static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
|
||||
}
|
||||
}
|
||||
|
||||
if (nb_output != 1) {
|
||||
// currently, the filter does not need multiple outputs,
|
||||
// so we just pending the support until we really need it.
|
||||
TF_DeleteTensor(input_tensor);
|
||||
avpriv_report_missing_feature(ctx, "multiple outputs");
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
tf_outputs = av_malloc_array(nb_output, sizeof(*tf_outputs));
|
||||
if (tf_outputs == NULL) {
|
||||
TF_DeleteTensor(input_tensor);
|
||||
@ -785,23 +777,31 @@ static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < nb_output; ++i) {
|
||||
output.height = TF_Dim(output_tensors[i], 1);
|
||||
output.width = TF_Dim(output_tensors[i], 2);
|
||||
output.channels = TF_Dim(output_tensors[i], 3);
|
||||
output.data = TF_TensorData(output_tensors[i]);
|
||||
output.dt = TF_TensorType(output_tensors[i]);
|
||||
outputs = av_malloc_array(nb_output, sizeof(*outputs));
|
||||
if (!outputs) {
|
||||
TF_DeleteTensor(input_tensor);
|
||||
av_freep(&tf_outputs);
|
||||
av_freep(&output_tensors);
|
||||
av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for *outputs\n"); \
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
if (do_ioproc) {
|
||||
if (tf_model->model->frame_post_proc != NULL) {
|
||||
tf_model->model->frame_post_proc(out_frame, &output, tf_model->model->filter_ctx);
|
||||
} else {
|
||||
ff_proc_from_dnn_to_frame(out_frame, &output, ctx);
|
||||
}
|
||||
for (uint32_t i = 0; i < nb_output; ++i) {
|
||||
outputs[i].height = TF_Dim(output_tensors[i], 1);
|
||||
outputs[i].width = TF_Dim(output_tensors[i], 2);
|
||||
outputs[i].channels = TF_Dim(output_tensors[i], 3);
|
||||
outputs[i].data = TF_TensorData(output_tensors[i]);
|
||||
outputs[i].dt = TF_TensorType(output_tensors[i]);
|
||||
}
|
||||
if (do_ioproc) {
|
||||
if (tf_model->model->frame_post_proc != NULL) {
|
||||
tf_model->model->frame_post_proc(out_frame, outputs, tf_model->model->filter_ctx);
|
||||
} else {
|
||||
out_frame->width = output.width;
|
||||
out_frame->height = output.height;
|
||||
ff_proc_from_dnn_to_frame(out_frame, outputs, ctx);
|
||||
}
|
||||
} else {
|
||||
out_frame->width = outputs[0].width;
|
||||
out_frame->height = outputs[0].height;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < nb_output; ++i) {
|
||||
@ -812,6 +812,7 @@ static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
|
||||
TF_DeleteTensor(input_tensor);
|
||||
av_freep(&output_tensors);
|
||||
av_freep(&tf_outputs);
|
||||
av_freep(&outputs);
|
||||
return DNN_SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,39 @@
|
||||
*/
|
||||
|
||||
#include "dnn_filter_common.h"
|
||||
#include "libavutil/avstring.h"
|
||||
|
||||
#define MAX_SUPPORTED_OUTPUTS_NB 4
|
||||
|
||||
static char **separate_output_names(const char *expr, const char *val_sep, int *separated_nb)
|
||||
{
|
||||
char *val, **parsed_vals = NULL;
|
||||
int val_num = 0;
|
||||
if (!expr || !val_sep || !separated_nb) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
parsed_vals = av_mallocz_array(MAX_SUPPORTED_OUTPUTS_NB, sizeof(*parsed_vals));
|
||||
if (!parsed_vals) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
do {
|
||||
val = av_get_token(&expr, val_sep);
|
||||
if(val) {
|
||||
parsed_vals[val_num] = val;
|
||||
val_num++;
|
||||
}
|
||||
if (*expr) {
|
||||
expr++;
|
||||
}
|
||||
} while(*expr);
|
||||
|
||||
parsed_vals[val_num] = NULL;
|
||||
*separated_nb = val_num;
|
||||
|
||||
return parsed_vals;
|
||||
}
|
||||
|
||||
int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx)
|
||||
{
|
||||
@ -28,8 +61,10 @@ int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *fil
|
||||
av_log(filter_ctx, AV_LOG_ERROR, "input name of the model network is not specified\n");
|
||||
return AVERROR(EINVAL);
|
||||
}
|
||||
if (!ctx->model_outputname) {
|
||||
av_log(filter_ctx, AV_LOG_ERROR, "output name of the model network is not specified\n");
|
||||
|
||||
ctx->model_outputnames = separate_output_names(ctx->model_outputnames_string, "&", &ctx->nb_outputs);
|
||||
if (!ctx->model_outputnames) {
|
||||
av_log(filter_ctx, AV_LOG_ERROR, "could not parse model output names\n");
|
||||
return AVERROR(EINVAL);
|
||||
}
|
||||
|
||||
@ -91,15 +126,15 @@ DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input)
|
||||
DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height)
|
||||
{
|
||||
return ctx->model->get_output(ctx->model->model, ctx->model_inputname, input_width, input_height,
|
||||
ctx->model_outputname, output_width, output_height);
|
||||
(const char *)ctx->model_outputnames[0], output_width, output_height);
|
||||
}
|
||||
|
||||
DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame)
|
||||
{
|
||||
DNNExecBaseParams exec_params = {
|
||||
.input_name = ctx->model_inputname,
|
||||
.output_names = (const char **)&ctx->model_outputname,
|
||||
.nb_output = 1,
|
||||
.output_names = (const char **)ctx->model_outputnames,
|
||||
.nb_output = ctx->nb_outputs,
|
||||
.in_frame = in_frame,
|
||||
.out_frame = out_frame,
|
||||
};
|
||||
@ -110,8 +145,8 @@ DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVF
|
||||
{
|
||||
DNNExecBaseParams exec_params = {
|
||||
.input_name = ctx->model_inputname,
|
||||
.output_names = (const char **)&ctx->model_outputname,
|
||||
.nb_output = 1,
|
||||
.output_names = (const char **)ctx->model_outputnames,
|
||||
.nb_output = ctx->nb_outputs,
|
||||
.in_frame = in_frame,
|
||||
.out_frame = out_frame,
|
||||
};
|
||||
@ -123,8 +158,8 @@ DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_f
|
||||
DNNExecClassificationParams class_params = {
|
||||
{
|
||||
.input_name = ctx->model_inputname,
|
||||
.output_names = (const char **)&ctx->model_outputname,
|
||||
.nb_output = 1,
|
||||
.output_names = (const char **)ctx->model_outputnames,
|
||||
.nb_output = ctx->nb_outputs,
|
||||
.in_frame = in_frame,
|
||||
.out_frame = out_frame,
|
||||
},
|
||||
|
@ -30,10 +30,12 @@ typedef struct DnnContext {
|
||||
char *model_filename;
|
||||
DNNBackendType backend_type;
|
||||
char *model_inputname;
|
||||
char *model_outputname;
|
||||
char *model_outputnames_string;
|
||||
char *backend_options;
|
||||
int async;
|
||||
|
||||
char **model_outputnames;
|
||||
uint32_t nb_outputs;
|
||||
DNNModule *dnn_module;
|
||||
DNNModel *model;
|
||||
} DnnContext;
|
||||
@ -41,7 +43,7 @@ typedef struct DnnContext {
|
||||
#define DNN_COMMON_OPTIONS \
|
||||
{ "model", "path to model file", OFFSET(model_filename), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },\
|
||||
{ "input", "input name of the model", OFFSET(model_inputname), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },\
|
||||
{ "output", "output name of the model", OFFSET(model_outputname), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },\
|
||||
{ "output", "output name of the model", OFFSET(model_outputnames_string), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },\
|
||||
{ "backend_configs", "backend configs", OFFSET(backend_options), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },\
|
||||
{ "options", "backend configs", OFFSET(backend_options), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },\
|
||||
{ "async", "use DNN async inference", OFFSET(async), AV_OPT_TYPE_BOOL, { .i64 = 1}, 0, 1, FLAGS},
|
||||
|
@ -50,7 +50,7 @@ static const AVOption derain_options[] = {
|
||||
#endif
|
||||
{ "model", "path to model file", OFFSET(dnnctx.model_filename), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },
|
||||
{ "input", "input name of the model", OFFSET(dnnctx.model_inputname), AV_OPT_TYPE_STRING, { .str = "x" }, 0, 0, FLAGS },
|
||||
{ "output", "output name of the model", OFFSET(dnnctx.model_outputname), AV_OPT_TYPE_STRING, { .str = "y" }, 0, 0, FLAGS },
|
||||
{ "output", "output name of the model", OFFSET(dnnctx.model_outputnames_string), AV_OPT_TYPE_STRING, { .str = "y" }, 0, 0, FLAGS },
|
||||
{ NULL }
|
||||
};
|
||||
|
||||
|
@ -54,7 +54,7 @@ static const AVOption sr_options[] = {
|
||||
{ "scale_factor", "scale factor for SRCNN model", OFFSET(scale_factor), AV_OPT_TYPE_INT, { .i64 = 2 }, 2, 4, FLAGS },
|
||||
{ "model", "path to model file specifying network architecture and its parameters", OFFSET(dnnctx.model_filename), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
|
||||
{ "input", "input name of the model", OFFSET(dnnctx.model_inputname), AV_OPT_TYPE_STRING, { .str = "x" }, 0, 0, FLAGS },
|
||||
{ "output", "output name of the model", OFFSET(dnnctx.model_outputname), AV_OPT_TYPE_STRING, { .str = "y" }, 0, 0, FLAGS },
|
||||
{ "output", "output name of the model", OFFSET(dnnctx.model_outputnames_string), AV_OPT_TYPE_STRING, { .str = "y" }, 0, 0, FLAGS },
|
||||
{ NULL }
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user