mirror of
https://github.com/immich-app/immich.git
synced 2025-01-02 12:48:35 +02:00
Fix Smart Search when using OpenVINO (#7389)
* Fix external_path loading in OpenVINO EP * Fix ruff lint * Wrap block in try finally * remove static input shape code * add unit test * remove unused imports * remove repeat line * linting * formatting --------- Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
This commit is contained in:
parent
912d723281
commit
2a75f884d9
@ -1 +0,0 @@
|
|||||||
from .ann import Ann, is_available
|
|
@ -32,8 +32,7 @@ T = TypeVar("T", covariant=True)
|
|||||||
|
|
||||||
|
|
||||||
class Newable(Protocol[T]):
|
class Newable(Protocol[T]):
|
||||||
def new(self) -> None:
|
def new(self) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class _Singleton(type, Newable[T]):
|
class _Singleton(type, Newable[T]):
|
||||||
|
@ -1,18 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import onnx
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from onnx.shape_inference import infer_shapes
|
|
||||||
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
|
||||||
|
|
||||||
import ann.ann
|
import ann.ann
|
||||||
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
|
from app.models.constants import SUPPORTED_PROVIDERS
|
||||||
|
|
||||||
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
||||||
from ..schemas import ModelRuntime, ModelType
|
from ..schemas import ModelRuntime, ModelType
|
||||||
@ -113,63 +111,25 @@ class InferenceModel(ABC):
|
|||||||
)
|
)
|
||||||
model_path = onnx_path
|
model_path = onnx_path
|
||||||
|
|
||||||
if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers):
|
|
||||||
static_path = model_path.parent / "static_1" / "model.onnx"
|
|
||||||
static_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
if not static_path.is_file():
|
|
||||||
self._convert_to_static(model_path, static_path)
|
|
||||||
model_path = static_path
|
|
||||||
|
|
||||||
match model_path.suffix:
|
match model_path.suffix:
|
||||||
case ".armnn":
|
case ".armnn":
|
||||||
session = AnnSession(model_path)
|
session = AnnSession(model_path)
|
||||||
case ".onnx":
|
case ".onnx":
|
||||||
|
cwd = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(model_path.parent)
|
||||||
session = ort.InferenceSession(
|
session = ort.InferenceSession(
|
||||||
model_path.as_posix(),
|
model_path.as_posix(),
|
||||||
sess_options=self.sess_options,
|
sess_options=self.sess_options,
|
||||||
providers=self.providers,
|
providers=self.providers,
|
||||||
provider_options=self.provider_options,
|
provider_options=self.provider_options,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
os.chdir(cwd)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def _convert_to_static(self, source_path: Path, target_path: Path) -> None:
|
|
||||||
inferred = infer_shapes(onnx.load(source_path))
|
|
||||||
inputs = self._get_static_dims(inferred.graph.input)
|
|
||||||
outputs = self._get_static_dims(inferred.graph.output)
|
|
||||||
|
|
||||||
# check_model gets called in update_inputs_outputs_dims and doesn't work for large models
|
|
||||||
check_model = onnx.checker.check_model
|
|
||||||
try:
|
|
||||||
|
|
||||||
def check_model_stub(*args: Any, **kwargs: Any) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
onnx.checker.check_model = check_model_stub
|
|
||||||
updated_model = update_inputs_outputs_dims(inferred, inputs, outputs)
|
|
||||||
finally:
|
|
||||||
onnx.checker.check_model = check_model
|
|
||||||
|
|
||||||
onnx.save(
|
|
||||||
updated_model,
|
|
||||||
target_path,
|
|
||||||
save_as_external_data=True,
|
|
||||||
all_tensors_to_one_file=False,
|
|
||||||
size_threshold=1048576,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]:
|
|
||||||
return {
|
|
||||||
field.name: [
|
|
||||||
d.dim_value if d.HasField("dim_value") else dim_size
|
|
||||||
for shape in field.type.ListFields()
|
|
||||||
if (dim := shape[1].shape.dim)
|
|
||||||
for d in dim
|
|
||||||
]
|
|
||||||
for field in graph_io
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_type(self) -> ModelType:
|
def model_type(self) -> ModelType:
|
||||||
return self._model_type
|
return self._model_type
|
||||||
|
@ -54,9 +54,6 @@ _INSIGHTFACE_MODELS = {
|
|||||||
SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
|
SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"]
|
||||||
|
|
||||||
|
|
||||||
STATIC_INPUT_PROVIDERS = ["OpenVINOExecutionProvider"]
|
|
||||||
|
|
||||||
|
|
||||||
def is_openclip(model_name: str) -> bool:
|
def is_openclip(model_name: str) -> bool:
|
||||||
return clean_name(model_name) in _OPENCLIP_MODELS
|
return clean_name(model_name) in _OPENCLIP_MODELS
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from random import randint
|
from random import randint
|
||||||
@ -237,12 +238,12 @@ class TestBase:
|
|||||||
mock_model_path.is_file.return_value = True
|
mock_model_path.is_file.return_value = True
|
||||||
mock_model_path.suffix = ".armnn"
|
mock_model_path.suffix = ".armnn"
|
||||||
mock_model_path.with_suffix.return_value = mock_model_path
|
mock_model_path.with_suffix.return_value = mock_model_path
|
||||||
mock_session = mocker.patch("app.models.base.AnnSession")
|
mock_ann = mocker.patch("app.models.base.AnnSession")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||||
encoder._make_session(mock_model_path)
|
encoder._make_session(mock_model_path)
|
||||||
|
|
||||||
mock_session.assert_called_once()
|
mock_ann.assert_called_once()
|
||||||
|
|
||||||
def test_make_session_return_ort_if_available_and_ann_is_not(self, mocker: MockerFixture) -> None:
|
def test_make_session_return_ort_if_available_and_ann_is_not(self, mocker: MockerFixture) -> None:
|
||||||
mock_armnn_path = mocker.Mock()
|
mock_armnn_path = mocker.Mock()
|
||||||
@ -256,6 +257,7 @@ class TestBase:
|
|||||||
|
|
||||||
mock_ann = mocker.patch("app.models.base.AnnSession")
|
mock_ann = mocker.patch("app.models.base.AnnSession")
|
||||||
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
|
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
|
||||||
|
mocker.patch("app.models.base.os.chdir")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||||
encoder._make_session(mock_armnn_path)
|
encoder._make_session(mock_armnn_path)
|
||||||
@ -278,6 +280,26 @@ class TestBase:
|
|||||||
mock_ann.assert_not_called()
|
mock_ann.assert_not_called()
|
||||||
mock_ort.assert_not_called()
|
mock_ort.assert_not_called()
|
||||||
|
|
||||||
|
def test_make_session_changes_cwd(self, mocker: MockerFixture) -> None:
|
||||||
|
mock_model_path = mocker.Mock()
|
||||||
|
mock_model_path.is_file.return_value = True
|
||||||
|
mock_model_path.suffix = ".onnx"
|
||||||
|
mock_model_path.parent = "model_parent"
|
||||||
|
mock_model_path.with_suffix.return_value = mock_model_path
|
||||||
|
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
|
||||||
|
mock_chdir = mocker.patch("app.models.base.os.chdir")
|
||||||
|
|
||||||
|
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||||
|
encoder._make_session(mock_model_path)
|
||||||
|
|
||||||
|
mock_chdir.assert_has_calls(
|
||||||
|
[
|
||||||
|
mock.call(mock_model_path.parent),
|
||||||
|
mock.call(os.getcwd()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
mock_ort.assert_called_once()
|
||||||
|
|
||||||
def test_download(self, mocker: MockerFixture) -> None:
|
def test_download(self, mocker: MockerFixture) -> None:
|
||||||
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
|
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
|
||||||
|
|
||||||
|
@ -82,10 +82,10 @@ warn_untyped_fields = true
|
|||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
target-version = "py311"
|
target-version = "py311"
|
||||||
select = ["E", "F", "I"]
|
|
||||||
|
|
||||||
[tool.ruff.per-file-ignores]
|
[tool.ruff.lint]
|
||||||
"test_main.py" = ["F403"]
|
select = ["E", "F", "I"]
|
||||||
|
per-file-ignores = { "test_main.py" = ["F403"] }
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
Loading…
Reference in New Issue
Block a user