1
0
mirror of https://github.com/immich-app/immich.git synced 2024-12-22 01:47:08 +02:00

feat(ml): improve test coverage (#7041)

* update e2e

* tokenizer tests

* more tests, remove unnecessary code

* fix e2e setting

* add tests for loading model

* update workflow

* fixed test
This commit is contained in:
Mert 2024-02-11 17:58:56 -05:00 committed by GitHub
parent 6e853e2a9d
commit 0c4df216d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 501 additions and 1636 deletions

View File

@ -247,7 +247,7 @@ jobs:
poetry run mypy --install-types --non-interactive --strict app/ poetry run mypy --install-types --non-interactive --strict app/
- name: Run tests and coverage - name: Run tests and coverage
run: | run: |
poetry run pytest --cov app poetry run pytest app --cov=app --cov-report term-missing
generated-api-up-to-date: generated-api-up-to-date:
name: OpenAPI Clients name: OpenAPI Clients

View File

@ -119,16 +119,12 @@ async def load(model: InferenceModel) -> InferenceModel:
if model.loaded: if model.loaded:
return model return model
def _load() -> None: def _load(model: InferenceModel) -> None:
with lock: with lock:
model.load() model.load()
loop = asyncio.get_running_loop()
try: try:
if thread_pool is None: await run(_load, model)
model.load()
else:
await loop.run_in_executor(thread_pool, _load)
return model return model
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile): except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warning( log.warning(
@ -138,10 +134,7 @@ async def load(model: InferenceModel) -> InferenceModel:
) )
) )
model.clear_cache() model.clear_cache()
if thread_pool is None: await run(_load, model)
model.load()
else:
await loop.run_in_executor(thread_pool, _load)
return model return model

View File

@ -21,4 +21,4 @@ def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any)
case _: case _:
raise ValueError(f"Unknown model type {model_type}") raise ValueError(f"Unknown model type {model_type}")
raise ValueError(f"Unknown ${model_type} model {model_name}") raise ValueError(f"Unknown {model_type} model {model_name}")

View File

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import pickle
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
@ -11,7 +10,6 @@ import onnxruntime as ort
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
from onnx.tools.update_model_dims import update_inputs_outputs_dims from onnx.tools.update_model_dims import update_inputs_outputs_dims
from typing_extensions import Buffer
import ann.ann import ann.ann
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
@ -200,7 +198,7 @@ class InferenceModel(ABC):
@providers.setter @providers.setter
def providers(self, providers: list[str]) -> None: def providers(self, providers: list[str]) -> None:
log.debug( log.info(
(f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"), (f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"),
) )
self._providers = providers self._providers = providers
@ -217,7 +215,7 @@ class InferenceModel(ABC):
@provider_options.setter @provider_options.setter
def provider_options(self, provider_options: list[dict[str, Any]]) -> None: def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
log.info(f"Setting execution provider options to {provider_options}") log.debug(f"Setting execution provider options to {provider_options}")
self._provider_options = provider_options self._provider_options = provider_options
@property @property
@ -255,7 +253,7 @@ class InferenceModel(ABC):
@property @property
def sess_options_default(self) -> ort.SessionOptions: def sess_options_default(self) -> ort.SessionOptions:
sess_options = PicklableSessionOptions() sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = False sess_options.enable_cpu_mem_arena = False
# avoid thread contention between models # avoid thread contention between models
@ -287,15 +285,3 @@ class InferenceModel(ABC):
@property @property
def preferred_runtime_default(self) -> ModelRuntime: def preferred_runtime_default(self) -> ModelRuntime:
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
# HF deep copies configs, so we need to make session options picklable
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
def __getstate__(self) -> bytes:
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
def __setstate__(self, state: Buffer) -> None:
self.__init__() # type: ignore[misc]
attrs: list[tuple[str, Any]] = pickle.loads(state)
for attr, val in attrs:
setattr(self, attr, val)

View File

@ -80,20 +80,3 @@ class RevalidationPlugin(BasePlugin): # type: ignore[misc]
key = client.build_key(key, namespace) key = client.build_key(key, namespace)
if key in client._handlers: if key in client._handlers:
await client.expire(key, client.ttl) await client.expire(key, client.ttl)
async def post_multi_get(
self,
client: SimpleMemoryCache,
keys: list[str],
ret: list[Any] | None = None,
namespace: str | None = None,
**kwargs: Any,
) -> None:
if ret is None:
return
for key, val in zip(keys, ret):
if namespace is not None:
key = client.build_key(key, namespace)
if val is not None and key in client._handlers:
await client.expire(key, client.ttl)

View File

@ -144,9 +144,7 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
def _load(self) -> None: def _load(self) -> None:
super()._load() super()._load()
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"] self._load_tokenizer()
context_length: int = text_cfg.get("context_length", 77)
pad_token: int = self.tokenizer_cfg["pad_token"]
size: list[int] | int = self.preprocess_cfg["size"] size: list[int] | int = self.preprocess_cfg["size"]
self.size = size[0] if isinstance(size, list) else size self.size = size[0] if isinstance(size, list) else size
@ -155,11 +153,19 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32) self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32) self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
def _load_tokenizer(self) -> Tokenizer:
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'") log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
context_length: int = text_cfg.get("context_length", 77)
pad_token: str = self.tokenizer_cfg["pad_token"]
self.tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix()) self.tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
pad_id: int = self.tokenizer.token_to_id(pad_token) pad_id: int = self.tokenizer.token_to_id(pad_token)
self.tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id) self.tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
self.tokenizer.enable_truncation(max_length=context_length) self.tokenizer.enable_truncation(max_length=context_length)
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'") log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]: def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:

View File

@ -1,7 +1,8 @@
import json import json
import pickle
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from random import randint
from types import SimpleNamespace
from typing import Any, Callable from typing import Any, Callable
from unittest import mock from unittest import mock
@ -13,10 +14,12 @@ from fastapi.testclient import TestClient
from PIL import Image from PIL import Image
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from app.main import load
from .config import log, settings from .config import log, settings
from .models.base import InferenceModel, PicklableSessionOptions from .models.base import InferenceModel
from .models.cache import ModelCache from .models.cache import ModelCache
from .models.clip import OpenCLIPEncoder from .models.clip import MCLIPEncoder, OpenCLIPEncoder
from .models.facial_recognition import FaceRecognizer from .models.facial_recognition import FaceRecognizer
from .schemas import ModelRuntime, ModelType from .schemas import ModelRuntime, ModelType
@ -72,6 +75,17 @@ class TestBase:
{"arena_extend_strategy": "kSameAsRequested"}, {"arena_extend_strategy": "kSameAsRequested"},
] ]
def test_sets_openvino_device_id_if_possible(self, mocker: MockerFixture) -> None:
mocked = mocker.patch("app.models.base.ort.capi._pybind_state")
mocked.get_available_openvino_device_ids.return_value = ["GPU.0", "CPU"]
encoder = OpenCLIPEncoder("ViT-B-32__openai", providers=["OpenVINOExecutionProvider", "CPUExecutionProvider"])
assert encoder.provider_options == [
{"device_id": "GPU.0"},
{"arena_extend_strategy": "kSameAsRequested"},
]
def test_sets_provider_options_kwarg(self) -> None: def test_sets_provider_options_kwarg(self) -> None:
encoder = OpenCLIPEncoder( encoder = OpenCLIPEncoder(
"ViT-B-32__openai", "ViT-B-32__openai",
@ -119,7 +133,7 @@ class TestBase:
def test_sets_default_cache_dir(self) -> None: def test_sets_default_cache_dir(self) -> None:
encoder = OpenCLIPEncoder("ViT-B-32__openai") encoder = OpenCLIPEncoder("ViT-B-32__openai")
assert encoder.cache_dir == Path("/cache/clip/ViT-B-32__openai") assert encoder.cache_dir == Path(settings.cache_folder) / "clip" / "ViT-B-32__openai"
def test_sets_cache_dir_kwarg(self) -> None: def test_sets_cache_dir_kwarg(self) -> None:
cache_dir = Path("/test_cache") cache_dir = Path("/test_cache")
@ -170,7 +184,7 @@ class TestBase:
encoder.clear_cache() encoder.clear_cache()
mock_rmtree.assert_called_once_with(encoder.cache_dir) mock_rmtree.assert_called_once_with(encoder.cache_dir)
assert info.call_count == 2 info.assert_called_with(f"Cleared cache directory for model '{encoder.model_name}'.")
def test_clear_cache_warns_if_path_does_not_exist(self, mocker: MockerFixture) -> None: def test_clear_cache_warns_if_path_does_not_exist(self, mocker: MockerFixture) -> None:
mock_rmtree = mocker.patch("app.models.base.rmtree", autospec=True) mock_rmtree = mocker.patch("app.models.base.rmtree", autospec=True)
@ -267,7 +281,7 @@ class TestBase:
def test_download(self, mocker: MockerFixture) -> None: def test_download(self, mocker: MockerFixture) -> None:
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download") mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
encoder = OpenCLIPEncoder("ViT-B-32__openai") encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir="/path/to/cache")
encoder.download() encoder.download()
mock_snapshot_download.assert_called_once_with( mock_snapshot_download.assert_called_once_with(
@ -348,6 +362,60 @@ class TestCLIP:
assert embedding.dtype == np.float32 assert embedding.dtype == np.float32
mocked.run.assert_called_once() mocked.run.assert_called_once()
def test_openclip_tokenizer(
self,
mocker: MockerFixture,
clip_model_cfg: dict[str, Any],
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
) -> None:
mocker.patch.object(OpenCLIPEncoder, "download")
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mock_tokenizer = mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True).return_value
mock_ids = [randint(0, 50000) for _ in range(77)]
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
clip_encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir="test_cache", mode="text")
clip_encoder._load_tokenizer()
tokens = clip_encoder.tokenize("test search query")
assert "text" in tokens
assert isinstance(tokens["text"], np.ndarray)
assert tokens["text"].shape == (1, 77)
assert tokens["text"].dtype == np.int32
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
def test_mclip_tokenizer(
self,
mocker: MockerFixture,
clip_model_cfg: dict[str, Any],
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
) -> None:
mocker.patch.object(OpenCLIPEncoder, "download")
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mock_tokenizer = mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True).return_value
mock_ids = [randint(0, 50000) for _ in range(77)]
mock_attention_mask = [randint(0, 1) for _ in range(77)]
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids, attention_mask=mock_attention_mask)
clip_encoder = MCLIPEncoder("ViT-B-32__openai", cache_dir="test_cache", mode="text")
clip_encoder._load_tokenizer()
tokens = clip_encoder.tokenize("test search query")
assert "input_ids" in tokens
assert "attention_mask" in tokens
assert isinstance(tokens["input_ids"], np.ndarray)
assert isinstance(tokens["attention_mask"], np.ndarray)
assert tokens["input_ids"].shape == (1, 77)
assert tokens["attention_mask"].shape == (1, 77)
assert np.allclose(tokens["input_ids"], np.array([mock_ids], dtype=np.int32), atol=0)
assert np.allclose(tokens["attention_mask"], np.array([mock_attention_mask], dtype=np.int32), atol=0)
class TestFaceRecognition: class TestFaceRecognition:
def test_set_min_score(self, mocker: MockerFixture) -> None: def test_set_min_score(self, mocker: MockerFixture) -> None:
@ -420,12 +488,75 @@ class TestCache:
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100) mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
@mock.patch("app.models.cache.SimpleMemoryCache.expire") @mock.patch("app.models.cache.SimpleMemoryCache.expire")
async def test_revalidate(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None: async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
model_cache = ModelCache(ttl=100, revalidate=True) model_cache = ModelCache(ttl=100, revalidate=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION) await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION) await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
mock_cache_expire.assert_called_once_with(mock.ANY, 100) mock_cache_expire.assert_called_once_with(mock.ANY, 100)
async def test_profiling(self, mock_get_model: mock.Mock) -> None:
model_cache = ModelCache(ttl=100, profiling=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
profiling = await model_cache.get_profiling()
assert isinstance(profiling, dict)
assert profiling == model_cache.cache.profiling
async def test_loads_mclip(self) -> None:
model_cache = ModelCache()
model = await model_cache.get("XLM-Roberta-Large-Vit-B-32", ModelType.CLIP, mode="text")
assert isinstance(model, MCLIPEncoder)
assert model.model_name == "XLM-Roberta-Large-Vit-B-32"
async def test_raises_exception_if_invalid_model_type(self) -> None:
invalid: Any = SimpleNamespace(value="invalid")
model_cache = ModelCache()
with pytest.raises(ValueError):
await model_cache.get("XLM-Roberta-Large-Vit-B-32", invalid, mode="text")
async def test_raises_exception_if_unknown_model_name(self) -> None:
model_cache = ModelCache()
with pytest.raises(ValueError):
await model_cache.get("test_model_name", ModelType.CLIP, mode="text")
@pytest.mark.asyncio
class TestLoad:
async def test_load(self) -> None:
mock_model = mock.Mock(spec=InferenceModel)
mock_model.loaded = False
res = await load(mock_model)
assert res is mock_model
mock_model.load.assert_called_once()
mock_model.clear_cache.assert_not_called()
async def test_load_returns_model_if_loaded(self) -> None:
mock_model = mock.Mock(spec=InferenceModel)
mock_model.loaded = True
res = await load(mock_model)
assert res is mock_model
mock_model.load.assert_not_called()
async def test_load_clears_cache_and_retries_if_os_error(self) -> None:
mock_model = mock.Mock(spec=InferenceModel)
mock_model.model_name = "test_model_name"
mock_model.model_type = ModelType.CLIP
mock_model.load.side_effect = [OSError, None]
mock_model.loaded = False
res = await load(mock_model)
assert res is mock_model
mock_model.clear_cache.assert_called_once()
assert mock_model.load.call_count == 2
@pytest.mark.skipif( @pytest.mark.skipif(
not settings.test_full, not settings.test_full,
@ -437,15 +568,21 @@ class TestEndpoints:
) -> None: ) -> None:
byte_image = BytesIO() byte_image = BytesIO()
pil_image.save(byte_image, format="jpeg") pil_image.save(byte_image, format="jpeg")
expected = responses["clip"]["image"]
response = deployed_app.post( response = deployed_app.post(
"http://localhost:3003/predict", "http://localhost:3003/predict",
data={"modelName": "ViT-B-32__openai", "modelType": "clip", "options": json.dumps({"mode": "vision"})}, data={"modelName": "ViT-B-32__openai", "modelType": "clip", "options": json.dumps({"mode": "vision"})},
files={"image": byte_image.getvalue()}, files={"image": byte_image.getvalue()},
) )
actual = response.json()
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == responses["clip"]["image"] assert np.allclose(expected, actual)
def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestClient) -> None: def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestClient) -> None:
expected = responses["clip"]["text"]
response = deployed_app.post( response = deployed_app.post(
"http://localhost:3003/predict", "http://localhost:3003/predict",
data={ data={
@ -455,12 +592,15 @@ class TestEndpoints:
"options": json.dumps({"mode": "text"}), "options": json.dumps({"mode": "text"}),
}, },
) )
actual = response.json()
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == responses["clip"]["text"] assert np.allclose(expected, actual)
def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient) -> None: def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient) -> None:
byte_image = BytesIO() byte_image = BytesIO()
pil_image.save(byte_image, format="jpeg") pil_image.save(byte_image, format="jpeg")
expected = responses["facial-recognition"]
response = deployed_app.post( response = deployed_app.post(
"http://localhost:3003/predict", "http://localhost:3003/predict",
@ -471,15 +611,13 @@ class TestEndpoints:
}, },
files={"image": byte_image.getvalue()}, files={"image": byte_image.getvalue()},
) )
actual = response.json()
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == responses["facial-recognition"] assert len(expected) == len(actual)
for expected_face, actual_face in zip(expected, actual):
assert expected_face["imageHeight"] == actual_face["imageHeight"]
def test_sess_options() -> None: assert expected_face["imageWidth"] == actual_face["imageWidth"]
sess_options = PicklableSessionOptions() assert expected_face["boundingBox"] == actual_face["boundingBox"]
sess_options.intra_op_num_threads = 1 assert np.allclose(expected_face["embedding"], actual_face["embedding"])
sess_options.inter_op_num_threads = 1 assert np.allclose(expected_face["score"], actual_face["score"])
pickled = pickle.dumps(sess_options)
unpickled = pickle.loads(pickled)
assert unpickled.intra_op_num_threads == 1
assert unpickled.inter_op_num_threads == 1

File diff suppressed because it is too large Load Diff