mirror of
https://github.com/FFmpeg/FFmpeg.git
synced 2025-01-24 13:56:33 +02:00
lavfi/dnn: add classify support with openvino backend
Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
This commit is contained in:
parent
a3b74651a0
commit
fc26dca64e
@ -29,6 +29,7 @@
|
||||
#include "libavutil/avassert.h"
|
||||
#include "libavutil/opt.h"
|
||||
#include "libavutil/avstring.h"
|
||||
#include "libavutil/detection_bbox.h"
|
||||
#include "../internal.h"
|
||||
#include "queue.h"
|
||||
#include "safe_queue.h"
|
||||
@ -74,6 +75,7 @@ typedef struct TaskItem {
|
||||
// one task might have multiple inferences
|
||||
typedef struct InferenceItem {
|
||||
TaskItem *task;
|
||||
uint32_t bbox_index;
|
||||
} InferenceItem;
|
||||
|
||||
// one request for one call to openvino
|
||||
@ -182,12 +184,23 @@ static DNNReturnType fill_model_input_ov(OVModel *ov_model, RequestItem *request
|
||||
request->inferences[i] = inference;
|
||||
request->inference_count = i + 1;
|
||||
task = inference->task;
|
||||
if (task->do_ioproc) {
|
||||
if (ov_model->model->frame_pre_proc != NULL) {
|
||||
ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx);
|
||||
} else {
|
||||
ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx);
|
||||
switch (task->ov_model->model->func_type) {
|
||||
case DFT_PROCESS_FRAME:
|
||||
case DFT_ANALYTICS_DETECT:
|
||||
if (task->do_ioproc) {
|
||||
if (ov_model->model->frame_pre_proc != NULL) {
|
||||
ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx);
|
||||
} else {
|
||||
ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case DFT_ANALYTICS_CLASSIFY:
|
||||
ff_frame_to_dnn_classify(task->in_frame, &input, inference->bbox_index, ctx);
|
||||
break;
|
||||
default:
|
||||
av_assert0(!"should not reach here");
|
||||
break;
|
||||
}
|
||||
input.data = (uint8_t *)input.data
|
||||
+ input.width * input.height * input.channels * get_datatype_size(input.dt);
|
||||
@ -276,6 +289,13 @@ static void infer_completion_callback(void *args)
|
||||
}
|
||||
task->ov_model->model->detect_post_proc(task->out_frame, &output, 1, task->ov_model->model->filter_ctx);
|
||||
break;
|
||||
case DFT_ANALYTICS_CLASSIFY:
|
||||
if (!task->ov_model->model->classify_post_proc) {
|
||||
av_log(ctx, AV_LOG_ERROR, "classify filter needs to provide post proc\n");
|
||||
return;
|
||||
}
|
||||
task->ov_model->model->classify_post_proc(task->out_frame, &output, request->inferences[i]->bbox_index, task->ov_model->model->filter_ctx);
|
||||
break;
|
||||
default:
|
||||
av_assert0(!"should not reach here");
|
||||
break;
|
||||
@ -513,7 +533,44 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue)
|
||||
static int contain_valid_detection_bbox(AVFrame *frame)
|
||||
{
|
||||
AVFrameSideData *sd;
|
||||
const AVDetectionBBoxHeader *header;
|
||||
const AVDetectionBBox *bbox;
|
||||
|
||||
sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
|
||||
if (!sd) { // this frame has nothing detected
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (!sd->size) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
header = (const AVDetectionBBoxHeader *)sd->data;
|
||||
if (!header->nb_bboxes) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < header->nb_bboxes; i++) {
|
||||
bbox = av_get_detection_bbox(header, i);
|
||||
if (bbox->x < 0 || bbox->w < 0 || bbox->x + bbox->w >= frame->width) {
|
||||
return 0;
|
||||
}
|
||||
if (bbox->y < 0 || bbox->h < 0 || bbox->y + bbox->h >= frame->width) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (bbox->classify_count == AV_NUM_DETECTION_BBOX_CLASSIFY) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue, DNNExecBaseParams *exec_params)
|
||||
{
|
||||
switch (func_type) {
|
||||
case DFT_PROCESS_FRAME:
|
||||
@ -532,6 +589,45 @@ static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, Task
|
||||
}
|
||||
return DNN_SUCCESS;
|
||||
}
|
||||
case DFT_ANALYTICS_CLASSIFY:
|
||||
{
|
||||
const AVDetectionBBoxHeader *header;
|
||||
AVFrame *frame = task->in_frame;
|
||||
AVFrameSideData *sd;
|
||||
DNNExecClassificationParams *params = (DNNExecClassificationParams *)exec_params;
|
||||
|
||||
task->inference_todo = 0;
|
||||
task->inference_done = 0;
|
||||
|
||||
if (!contain_valid_detection_bbox(frame)) {
|
||||
return DNN_SUCCESS;
|
||||
}
|
||||
|
||||
sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
|
||||
header = (const AVDetectionBBoxHeader *)sd->data;
|
||||
|
||||
for (uint32_t i = 0; i < header->nb_bboxes; i++) {
|
||||
InferenceItem *inference;
|
||||
const AVDetectionBBox *bbox = av_get_detection_bbox(header, i);
|
||||
|
||||
if (av_strncasecmp(bbox->detect_label, params->target, sizeof(bbox->detect_label)) != 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
inference = av_malloc(sizeof(*inference));
|
||||
if (!inference) {
|
||||
return DNN_ERROR;
|
||||
}
|
||||
task->inference_todo++;
|
||||
inference->task = task;
|
||||
inference->bbox_index = i;
|
||||
if (ff_queue_push_back(inference_queue, inference) < 0) {
|
||||
av_freep(&inference);
|
||||
return DNN_ERROR;
|
||||
}
|
||||
}
|
||||
return DNN_SUCCESS;
|
||||
}
|
||||
default:
|
||||
av_assert0(!"should not reach here");
|
||||
return DNN_ERROR;
|
||||
@ -598,7 +694,7 @@ static DNNReturnType get_output_ov(void *model, const char *input_name, int inpu
|
||||
task.out_frame = out_frame;
|
||||
task.ov_model = ov_model;
|
||||
|
||||
if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) {
|
||||
if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, NULL) != DNN_SUCCESS) {
|
||||
av_frame_free(&out_frame);
|
||||
av_frame_free(&in_frame);
|
||||
av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
|
||||
@ -690,6 +786,14 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams *
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
if (model->func_type == DFT_ANALYTICS_CLASSIFY) {
|
||||
// Once we add async support for tensorflow backend and native backend,
|
||||
// we'll combine the two sync/async functions in dnn_interface.h to
|
||||
// simplify the code in filter, and async will be an option within backends.
|
||||
// so, do not support now, and classify filter will not call this function.
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
if (ctx->options.batch_size > 1) {
|
||||
avpriv_report_missing_feature(ctx, "batch mode for sync execution");
|
||||
return DNN_ERROR;
|
||||
@ -710,7 +814,7 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams *
|
||||
task.out_frame = exec_params->out_frame ? exec_params->out_frame : exec_params->in_frame;
|
||||
task.ov_model = ov_model;
|
||||
|
||||
if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) {
|
||||
if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) {
|
||||
av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
|
||||
return DNN_ERROR;
|
||||
}
|
||||
@ -730,6 +834,7 @@ DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa
|
||||
OVContext *ctx = &ov_model->ctx;
|
||||
RequestItem *request;
|
||||
TaskItem *task;
|
||||
DNNReturnType ret;
|
||||
|
||||
if (ff_check_exec_params(ctx, DNN_OV, model->func_type, exec_params) != 0) {
|
||||
return DNN_ERROR;
|
||||
@ -761,23 +866,25 @@ DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
if (extract_inference_from_task(ov_model->model->func_type, task, ov_model->inference_queue) != DNN_SUCCESS) {
|
||||
if (extract_inference_from_task(model->func_type, task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) {
|
||||
av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
if (ff_queue_size(ov_model->inference_queue) < ctx->options.batch_size) {
|
||||
// not enough inference items queued for a batch
|
||||
return DNN_SUCCESS;
|
||||
while (ff_queue_size(ov_model->inference_queue) >= ctx->options.batch_size) {
|
||||
request = ff_safe_queue_pop_front(ov_model->request_queue);
|
||||
if (!request) {
|
||||
av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
ret = execute_model_ov(request, ov_model->inference_queue);
|
||||
if (ret != DNN_SUCCESS) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
request = ff_safe_queue_pop_front(ov_model->request_queue);
|
||||
if (!request) {
|
||||
av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
return execute_model_ov(request, ov_model->inference_queue);
|
||||
return DNN_SUCCESS;
|
||||
}
|
||||
|
||||
DNNAsyncStatusType ff_dnn_get_async_result_ov(const DNNModel *model, AVFrame **in, AVFrame **out)
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "libavutil/imgutils.h"
|
||||
#include "libswscale/swscale.h"
|
||||
#include "libavutil/avassert.h"
|
||||
#include "libavutil/detection_bbox.h"
|
||||
|
||||
DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx)
|
||||
{
|
||||
@ -175,6 +176,65 @@ static enum AVPixelFormat get_pixel_format(DNNData *data)
|
||||
return AV_PIX_FMT_BGR24;
|
||||
}
|
||||
|
||||
DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx)
|
||||
{
|
||||
const AVPixFmtDescriptor *desc;
|
||||
int offsetx[4], offsety[4];
|
||||
uint8_t *bbox_data[4];
|
||||
struct SwsContext *sws_ctx;
|
||||
int linesizes[4];
|
||||
enum AVPixelFormat fmt;
|
||||
int left, top, width, height;
|
||||
const AVDetectionBBoxHeader *header;
|
||||
const AVDetectionBBox *bbox;
|
||||
AVFrameSideData *sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
|
||||
av_assert0(sd);
|
||||
|
||||
header = (const AVDetectionBBoxHeader *)sd->data;
|
||||
bbox = av_get_detection_bbox(header, bbox_index);
|
||||
|
||||
left = bbox->x;
|
||||
width = bbox->w;
|
||||
top = bbox->y;
|
||||
height = bbox->h;
|
||||
|
||||
fmt = get_pixel_format(input);
|
||||
sws_ctx = sws_getContext(width, height, frame->format,
|
||||
input->width, input->height, fmt,
|
||||
SWS_FAST_BILINEAR, NULL, NULL, NULL);
|
||||
if (!sws_ctx) {
|
||||
av_log(log_ctx, AV_LOG_ERROR, "Failed to create scale context for the conversion "
|
||||
"fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
|
||||
av_get_pix_fmt_name(frame->format), width, height,
|
||||
av_get_pix_fmt_name(fmt), input->width, input->height);
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
if (av_image_fill_linesizes(linesizes, fmt, input->width) < 0) {
|
||||
av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes");
|
||||
sws_freeContext(sws_ctx);
|
||||
return DNN_ERROR;
|
||||
}
|
||||
|
||||
desc = av_pix_fmt_desc_get(frame->format);
|
||||
offsetx[1] = offsetx[2] = AV_CEIL_RSHIFT(left, desc->log2_chroma_w);
|
||||
offsetx[0] = offsetx[3] = left;
|
||||
|
||||
offsety[1] = offsety[2] = AV_CEIL_RSHIFT(top, desc->log2_chroma_h);
|
||||
offsety[0] = offsety[3] = top;
|
||||
|
||||
for (int k = 0; frame->data[k]; k++)
|
||||
bbox_data[k] = frame->data[k] + offsety[k] * frame->linesize[k] + offsetx[k];
|
||||
|
||||
sws_scale(sws_ctx, (const uint8_t *const *)&bbox_data, frame->linesize,
|
||||
0, height,
|
||||
(uint8_t *const *)(&input->data), linesizes);
|
||||
|
||||
sws_freeContext(sws_ctx);
|
||||
|
||||
return DNN_SUCCESS;
|
||||
}
|
||||
|
||||
static DNNReturnType proc_from_frame_to_dnn_analytics(AVFrame *frame, DNNData *input, void *log_ctx)
|
||||
{
|
||||
struct SwsContext *sws_ctx;
|
||||
|
@ -32,5 +32,6 @@
|
||||
|
||||
DNNReturnType ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, DNNFunctionType func_type, void *log_ctx);
|
||||
DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx);
|
||||
DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx);
|
||||
|
||||
#endif
|
||||
|
@ -77,6 +77,12 @@ int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc)
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc)
|
||||
{
|
||||
ctx->model->classify_post_proc = post_proc;
|
||||
return 0;
|
||||
}
|
||||
|
||||
DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input)
|
||||
{
|
||||
return ctx->model->get_input(ctx->model->model, input, ctx->model_inputname);
|
||||
@ -112,6 +118,21 @@ DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVF
|
||||
return (ctx->dnn_module->execute_model_async)(ctx->model, &exec_params);
|
||||
}
|
||||
|
||||
DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target)
|
||||
{
|
||||
DNNExecClassificationParams class_params = {
|
||||
{
|
||||
.input_name = ctx->model_inputname,
|
||||
.output_names = (const char **)&ctx->model_outputname,
|
||||
.nb_output = 1,
|
||||
.in_frame = in_frame,
|
||||
.out_frame = out_frame,
|
||||
},
|
||||
.target = target,
|
||||
};
|
||||
return (ctx->dnn_module->execute_model_async)(ctx->model, &class_params.base);
|
||||
}
|
||||
|
||||
DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame)
|
||||
{
|
||||
return (ctx->dnn_module->get_async_result)(ctx->model, in_frame, out_frame);
|
||||
|
@ -50,10 +50,12 @@ typedef struct DnnContext {
|
||||
int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx);
|
||||
int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePostProc post_proc);
|
||||
int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc);
|
||||
int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc);
|
||||
DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input);
|
||||
DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height);
|
||||
DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
|
||||
DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
|
||||
DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target);
|
||||
DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame);
|
||||
DNNReturnType ff_dnn_flush(DnnContext *ctx);
|
||||
void ff_dnn_uninit(DnnContext *ctx);
|
||||
|
@ -52,7 +52,7 @@ typedef enum {
|
||||
DFT_NONE,
|
||||
DFT_PROCESS_FRAME, // process the whole frame
|
||||
DFT_ANALYTICS_DETECT, // detect from the whole frame
|
||||
// we can add more such as detect_from_crop, classify_from_bbox, etc.
|
||||
DFT_ANALYTICS_CLASSIFY, // classify for each bounding box
|
||||
}DNNFunctionType;
|
||||
|
||||
typedef struct DNNData{
|
||||
@ -71,8 +71,14 @@ typedef struct DNNExecBaseParams {
|
||||
AVFrame *out_frame;
|
||||
} DNNExecBaseParams;
|
||||
|
||||
typedef struct DNNExecClassificationParams {
|
||||
DNNExecBaseParams base;
|
||||
const char *target;
|
||||
} DNNExecClassificationParams;
|
||||
|
||||
typedef int (*FramePrePostProc)(AVFrame *frame, DNNData *model, AVFilterContext *filter_ctx);
|
||||
typedef int (*DetectPostProc)(AVFrame *frame, DNNData *output, uint32_t nb, AVFilterContext *filter_ctx);
|
||||
typedef int (*ClassifyPostProc)(AVFrame *frame, DNNData *output, uint32_t bbox_index, AVFilterContext *filter_ctx);
|
||||
|
||||
typedef struct DNNModel{
|
||||
// Stores model that can be different for different backends.
|
||||
@ -97,6 +103,8 @@ typedef struct DNNModel{
|
||||
FramePrePostProc frame_post_proc;
|
||||
// set the post process to interpret detect result from DNNData
|
||||
DetectPostProc detect_post_proc;
|
||||
// set the post process to interpret classify result from DNNData
|
||||
ClassifyPostProc classify_post_proc;
|
||||
} DNNModel;
|
||||
|
||||
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
|
||||
|
Loading…
x
Reference in New Issue
Block a user