1
0
mirror of https://github.com/immich-app/immich.git synced 2025-01-02 12:48:35 +02:00

Fix Smart Search when using OpenVINO (#7389)

* Fix external_path loading in OpenVINO EP

* Fix ruff lint

* Wrap block in try finally

* remove static input shape code

* add unit test

* remove unused imports

* remove repeat line

* linting

* formatting

---------

Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
This commit is contained in:
Sourav Agrawal 2024-02-25 04:52:27 +05:30 committed by GitHub
parent 912d723281
commit 2a75f884d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 41 additions and 64 deletions

View File

@ -1 +0,0 @@
from .ann import Ann, is_available

View File

@ -32,8 +32,7 @@ T = TypeVar("T", covariant=True)
class Newable(Protocol[T]): class Newable(Protocol[T]):
def new(self) -> None: def new(self) -> None: ...
...
class _Singleton(type, Newable[T]): class _Singleton(type, Newable[T]):

View File

@ -1,18 +1,16 @@
from __future__ import annotations from __future__ import annotations
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from typing import Any from typing import Any
import onnx
import onnxruntime as ort import onnxruntime as ort
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from onnx.shape_inference import infer_shapes
from onnx.tools.update_model_dims import update_inputs_outputs_dims
import ann.ann import ann.ann
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS from app.models.constants import SUPPORTED_PROVIDERS
from ..config import get_cache_dir, get_hf_model_name, log, settings from ..config import get_cache_dir, get_hf_model_name, log, settings
from ..schemas import ModelRuntime, ModelType from ..schemas import ModelRuntime, ModelType
@ -113,63 +111,25 @@ class InferenceModel(ABC):
) )
model_path = onnx_path model_path = onnx_path
if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers):
static_path = model_path.parent / "static_1" / "model.onnx"
static_path.parent.mkdir(parents=True, exist_ok=True)
if not static_path.is_file():
self._convert_to_static(model_path, static_path)
model_path = static_path
match model_path.suffix: match model_path.suffix:
case ".armnn": case ".armnn":
session = AnnSession(model_path) session = AnnSession(model_path)
case ".onnx": case ".onnx":
cwd = os.getcwd()
try:
os.chdir(model_path.parent)
session = ort.InferenceSession( session = ort.InferenceSession(
model_path.as_posix(), model_path.as_posix(),
sess_options=self.sess_options, sess_options=self.sess_options,
providers=self.providers, providers=self.providers,
provider_options=self.provider_options, provider_options=self.provider_options,
) )
finally:
os.chdir(cwd)
case _: case _:
raise ValueError(f"Unsupported model file type: {model_path.suffix}") raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session return session
def _convert_to_static(self, source_path: Path, target_path: Path) -> None:
inferred = infer_shapes(onnx.load(source_path))
inputs = self._get_static_dims(inferred.graph.input)
outputs = self._get_static_dims(inferred.graph.output)
# check_model gets called in update_inputs_outputs_dims and doesn't work for large models
check_model = onnx.checker.check_model
try:
def check_model_stub(*args: Any, **kwargs: Any) -> None:
pass
onnx.checker.check_model = check_model_stub
updated_model = update_inputs_outputs_dims(inferred, inputs, outputs)
finally:
onnx.checker.check_model = check_model
onnx.save(
updated_model,
target_path,
save_as_external_data=True,
all_tensors_to_one_file=False,
size_threshold=1048576,
)
def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]:
return {
field.name: [
d.dim_value if d.HasField("dim_value") else dim_size
for shape in field.type.ListFields()
if (dim := shape[1].shape.dim)
for d in dim
]
for field in graph_io
}
@property @property
def model_type(self) -> ModelType: def model_type(self) -> ModelType:
return self._model_type return self._model_type

View File

@ -54,9 +54,6 @@ _INSIGHTFACE_MODELS = {
SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"] SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
STATIC_INPUT_PROVIDERS = ["OpenVINOExecutionProvider"]
def is_openclip(model_name: str) -> bool: def is_openclip(model_name: str) -> bool:
return clean_name(model_name) in _OPENCLIP_MODELS return clean_name(model_name) in _OPENCLIP_MODELS

View File

@ -1,4 +1,5 @@
import json import json
import os
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from random import randint from random import randint
@ -237,12 +238,12 @@ class TestBase:
mock_model_path.is_file.return_value = True mock_model_path.is_file.return_value = True
mock_model_path.suffix = ".armnn" mock_model_path.suffix = ".armnn"
mock_model_path.with_suffix.return_value = mock_model_path mock_model_path.with_suffix.return_value = mock_model_path
mock_session = mocker.patch("app.models.base.AnnSession") mock_ann = mocker.patch("app.models.base.AnnSession")
encoder = OpenCLIPEncoder("ViT-B-32__openai") encoder = OpenCLIPEncoder("ViT-B-32__openai")
encoder._make_session(mock_model_path) encoder._make_session(mock_model_path)
mock_session.assert_called_once() mock_ann.assert_called_once()
def test_make_session_return_ort_if_available_and_ann_is_not(self, mocker: MockerFixture) -> None: def test_make_session_return_ort_if_available_and_ann_is_not(self, mocker: MockerFixture) -> None:
mock_armnn_path = mocker.Mock() mock_armnn_path = mocker.Mock()
@ -256,6 +257,7 @@ class TestBase:
mock_ann = mocker.patch("app.models.base.AnnSession") mock_ann = mocker.patch("app.models.base.AnnSession")
mock_ort = mocker.patch("app.models.base.ort.InferenceSession") mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
mocker.patch("app.models.base.os.chdir")
encoder = OpenCLIPEncoder("ViT-B-32__openai") encoder = OpenCLIPEncoder("ViT-B-32__openai")
encoder._make_session(mock_armnn_path) encoder._make_session(mock_armnn_path)
@ -278,6 +280,26 @@ class TestBase:
mock_ann.assert_not_called() mock_ann.assert_not_called()
mock_ort.assert_not_called() mock_ort.assert_not_called()
def test_make_session_changes_cwd(self, mocker: MockerFixture) -> None:
mock_model_path = mocker.Mock()
mock_model_path.is_file.return_value = True
mock_model_path.suffix = ".onnx"
mock_model_path.parent = "model_parent"
mock_model_path.with_suffix.return_value = mock_model_path
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
mock_chdir = mocker.patch("app.models.base.os.chdir")
encoder = OpenCLIPEncoder("ViT-B-32__openai")
encoder._make_session(mock_model_path)
mock_chdir.assert_has_calls(
[
mock.call(mock_model_path.parent),
mock.call(os.getcwd()),
]
)
mock_ort.assert_called_once()
def test_download(self, mocker: MockerFixture) -> None: def test_download(self, mocker: MockerFixture) -> None:
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download") mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")

View File

@ -82,10 +82,10 @@ warn_untyped_fields = true
[tool.ruff] [tool.ruff]
line-length = 120 line-length = 120
target-version = "py311" target-version = "py311"
select = ["E", "F", "I"]
[tool.ruff.per-file-ignores] [tool.ruff.lint]
"test_main.py" = ["F403"] select = ["E", "F", "I"]
per-file-ignores = { "test_main.py" = ["F403"] }
[tool.black] [tool.black]
line-length = 120 line-length = 120