1
0
mirror of https://github.com/FFmpeg/FFmpeg.git synced 2025-01-13 21:28:01 +02:00

dnn: put DNNModel.set_input and DNNModule.execute_model together

suppose we have a detect and classify filter in the future, the
detect filter generates some bounding boxes (BBox) as AVFrame sidedata,
and the classify filter executes DNN model for each BBox. For each
BBox, we need to crop the AVFrame, copy data to DNN model input and do
the model execution. So we have to save the in_frame at DNNModel.set_input
and use it at DNNModule.execute_model, such saving is not feasible
when we support async execute_model.

This patch sets the in_frame as execution_model parameter, and so
all the information are put together within the same function for
each inference. It also makes easy to support BBox async inference.
This commit is contained in:
Guo, Yejun 2020-09-10 22:29:57 +08:00
parent 2003e32f62
commit fce3e3e137
10 changed files with 157 additions and 236 deletions

View File

@ -70,64 +70,6 @@ static DNNReturnType get_input_native(void *model, DNNData *input, const char *i
return DNN_ERROR; return DNN_ERROR;
} }
static DNNReturnType set_input_native(void *model, AVFrame *frame, const char *input_name)
{
NativeModel *native_model = (NativeModel *)model;
NativeContext *ctx = &native_model->ctx;
DnnOperand *oprd = NULL;
DNNData input;
if (native_model->layers_num <= 0 || native_model->operands_num <= 0) {
av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n");
return DNN_ERROR;
}
/* inputs */
for (int i = 0; i < native_model->operands_num; ++i) {
oprd = &native_model->operands[i];
if (strcmp(oprd->name, input_name) == 0) {
if (oprd->type != DOT_INPUT) {
av_log(ctx, AV_LOG_ERROR, "Found \"%s\" in model, but it is not input node\n", input_name);
return DNN_ERROR;
}
break;
}
oprd = NULL;
}
if (!oprd) {
av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
return DNN_ERROR;
}
oprd->dims[1] = frame->height;
oprd->dims[2] = frame->width;
av_freep(&oprd->data);
oprd->length = calculate_operand_data_length(oprd);
if (oprd->length <= 0) {
av_log(ctx, AV_LOG_ERROR, "The input data length overflow\n");
return DNN_ERROR;
}
oprd->data = av_malloc(oprd->length);
if (!oprd->data) {
av_log(ctx, AV_LOG_ERROR, "Failed to malloc memory for input data\n");
return DNN_ERROR;
}
input.height = oprd->dims[1];
input.width = oprd->dims[2];
input.channels = oprd->dims[3];
input.data = oprd->data;
input.dt = oprd->data_type;
if (native_model->model->pre_proc != NULL) {
native_model->model->pre_proc(frame, &input, native_model->model->userdata);
} else {
proc_from_frame_to_dnn(frame, &input, ctx);
}
return DNN_SUCCESS;
}
// Loads model and its parameters that are stored in a binary file with following structure: // Loads model and its parameters that are stored in a binary file with following structure:
// layers_num,layer_type,layer_parameterss,layer_type,layer_parameters... // layers_num,layer_type,layer_parameterss,layer_type,layer_parameters...
// For CONV layer: activation_function, input_num, output_num, kernel_size, kernel, biases // For CONV layer: activation_function, input_num, output_num, kernel_size, kernel, biases
@ -273,7 +215,6 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *optio
return NULL; return NULL;
} }
model->set_input = &set_input_native;
model->get_input = &get_input_native; model->get_input = &get_input_native;
model->userdata = userdata; model->userdata = userdata;
@ -285,26 +226,66 @@ fail:
return NULL; return NULL;
} }
DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame)
{ {
NativeModel *native_model = (NativeModel *)model->model; NativeModel *native_model = (NativeModel *)model->model;
NativeContext *ctx = &native_model->ctx; NativeContext *ctx = &native_model->ctx;
int32_t layer; int32_t layer;
DNNData output; DNNData input, output;
DnnOperand *oprd = NULL;
if (nb_output != 1) {
// currently, the filter does not need multiple outputs,
// so we just pending the support until we really need it.
av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n");
return DNN_ERROR;
}
if (native_model->layers_num <= 0 || native_model->operands_num <= 0) { if (native_model->layers_num <= 0 || native_model->operands_num <= 0) {
av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n"); av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n");
return DNN_ERROR; return DNN_ERROR;
} }
if (!native_model->operands[0].data) {
av_log(ctx, AV_LOG_ERROR, "Empty model input data\n"); for (int i = 0; i < native_model->operands_num; ++i) {
oprd = &native_model->operands[i];
if (strcmp(oprd->name, input_name) == 0) {
if (oprd->type != DOT_INPUT) {
av_log(ctx, AV_LOG_ERROR, "Found \"%s\" in model, but it is not input node\n", input_name);
return DNN_ERROR;
}
break;
}
oprd = NULL;
}
if (!oprd) {
av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
return DNN_ERROR;
}
oprd->dims[1] = in_frame->height;
oprd->dims[2] = in_frame->width;
av_freep(&oprd->data);
oprd->length = calculate_operand_data_length(oprd);
if (oprd->length <= 0) {
av_log(ctx, AV_LOG_ERROR, "The input data length overflow\n");
return DNN_ERROR;
}
oprd->data = av_malloc(oprd->length);
if (!oprd->data) {
av_log(ctx, AV_LOG_ERROR, "Failed to malloc memory for input data\n");
return DNN_ERROR;
}
input.height = oprd->dims[1];
input.width = oprd->dims[2];
input.channels = oprd->dims[3];
input.data = oprd->data;
input.dt = oprd->data_type;
if (native_model->model->pre_proc != NULL) {
native_model->model->pre_proc(in_frame, &input, native_model->model->userdata);
} else {
proc_from_frame_to_dnn(in_frame, &input, ctx);
}
if (nb_output != 1) {
// currently, the filter does not need multiple outputs,
// so we just pending the support until we really need it.
av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n");
return DNN_ERROR; return DNN_ERROR;
} }

View File

@ -128,7 +128,8 @@ typedef struct NativeModel{
DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, void *userdata); DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, void *userdata);
DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame);
void ff_dnn_free_model_native(DNNModel **model); void ff_dnn_free_model_native(DNNModel **model);

View File

@ -48,7 +48,6 @@ typedef struct OVModel{
ie_network_t *network; ie_network_t *network;
ie_executable_network_t *exe_network; ie_executable_network_t *exe_network;
ie_infer_request_t *infer_request; ie_infer_request_t *infer_request;
ie_blob_t *input_blob;
} OVModel; } OVModel;
#define APPEND_STRING(generated_string, iterate_string) \ #define APPEND_STRING(generated_string, iterate_string) \
@ -133,49 +132,6 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input
return DNN_ERROR; return DNN_ERROR;
} }
static DNNReturnType set_input_ov(void *model, AVFrame *frame, const char *input_name)
{
OVModel *ov_model = (OVModel *)model;
OVContext *ctx = &ov_model->ctx;
IEStatusCode status;
dimensions_t dims;
precision_e precision;
ie_blob_buffer_t blob_buffer;
DNNData input;
status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &ov_model->input_blob);
if (status != OK)
goto err;
status |= ie_blob_get_dims(ov_model->input_blob, &dims);
status |= ie_blob_get_precision(ov_model->input_blob, &precision);
if (status != OK)
goto err;
status = ie_blob_get_buffer(ov_model->input_blob, &blob_buffer);
if (status != OK)
goto err;
input.height = dims.dims[2];
input.width = dims.dims[3];
input.channels = dims.dims[1];
input.data = blob_buffer.buffer;
input.dt = precision_to_datatype(precision);
if (ov_model->model->pre_proc != NULL) {
ov_model->model->pre_proc(frame, &input, ov_model->model->userdata);
} else {
proc_from_frame_to_dnn(frame, &input, ctx);
}
return DNN_SUCCESS;
err:
if (ov_model->input_blob)
ie_blob_free(&ov_model->input_blob);
av_log(ctx, AV_LOG_ERROR, "Failed to create inference instance or get input data/dims/precision/memory\n");
return DNN_ERROR;
}
DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata) DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata)
{ {
char *all_dev_names = NULL; char *all_dev_names = NULL;
@ -234,7 +190,6 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options,
goto err; goto err;
model->model = (void *)ov_model; model->model = (void *)ov_model;
model->set_input = &set_input_ov;
model->get_input = &get_input_ov; model->get_input = &get_input_ov;
model->options = options; model->options = options;
model->userdata = userdata; model->userdata = userdata;
@ -258,7 +213,8 @@ err:
return NULL; return NULL;
} }
DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame)
{ {
char *model_output_name = NULL; char *model_output_name = NULL;
char *all_output_names = NULL; char *all_output_names = NULL;
@ -269,7 +225,39 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output
OVContext *ctx = &ov_model->ctx; OVContext *ctx = &ov_model->ctx;
IEStatusCode status; IEStatusCode status;
size_t model_output_count = 0; size_t model_output_count = 0;
DNNData output; DNNData input, output;
ie_blob_t *input_blob = NULL;
status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &input_blob);
if (status != OK) {
av_log(ctx, AV_LOG_ERROR, "Failed to get input blob\n");
return DNN_ERROR;
}
status |= ie_blob_get_dims(input_blob, &dims);
status |= ie_blob_get_precision(input_blob, &precision);
if (status != OK) {
av_log(ctx, AV_LOG_ERROR, "Failed to get input blob dims/precision\n");
return DNN_ERROR;
}
status = ie_blob_get_buffer(input_blob, &blob_buffer);
if (status != OK) {
av_log(ctx, AV_LOG_ERROR, "Failed to get input blob buffer\n");
return DNN_ERROR;
}
input.height = dims.dims[2];
input.width = dims.dims[3];
input.channels = dims.dims[1];
input.data = blob_buffer.buffer;
input.dt = precision_to_datatype(precision);
if (ov_model->model->pre_proc != NULL) {
ov_model->model->pre_proc(in_frame, &input, ov_model->model->userdata);
} else {
proc_from_frame_to_dnn(in_frame, &input, ctx);
}
ie_blob_free(&input_blob);
if (nb_output != 1) { if (nb_output != 1) {
// currently, the filter does not need multiple outputs, // currently, the filter does not need multiple outputs,
@ -330,6 +318,7 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output
proc_from_dnn_to_frame(out_frame, &output, ctx); proc_from_dnn_to_frame(out_frame, &output, ctx);
} }
} }
ie_blob_free(&output_blob);
} }
return DNN_SUCCESS; return DNN_SUCCESS;
@ -339,8 +328,6 @@ void ff_dnn_free_model_ov(DNNModel **model)
{ {
if (*model){ if (*model){
OVModel *ov_model = (OVModel *)(*model)->model; OVModel *ov_model = (OVModel *)(*model)->model;
if (ov_model->input_blob)
ie_blob_free(&ov_model->input_blob);
if (ov_model->infer_request) if (ov_model->infer_request)
ie_infer_request_free(&ov_model->infer_request); ie_infer_request_free(&ov_model->infer_request);
if (ov_model->exe_network) if (ov_model->exe_network)

View File

@ -31,7 +31,8 @@
DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata); DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata);
DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame);
void ff_dnn_free_model_ov(DNNModel **model); void ff_dnn_free_model_ov(DNNModel **model);

View File

@ -45,8 +45,6 @@ typedef struct TFModel{
TF_Graph *graph; TF_Graph *graph;
TF_Session *session; TF_Session *session;
TF_Status *status; TF_Status *status;
TF_Output input;
TF_Tensor *input_tensor;
} TFModel; } TFModel;
static const AVClass dnn_tensorflow_class = { static const AVClass dnn_tensorflow_class = {
@ -152,48 +150,33 @@ static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input
return DNN_SUCCESS; return DNN_SUCCESS;
} }
static DNNReturnType set_input_tf(void *model, AVFrame *frame, const char *input_name) static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename)
{ {
TFModel *tf_model = (TFModel *)model;
TFContext *ctx = &tf_model->ctx; TFContext *ctx = &tf_model->ctx;
DNNData input; TF_Buffer *graph_def;
TF_ImportGraphDefOptions *graph_opts;
TF_SessionOptions *sess_opts; TF_SessionOptions *sess_opts;
const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); const TF_Operation *init_op;
if (get_input_tf(model, &input, input_name) != DNN_SUCCESS) graph_def = read_graph(model_filename);
return DNN_ERROR; if (!graph_def){
input.height = frame->height; av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename);
input.width = frame->width;
// Input operation
tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
if (!tf_model->input.oper){
av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
return DNN_ERROR; return DNN_ERROR;
} }
tf_model->input.index = 0; tf_model->graph = TF_NewGraph();
if (tf_model->input_tensor){ tf_model->status = TF_NewStatus();
TF_DeleteTensor(tf_model->input_tensor); graph_opts = TF_NewImportGraphDefOptions();
} TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
tf_model->input_tensor = allocate_input_tensor(&input); TF_DeleteImportGraphDefOptions(graph_opts);
if (!tf_model->input_tensor){ TF_DeleteBuffer(graph_def);
av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n"); if (TF_GetCode(tf_model->status) != TF_OK){
TF_DeleteGraph(tf_model->graph);
TF_DeleteStatus(tf_model->status);
av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n");
return DNN_ERROR; return DNN_ERROR;
} }
input.data = (float *)TF_TensorData(tf_model->input_tensor);
if (tf_model->model->pre_proc != NULL) {
tf_model->model->pre_proc(frame, &input, tf_model->model->userdata);
} else {
proc_from_frame_to_dnn(frame, &input, ctx);
}
// session
if (tf_model->session){
TF_CloseSession(tf_model->session, tf_model->status);
TF_DeleteSession(tf_model->session, tf_model->status);
}
init_op = TF_GraphOperationByName(tf_model->graph, "init");
sess_opts = TF_NewSessionOptions(); sess_opts = TF_NewSessionOptions();
tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status); tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status);
TF_DeleteSessionOptions(sess_opts); TF_DeleteSessionOptions(sess_opts);
@ -219,33 +202,6 @@ static DNNReturnType set_input_tf(void *model, AVFrame *frame, const char *input
return DNN_SUCCESS; return DNN_SUCCESS;
} }
static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename)
{
TFContext *ctx = &tf_model->ctx;
TF_Buffer *graph_def;
TF_ImportGraphDefOptions *graph_opts;
graph_def = read_graph(model_filename);
if (!graph_def){
av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename);
return DNN_ERROR;
}
tf_model->graph = TF_NewGraph();
tf_model->status = TF_NewStatus();
graph_opts = TF_NewImportGraphDefOptions();
TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status);
TF_DeleteImportGraphDefOptions(graph_opts);
TF_DeleteBuffer(graph_def);
if (TF_GetCode(tf_model->status) != TF_OK){
TF_DeleteGraph(tf_model->graph);
TF_DeleteStatus(tf_model->status);
av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n");
return DNN_ERROR;
}
return DNN_SUCCESS;
}
#define NAME_BUFFER_SIZE 256 #define NAME_BUFFER_SIZE 256
static DNNReturnType add_conv_layer(TFModel *tf_model, TF_Operation *transpose_op, TF_Operation **cur_op, static DNNReturnType add_conv_layer(TFModel *tf_model, TF_Operation *transpose_op, TF_Operation **cur_op,
@ -626,7 +582,6 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
} }
model->model = (void *)tf_model; model->model = (void *)tf_model;
model->set_input = &set_input_tf;
model->get_input = &get_input_tf; model->get_input = &get_input_tf;
model->options = options; model->options = options;
model->userdata = userdata; model->userdata = userdata;
@ -634,13 +589,40 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
return model; return model;
} }
DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame)
{ {
TF_Output *tf_outputs; TF_Output *tf_outputs;
TFModel *tf_model = (TFModel *)model->model; TFModel *tf_model = (TFModel *)model->model;
TFContext *ctx = &tf_model->ctx; TFContext *ctx = &tf_model->ctx;
DNNData output; DNNData input, output;
TF_Tensor **output_tensors; TF_Tensor **output_tensors;
TF_Output tf_input;
TF_Tensor *input_tensor;
if (get_input_tf(tf_model, &input, input_name) != DNN_SUCCESS)
return DNN_ERROR;
input.height = in_frame->height;
input.width = in_frame->width;
tf_input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
if (!tf_input.oper){
av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name);
return DNN_ERROR;
}
tf_input.index = 0;
input_tensor = allocate_input_tensor(&input);
if (!input_tensor){
av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n");
return DNN_ERROR;
}
input.data = (float *)TF_TensorData(input_tensor);
if (tf_model->model->pre_proc != NULL) {
tf_model->model->pre_proc(in_frame, &input, tf_model->model->userdata);
} else {
proc_from_frame_to_dnn(in_frame, &input, ctx);
}
if (nb_output != 1) { if (nb_output != 1) {
// currently, the filter does not need multiple outputs, // currently, the filter does not need multiple outputs,
@ -674,7 +656,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output
} }
TF_SessionRun(tf_model->session, NULL, TF_SessionRun(tf_model->session, NULL,
&tf_model->input, &tf_model->input_tensor, 1, &tf_input, &input_tensor, 1,
tf_outputs, output_tensors, nb_output, tf_outputs, output_tensors, nb_output,
NULL, 0, NULL, tf_model->status); NULL, 0, NULL, tf_model->status);
if (TF_GetCode(tf_model->status) != TF_OK) { if (TF_GetCode(tf_model->status) != TF_OK) {
@ -708,6 +690,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output
TF_DeleteTensor(output_tensors[i]); TF_DeleteTensor(output_tensors[i]);
} }
} }
TF_DeleteTensor(input_tensor);
av_freep(&output_tensors); av_freep(&output_tensors);
av_freep(&tf_outputs); av_freep(&tf_outputs);
return DNN_SUCCESS; return DNN_SUCCESS;
@ -729,9 +712,6 @@ void ff_dnn_free_model_tf(DNNModel **model)
if (tf_model->status){ if (tf_model->status){
TF_DeleteStatus(tf_model->status); TF_DeleteStatus(tf_model->status);
} }
if (tf_model->input_tensor){
TF_DeleteTensor(tf_model->input_tensor);
}
av_freep(&tf_model); av_freep(&tf_model);
av_freep(model); av_freep(model);
} }

View File

@ -31,7 +31,8 @@
DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, void *userdata); DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, void *userdata);
DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame);
void ff_dnn_free_model_tf(DNNModel **model); void ff_dnn_free_model_tf(DNNModel **model);

View File

@ -51,9 +51,6 @@ typedef struct DNNModel{
// Gets model input information // Gets model input information
// Just reuse struct DNNData here, actually the DNNData.data field is not needed. // Just reuse struct DNNData here, actually the DNNData.data field is not needed.
DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name); DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name);
// Sets model input.
// Should be called every time before model execution.
DNNReturnType (*set_input)(void *model, AVFrame *frame, const char *input_name);
// set the pre process to transfer data from AVFrame to DNNData // set the pre process to transfer data from AVFrame to DNNData
// the default implementation within DNN is used if it is not provided by the filter // the default implementation within DNN is used if it is not provided by the filter
int (*pre_proc)(AVFrame *frame_in, DNNData *model_input, void *user_data); int (*pre_proc)(AVFrame *frame_in, DNNData *model_input, void *user_data);
@ -66,8 +63,9 @@ typedef struct DNNModel{
typedef struct DNNModule{ typedef struct DNNModule{
// Loads model and parameters from given file. Returns NULL if it is not possible. // Loads model and parameters from given file. Returns NULL if it is not possible.
DNNModel *(*load_model)(const char *model_filename, const char *options, void *userdata); DNNModel *(*load_model)(const char *model_filename, const char *options, void *userdata);
// Executes model with specified output. Returns DNN_ERROR otherwise. // Executes model with specified input and output. Returns DNN_ERROR otherwise.
DNNReturnType (*execute_model)(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); DNNReturnType (*execute_model)(const DNNModel *model, const char *input_name, AVFrame *in_frame,
const char **output_names, uint32_t nb_output, AVFrame *out_frame);
// Frees memory allocated for model. // Frees memory allocated for model.
void (*free_model)(DNNModel **model); void (*free_model)(DNNModel **model);
} DNNModule; } DNNModule;

View File

@ -80,13 +80,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
const char *model_output_name = "y"; const char *model_output_name = "y";
AVFrame *out; AVFrame *out;
dnn_result = (dr_context->model->set_input)(dr_context->model->model, in, "x");
if (dnn_result != DNN_SUCCESS) {
av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
av_frame_free(&in);
return AVERROR(EIO);
}
out = ff_get_video_buffer(outlink, outlink->w, outlink->h); out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
if (!out) { if (!out) {
av_log(ctx, AV_LOG_ERROR, "could not allocate memory for output frame\n"); av_log(ctx, AV_LOG_ERROR, "could not allocate memory for output frame\n");
@ -95,7 +88,7 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
} }
av_frame_copy_props(out, in); av_frame_copy_props(out, in);
dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, &model_output_name, 1, out); dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, "x", in, &model_output_name, 1, out);
if (dnn_result != DNN_SUCCESS){ if (dnn_result != DNN_SUCCESS){
av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
av_frame_free(&in); av_frame_free(&in);

View File

@ -236,15 +236,11 @@ static int config_output(AVFilterLink *outlink)
AVFrame *out = NULL; AVFrame *out = NULL;
AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h); AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
result = (ctx->model->set_input)(ctx->model->model, fake_in, ctx->model_inputname);
if (result != DNN_SUCCESS) {
av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
return AVERROR(EIO);
}
// have a try run in case that the dnn model resize the frame // have a try run in case that the dnn model resize the frame
out = ff_get_video_buffer(inlink, inlink->w, inlink->h); out = ff_get_video_buffer(inlink, inlink->w, inlink->h);
result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out); result = (ctx->dnn_module->execute_model)(ctx->model, ctx->model_inputname, fake_in,
(const char **)&ctx->model_outputname, 1, out);
if (result != DNN_SUCCESS){ if (result != DNN_SUCCESS){
av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
return AVERROR(EIO); return AVERROR(EIO);
@ -293,13 +289,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
DNNReturnType dnn_result; DNNReturnType dnn_result;
AVFrame *out; AVFrame *out;
dnn_result = (ctx->model->set_input)(ctx->model->model, in, ctx->model_inputname);
if (dnn_result != DNN_SUCCESS) {
av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
av_frame_free(&in);
return AVERROR(EIO);
}
out = ff_get_video_buffer(outlink, outlink->w, outlink->h); out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
if (!out) { if (!out) {
av_frame_free(&in); av_frame_free(&in);
@ -307,7 +296,8 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
} }
av_frame_copy_props(out, in); av_frame_copy_props(out, in);
dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out); dnn_result = (ctx->dnn_module->execute_model)(ctx->model, ctx->model_inputname, in,
(const char **)&ctx->model_outputname, 1, out);
if (dnn_result != DNN_SUCCESS){ if (dnn_result != DNN_SUCCESS){
av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
av_frame_free(&in); av_frame_free(&in);

View File

@ -114,16 +114,11 @@ static int config_output(AVFilterLink *outlink)
AVFrame *out = NULL; AVFrame *out = NULL;
const char *model_output_name = "y"; const char *model_output_name = "y";
AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
result = (ctx->model->set_input)(ctx->model->model, fake_in, "x");
if (result != DNN_SUCCESS) {
av_log(context, AV_LOG_ERROR, "could not set input for the model\n");
return AVERROR(EIO);
}
// have a try run in case that the dnn model resize the frame // have a try run in case that the dnn model resize the frame
AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
out = ff_get_video_buffer(inlink, inlink->w, inlink->h); out = ff_get_video_buffer(inlink, inlink->w, inlink->h);
result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out); result = (ctx->dnn_module->execute_model)(ctx->model, "x", fake_in,
(const char **)&model_output_name, 1, out);
if (result != DNN_SUCCESS){ if (result != DNN_SUCCESS){
av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n"); av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
return AVERROR(EIO); return AVERROR(EIO);
@ -178,19 +173,13 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
sws_scale(ctx->sws_pre_scale, sws_scale(ctx->sws_pre_scale,
(const uint8_t **)in->data, in->linesize, 0, in->height, (const uint8_t **)in->data, in->linesize, 0, in->height,
out->data, out->linesize); out->data, out->linesize);
dnn_result = (ctx->model->set_input)(ctx->model->model, out, "x"); dnn_result = (ctx->dnn_module->execute_model)(ctx->model, "x", out,
(const char **)&model_output_name, 1, out);
} else { } else {
dnn_result = (ctx->model->set_input)(ctx->model->model, in, "x"); dnn_result = (ctx->dnn_module->execute_model)(ctx->model, "x", in,
(const char **)&model_output_name, 1, out);
} }
if (dnn_result != DNN_SUCCESS) {
av_frame_free(&in);
av_frame_free(&out);
av_log(context, AV_LOG_ERROR, "could not set input for the model\n");
return AVERROR(EIO);
}
dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out);
if (dnn_result != DNN_SUCCESS){ if (dnn_result != DNN_SUCCESS){
av_log(ctx, AV_LOG_ERROR, "failed to execute loaded model\n"); av_log(ctx, AV_LOG_ERROR, "failed to execute loaded model\n");
av_frame_free(&in); av_frame_free(&in);