mirror of
https://github.com/FFmpeg/FFmpeg.git
synced 2024-12-23 12:43:46 +02:00
libavfilter/dnn: avoid memcpy for tensorflow dnn output
use TF_Tensor's cpu address to avoid extra memcpy. Signed-off-by: Guo, Yejun <yejun.guo@intel.com> Signed-off-by: Pedro Arthur <bygrandao@gmail.com>
This commit is contained in:
parent
e2b92896c4
commit
7adfb6132e
@ -35,6 +35,7 @@ typedef struct TFModel{
|
||||
TF_Status *status;
|
||||
TF_Output input, output;
|
||||
TF_Tensor *input_tensor;
|
||||
TF_Tensor *output_tensor;
|
||||
} TFModel;
|
||||
|
||||
static void free_buffer(void *data, size_t length)
|
||||
@ -460,13 +461,11 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
|
||||
return NULL;
|
||||
}
|
||||
|
||||
tf_model = av_malloc(sizeof(TFModel));
|
||||
tf_model = av_mallocz(sizeof(TFModel));
|
||||
if (!tf_model){
|
||||
av_freep(&model);
|
||||
return NULL;
|
||||
}
|
||||
tf_model->session = NULL;
|
||||
tf_model->input_tensor = NULL;
|
||||
|
||||
if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){
|
||||
if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){
|
||||
@ -488,36 +487,22 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
|
||||
DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *output)
|
||||
{
|
||||
TFModel *tf_model = (TFModel *)model->model;
|
||||
TF_Tensor *output_tensor;
|
||||
uint64_t count;
|
||||
uint64_t old_count = output->height * output->width * output->channels * sizeof(float);
|
||||
if (tf_model->output_tensor)
|
||||
TF_DeleteTensor(tf_model->output_tensor);
|
||||
|
||||
TF_SessionRun(tf_model->session, NULL,
|
||||
&tf_model->input, &tf_model->input_tensor, 1,
|
||||
&tf_model->output, &output_tensor, 1,
|
||||
&tf_model->output, &tf_model->output_tensor, 1,
|
||||
NULL, 0, NULL, tf_model->status);
|
||||
|
||||
if (TF_GetCode(tf_model->status) != TF_OK){
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
output->height = TF_Dim(output_tensor, 1);
|
||||
output->width = TF_Dim(output_tensor, 2);
|
||||
output->channels = TF_Dim(output_tensor, 3);
|
||||
count = output->height * output->width * output->channels * sizeof(float);
|
||||
if (output->data) {
|
||||
if (count > old_count) {
|
||||
av_freep(&output->data);
|
||||
}
|
||||
}
|
||||
if (!output->data) {
|
||||
output->data = av_malloc(count);
|
||||
if (!output->data){
|
||||
return DNN_ERROR;
|
||||
}
|
||||
}
|
||||
memcpy(output->data, TF_TensorData(output_tensor), count);
|
||||
TF_DeleteTensor(output_tensor);
|
||||
output->height = TF_Dim(tf_model->output_tensor, 1);
|
||||
output->width = TF_Dim(tf_model->output_tensor, 2);
|
||||
output->channels = TF_Dim(tf_model->output_tensor, 3);
|
||||
output->data = TF_TensorData(tf_model->output_tensor);
|
||||
|
||||
return DNN_SUCCESS;
|
||||
}
|
||||
@ -541,6 +526,9 @@ void ff_dnn_free_model_tf(DNNModel **model)
|
||||
if (tf_model->input_tensor){
|
||||
TF_DeleteTensor(tf_model->input_tensor);
|
||||
}
|
||||
if (tf_model->output_tensor){
|
||||
TF_DeleteTensor(tf_model->output_tensor);
|
||||
}
|
||||
av_freep(&tf_model);
|
||||
av_freep(model);
|
||||
}
|
||||
|
@ -274,9 +274,6 @@ static av_cold void uninit(AVFilterContext *context)
|
||||
int i;
|
||||
SRContext *sr_context = context->priv;
|
||||
|
||||
if (sr_context->backend_type == DNN_TF)
|
||||
av_freep(&sr_context->output.data);
|
||||
|
||||
if (sr_context->dnn_module){
|
||||
(sr_context->dnn_module->free_model)(&sr_context->model);
|
||||
av_freep(&sr_context->dnn_module);
|
||||
|
Loading…
Reference in New Issue
Block a user