mirror of
https://github.com/immich-app/immich.git
synced 2024-12-25 10:43:13 +02:00
feat: preloading of machine learning models (#7540)
This commit is contained in:
parent
762c4684f8
commit
e8b001f62f
@ -124,16 +124,18 @@ Redis (Sentinel) URL example JSON before encoding:
|
|||||||
|
|
||||||
## Machine Learning
|
## Machine Learning
|
||||||
|
|
||||||
| Variable | Description | Default | Services |
|
| Variable | Description | Default | Services |
|
||||||
| :----------------------------------------------- | :----------------------------------------------------------------- | :-----------------: | :--------------- |
|
| :----------------------------------------------- | :------------------------------------------------------------------- | :-----------------: | :--------------- |
|
||||||
| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning |
|
| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning |
|
||||||
| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning |
|
| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning |
|
||||||
| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning |
|
| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning |
|
||||||
| `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*1</sup> | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning |
|
| `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*1</sup> | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning |
|
||||||
| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning |
|
| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning |
|
||||||
| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning |
|
| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning |
|
||||||
| `MACHINE_LEARNING_WORKERS`<sup>\*2</sup> | Number of worker processes to spawn | `1` | machine learning |
|
| `MACHINE_LEARNING_WORKERS`<sup>\*2</sup> | Number of worker processes to spawn | `1` | machine learning |
|
||||||
| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning |
|
| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning |
|
||||||
|
| `MACHINE_LEARNING_PRELOAD__CLIP` | Name of a CLIP model to be preloaded and kept in cache | | machine learning |
|
||||||
|
| `MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION` | Name of a facial recognition model to be preloaded and kept in cache | | machine learning |
|
||||||
|
|
||||||
\*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones.
|
\*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones.
|
||||||
|
|
||||||
|
2
machine-learning/.gitignore
vendored
2
machine-learning/.gitignore
vendored
@ -167,6 +167,8 @@ cython_debug/
|
|||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
|
# VS Code
|
||||||
|
.vscode
|
||||||
|
|
||||||
*.onnx
|
*.onnx
|
||||||
*.zip
|
*.zip
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||||||
from socket import socket
|
from socket import socket
|
||||||
|
|
||||||
from gunicorn.arbiter import Arbiter
|
from gunicorn.arbiter import Arbiter
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseModel, BaseSettings
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.logging import RichHandler
|
from rich.logging import RichHandler
|
||||||
from uvicorn import Server
|
from uvicorn import Server
|
||||||
@ -15,6 +15,11 @@ from uvicorn.workers import UvicornWorker
|
|||||||
from .schemas import ModelType
|
from .schemas import ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class PreloadModelData(BaseModel):
|
||||||
|
clip: str | None
|
||||||
|
facial_recognition: str | None
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
cache_folder: str = "/cache"
|
cache_folder: str = "/cache"
|
||||||
model_ttl: int = 300
|
model_ttl: int = 300
|
||||||
@ -27,10 +32,12 @@ class Settings(BaseSettings):
|
|||||||
model_inter_op_threads: int = 0
|
model_inter_op_threads: int = 0
|
||||||
model_intra_op_threads: int = 0
|
model_intra_op_threads: int = 0
|
||||||
ann: bool = True
|
ann: bool = True
|
||||||
|
preload: PreloadModelData | None = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_prefix = "MACHINE_LEARNING_"
|
env_prefix = "MACHINE_LEARNING_"
|
||||||
case_sensitive = False
|
case_sensitive = False
|
||||||
|
env_nested_delimiter = "__"
|
||||||
|
|
||||||
|
|
||||||
class LogSettings(BaseSettings):
|
class LogSettings(BaseSettings):
|
||||||
|
@ -17,7 +17,7 @@ from starlette.formparsers import MultiPartParser
|
|||||||
|
|
||||||
from app.models.base import InferenceModel
|
from app.models.base import InferenceModel
|
||||||
|
|
||||||
from .config import log, settings
|
from .config import PreloadModelData, log, settings
|
||||||
from .models.cache import ModelCache
|
from .models.cache import ModelCache
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
@ -27,7 +27,7 @@ from .schemas import (
|
|||||||
|
|
||||||
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
|
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
|
||||||
|
|
||||||
model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
|
model_cache = ModelCache(revalidate=settings.model_ttl > 0)
|
||||||
thread_pool: ThreadPoolExecutor | None = None
|
thread_pool: ThreadPoolExecutor | None = None
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
active_requests = 0
|
active_requests = 0
|
||||||
@ -51,6 +51,8 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
||||||
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
|
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
|
||||||
asyncio.ensure_future(idle_shutdown_task())
|
asyncio.ensure_future(idle_shutdown_task())
|
||||||
|
if settings.preload is not None:
|
||||||
|
await preload_models(settings.preload)
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
log.handlers.clear()
|
log.handlers.clear()
|
||||||
@ -61,6 +63,14 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
async def preload_models(preload_models: PreloadModelData) -> None:
|
||||||
|
log.info(f"Preloading models: {preload_models}")
|
||||||
|
if preload_models.clip is not None:
|
||||||
|
await load(await model_cache.get(preload_models.clip, ModelType.CLIP))
|
||||||
|
if preload_models.facial_recognition is not None:
|
||||||
|
await load(await model_cache.get(preload_models.facial_recognition, ModelType.FACIAL_RECOGNITION))
|
||||||
|
|
||||||
|
|
||||||
def update_state() -> Iterator[None]:
|
def update_state() -> Iterator[None]:
|
||||||
global active_requests, last_called
|
global active_requests, last_called
|
||||||
active_requests += 1
|
active_requests += 1
|
||||||
@ -103,7 +113,7 @@ async def predict(
|
|||||||
except orjson.JSONDecodeError:
|
except orjson.JSONDecodeError:
|
||||||
raise HTTPException(400, f"Invalid options JSON: {options}")
|
raise HTTPException(400, f"Invalid options JSON: {options}")
|
||||||
|
|
||||||
model = await load(await model_cache.get(model_name, model_type, **kwargs))
|
model = await load(await model_cache.get(model_name, model_type, ttl=settings.model_ttl, **kwargs))
|
||||||
model.configure(**kwargs)
|
model.configure(**kwargs)
|
||||||
outputs = await run(model.predict, inputs)
|
outputs = await run(model.predict, inputs)
|
||||||
return ORJSONResponse(outputs)
|
return ORJSONResponse(outputs)
|
||||||
|
@ -2,7 +2,7 @@ from typing import Any
|
|||||||
|
|
||||||
from aiocache.backends.memory import SimpleMemoryCache
|
from aiocache.backends.memory import SimpleMemoryCache
|
||||||
from aiocache.lock import OptimisticLock
|
from aiocache.lock import OptimisticLock
|
||||||
from aiocache.plugins import BasePlugin, TimingPlugin
|
from aiocache.plugins import TimingPlugin
|
||||||
|
|
||||||
from app.models import from_model_type
|
from app.models import from_model_type
|
||||||
|
|
||||||
@ -15,28 +15,25 @@ class ModelCache:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ttl: float | None = None,
|
|
||||||
revalidate: bool = False,
|
revalidate: bool = False,
|
||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
profiling: bool = False,
|
profiling: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
|
|
||||||
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
|
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
|
||||||
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
|
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
|
||||||
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
|
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.ttl = ttl
|
|
||||||
plugins = []
|
plugins = []
|
||||||
|
|
||||||
if revalidate:
|
|
||||||
plugins.append(RevalidationPlugin())
|
|
||||||
if profiling:
|
if profiling:
|
||||||
plugins.append(TimingPlugin())
|
plugins.append(TimingPlugin())
|
||||||
|
|
||||||
self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
|
self.revalidate_enable = revalidate
|
||||||
|
|
||||||
|
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
|
||||||
|
|
||||||
async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
|
async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
|
||||||
"""
|
"""
|
||||||
@ -49,11 +46,14 @@ class ModelCache:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
||||||
|
|
||||||
async with OptimisticLock(self.cache, key) as lock:
|
async with OptimisticLock(self.cache, key) as lock:
|
||||||
model: InferenceModel | None = await self.cache.get(key)
|
model: InferenceModel | None = await self.cache.get(key)
|
||||||
if model is None:
|
if model is None:
|
||||||
model = from_model_type(model_type, model_name, **model_kwargs)
|
model = from_model_type(model_type, model_name, **model_kwargs)
|
||||||
await lock.cas(model, ttl=self.ttl)
|
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
|
||||||
|
elif self.revalidate_enable:
|
||||||
|
await self.revalidate(key, model_kwargs.get("ttl", None))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def get_profiling(self) -> dict[str, float] | None:
|
async def get_profiling(self) -> dict[str, float] | None:
|
||||||
@ -62,21 +62,6 @@ class ModelCache:
|
|||||||
|
|
||||||
return self.cache.profiling
|
return self.cache.profiling
|
||||||
|
|
||||||
|
async def revalidate(self, key: str, ttl: int | None) -> None:
|
||||||
class RevalidationPlugin(BasePlugin): # type: ignore[misc]
|
if ttl is not None and key in self.cache._handlers:
|
||||||
"""Revalidates cache item's TTL after cache hit."""
|
await self.cache.expire(key, ttl)
|
||||||
|
|
||||||
async def post_get(
|
|
||||||
self,
|
|
||||||
client: SimpleMemoryCache,
|
|
||||||
key: str,
|
|
||||||
ret: Any | None = None,
|
|
||||||
namespace: str | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
if ret is None:
|
|
||||||
return
|
|
||||||
if namespace is not None:
|
|
||||||
key = client.build_key(key, namespace)
|
|
||||||
if key in client._handlers:
|
|
||||||
await client.expire(key, client.ttl)
|
|
||||||
|
@ -13,11 +13,12 @@ import onnxruntime as ort
|
|||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pytest import MonkeyPatch
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from app.main import load
|
from app.main import load, preload_models
|
||||||
|
|
||||||
from .config import log, settings
|
from .config import Settings, log, settings
|
||||||
from .models.base import InferenceModel
|
from .models.base import InferenceModel
|
||||||
from .models.cache import ModelCache
|
from .models.cache import ModelCache
|
||||||
from .models.clip import MCLIPEncoder, OpenCLIPEncoder
|
from .models.clip import MCLIPEncoder, OpenCLIPEncoder
|
||||||
@ -509,20 +510,20 @@ class TestCache:
|
|||||||
|
|
||||||
@mock.patch("app.models.cache.OptimisticLock", autospec=True)
|
@mock.patch("app.models.cache.OptimisticLock", autospec=True)
|
||||||
async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
|
async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
|
||||||
model_cache = ModelCache(ttl=100)
|
model_cache = ModelCache()
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
||||||
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
|
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
|
||||||
|
|
||||||
@mock.patch("app.models.cache.SimpleMemoryCache.expire")
|
@mock.patch("app.models.cache.SimpleMemoryCache.expire")
|
||||||
async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
|
async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
|
||||||
model_cache = ModelCache(ttl=100, revalidate=True)
|
model_cache = ModelCache(revalidate=True)
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
||||||
mock_cache_expire.assert_called_once_with(mock.ANY, 100)
|
mock_cache_expire.assert_called_once_with(mock.ANY, 100)
|
||||||
|
|
||||||
async def test_profiling(self, mock_get_model: mock.Mock) -> None:
|
async def test_profiling(self, mock_get_model: mock.Mock) -> None:
|
||||||
model_cache = ModelCache(ttl=100, profiling=True)
|
model_cache = ModelCache(profiling=True)
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
|
||||||
profiling = await model_cache.get_profiling()
|
profiling = await model_cache.get_profiling()
|
||||||
assert isinstance(profiling, dict)
|
assert isinstance(profiling, dict)
|
||||||
assert profiling == model_cache.cache.profiling
|
assert profiling == model_cache.cache.profiling
|
||||||
@ -548,6 +549,25 @@ class TestCache:
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await model_cache.get("test_model_name", ModelType.CLIP, mode="text")
|
await model_cache.get("test_model_name", ModelType.CLIP, mode="text")
|
||||||
|
|
||||||
|
async def test_preloads_models(self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock) -> None:
|
||||||
|
os.environ["MACHINE_LEARNING_PRELOAD__CLIP"] = "ViT-B-32__openai"
|
||||||
|
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"] = "buffalo_s"
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
assert settings.preload is not None
|
||||||
|
assert settings.preload.clip == "ViT-B-32__openai"
|
||||||
|
assert settings.preload.facial_recognition == "buffalo_s"
|
||||||
|
|
||||||
|
model_cache = ModelCache()
|
||||||
|
monkeypatch.setattr("app.main.model_cache", model_cache)
|
||||||
|
|
||||||
|
await preload_models(settings.preload)
|
||||||
|
assert len(model_cache.cache._cache) == 2
|
||||||
|
assert mock_get_model.call_count == 2
|
||||||
|
await model_cache.get("ViT-B-32__openai", ModelType.CLIP, ttl=100)
|
||||||
|
await model_cache.get("buffalo_s", ModelType.FACIAL_RECOGNITION, ttl=100)
|
||||||
|
assert mock_get_model.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestLoad:
|
class TestLoad:
|
||||||
|
Loading…
Reference in New Issue
Block a user