mirror of
https://github.com/immich-app/immich.git
synced 2024-12-25 10:43:13 +02:00
fix(ml): race condition when loading models (#3207)
* sync model loading, disabled model ttl by default * disable revalidation if model unloading disabled * moved lock
This commit is contained in:
parent
9ad024c189
commit
848ba685eb
@ -13,7 +13,7 @@ class Settings(BaseSettings):
|
||||
facial_recognition_model: str = "buffalo_l"
|
||||
min_tag_score: float = 0.9
|
||||
eager_startup: bool = True
|
||||
model_ttl: int = 300
|
||||
model_ttl: int = 0
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 3003
|
||||
workers: int = 1
|
||||
|
@ -25,7 +25,7 @@ app = FastAPI()
|
||||
|
||||
|
||||
def init_state() -> None:
|
||||
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
|
||||
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
|
||||
|
||||
|
||||
async def load_models() -> None:
|
||||
|
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from aiocache.backends.memory import SimpleMemoryCache
|
||||
@ -48,13 +47,10 @@ class ModelCache:
|
||||
"""
|
||||
|
||||
key = self.cache.build_key(model_name, model_type.value)
|
||||
model = await self.cache.get(key)
|
||||
if model is None:
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model = await asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
lambda: InferenceModel.from_model_type(model_type, model_name, **model_kwargs),
|
||||
)
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model = await self.cache.get(key)
|
||||
if model is None:
|
||||
model = InferenceModel.from_model_type(model_type, model_name, **model_kwargs)
|
||||
await lock.cas(model, ttl=self.ttl)
|
||||
return model
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user