You've already forked FFmpeg
mirror of
https://github.com/FFmpeg/FFmpeg.git
synced 2025-08-10 06:10:52 +02:00
libavfi/dnn: enable LibTorch xpu device option support
Add xpu device support to libtorch backend. To enable xpu support you need to add "-Wl,--no-as-needed -lintel-ext-pt-gpu -Wl,--as-needed" to "--extra-libs" when configure ffmpeg. Signed-off-by: Wenbin Chen <wenbin.chen@intel.com>
This commit is contained in:
@@ -250,6 +250,10 @@ static int th_start_inference(void *args)
|
||||
av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n");
|
||||
return DNN_GENERIC_ERROR;
|
||||
}
|
||||
// Transfer tensor to the same device as model
|
||||
c10::Device device = (*th_model->jit_model->parameters().begin()).device();
|
||||
if (infer_request->input_tensor->device() != device)
|
||||
*infer_request->input_tensor = infer_request->input_tensor->to(device);
|
||||
inputs.push_back(*infer_request->input_tensor);
|
||||
|
||||
*infer_request->output = th_model->jit_model->forward(inputs).toTensor();
|
||||
@@ -285,6 +289,9 @@ static void infer_completion_callback(void *args) {
|
||||
switch (th_model->model.func_type) {
|
||||
case DFT_PROCESS_FRAME:
|
||||
if (task->do_ioproc) {
|
||||
// Post process can only deal with CPU memory.
|
||||
if (output->device() != torch::kCPU)
|
||||
*output = output->to(torch::kCPU);
|
||||
outputs.scale = 255;
|
||||
outputs.data = output->data_ptr();
|
||||
if (th_model->model.frame_post_proc != NULL) {
|
||||
@@ -424,7 +431,13 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A
|
||||
th_model->ctx = ctx;
|
||||
|
||||
c10::Device device = c10::Device(device_name);
|
||||
if (!device.is_cpu()) {
|
||||
if (device.is_xpu()) {
|
||||
if (!at::hasXPU()) {
|
||||
av_log(ctx, AV_LOG_ERROR, "No XPU device found\n");
|
||||
goto fail;
|
||||
}
|
||||
at::detail::getXPUHooks().initXPU();
|
||||
} else if (!device.is_cpu()) {
|
||||
av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", device_name);
|
||||
goto fail;
|
||||
}
|
||||
@@ -432,6 +445,7 @@ static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, A
|
||||
try {
|
||||
th_model->jit_model = new torch::jit::Module;
|
||||
(*th_model->jit_model) = torch::jit::load(ctx->model_filename);
|
||||
th_model->jit_model->to(device);
|
||||
} catch (const c10::Error& e) {
|
||||
av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n");
|
||||
goto fail;
|
||||
|
Reference in New Issue
Block a user