mirror of
https://github.com/FFmpeg/FFmpeg.git
synced 2024-12-23 12:43:46 +02:00
avfilter/dnn: add a new interface to query dnn model's input info
to support dnn networks more general, we need to know the input info of the dnn model. background: The data type of dnn model's input could be float32, uint8 or fp16, etc. And the w/h of input image could be fixed or variable. Signed-off-by: Guo, Yejun <yejun.guo@intel.com> Signed-off-by: Pedro Arthur <bygrandao@gmail.com>
This commit is contained in:
parent
e1b45b8596
commit
f4b3c0e55c
@ -28,6 +28,28 @@
|
|||||||
#include "dnn_backend_native_layer_conv2d.h"
|
#include "dnn_backend_native_layer_conv2d.h"
|
||||||
#include "dnn_backend_native_layers.h"
|
#include "dnn_backend_native_layers.h"
|
||||||
|
|
||||||
|
static DNNReturnType get_input_native(void *model, DNNData *input, const char *input_name)
|
||||||
|
{
|
||||||
|
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
|
||||||
|
|
||||||
|
for (int i = 0; i < network->operands_num; ++i) {
|
||||||
|
DnnOperand *oprd = &network->operands[i];
|
||||||
|
if (strcmp(oprd->name, input_name) == 0) {
|
||||||
|
if (oprd->type != DOT_INPUT)
|
||||||
|
return DNN_ERROR;
|
||||||
|
input->dt = oprd->data_type;
|
||||||
|
av_assert0(oprd->dims[0] == 1);
|
||||||
|
input->height = oprd->dims[1];
|
||||||
|
input->width = oprd->dims[2];
|
||||||
|
input->channels = oprd->dims[3];
|
||||||
|
return DNN_SUCCESS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// do not find the input operand
|
||||||
|
return DNN_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
|
static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
|
||||||
{
|
{
|
||||||
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
|
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
|
||||||
@ -37,7 +59,6 @@ static DNNReturnType set_input_output_native(void *model, DNNData *input, const
|
|||||||
return DNN_ERROR;
|
return DNN_ERROR;
|
||||||
|
|
||||||
/* inputs */
|
/* inputs */
|
||||||
av_assert0(input->dt == DNN_FLOAT);
|
|
||||||
for (int i = 0; i < network->operands_num; ++i) {
|
for (int i = 0; i < network->operands_num; ++i) {
|
||||||
oprd = &network->operands[i];
|
oprd = &network->operands[i];
|
||||||
if (strcmp(oprd->name, input_name) == 0) {
|
if (strcmp(oprd->name, input_name) == 0) {
|
||||||
@ -234,6 +255,7 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename)
|
|||||||
}
|
}
|
||||||
|
|
||||||
model->set_input_output = &set_input_output_native;
|
model->set_input_output = &set_input_output_native;
|
||||||
|
model->get_input = &get_input_native;
|
||||||
|
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
@ -105,6 +105,37 @@ static TF_Tensor *allocate_input_tensor(const DNNData *input)
|
|||||||
input_dims[1] * input_dims[2] * input_dims[3] * size);
|
input_dims[1] * input_dims[2] * input_dims[3] * size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input_name)
|
||||||
|
{
|
||||||
|
TFModel *tf_model = (TFModel *)model;
|
||||||
|
TF_Status *status;
|
||||||
|
int64_t dims[4];
|
||||||
|
|
||||||
|
TF_Output tf_output;
|
||||||
|
tf_output.oper = TF_GraphOperationByName(tf_model->graph, input_name);
|
||||||
|
if (!tf_output.oper)
|
||||||
|
return DNN_ERROR;
|
||||||
|
|
||||||
|
tf_output.index = 0;
|
||||||
|
input->dt = TF_OperationOutputType(tf_output);
|
||||||
|
|
||||||
|
status = TF_NewStatus();
|
||||||
|
TF_GraphGetTensorShape(tf_model->graph, tf_output, dims, 4, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK){
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return DNN_ERROR;
|
||||||
|
}
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
// currently only NHWC is supported
|
||||||
|
av_assert0(dims[0] == 1);
|
||||||
|
input->height = dims[1];
|
||||||
|
input->width = dims[2];
|
||||||
|
input->channels = dims[3];
|
||||||
|
|
||||||
|
return DNN_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
|
static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
|
||||||
{
|
{
|
||||||
TFModel *tf_model = (TFModel *)model;
|
TFModel *tf_model = (TFModel *)model;
|
||||||
@ -568,6 +599,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
|
|||||||
|
|
||||||
model->model = (void *)tf_model;
|
model->model = (void *)tf_model;
|
||||||
model->set_input_output = &set_input_output_tf;
|
model->set_input_output = &set_input_output_tf;
|
||||||
|
model->get_input = &get_input_tf;
|
||||||
|
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
@ -43,6 +43,9 @@ typedef struct DNNData{
|
|||||||
typedef struct DNNModel{
|
typedef struct DNNModel{
|
||||||
// Stores model that can be different for different backends.
|
// Stores model that can be different for different backends.
|
||||||
void *model;
|
void *model;
|
||||||
|
// Gets model input information
|
||||||
|
// Just reuse struct DNNData here, actually the DNNData.data field is not needed.
|
||||||
|
DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name);
|
||||||
// Sets model input and output.
|
// Sets model input and output.
|
||||||
// Should be called at least once before model execution.
|
// Should be called at least once before model execution.
|
||||||
DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output);
|
DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output);
|
||||||
|
Loading…
Reference in New Issue
Block a user