diff --git a/machine-learning/ann/__init__.py b/machine-learning/ann/__init__.py index 0793d1011b..e69de29bb2 100644 --- a/machine-learning/ann/__init__.py +++ b/machine-learning/ann/__init__.py @@ -1 +0,0 @@ -from .ann import Ann, is_available diff --git a/machine-learning/ann/ann.py b/machine-learning/ann/ann.py index 94f665bfc7..148d5ba101 100644 --- a/machine-learning/ann/ann.py +++ b/machine-learning/ann/ann.py @@ -32,8 +32,7 @@ T = TypeVar("T", covariant=True) class Newable(Protocol[T]): - def new(self) -> None: - ... + def new(self) -> None: ... class _Singleton(type, Newable[T]): diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index 6097c7c987..6909a935c3 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -1,18 +1,16 @@ from __future__ import annotations +import os from abc import ABC, abstractmethod from pathlib import Path from shutil import rmtree from typing import Any -import onnx import onnxruntime as ort from huggingface_hub import snapshot_download -from onnx.shape_inference import infer_shapes -from onnx.tools.update_model_dims import update_inputs_outputs_dims import ann.ann -from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS +from app.models.constants import SUPPORTED_PROVIDERS from ..config import get_cache_dir, get_hf_model_name, log, settings from ..schemas import ModelRuntime, ModelType @@ -113,63 +111,25 @@ class InferenceModel(ABC): ) model_path = onnx_path - if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers): - static_path = model_path.parent / "static_1" / "model.onnx" - static_path.parent.mkdir(parents=True, exist_ok=True) - if not static_path.is_file(): - self._convert_to_static(model_path, static_path) - model_path = static_path - match model_path.suffix: case ".armnn": session = AnnSession(model_path) case ".onnx": - session = ort.InferenceSession( - model_path.as_posix(), - sess_options=self.sess_options, - providers=self.providers, - provider_options=self.provider_options, - ) + cwd = os.getcwd() + try: + os.chdir(model_path.parent) + session = ort.InferenceSession( + model_path.as_posix(), + sess_options=self.sess_options, + providers=self.providers, + provider_options=self.provider_options, + ) + finally: + os.chdir(cwd) case _: raise ValueError(f"Unsupported model file type: {model_path.suffix}") return session - def _convert_to_static(self, source_path: Path, target_path: Path) -> None: - inferred = infer_shapes(onnx.load(source_path)) - inputs = self._get_static_dims(inferred.graph.input) - outputs = self._get_static_dims(inferred.graph.output) - - # check_model gets called in update_inputs_outputs_dims and doesn't work for large models - check_model = onnx.checker.check_model - try: - - def check_model_stub(*args: Any, **kwargs: Any) -> None: - pass - - onnx.checker.check_model = check_model_stub - updated_model = update_inputs_outputs_dims(inferred, inputs, outputs) - finally: - onnx.checker.check_model = check_model - - onnx.save( - updated_model, - target_path, - save_as_external_data=True, - all_tensors_to_one_file=False, - size_threshold=1048576, - ) - - def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]: - return { - field.name: [ - d.dim_value if d.HasField("dim_value") else dim_size - for shape in field.type.ListFields() - if (dim := shape[1].shape.dim) - for d in dim - ] - for field in graph_io - } - @property def model_type(self) -> ModelType: return self._model_type diff --git a/machine-learning/app/models/constants.py b/machine-learning/app/models/constants.py index 18965d2b1d..b112e9279d 100644 --- a/machine-learning/app/models/constants.py +++ b/machine-learning/app/models/constants.py @@ -54,9 +54,6 @@ _INSIGHTFACE_MODELS = { SUPPORTED_PROVIDERS = ["CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"] -STATIC_INPUT_PROVIDERS = ["OpenVINOExecutionProvider"] - - def is_openclip(model_name: str) -> bool: return clean_name(model_name) in _OPENCLIP_MODELS diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index cf941c1bbf..e25099e67e 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -1,4 +1,5 @@ import json +import os from io import BytesIO from pathlib import Path from random import randint @@ -237,12 +238,12 @@ class TestBase: mock_model_path.is_file.return_value = True mock_model_path.suffix = ".armnn" mock_model_path.with_suffix.return_value = mock_model_path - mock_session = mocker.patch("app.models.base.AnnSession") + mock_ann = mocker.patch("app.models.base.AnnSession") encoder = OpenCLIPEncoder("ViT-B-32__openai") encoder._make_session(mock_model_path) - mock_session.assert_called_once() + mock_ann.assert_called_once() def test_make_session_return_ort_if_available_and_ann_is_not(self, mocker: MockerFixture) -> None: mock_armnn_path = mocker.Mock() @@ -256,6 +257,7 @@ class TestBase: mock_ann = mocker.patch("app.models.base.AnnSession") mock_ort = mocker.patch("app.models.base.ort.InferenceSession") + mocker.patch("app.models.base.os.chdir") encoder = OpenCLIPEncoder("ViT-B-32__openai") encoder._make_session(mock_armnn_path) @@ -278,6 +280,26 @@ class TestBase: mock_ann.assert_not_called() mock_ort.assert_not_called() + def test_make_session_changes_cwd(self, mocker: MockerFixture) -> None: + mock_model_path = mocker.Mock() + mock_model_path.is_file.return_value = True + mock_model_path.suffix = ".onnx" + mock_model_path.parent = "model_parent" + mock_model_path.with_suffix.return_value = mock_model_path + mock_ort = mocker.patch("app.models.base.ort.InferenceSession") + mock_chdir = mocker.patch("app.models.base.os.chdir") + + encoder = OpenCLIPEncoder("ViT-B-32__openai") + encoder._make_session(mock_model_path) + + mock_chdir.assert_has_calls( + [ + mock.call(mock_model_path.parent), + mock.call(os.getcwd()), + ] + ) + mock_ort.assert_called_once() + def test_download(self, mocker: MockerFixture) -> None: mock_snapshot_download = mocker.patch("app.models.base.snapshot_download") diff --git a/machine-learning/pyproject.toml b/machine-learning/pyproject.toml index 750ca65f26..fec5c72130 100644 --- a/machine-learning/pyproject.toml +++ b/machine-learning/pyproject.toml @@ -82,10 +82,10 @@ warn_untyped_fields = true [tool.ruff] line-length = 120 target-version = "py311" -select = ["E", "F", "I"] -[tool.ruff.per-file-ignores] -"test_main.py" = ["F403"] +[tool.ruff.lint] +select = ["E", "F", "I"] +per-file-ignores = { "test_main.py" = ["F403"] } [tool.black] line-length = 120