You've already forked FFmpeg
mirror of
https://github.com/FFmpeg/FFmpeg.git
synced 2025-08-10 06:10:52 +02:00
libavfi/dnn: enable LibTorch xpu device option support
Add xpu device support to libtorch backend. To enable xpu support you need to add "-Wl,--no-as-needed -lintel-ext-pt-gpu -Wl,--as-needed" to "--extra-libs" when configure ffmpeg. Signed-off-by: Wenbin Chen <wenbin.chen@intel.com>
This commit is contained in:
@@ -250,6 +250,10 @@ static int th_start_inference(void *args)
|
|||||||
av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
|
av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
|
||||||
return DNN_GENERIC_ERROR;
|
return DNN_GENERIC_ERROR;
|
||||||
}
|
}
|
||||||
|
// Transfer tensor to the same device as model
|
||||||
|
c10::Device device = (*th_model->jit_model->parameters().begin()).device();
|
||||||
|
if (infer_request->input_tensor->device() != device)
|
||||||
|
*infer_request->input_tensor = infer_request->input_tensor->to(device);
|
||||||
inputs.push_back(*infer_request->input_tensor);
|
inputs.push_back(*infer_request->input_tensor);
|
||||||
|
|
||||||
*infer_request->output = th_model->jit_model->forward(inputs).toTensor();
|
*infer_request->output = th_model->jit_model->forward(inputs).toTensor();
|
||||||
@@ -285,6 +289,9 @@ static void infer_completion_callback(void *args) {
|
|||||||
switch (th_model->model.func_type) {
|
switch (th_model->model.func_type) {
|
||||||
case DFT_PROCESS_FRAME:
|
case DFT_PROCESS_FRAME:
|
||||||
if (task->do_ioproc) {
|
if (task->do_ioproc) {
|
||||||
|
// Post process can only deal with CPU memory.
|
||||||
|
if (output->device() != torch::kCPU)
|
||||||
|
*output = output->to(torch::kCPU);
|
||||||
outputs.scale = 255;
|
outputs.scale = 255;
|
||||||
outputs.data = output->data_ptr();
|
outputs.data = output->data_ptr();
|
||||||
if (th_model->model.frame_post_proc != NULL) {
|
if (th_model->model.frame_post_proc != NULL) {
|
||||||
@@ -424,7 +431,13 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A
|
|||||||
th_model->ctx = ctx;
|
th_model->ctx = ctx;
|
||||||
|
|
||||||
c10::Device device = c10::Device(device_name);
|
c10::Device device = c10::Device(device_name);
|
||||||
if (!device.is_cpu()) {
|
if (device.is_xpu()) {
|
||||||
|
if (!at::hasXPU()) {
|
||||||
|
av_log(ctx, AV_LOG_ERROR, "No XPU device found\n");
|
||||||
|
goto fail;
|
||||||
|
}
|
||||||
|
at::detail::getXPUHooks().initXPU();
|
||||||
|
} else if (!device.is_cpu()) {
|
||||||
av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", device_name);
|
av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", device_name);
|
||||||
goto fail;
|
goto fail;
|
||||||
}
|
}
|
||||||
@@ -432,6 +445,7 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A
|
|||||||
try {
|
try {
|
||||||
th_model->jit_model = new torch::jit::Module;
|
th_model->jit_model = new torch::jit::Module;
|
||||||
(*th_model->jit_model) = torch::jit::load(ctx->model_filename);
|
(*th_model->jit_model) = torch::jit::load(ctx->model_filename);
|
||||||
|
th_model->jit_model->to(device);
|
||||||
} catch (const c10::Error& e) {
|
} catch (const c10::Error& e) {
|
||||||
av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n");
|
av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n");
|
||||||
goto fail;
|
goto fail;
|
||||||
|
Reference in New Issue
Block a user