mirror of
https://github.com/immich-app/immich.git
synced 2024-11-28 09:33:27 +02:00
feat(ml): conditionally download .armnn models (#6650)
This commit is contained in:
parent
fa0913120d
commit
a84b6f5fb1
@ -14,7 +14,7 @@ import ann.ann
|
|||||||
from app.models.constants import SUPPORTED_PROVIDERS
|
from app.models.constants import SUPPORTED_PROVIDERS
|
||||||
|
|
||||||
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
||||||
from ..schemas import ModelType
|
from ..schemas import ModelRuntime, ModelType
|
||||||
from .ann import AnnSession
|
from .ann import AnnSession
|
||||||
|
|
||||||
|
|
||||||
@ -28,6 +28,7 @@ class InferenceModel(ABC):
|
|||||||
providers: list[str] | None = None,
|
providers: list[str] | None = None,
|
||||||
provider_options: list[dict[str, Any]] | None = None,
|
provider_options: list[dict[str, Any]] | None = None,
|
||||||
sess_options: ort.SessionOptions | None = None,
|
sess_options: ort.SessionOptions | None = None,
|
||||||
|
preferred_runtime: ModelRuntime | None = None,
|
||||||
**model_kwargs: Any,
|
**model_kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.loaded = False
|
self.loaded = False
|
||||||
@ -36,6 +37,7 @@ class InferenceModel(ABC):
|
|||||||
self.providers = providers if providers is not None else self.providers_default
|
self.providers = providers if providers is not None else self.providers_default
|
||||||
self.provider_options = provider_options if provider_options is not None else self.provider_options_default
|
self.provider_options = provider_options if provider_options is not None else self.provider_options_default
|
||||||
self.sess_options = sess_options if sess_options is not None else self.sess_options_default
|
self.sess_options = sess_options if sess_options is not None else self.sess_options_default
|
||||||
|
self.preferred_runtime = preferred_runtime if preferred_runtime is not None else self.preferred_runtime_default
|
||||||
|
|
||||||
def download(self) -> None:
|
def download(self) -> None:
|
||||||
if not self.cached:
|
if not self.cached:
|
||||||
@ -66,11 +68,13 @@ class InferenceModel(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _download(self) -> None:
|
def _download(self) -> None:
|
||||||
|
ignore_patterns = [] if self.preferred_runtime == ModelRuntime.ARMNN else ["*.armnn"]
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
get_hf_model_name(self.model_name),
|
get_hf_model_name(self.model_name),
|
||||||
cache_dir=self.cache_dir,
|
cache_dir=self.cache_dir,
|
||||||
local_dir=self.cache_dir,
|
local_dir=self.cache_dir,
|
||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
|
ignore_patterns=ignore_patterns,
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -100,18 +104,28 @@ class InferenceModel(ABC):
|
|||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
|
def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
|
||||||
armnn_path = model_path.with_suffix(".armnn")
|
if not model_path.is_file():
|
||||||
if settings.ann and ann.ann.is_available and armnn_path.is_file():
|
onnx_path = model_path.with_suffix(".onnx")
|
||||||
session = AnnSession(armnn_path)
|
if not onnx_path.is_file():
|
||||||
elif model_path.is_file():
|
raise ValueError(f"Model path '{model_path}' does not exist")
|
||||||
|
|
||||||
|
log.warning(
|
||||||
|
f"Could not find model path '{model_path}'. " f"Falling back to ONNX model path '{onnx_path}' instead.",
|
||||||
|
)
|
||||||
|
model_path = onnx_path
|
||||||
|
|
||||||
|
match model_path.suffix:
|
||||||
|
case ".armnn":
|
||||||
|
session = AnnSession(model_path)
|
||||||
|
case ".onnx":
|
||||||
session = ort.InferenceSession(
|
session = ort.InferenceSession(
|
||||||
model_path.as_posix(),
|
model_path.as_posix(),
|
||||||
sess_options=self.sess_options,
|
sess_options=self.sess_options,
|
||||||
providers=self.providers,
|
providers=self.providers,
|
||||||
provider_options=self.provider_options,
|
provider_options=self.provider_options,
|
||||||
)
|
)
|
||||||
else:
|
case _:
|
||||||
raise ValueError(f"the file model_path='{model_path}' does not exist")
|
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
||||||
return session
|
return session
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -132,7 +146,7 @@ class InferenceModel(ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def cached(self) -> bool:
|
def cached(self) -> bool:
|
||||||
return self.cache_dir.exists() and any(self.cache_dir.iterdir())
|
return self.cache_dir.is_dir() and any(self.cache_dir.iterdir())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def providers(self) -> list[str]:
|
def providers(self) -> list[str]:
|
||||||
@ -215,6 +229,19 @@ class InferenceModel(ABC):
|
|||||||
|
|
||||||
return sess_options
|
return sess_options
|
||||||
|
|
||||||
|
@property
|
||||||
|
def preferred_runtime(self) -> ModelRuntime:
|
||||||
|
return self._preferred_runtime
|
||||||
|
|
||||||
|
@preferred_runtime.setter
|
||||||
|
def preferred_runtime(self, preferred_runtime: ModelRuntime) -> None:
|
||||||
|
log.debug(f"Setting preferred runtime to {preferred_runtime}")
|
||||||
|
self._preferred_runtime = preferred_runtime
|
||||||
|
|
||||||
|
@property
|
||||||
|
def preferred_runtime_default(self) -> ModelRuntime:
|
||||||
|
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
|
||||||
|
|
||||||
|
|
||||||
# HF deep copies configs, so we need to make session options picklable
|
# HF deep copies configs, so we need to make session options picklable
|
||||||
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
|
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
|
||||||
|
@ -81,11 +81,11 @@ class BaseCLIPEncoder(InferenceModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def textual_path(self) -> Path:
|
def textual_path(self) -> Path:
|
||||||
return self.textual_dir / "model.onnx"
|
return self.textual_dir / f"model.{self.preferred_runtime}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def visual_path(self) -> Path:
|
def visual_path(self) -> Path:
|
||||||
return self.visual_dir / "model.onnx"
|
return self.visual_dir / f"model.{self.preferred_runtime}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tokenizer_file_path(self) -> Path:
|
def tokenizer_file_path(self) -> Path:
|
||||||
|
@ -77,11 +77,11 @@ class FaceRecognizer(InferenceModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def det_file(self) -> Path:
|
def det_file(self) -> Path:
|
||||||
return self.cache_dir / "detection" / "model.onnx"
|
return self.cache_dir / "detection" / f"model.{self.preferred_runtime}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rec_file(self) -> Path:
|
def rec_file(self) -> Path:
|
||||||
return self.cache_dir / "recognition" / "model.onnx"
|
return self.cache_dir / "recognition" / f"model.{self.preferred_runtime}"
|
||||||
|
|
||||||
def configure(self, **model_kwargs: Any) -> None:
|
def configure(self, **model_kwargs: Any) -> None:
|
||||||
self.det_model.det_thresh = model_kwargs.pop("minScore", self.det_model.det_thresh)
|
self.det_model.det_thresh = model_kwargs.pop("minScore", self.det_model.det_thresh)
|
||||||
|
@ -26,6 +26,11 @@ class ModelType(str, Enum):
|
|||||||
FACIAL_RECOGNITION = "facial-recognition"
|
FACIAL_RECOGNITION = "facial-recognition"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRuntime(str, Enum):
|
||||||
|
ONNX = "onnx"
|
||||||
|
ARMNN = "armnn"
|
||||||
|
|
||||||
|
|
||||||
class HasProfiling(Protocol):
|
class HasProfiling(Protocol):
|
||||||
profiling: dict[str, float]
|
profiling: dict[str, float]
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ from .models.base import InferenceModel, PicklableSessionOptions
|
|||||||
from .models.cache import ModelCache
|
from .models.cache import ModelCache
|
||||||
from .models.clip import OpenCLIPEncoder
|
from .models.clip import OpenCLIPEncoder
|
||||||
from .models.facial_recognition import FaceRecognizer
|
from .models.facial_recognition import FaceRecognizer
|
||||||
from .schemas import ModelType
|
from .schemas import ModelRuntime, ModelType
|
||||||
|
|
||||||
|
|
||||||
class TestBase:
|
class TestBase:
|
||||||
@ -127,6 +127,30 @@ class TestBase:
|
|||||||
|
|
||||||
assert encoder.cache_dir == cache_dir
|
assert encoder.cache_dir == cache_dir
|
||||||
|
|
||||||
|
def test_sets_default_preferred_runtime(self, mocker: MockerFixture) -> None:
|
||||||
|
mocker.patch.object(settings, "ann", True)
|
||||||
|
mocker.patch("ann.ann.is_available", False)
|
||||||
|
|
||||||
|
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
|
assert encoder.preferred_runtime == ModelRuntime.ONNX
|
||||||
|
|
||||||
|
def test_sets_default_preferred_runtime_to_armnn_if_available(self, mocker: MockerFixture) -> None:
|
||||||
|
mocker.patch.object(settings, "ann", True)
|
||||||
|
mocker.patch("ann.ann.is_available", True)
|
||||||
|
|
||||||
|
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
|
assert encoder.preferred_runtime == ModelRuntime.ARMNN
|
||||||
|
|
||||||
|
def test_sets_preferred_runtime_kwarg(self, mocker: MockerFixture) -> None:
|
||||||
|
mocker.patch.object(settings, "ann", False)
|
||||||
|
mocker.patch("ann.ann.is_available", False)
|
||||||
|
|
||||||
|
encoder = OpenCLIPEncoder("ViT-B-32__openai", preferred_runtime=ModelRuntime.ARMNN)
|
||||||
|
|
||||||
|
assert encoder.preferred_runtime == ModelRuntime.ARMNN
|
||||||
|
|
||||||
def test_casts_cache_dir_string_to_path(self) -> None:
|
def test_casts_cache_dir_string_to_path(self) -> None:
|
||||||
cache_dir = "/test_cache"
|
cache_dir = "/test_cache"
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir=cache_dir)
|
encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir=cache_dir)
|
||||||
@ -195,46 +219,79 @@ class TestBase:
|
|||||||
warning.assert_called_once()
|
warning.assert_called_once()
|
||||||
|
|
||||||
def test_make_session_return_ann_if_available(self, mocker: MockerFixture) -> None:
|
def test_make_session_return_ann_if_available(self, mocker: MockerFixture) -> None:
|
||||||
mock_cache_dir = mocker.Mock()
|
mock_model_path = mocker.Mock()
|
||||||
mock_cache_dir.is_file.return_value = True
|
mock_model_path.is_file.return_value = True
|
||||||
mock_cache_dir.with_suffix.return_value = mock_cache_dir
|
mock_model_path.suffix = ".armnn"
|
||||||
mocker.patch.object(settings, "ann", True)
|
mock_model_path.with_suffix.return_value = mock_model_path
|
||||||
mocker.patch("ann.ann.is_available", True)
|
|
||||||
mock_session = mocker.patch("app.models.base.AnnSession")
|
mock_session = mocker.patch("app.models.base.AnnSession")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||||
encoder._make_session(mock_cache_dir)
|
encoder._make_session(mock_model_path)
|
||||||
|
|
||||||
mock_session.assert_called_once()
|
mock_session.assert_called_once()
|
||||||
|
|
||||||
def test_make_session_return_ort_if_available_and_ann_is_not(self, mocker: MockerFixture) -> None:
|
def test_make_session_return_ort_if_available_and_ann_is_not(self, mocker: MockerFixture) -> None:
|
||||||
mock_cache_dir = mocker.Mock()
|
mock_armnn_path = mocker.Mock()
|
||||||
mock_cache_dir.is_file.return_value = True
|
mock_armnn_path.is_file.return_value = False
|
||||||
mock_cache_dir.with_suffix.return_value = mock_cache_dir
|
mock_armnn_path.suffix = ".armnn"
|
||||||
mocker.patch.object(settings, "ann", False)
|
|
||||||
mocker.patch("ann.ann.is_available", False)
|
mock_onnx_path = mocker.Mock()
|
||||||
mock_session = mocker.patch("app.models.base.ort.InferenceSession")
|
mock_onnx_path.is_file.return_value = True
|
||||||
|
mock_onnx_path.suffix = ".onnx"
|
||||||
|
mock_armnn_path.with_suffix.return_value = mock_onnx_path
|
||||||
|
|
||||||
|
mock_ann = mocker.patch("app.models.base.AnnSession")
|
||||||
|
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||||
encoder._make_session(mock_cache_dir)
|
encoder._make_session(mock_armnn_path)
|
||||||
|
|
||||||
mock_session.assert_called_once()
|
mock_ort.assert_called_once()
|
||||||
|
mock_ann.assert_not_called()
|
||||||
|
|
||||||
def test_make_session_raises_exception_if_path_does_not_exist(self, mocker: MockerFixture) -> None:
|
def test_make_session_raises_exception_if_path_does_not_exist(self, mocker: MockerFixture) -> None:
|
||||||
mock_cache_dir = mocker.Mock()
|
mock_model_path = mocker.Mock()
|
||||||
mock_cache_dir.is_file.return_value = False
|
mock_model_path.is_file.return_value = False
|
||||||
mock_cache_dir.with_suffix.return_value = mock_cache_dir
|
mock_model_path.suffix = ".onnx"
|
||||||
mocker.patch("ann.ann.is_available", False)
|
mock_model_path.with_suffix.return_value = mock_model_path
|
||||||
mock_ann = mocker.patch("app.models.base.ort.InferenceSession")
|
mock_ann = mocker.patch("app.models.base.AnnSession")
|
||||||
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
|
mock_ort = mocker.patch("app.models.base.ort.InferenceSession")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
encoder._make_session(mock_cache_dir)
|
encoder._make_session(mock_model_path)
|
||||||
|
|
||||||
mock_ann.assert_not_called()
|
mock_ann.assert_not_called()
|
||||||
mock_ort.assert_not_called()
|
mock_ort.assert_not_called()
|
||||||
|
|
||||||
|
def test_download(self, mocker: MockerFixture) -> None:
|
||||||
|
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
|
||||||
|
|
||||||
|
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||||
|
encoder.download()
|
||||||
|
|
||||||
|
mock_snapshot_download.assert_called_once_with(
|
||||||
|
"immich-app/ViT-B-32__openai",
|
||||||
|
cache_dir=encoder.cache_dir,
|
||||||
|
local_dir=encoder.cache_dir,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
ignore_patterns=["*.armnn"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_download_downloads_armnn_if_preferred_runtime(self, mocker: MockerFixture) -> None:
|
||||||
|
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
|
||||||
|
|
||||||
|
encoder = OpenCLIPEncoder("ViT-B-32__openai", preferred_runtime=ModelRuntime.ARMNN)
|
||||||
|
encoder.download()
|
||||||
|
|
||||||
|
mock_snapshot_download.assert_called_once_with(
|
||||||
|
"immich-app/ViT-B-32__openai",
|
||||||
|
cache_dir=encoder.cache_dir,
|
||||||
|
local_dir=encoder.cache_dir,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
ignore_patterns=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestCLIP:
|
class TestCLIP:
|
||||||
embedding = np.random.rand(512).astype(np.float32)
|
embedding = np.random.rand(512).astype(np.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user