1
0
mirror of https://github.com/immich-app/immich.git synced 2024-12-25 10:43:13 +02:00

feat: preloading of machine learning models (#7540)

This commit is contained in:
DawidPietrykowski 2024-03-04 01:48:56 +01:00 committed by GitHub
parent 762c4684f8
commit e8b001f62f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 75 additions and 49 deletions

View File

@ -124,16 +124,18 @@ Redis (Sentinel) URL example JSON before encoding:
## Machine Learning ## Machine Learning
| Variable | Description | Default | Services | | Variable | Description | Default | Services |
| :----------------------------------------------- | :----------------------------------------------------------------- | :-----------------: | :--------------- | | :----------------------------------------------- | :------------------------------------------------------------------- | :-----------------: | :--------------- |
| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning | | `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning |
| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning | | `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning |
| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning | | `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning |
| `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*1</sup> | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning | | `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*1</sup> | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning |
| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning | | `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning |
| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning | | `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning |
| `MACHINE_LEARNING_WORKERS`<sup>\*2</sup> | Number of worker processes to spawn | `1` | machine learning | | `MACHINE_LEARNING_WORKERS`<sup>\*2</sup> | Number of worker processes to spawn | `1` | machine learning |
| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning | | `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning |
| `MACHINE_LEARNING_PRELOAD__CLIP` | Name of a CLIP model to be preloaded and kept in cache | | machine learning |
| `MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION` | Name of a facial recognition model to be preloaded and kept in cache | | machine learning |
\*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones. \*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones.

View File

@ -167,6 +167,8 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/ .idea/
# VS Code
.vscode
*.onnx *.onnx
*.zip *.zip

View File

@ -6,7 +6,7 @@ from pathlib import Path
from socket import socket from socket import socket
from gunicorn.arbiter import Arbiter from gunicorn.arbiter import Arbiter
from pydantic import BaseSettings from pydantic import BaseModel, BaseSettings
from rich.console import Console from rich.console import Console
from rich.logging import RichHandler from rich.logging import RichHandler
from uvicorn import Server from uvicorn import Server
@ -15,6 +15,11 @@ from uvicorn.workers import UvicornWorker
from .schemas import ModelType from .schemas import ModelType
class PreloadModelData(BaseModel):
clip: str | None
facial_recognition: str | None
class Settings(BaseSettings): class Settings(BaseSettings):
cache_folder: str = "/cache" cache_folder: str = "/cache"
model_ttl: int = 300 model_ttl: int = 300
@ -27,10 +32,12 @@ class Settings(BaseSettings):
model_inter_op_threads: int = 0 model_inter_op_threads: int = 0
model_intra_op_threads: int = 0 model_intra_op_threads: int = 0
ann: bool = True ann: bool = True
preload: PreloadModelData | None = None
class Config: class Config:
env_prefix = "MACHINE_LEARNING_" env_prefix = "MACHINE_LEARNING_"
case_sensitive = False case_sensitive = False
env_nested_delimiter = "__"
class LogSettings(BaseSettings): class LogSettings(BaseSettings):

View File

@ -17,7 +17,7 @@ from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel from app.models.base import InferenceModel
from .config import log, settings from .config import PreloadModelData, log, settings
from .models.cache import ModelCache from .models.cache import ModelCache
from .schemas import ( from .schemas import (
MessageResponse, MessageResponse,
@ -27,7 +27,7 @@ from .schemas import (
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0) model_cache = ModelCache(revalidate=settings.model_ttl > 0)
thread_pool: ThreadPoolExecutor | None = None thread_pool: ThreadPoolExecutor | None = None
lock = threading.Lock() lock = threading.Lock()
active_requests = 0 active_requests = 0
@ -51,6 +51,8 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
log.info(f"Initialized request thread pool with {settings.request_threads} threads.") log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0: if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task()) asyncio.ensure_future(idle_shutdown_task())
if settings.preload is not None:
await preload_models(settings.preload)
yield yield
finally: finally:
log.handlers.clear() log.handlers.clear()
@ -61,6 +63,14 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
gc.collect() gc.collect()
async def preload_models(preload_models: PreloadModelData) -> None:
log.info(f"Preloading models: {preload_models}")
if preload_models.clip is not None:
await load(await model_cache.get(preload_models.clip, ModelType.CLIP))
if preload_models.facial_recognition is not None:
await load(await model_cache.get(preload_models.facial_recognition, ModelType.FACIAL_RECOGNITION))
def update_state() -> Iterator[None]: def update_state() -> Iterator[None]:
global active_requests, last_called global active_requests, last_called
active_requests += 1 active_requests += 1
@ -103,7 +113,7 @@ async def predict(
except orjson.JSONDecodeError: except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}") raise HTTPException(400, f"Invalid options JSON: {options}")
model = await load(await model_cache.get(model_name, model_type, **kwargs)) model = await load(await model_cache.get(model_name, model_type, ttl=settings.model_ttl, **kwargs))
model.configure(**kwargs) model.configure(**kwargs)
outputs = await run(model.predict, inputs) outputs = await run(model.predict, inputs)
return ORJSONResponse(outputs) return ORJSONResponse(outputs)

View File

@ -2,7 +2,7 @@ from typing import Any
from aiocache.backends.memory import SimpleMemoryCache from aiocache.backends.memory import SimpleMemoryCache
from aiocache.lock import OptimisticLock from aiocache.lock import OptimisticLock
from aiocache.plugins import BasePlugin, TimingPlugin from aiocache.plugins import TimingPlugin
from app.models import from_model_type from app.models import from_model_type
@ -15,28 +15,25 @@ class ModelCache:
def __init__( def __init__(
self, self,
ttl: float | None = None,
revalidate: bool = False, revalidate: bool = False,
timeout: int | None = None, timeout: int | None = None,
profiling: bool = False, profiling: bool = False,
) -> None: ) -> None:
""" """
Args: Args:
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False. revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None. timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False. profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
""" """
self.ttl = ttl
plugins = [] plugins = []
if revalidate:
plugins.append(RevalidationPlugin())
if profiling: if profiling:
plugins.append(TimingPlugin()) plugins.append(TimingPlugin())
self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None) self.revalidate_enable = revalidate
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel: async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
""" """
@ -49,11 +46,14 @@ class ModelCache:
""" """
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}" key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
async with OptimisticLock(self.cache, key) as lock: async with OptimisticLock(self.cache, key) as lock:
model: InferenceModel | None = await self.cache.get(key) model: InferenceModel | None = await self.cache.get(key)
if model is None: if model is None:
model = from_model_type(model_type, model_name, **model_kwargs) model = from_model_type(model_type, model_name, **model_kwargs)
await lock.cas(model, ttl=self.ttl) await lock.cas(model, ttl=model_kwargs.get("ttl", None))
elif self.revalidate_enable:
await self.revalidate(key, model_kwargs.get("ttl", None))
return model return model
async def get_profiling(self) -> dict[str, float] | None: async def get_profiling(self) -> dict[str, float] | None:
@ -62,21 +62,6 @@ class ModelCache:
return self.cache.profiling return self.cache.profiling
async def revalidate(self, key: str, ttl: int | None) -> None:
class RevalidationPlugin(BasePlugin): # type: ignore[misc] if ttl is not None and key in self.cache._handlers:
"""Revalidates cache item's TTL after cache hit.""" await self.cache.expire(key, ttl)
async def post_get(
self,
client: SimpleMemoryCache,
key: str,
ret: Any | None = None,
namespace: str | None = None,
**kwargs: Any,
) -> None:
if ret is None:
return
if namespace is not None:
key = client.build_key(key, namespace)
if key in client._handlers:
await client.expire(key, client.ttl)

View File

@ -13,11 +13,12 @@ import onnxruntime as ort
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from PIL import Image from PIL import Image
from pytest import MonkeyPatch
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from app.main import load from app.main import load, preload_models
from .config import log, settings from .config import Settings, log, settings
from .models.base import InferenceModel from .models.base import InferenceModel
from .models.cache import ModelCache from .models.cache import ModelCache
from .models.clip import MCLIPEncoder, OpenCLIPEncoder from .models.clip import MCLIPEncoder, OpenCLIPEncoder
@ -509,20 +510,20 @@ class TestCache:
@mock.patch("app.models.cache.OptimisticLock", autospec=True) @mock.patch("app.models.cache.OptimisticLock", autospec=True)
async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None: async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
model_cache = ModelCache(ttl=100) model_cache = ModelCache()
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION) await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100) mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
@mock.patch("app.models.cache.SimpleMemoryCache.expire") @mock.patch("app.models.cache.SimpleMemoryCache.expire")
async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None: async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
model_cache = ModelCache(ttl=100, revalidate=True) model_cache = ModelCache(revalidate=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION) await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION) await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
mock_cache_expire.assert_called_once_with(mock.ANY, 100) mock_cache_expire.assert_called_once_with(mock.ANY, 100)
async def test_profiling(self, mock_get_model: mock.Mock) -> None: async def test_profiling(self, mock_get_model: mock.Mock) -> None:
model_cache = ModelCache(ttl=100, profiling=True) model_cache = ModelCache(profiling=True)
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION) await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
profiling = await model_cache.get_profiling() profiling = await model_cache.get_profiling()
assert isinstance(profiling, dict) assert isinstance(profiling, dict)
assert profiling == model_cache.cache.profiling assert profiling == model_cache.cache.profiling
@ -548,6 +549,25 @@ class TestCache:
with pytest.raises(ValueError): with pytest.raises(ValueError):
await model_cache.get("test_model_name", ModelType.CLIP, mode="text") await model_cache.get("test_model_name", ModelType.CLIP, mode="text")
async def test_preloads_models(self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock) -> None:
os.environ["MACHINE_LEARNING_PRELOAD__CLIP"] = "ViT-B-32__openai"
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"] = "buffalo_s"
settings = Settings()
assert settings.preload is not None
assert settings.preload.clip == "ViT-B-32__openai"
assert settings.preload.facial_recognition == "buffalo_s"
model_cache = ModelCache()
monkeypatch.setattr("app.main.model_cache", model_cache)
await preload_models(settings.preload)
assert len(model_cache.cache._cache) == 2
assert mock_get_model.call_count == 2
await model_cache.get("ViT-B-32__openai", ModelType.CLIP, ttl=100)
await model_cache.get("buffalo_s", ModelType.FACIAL_RECOGNITION, ttl=100)
assert mock_get_model.call_count == 2
@pytest.mark.asyncio @pytest.mark.asyncio
class TestLoad: class TestLoad: