mirror of
https://github.com/FFmpeg/FFmpeg.git
synced 2024-12-28 20:53:54 +02:00
libavfilter/vf_dnn_detect: Add yolov4 support
The difference of yolov4 is that sigmoid function needed to be applied on x, y coordinates. Also make it compatiple with NHWC output as the yolov4 model from openvino model zoo has NHWC output layout. Model refer to: https://github.com/openvinotoolkit/open_model_zoo/tree/master/models/public/yolo-v4-tf Signed-off-by: Wenbin Chen <wenbin.chen@intel.com> Reviewed-by: Guo Yejun <yejun.guo@intel.com>
This commit is contained in:
parent
a882fc0294
commit
1fa3346c70
@ -35,7 +35,8 @@
|
|||||||
typedef enum {
|
typedef enum {
|
||||||
DDMT_SSD,
|
DDMT_SSD,
|
||||||
DDMT_YOLOV1V2,
|
DDMT_YOLOV1V2,
|
||||||
DDMT_YOLOV3
|
DDMT_YOLOV3,
|
||||||
|
DDMT_YOLOV4
|
||||||
} DNNDetectionModelType;
|
} DNNDetectionModelType;
|
||||||
|
|
||||||
typedef struct DnnDetectContext {
|
typedef struct DnnDetectContext {
|
||||||
@ -75,6 +76,7 @@ static const AVOption dnn_detect_options[] = {
|
|||||||
{ "ssd", "output shape [1, 1, N, 7]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_SSD }, 0, 0, FLAGS, "model_type" },
|
{ "ssd", "output shape [1, 1, N, 7]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_SSD }, 0, 0, FLAGS, "model_type" },
|
||||||
{ "yolo", "output shape [1, N*Cx*Cy*DetectionBox]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV1V2 }, 0, 0, FLAGS, "model_type" },
|
{ "yolo", "output shape [1, N*Cx*Cy*DetectionBox]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV1V2 }, 0, 0, FLAGS, "model_type" },
|
||||||
{ "yolov3", "outputs shape [1, N*D, Cx, Cy]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV3 }, 0, 0, FLAGS, "model_type" },
|
{ "yolov3", "outputs shape [1, N*D, Cx, Cy]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV3 }, 0, 0, FLAGS, "model_type" },
|
||||||
|
{ "yolov4", "outputs shape [1, N*D, Cx, Cy]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV4 }, 0, 0, FLAGS, "model_type" },
|
||||||
{ "cell_w", "cell width", OFFSET2(cell_w), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
|
{ "cell_w", "cell width", OFFSET2(cell_w), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
|
||||||
{ "cell_h", "cell height", OFFSET2(cell_h), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
|
{ "cell_h", "cell height", OFFSET2(cell_h), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
|
||||||
{ "nb_classes", "The number of class", OFFSET2(nb_classes), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
|
{ "nb_classes", "The number of class", OFFSET2(nb_classes), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
|
||||||
@ -84,6 +86,14 @@ static const AVOption dnn_detect_options[] = {
|
|||||||
|
|
||||||
AVFILTER_DEFINE_CLASS(dnn_detect);
|
AVFILTER_DEFINE_CLASS(dnn_detect);
|
||||||
|
|
||||||
|
static inline float sigmoid(float x) {
|
||||||
|
return 1.f / (1.f + exp(-x));
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline float linear(float x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
static int dnn_detect_get_label_id(int nb_classes, int cell_size, float *label_data)
|
static int dnn_detect_get_label_id(int nb_classes, int cell_size, float *label_data)
|
||||||
{
|
{
|
||||||
float max_prob = 0;
|
float max_prob = 0;
|
||||||
@ -147,6 +157,8 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out
|
|||||||
float *output_data = output[output_index].data;
|
float *output_data = output[output_index].data;
|
||||||
float *anchors = ctx->anchors;
|
float *anchors = ctx->anchors;
|
||||||
AVDetectionBBox *bbox;
|
AVDetectionBBox *bbox;
|
||||||
|
float (*post_process_raw_data)(float x);
|
||||||
|
int is_NHWC = 0;
|
||||||
|
|
||||||
if (ctx->model_type == DDMT_YOLOV1V2) {
|
if (ctx->model_type == DDMT_YOLOV1V2) {
|
||||||
cell_w = ctx->cell_w;
|
cell_w = ctx->cell_w;
|
||||||
@ -154,13 +166,30 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out
|
|||||||
scale_w = cell_w;
|
scale_w = cell_w;
|
||||||
scale_h = cell_h;
|
scale_h = cell_h;
|
||||||
} else {
|
} else {
|
||||||
cell_w = output[output_index].width;
|
if (output[output_index].height != output[output_index].width &&
|
||||||
cell_h = output[output_index].height;
|
output[output_index].height == output[output_index].channels) {
|
||||||
|
is_NHWC = 1;
|
||||||
|
cell_w = output[output_index].height;
|
||||||
|
cell_h = output[output_index].channels;
|
||||||
|
} else {
|
||||||
|
cell_w = output[output_index].width;
|
||||||
|
cell_h = output[output_index].height;
|
||||||
|
}
|
||||||
scale_w = ctx->scale_width;
|
scale_w = ctx->scale_width;
|
||||||
scale_h = ctx->scale_height;
|
scale_h = ctx->scale_height;
|
||||||
}
|
}
|
||||||
box_size = nb_classes + 5;
|
box_size = nb_classes + 5;
|
||||||
|
|
||||||
|
switch (ctx->model_type) {
|
||||||
|
case DDMT_YOLOV1V2:
|
||||||
|
case DDMT_YOLOV3:
|
||||||
|
post_process_raw_data = linear;
|
||||||
|
break;
|
||||||
|
case DDMT_YOLOV4:
|
||||||
|
post_process_raw_data = sigmoid;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
if (!cell_h || !cell_w) {
|
if (!cell_h || !cell_w) {
|
||||||
av_log(filter_ctx, AV_LOG_ERROR, "cell_w and cell_h are detected\n");
|
av_log(filter_ctx, AV_LOG_ERROR, "cell_w and cell_h are detected\n");
|
||||||
return AVERROR(EINVAL);
|
return AVERROR(EINVAL);
|
||||||
@ -198,19 +227,36 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out
|
|||||||
float *detection_boxes_data;
|
float *detection_boxes_data;
|
||||||
int label_id;
|
int label_id;
|
||||||
|
|
||||||
detection_boxes_data = output_data + box_id * box_size * cell_w * cell_h;
|
if (is_NHWC) {
|
||||||
conf = detection_boxes_data[cy * cell_w + cx + 4 * cell_w * cell_h];
|
detection_boxes_data = output_data +
|
||||||
|
((cy * cell_w + cx) * detection_boxes + box_id) * box_size;
|
||||||
|
conf = post_process_raw_data(detection_boxes_data[4]);
|
||||||
|
} else {
|
||||||
|
detection_boxes_data = output_data + box_id * box_size * cell_w * cell_h;
|
||||||
|
conf = post_process_raw_data(
|
||||||
|
detection_boxes_data[cy * cell_w + cx + 4 * cell_w * cell_h]);
|
||||||
|
}
|
||||||
if (conf < conf_threshold) {
|
if (conf < conf_threshold) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
x = detection_boxes_data[cy * cell_w + cx];
|
if (is_NHWC) {
|
||||||
y = detection_boxes_data[cy * cell_w + cx + cell_w * cell_h];
|
x = post_process_raw_data(detection_boxes_data[0]);
|
||||||
w = detection_boxes_data[cy * cell_w + cx + 2 * cell_w * cell_h];
|
y = post_process_raw_data(detection_boxes_data[1]);
|
||||||
h = detection_boxes_data[cy * cell_w + cx + 3 * cell_w * cell_h];
|
w = detection_boxes_data[2];
|
||||||
label_id = dnn_detect_get_label_id(ctx->nb_classes, cell_w * cell_h,
|
h = detection_boxes_data[3];
|
||||||
detection_boxes_data + cy * cell_w + cx + 5 * cell_w * cell_h);
|
label_id = dnn_detect_get_label_id(ctx->nb_classes, 1, detection_boxes_data + 5);
|
||||||
conf = conf * detection_boxes_data[cy * cell_w + cx + (label_id + 5) * cell_w * cell_h];
|
conf = conf * post_process_raw_data(detection_boxes_data[label_id + 5]);
|
||||||
|
} else {
|
||||||
|
x = post_process_raw_data(detection_boxes_data[cy * cell_w + cx]);
|
||||||
|
y = post_process_raw_data(detection_boxes_data[cy * cell_w + cx + cell_w * cell_h]);
|
||||||
|
w = detection_boxes_data[cy * cell_w + cx + 2 * cell_w * cell_h];
|
||||||
|
h = detection_boxes_data[cy * cell_w + cx + 3 * cell_w * cell_h];
|
||||||
|
label_id = dnn_detect_get_label_id(ctx->nb_classes, cell_w * cell_h,
|
||||||
|
detection_boxes_data + cy * cell_w + cx + 5 * cell_w * cell_h);
|
||||||
|
conf = conf * post_process_raw_data(
|
||||||
|
detection_boxes_data[cy * cell_w + cx + (label_id + 5) * cell_w * cell_h]);
|
||||||
|
}
|
||||||
|
|
||||||
bbox = av_mallocz(sizeof(*bbox));
|
bbox = av_mallocz(sizeof(*bbox));
|
||||||
if (!bbox)
|
if (!bbox)
|
||||||
@ -410,6 +456,7 @@ static int dnn_detect_post_proc_ov(AVFrame *frame, DNNData *output, int nb_outpu
|
|||||||
if (ret < 0)
|
if (ret < 0)
|
||||||
return ret;
|
return ret;
|
||||||
case DDMT_YOLOV3:
|
case DDMT_YOLOV3:
|
||||||
|
case DDMT_YOLOV4:
|
||||||
ret = dnn_detect_post_proc_yolov3(frame, output, filter_ctx, nb_outputs);
|
ret = dnn_detect_post_proc_yolov3(frame, output, filter_ctx, nb_outputs);
|
||||||
if (ret < 0)
|
if (ret < 0)
|
||||||
return ret;
|
return ret;
|
||||||
|
Loading…
Reference in New Issue
Block a user