1
0
mirror of https://github.com/immich-app/immich.git synced 2025-04-23 13:09:00 +02:00

fix(ml): better model unloading (#3340)

* restart process on inactivity

* formatting

* always update `last_called`

* load models sequentially

* renamed variable, updated docs

* formatting

* made poll env name consistent with model ttl env
This commit is contained in:
Mert 2023-11-16 21:42:44 -05:00 committed by GitHub
parent 98f87c6548
commit a6af4892e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 11 deletions

View File

@ -188,19 +188,18 @@ Typesense URL example JSON before encoding:
| Variable | Description | Default | Services | | Variable | Description | Default | Services |
| :----------------------------------------------- | :---------------------------------------------------------------- | :-----------------: | :--------------- | | :----------------------------------------------- | :---------------------------------------------------------------- | :-----------------: | :--------------- |
| `MACHINE_LEARNING_MODEL_TTL`<sup>\*1</sup> | Inactivity time (s) before a model is unloaded (disabled if <= 0) | `0` | machine learning | | `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if <= 0) | `300` | machine learning |
| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if <= 0) | `10` | machine learning |
| `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning | | `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning |
| `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*2</sup> | Thread count of the request thread pool (disabled if <= 0) | number of CPU cores | machine learning | | `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*1</sup> | Thread count of the request thread pool (disabled if <= 0) | number of CPU cores | machine learning |
| `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning | | `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning |
| `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning | | `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning |
| `MACHINE_LEARNING_WORKERS`<sup>\*3</sup> | Number of worker processes to spawn | `1` | machine learning | | `MACHINE_LEARNING_WORKERS`<sup>\*2</sup> | Number of worker processes to spawn | `1` | machine learning |
| `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning | | `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning |
\*1: This is an experimental feature. It may result in increased memory use over time when loading models repeatedly. \*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones.
\*2: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones. \*2: Since each process duplicates models in memory, changing this is not recommended unless you have abundant memory to go around.
\*3: Since each process duplicates models in memory, changing this is not recommended unless you have abundant memory to go around.
:::info :::info

View File

@ -13,7 +13,8 @@ from .schemas import ModelType
class Settings(BaseSettings): class Settings(BaseSettings):
cache_folder: str = "/cache" cache_folder: str = "/cache"
model_ttl: int = 0 model_ttl: int = 300
model_ttl_poll_s: int = 10
host: str = "0.0.0.0" host: str = "0.0.0.0"
port: int = 3003 port: int = 3003
workers: int = 1 workers: int = 1

View File

@ -1,5 +1,9 @@
import asyncio import asyncio
import gc
import os
import sys
import threading import threading
import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any
from zipfile import BadZipFile from zipfile import BadZipFile
@ -34,7 +38,10 @@ def init_state() -> None:
) )
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
app.state.locks = {model_type: threading.Lock() for model_type in ModelType} app.state.lock = threading.Lock()
app.state.last_called = None
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task())
log.info(f"Initialized request thread pool with {settings.request_threads} threads.") log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
@ -79,9 +86,9 @@ async def predict(
async def run(model: InferenceModel, inputs: Any) -> Any: async def run(model: InferenceModel, inputs: Any) -> Any:
app.state.last_called = time.time()
if app.state.thread_pool is None: if app.state.thread_pool is None:
return model.predict(inputs) return model.predict(inputs)
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs) return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
@ -90,7 +97,7 @@ async def load(model: InferenceModel) -> InferenceModel:
return model return model
def _load() -> None: def _load() -> None:
with app.state.locks[model.model_type]: with app.state.lock:
model.load() model.load()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -113,3 +120,27 @@ async def load(model: InferenceModel) -> InferenceModel:
else: else:
await loop.run_in_executor(app.state.thread_pool, _load) await loop.run_in_executor(app.state.thread_pool, _load)
return model return model
async def idle_shutdown_task() -> None:
while True:
log.debug("Checking for inactivity...")
if app.state.last_called is not None and time.time() - app.state.last_called > settings.model_ttl:
log.info("Shutting down due to inactivity.")
loop = asyncio.get_running_loop()
for task in asyncio.all_tasks(loop):
if task is not asyncio.current_task():
try:
task.cancel()
except asyncio.CancelledError:
pass
sys.stderr.close()
sys.stdout.close()
sys.stdout = sys.stderr = open(os.devnull, "w")
try:
await app.state.model_cache.cache.clear()
gc.collect()
loop.stop()
except asyncio.CancelledError:
pass
await asyncio.sleep(settings.model_ttl_poll_s)