From e8b001f62f4ecdae3ed6dbd94a381ae5f3133579 Mon Sep 17 00:00:00 2001 From: DawidPietrykowski <53954695+DawidPietrykowski@users.noreply.github.com> Date: Mon, 4 Mar 2024 01:48:56 +0100 Subject: [PATCH] feat: preloading of machine learning models (#7540) --- docs/docs/install/environment-variables.md | 22 +++++++------ machine-learning/.gitignore | 2 ++ machine-learning/app/config.py | 9 ++++- machine-learning/app/main.py | 16 +++++++-- machine-learning/app/models/cache.py | 37 +++++++-------------- machine-learning/app/test_main.py | 38 +++++++++++++++++----- 6 files changed, 75 insertions(+), 49 deletions(-) diff --git a/docs/docs/install/environment-variables.md b/docs/docs/install/environment-variables.md index 2849c60549..a915e20e20 100644 --- a/docs/docs/install/environment-variables.md +++ b/docs/docs/install/environment-variables.md @@ -124,16 +124,18 @@ Redis (Sentinel) URL example JSON before encoding: ## Machine Learning -| Variable | Description | Default | Services | -| :----------------------------------------------- | :----------------------------------------------------------------- | :-----------------: | :--------------- | -| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning | -| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning | -| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning | -| `MACHINE_LEARNING_REQUEST_THREADS`\*1 | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning | -| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning | -| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning | -| `MACHINE_LEARNING_WORKERS`\*2 | Number of worker processes to spawn | `1` | machine learning | -| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning | +| Variable | Description | Default | Services | +| :----------------------------------------------- | :------------------------------------------------------------------- | :-----------------: | :--------------- | +| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning | +| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning | +| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning | +| `MACHINE_LEARNING_REQUEST_THREADS`\*1 | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning | +| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning | +| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning | +| `MACHINE_LEARNING_WORKERS`\*2 | Number of worker processes to spawn | `1` | machine learning | +| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning | +| `MACHINE_LEARNING_PRELOAD__CLIP` | Name of a CLIP model to be preloaded and kept in cache | | machine learning | +| `MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION` | Name of a facial recognition model to be preloaded and kept in cache | | machine learning | \*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones. diff --git a/machine-learning/.gitignore b/machine-learning/.gitignore index e31c7773ee..d3163ea5b0 100644 --- a/machine-learning/.gitignore +++ b/machine-learning/.gitignore @@ -167,6 +167,8 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +# VS Code +.vscode *.onnx *.zip \ No newline at end of file diff --git a/machine-learning/app/config.py b/machine-learning/app/config.py index c48b3278d6..a911659dbc 100644 --- a/machine-learning/app/config.py +++ b/machine-learning/app/config.py @@ -6,7 +6,7 @@ from pathlib import Path from socket import socket from gunicorn.arbiter import Arbiter -from pydantic import BaseSettings +from pydantic import BaseModel, BaseSettings from rich.console import Console from rich.logging import RichHandler from uvicorn import Server @@ -15,6 +15,11 @@ from uvicorn.workers import UvicornWorker from .schemas import ModelType +class PreloadModelData(BaseModel): + clip: str | None + facial_recognition: str | None + + class Settings(BaseSettings): cache_folder: str = "/cache" model_ttl: int = 300 @@ -27,10 +32,12 @@ class Settings(BaseSettings): model_inter_op_threads: int = 0 model_intra_op_threads: int = 0 ann: bool = True + preload: PreloadModelData | None = None class Config: env_prefix = "MACHINE_LEARNING_" case_sensitive = False + env_nested_delimiter = "__" class LogSettings(BaseSettings): diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index bde40f36e4..277ad76898 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -17,7 +17,7 @@ from starlette.formparsers import MultiPartParser from app.models.base import InferenceModel -from .config import log, settings +from .config import PreloadModelData, log, settings from .models.cache import ModelCache from .schemas import ( MessageResponse, @@ -27,7 +27,7 @@ from .schemas import ( MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger -model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0) +model_cache = ModelCache(revalidate=settings.model_ttl > 0) thread_pool: ThreadPoolExecutor | None = None lock = threading.Lock() active_requests = 0 @@ -51,6 +51,8 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]: log.info(f"Initialized request thread pool with {settings.request_threads} threads.") if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0: asyncio.ensure_future(idle_shutdown_task()) + if settings.preload is not None: + await preload_models(settings.preload) yield finally: log.handlers.clear() @@ -61,6 +63,14 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]: gc.collect() +async def preload_models(preload_models: PreloadModelData) -> None: + log.info(f"Preloading models: {preload_models}") + if preload_models.clip is not None: + await load(await model_cache.get(preload_models.clip, ModelType.CLIP)) + if preload_models.facial_recognition is not None: + await load(await model_cache.get(preload_models.facial_recognition, ModelType.FACIAL_RECOGNITION)) + + def update_state() -> Iterator[None]: global active_requests, last_called active_requests += 1 @@ -103,7 +113,7 @@ async def predict( except orjson.JSONDecodeError: raise HTTPException(400, f"Invalid options JSON: {options}") - model = await load(await model_cache.get(model_name, model_type, **kwargs)) + model = await load(await model_cache.get(model_name, model_type, ttl=settings.model_ttl, **kwargs)) model.configure(**kwargs) outputs = await run(model.predict, inputs) return ORJSONResponse(outputs) diff --git a/machine-learning/app/models/cache.py b/machine-learning/app/models/cache.py index 62afd05a09..781a9caea0 100644 --- a/machine-learning/app/models/cache.py +++ b/machine-learning/app/models/cache.py @@ -2,7 +2,7 @@ from typing import Any from aiocache.backends.memory import SimpleMemoryCache from aiocache.lock import OptimisticLock -from aiocache.plugins import BasePlugin, TimingPlugin +from aiocache.plugins import TimingPlugin from app.models import from_model_type @@ -15,28 +15,25 @@ class ModelCache: def __init__( self, - ttl: float | None = None, revalidate: bool = False, timeout: int | None = None, profiling: bool = False, ) -> None: """ Args: - ttl: Unloads model after this duration. Disabled if None. Defaults to None. revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False. timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None. profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False. """ - self.ttl = ttl plugins = [] - if revalidate: - plugins.append(RevalidationPlugin()) if profiling: plugins.append(TimingPlugin()) - self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None) + self.revalidate_enable = revalidate + + self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None) async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel: """ @@ -49,11 +46,14 @@ class ModelCache: """ key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}" + async with OptimisticLock(self.cache, key) as lock: model: InferenceModel | None = await self.cache.get(key) if model is None: model = from_model_type(model_type, model_name, **model_kwargs) - await lock.cas(model, ttl=self.ttl) + await lock.cas(model, ttl=model_kwargs.get("ttl", None)) + elif self.revalidate_enable: + await self.revalidate(key, model_kwargs.get("ttl", None)) return model async def get_profiling(self) -> dict[str, float] | None: @@ -62,21 +62,6 @@ class ModelCache: return self.cache.profiling - -class RevalidationPlugin(BasePlugin): # type: ignore[misc] - """Revalidates cache item's TTL after cache hit.""" - - async def post_get( - self, - client: SimpleMemoryCache, - key: str, - ret: Any | None = None, - namespace: str | None = None, - **kwargs: Any, - ) -> None: - if ret is None: - return - if namespace is not None: - key = client.build_key(key, namespace) - if key in client._handlers: - await client.expire(key, client.ttl) + async def revalidate(self, key: str, ttl: int | None) -> None: + if ttl is not None and key in self.cache._handlers: + await self.cache.expire(key, ttl) diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index 0f802997fd..72cd020ff2 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -13,11 +13,12 @@ import onnxruntime as ort import pytest from fastapi.testclient import TestClient from PIL import Image +from pytest import MonkeyPatch from pytest_mock import MockerFixture -from app.main import load +from app.main import load, preload_models -from .config import log, settings +from .config import Settings, log, settings from .models.base import InferenceModel from .models.cache import ModelCache from .models.clip import MCLIPEncoder, OpenCLIPEncoder @@ -509,20 +510,20 @@ class TestCache: @mock.patch("app.models.cache.OptimisticLock", autospec=True) async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None: - model_cache = ModelCache(ttl=100) - await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION) + model_cache = ModelCache() + await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100) mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100) @mock.patch("app.models.cache.SimpleMemoryCache.expire") async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None: - model_cache = ModelCache(ttl=100, revalidate=True) - await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION) - await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION) + model_cache = ModelCache(revalidate=True) + await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100) + await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100) mock_cache_expire.assert_called_once_with(mock.ANY, 100) async def test_profiling(self, mock_get_model: mock.Mock) -> None: - model_cache = ModelCache(ttl=100, profiling=True) - await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION) + model_cache = ModelCache(profiling=True) + await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100) profiling = await model_cache.get_profiling() assert isinstance(profiling, dict) assert profiling == model_cache.cache.profiling @@ -548,6 +549,25 @@ class TestCache: with pytest.raises(ValueError): await model_cache.get("test_model_name", ModelType.CLIP, mode="text") + async def test_preloads_models(self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock) -> None: + os.environ["MACHINE_LEARNING_PRELOAD__CLIP"] = "ViT-B-32__openai" + os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"] = "buffalo_s" + + settings = Settings() + assert settings.preload is not None + assert settings.preload.clip == "ViT-B-32__openai" + assert settings.preload.facial_recognition == "buffalo_s" + + model_cache = ModelCache() + monkeypatch.setattr("app.main.model_cache", model_cache) + + await preload_models(settings.preload) + assert len(model_cache.cache._cache) == 2 + assert mock_get_model.call_count == 2 + await model_cache.get("ViT-B-32__openai", ModelType.CLIP, ttl=100) + await model_cache.get("buffalo_s", ModelType.FACIAL_RECOGNITION, ttl=100) + assert mock_get_model.call_count == 2 + @pytest.mark.asyncio class TestLoad: