diff --git a/docs/docs/install/environment-variables.md b/docs/docs/install/environment-variables.md
index 2849c60549..a915e20e20 100644
--- a/docs/docs/install/environment-variables.md
+++ b/docs/docs/install/environment-variables.md
@@ -124,16 +124,18 @@ Redis (Sentinel) URL example JSON before encoding:
## Machine Learning
-| Variable | Description | Default | Services |
-| :----------------------------------------------- | :----------------------------------------------------------------- | :-----------------: | :--------------- |
-| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning |
-| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning |
-| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning |
-| `MACHINE_LEARNING_REQUEST_THREADS`\*1 | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning |
-| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning |
-| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning |
-| `MACHINE_LEARNING_WORKERS`\*2 | Number of worker processes to spawn | `1` | machine learning |
-| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning |
+| Variable | Description | Default | Services |
+| :----------------------------------------------- | :------------------------------------------------------------------- | :-----------------: | :--------------- |
+| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if \<= 0) | `300` | machine learning |
+| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if \<= 0) | `10` | machine learning |
+| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning |
+| `MACHINE_LEARNING_REQUEST_THREADS`\*1 | Thread count of the request thread pool (disabled if \<= 0) | number of CPU cores | machine learning |
+| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning |
+| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning |
+| `MACHINE_LEARNING_WORKERS`\*2 | Number of worker processes to spawn | `1` | machine learning |
+| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning |
+| `MACHINE_LEARNING_PRELOAD__CLIP` | Name of a CLIP model to be preloaded and kept in cache | | machine learning |
+| `MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION` | Name of a facial recognition model to be preloaded and kept in cache | | machine learning |
\*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones.
diff --git a/machine-learning/.gitignore b/machine-learning/.gitignore
index e31c7773ee..d3163ea5b0 100644
--- a/machine-learning/.gitignore
+++ b/machine-learning/.gitignore
@@ -167,6 +167,8 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
+# VS Code
+.vscode
*.onnx
*.zip
\ No newline at end of file
diff --git a/machine-learning/app/config.py b/machine-learning/app/config.py
index c48b3278d6..a911659dbc 100644
--- a/machine-learning/app/config.py
+++ b/machine-learning/app/config.py
@@ -6,7 +6,7 @@ from pathlib import Path
from socket import socket
from gunicorn.arbiter import Arbiter
-from pydantic import BaseSettings
+from pydantic import BaseModel, BaseSettings
from rich.console import Console
from rich.logging import RichHandler
from uvicorn import Server
@@ -15,6 +15,11 @@ from uvicorn.workers import UvicornWorker
from .schemas import ModelType
+class PreloadModelData(BaseModel):
+ clip: str | None
+ facial_recognition: str | None
+
+
class Settings(BaseSettings):
cache_folder: str = "/cache"
model_ttl: int = 300
@@ -27,10 +32,12 @@ class Settings(BaseSettings):
model_inter_op_threads: int = 0
model_intra_op_threads: int = 0
ann: bool = True
+ preload: PreloadModelData | None = None
class Config:
env_prefix = "MACHINE_LEARNING_"
case_sensitive = False
+ env_nested_delimiter = "__"
class LogSettings(BaseSettings):
diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py
index bde40f36e4..277ad76898 100644
--- a/machine-learning/app/main.py
+++ b/machine-learning/app/main.py
@@ -17,7 +17,7 @@ from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel
-from .config import log, settings
+from .config import PreloadModelData, log, settings
from .models.cache import ModelCache
from .schemas import (
MessageResponse,
@@ -27,7 +27,7 @@ from .schemas import (
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
-model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
+model_cache = ModelCache(revalidate=settings.model_ttl > 0)
thread_pool: ThreadPoolExecutor | None = None
lock = threading.Lock()
active_requests = 0
@@ -51,6 +51,8 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task())
+ if settings.preload is not None:
+ await preload_models(settings.preload)
yield
finally:
log.handlers.clear()
@@ -61,6 +63,14 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
gc.collect()
+async def preload_models(preload_models: PreloadModelData) -> None:
+ log.info(f"Preloading models: {preload_models}")
+ if preload_models.clip is not None:
+ await load(await model_cache.get(preload_models.clip, ModelType.CLIP))
+ if preload_models.facial_recognition is not None:
+ await load(await model_cache.get(preload_models.facial_recognition, ModelType.FACIAL_RECOGNITION))
+
+
def update_state() -> Iterator[None]:
global active_requests, last_called
active_requests += 1
@@ -103,7 +113,7 @@ async def predict(
except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}")
- model = await load(await model_cache.get(model_name, model_type, **kwargs))
+ model = await load(await model_cache.get(model_name, model_type, ttl=settings.model_ttl, **kwargs))
model.configure(**kwargs)
outputs = await run(model.predict, inputs)
return ORJSONResponse(outputs)
diff --git a/machine-learning/app/models/cache.py b/machine-learning/app/models/cache.py
index 62afd05a09..781a9caea0 100644
--- a/machine-learning/app/models/cache.py
+++ b/machine-learning/app/models/cache.py
@@ -2,7 +2,7 @@ from typing import Any
from aiocache.backends.memory import SimpleMemoryCache
from aiocache.lock import OptimisticLock
-from aiocache.plugins import BasePlugin, TimingPlugin
+from aiocache.plugins import TimingPlugin
from app.models import from_model_type
@@ -15,28 +15,25 @@ class ModelCache:
def __init__(
self,
- ttl: float | None = None,
revalidate: bool = False,
timeout: int | None = None,
profiling: bool = False,
) -> None:
"""
Args:
- ttl: Unloads model after this duration. Disabled if None. Defaults to None.
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
"""
- self.ttl = ttl
plugins = []
- if revalidate:
- plugins.append(RevalidationPlugin())
if profiling:
plugins.append(TimingPlugin())
- self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
+ self.revalidate_enable = revalidate
+
+ self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
"""
@@ -49,11 +46,14 @@ class ModelCache:
"""
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
+
async with OptimisticLock(self.cache, key) as lock:
model: InferenceModel | None = await self.cache.get(key)
if model is None:
model = from_model_type(model_type, model_name, **model_kwargs)
- await lock.cas(model, ttl=self.ttl)
+ await lock.cas(model, ttl=model_kwargs.get("ttl", None))
+ elif self.revalidate_enable:
+ await self.revalidate(key, model_kwargs.get("ttl", None))
return model
async def get_profiling(self) -> dict[str, float] | None:
@@ -62,21 +62,6 @@ class ModelCache:
return self.cache.profiling
-
-class RevalidationPlugin(BasePlugin): # type: ignore[misc]
- """Revalidates cache item's TTL after cache hit."""
-
- async def post_get(
- self,
- client: SimpleMemoryCache,
- key: str,
- ret: Any | None = None,
- namespace: str | None = None,
- **kwargs: Any,
- ) -> None:
- if ret is None:
- return
- if namespace is not None:
- key = client.build_key(key, namespace)
- if key in client._handlers:
- await client.expire(key, client.ttl)
+ async def revalidate(self, key: str, ttl: int | None) -> None:
+ if ttl is not None and key in self.cache._handlers:
+ await self.cache.expire(key, ttl)
diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py
index 0f802997fd..72cd020ff2 100644
--- a/machine-learning/app/test_main.py
+++ b/machine-learning/app/test_main.py
@@ -13,11 +13,12 @@ import onnxruntime as ort
import pytest
from fastapi.testclient import TestClient
from PIL import Image
+from pytest import MonkeyPatch
from pytest_mock import MockerFixture
-from app.main import load
+from app.main import load, preload_models
-from .config import log, settings
+from .config import Settings, log, settings
from .models.base import InferenceModel
from .models.cache import ModelCache
from .models.clip import MCLIPEncoder, OpenCLIPEncoder
@@ -509,20 +510,20 @@ class TestCache:
@mock.patch("app.models.cache.OptimisticLock", autospec=True)
async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
- model_cache = ModelCache(ttl=100)
- await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
+ model_cache = ModelCache()
+ await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
@mock.patch("app.models.cache.SimpleMemoryCache.expire")
async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
- model_cache = ModelCache(ttl=100, revalidate=True)
- await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
- await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
+ model_cache = ModelCache(revalidate=True)
+ await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
+ await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
mock_cache_expire.assert_called_once_with(mock.ANY, 100)
async def test_profiling(self, mock_get_model: mock.Mock) -> None:
- model_cache = ModelCache(ttl=100, profiling=True)
- await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
+ model_cache = ModelCache(profiling=True)
+ await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION, ttl=100)
profiling = await model_cache.get_profiling()
assert isinstance(profiling, dict)
assert profiling == model_cache.cache.profiling
@@ -548,6 +549,25 @@ class TestCache:
with pytest.raises(ValueError):
await model_cache.get("test_model_name", ModelType.CLIP, mode="text")
+ async def test_preloads_models(self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock) -> None:
+ os.environ["MACHINE_LEARNING_PRELOAD__CLIP"] = "ViT-B-32__openai"
+ os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"] = "buffalo_s"
+
+ settings = Settings()
+ assert settings.preload is not None
+ assert settings.preload.clip == "ViT-B-32__openai"
+ assert settings.preload.facial_recognition == "buffalo_s"
+
+ model_cache = ModelCache()
+ monkeypatch.setattr("app.main.model_cache", model_cache)
+
+ await preload_models(settings.preload)
+ assert len(model_cache.cache._cache) == 2
+ assert mock_get_model.call_count == 2
+ await model_cache.get("ViT-B-32__openai", ModelType.CLIP, ttl=100)
+ await model_cache.get("buffalo_s", ModelType.FACIAL_RECOGNITION, ttl=100)
+ assert mock_get_model.call_count == 2
+
@pytest.mark.asyncio
class TestLoad: