From a6af4892e31ffd24d4589e1d1d5711b8ff815e28 Mon Sep 17 00:00:00 2001 From: Mert <101130780+mertalev@users.noreply.github.com> Date: Thu, 16 Nov 2023 21:42:44 -0500 Subject: [PATCH] fix(ml): better model unloading (#3340) * restart process on inactivity * formatting * always update `last_called` * load models sequentially * renamed variable, updated docs * formatting * made poll env name consistent with model ttl env --- docs/docs/install/environment-variables.md | 13 ++++---- machine-learning/app/config.py | 3 +- machine-learning/app/main.py | 37 ++++++++++++++++++++-- 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/docs/docs/install/environment-variables.md b/docs/docs/install/environment-variables.md index 403ff2c363..d41fef0a75 100644 --- a/docs/docs/install/environment-variables.md +++ b/docs/docs/install/environment-variables.md @@ -188,19 +188,18 @@ Typesense URL example JSON before encoding: | Variable | Description | Default | Services | | :----------------------------------------------- | :---------------------------------------------------------------- | :-----------------: | :--------------- | -| `MACHINE_LEARNING_MODEL_TTL`\*1 | Inactivity time (s) before a model is unloaded (disabled if <= 0) | `0` | machine learning | +| `MACHINE_LEARNING_MODEL_TTL` | Inactivity time (s) before a model is unloaded (disabled if <= 0) | `300` | machine learning | +| `MACHINE_LEARNING_MODEL_TTL_POLL_S` | Interval (s) between checks for the model TTL (disabled if <= 0) | `10` | machine learning | | `MACHINE_LEARNING_CACHE_FOLDER` | Directory where models are downloaded | `/cache` | machine learning | -| `MACHINE_LEARNING_REQUEST_THREADS`\*2 | Thread count of the request thread pool (disabled if <= 0) | number of CPU cores | machine learning | +| `MACHINE_LEARNING_REQUEST_THREADS`\*1 | Thread count of the request thread pool (disabled if <= 0) | number of CPU cores | machine learning | | `MACHINE_LEARNING_MODEL_INTER_OP_THREADS` | Number of parallel model operations | `1` | machine learning | | `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS` | Number of threads for each model operation | `2` | machine learning | -| `MACHINE_LEARNING_WORKERS`\*3 | Number of worker processes to spawn | `1` | machine learning | +| `MACHINE_LEARNING_WORKERS`\*2 | Number of worker processes to spawn | `1` | machine learning | | `MACHINE_LEARNING_WORKER_TIMEOUT` | Maximum time (s) of unresponsiveness before a worker is killed | `120` | machine learning | -\*1: This is an experimental feature. It may result in increased memory use over time when loading models repeatedly. +\*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones. -\*2: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones. - -\*3: Since each process duplicates models in memory, changing this is not recommended unless you have abundant memory to go around. +\*2: Since each process duplicates models in memory, changing this is not recommended unless you have abundant memory to go around. :::info diff --git a/machine-learning/app/config.py b/machine-learning/app/config.py index 8870b8c0e8..fa4fefeb37 100644 --- a/machine-learning/app/config.py +++ b/machine-learning/app/config.py @@ -13,7 +13,8 @@ from .schemas import ModelType class Settings(BaseSettings): cache_folder: str = "/cache" - model_ttl: int = 0 + model_ttl: int = 300 + model_ttl_poll_s: int = 10 host: str = "0.0.0.0" port: int = 3003 workers: int = 1 diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index e1d71e9fa2..2f6902760d 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -1,5 +1,9 @@ import asyncio +import gc +import os +import sys import threading +import time from concurrent.futures import ThreadPoolExecutor from typing import Any from zipfile import BadZipFile @@ -34,7 +38,10 @@ def init_state() -> None: ) # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None - app.state.locks = {model_type: threading.Lock() for model_type in ModelType} + app.state.lock = threading.Lock() + app.state.last_called = None + if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0: + asyncio.ensure_future(idle_shutdown_task()) log.info(f"Initialized request thread pool with {settings.request_threads} threads.") @@ -79,9 +86,9 @@ async def predict( async def run(model: InferenceModel, inputs: Any) -> Any: + app.state.last_called = time.time() if app.state.thread_pool is None: return model.predict(inputs) - return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs) @@ -90,7 +97,7 @@ async def load(model: InferenceModel) -> InferenceModel: return model def _load() -> None: - with app.state.locks[model.model_type]: + with app.state.lock: model.load() loop = asyncio.get_running_loop() @@ -113,3 +120,27 @@ async def load(model: InferenceModel) -> InferenceModel: else: await loop.run_in_executor(app.state.thread_pool, _load) return model + + +async def idle_shutdown_task() -> None: + while True: + log.debug("Checking for inactivity...") + if app.state.last_called is not None and time.time() - app.state.last_called > settings.model_ttl: + log.info("Shutting down due to inactivity.") + loop = asyncio.get_running_loop() + for task in asyncio.all_tasks(loop): + if task is not asyncio.current_task(): + try: + task.cancel() + except asyncio.CancelledError: + pass + sys.stderr.close() + sys.stdout.close() + sys.stdout = sys.stderr = open(os.devnull, "w") + try: + await app.state.model_cache.cache.clear() + gc.collect() + loop.stop() + except asyncio.CancelledError: + pass + await asyncio.sleep(settings.model_ttl_poll_s)