1
0
mirror of https://github.com/immich-app/immich.git synced 2024-11-24 08:52:28 +02:00

fix(ml): race condition when loading models (#3207)

* sync model loading, disabled model ttl by default

* disable revalidation if model unloading disabled

* moved lock
This commit is contained in:
Mert 2023-07-11 13:01:21 -04:00 committed by GitHub
parent 9ad024c189
commit 848ba685eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 10 deletions

View File

@ -13,7 +13,7 @@ class Settings(BaseSettings):
facial_recognition_model: str = "buffalo_l"
min_tag_score: float = 0.9
eager_startup: bool = True
model_ttl: int = 300
model_ttl: int = 0
host: str = "0.0.0.0"
port: int = 3003
workers: int = 1

View File

@ -25,7 +25,7 @@ app = FastAPI()
def init_state() -> None:
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
async def load_models() -> None:

View File

@ -1,4 +1,3 @@
import asyncio
from typing import Any
from aiocache.backends.memory import SimpleMemoryCache
@ -48,13 +47,10 @@ class ModelCache:
"""
key = self.cache.build_key(model_name, model_type.value)
async with OptimisticLock(self.cache, key) as lock:
model = await self.cache.get(key)
if model is None:
async with OptimisticLock(self.cache, key) as lock:
model = await asyncio.get_running_loop().run_in_executor(
None,
lambda: InferenceModel.from_model_type(model_type, model_name, **model_kwargs),
)
model = InferenceModel.from_model_type(model_type, model_name, **model_kwargs)
await lock.cas(model, ttl=self.ttl)
return model