1
0
mirror of https://github.com/immich-app/immich.git synced 2024-11-28 09:33:27 +02:00

chore(ml): improve shutdown (#5689)

This commit is contained in:
Mert 2023-12-14 14:51:24 -05:00 committed by GitHub
parent 9768931275
commit d729c863c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 48 deletions

View File

@ -1,12 +1,16 @@
import logging
import os
import sys
from pathlib import Path
from socket import socket
import gunicorn
import starlette
from gunicorn.arbiter import Arbiter
from pydantic import BaseSettings
from rich.console import Console
from rich.logging import RichHandler
from uvicorn import Server
from uvicorn.workers import UvicornWorker
from .schemas import ModelType
@ -69,10 +73,26 @@ log_settings = LogSettings()
class CustomRichHandler(RichHandler):
def __init__(self) -> None:
console = Console(color_system="standard", no_color=log_settings.no_color)
super().__init__(
show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[gunicorn, starlette]
)
super().__init__(show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[starlette])
log = logging.getLogger("gunicorn.access")
log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))
# patches this issue https://github.com/encode/uvicorn/discussions/1803
class CustomUvicornServer(Server):
async def shutdown(self, sockets: list[socket] | None = None) -> None:
for sock in sockets or []:
sock.close()
await super().shutdown()
class CustomUvicornWorker(UvicornWorker):
async def _serve(self) -> None:
self.config.app = self.wsgi
server = CustomUvicornServer(config=self.config)
self._install_sigquit_handler()
await server.serve(sockets=self.sockets)
if not server.started:
sys.exit(Arbiter.WORKER_BOOT_ERROR)

View File

@ -1,5 +1,4 @@
import json
from pathlib import Path
from typing import Any, Iterator
from unittest import mock
@ -8,7 +7,7 @@ import pytest
from fastapi.testclient import TestClient
from PIL import Image
from .main import app, init_state
from .main import app
from .schemas import ndarray_f32
@ -29,9 +28,9 @@ def mock_get_model() -> Iterator[mock.Mock]:
@pytest.fixture(scope="session")
def deployed_app() -> TestClient:
init_state()
return TestClient(app)
def deployed_app() -> Iterator[TestClient]:
with TestClient(app) as client:
yield client
@pytest.fixture(scope="session")

View File

@ -1,15 +1,16 @@
import asyncio
import gc
import os
import signal
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Any, Iterator
from zipfile import BadZipFile
import orjson
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from starlette.formparsers import MultiPartParser
@ -27,9 +28,16 @@ from .schemas import (
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
app = FastAPI()
model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
thread_pool: ThreadPoolExecutor | None = None
lock = threading.Lock()
active_requests = 0
last_called: float | None = None
def init_state() -> None:
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
@app.on_event("startup")
def startup() -> None:
global thread_pool
log.info(
(
"Created in-memory cache with unloading "
@ -37,17 +45,30 @@ def init_state() -> None:
)
)
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
app.state.lock = threading.Lock()
app.state.last_called = None
thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task())
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
@app.on_event("startup")
async def startup_event() -> None:
init_state()
@app.on_event("shutdown")
def shutdown() -> None:
log.handlers.clear()
for model in model_cache.cache._cache.values():
del model
if thread_pool is not None:
thread_pool.shutdown()
gc.collect()
def update_state() -> Iterator[None]:
global active_requests, last_called
active_requests += 1
last_called = time.time()
try:
yield
finally:
active_requests -= 1
@app.get("/", response_model=MessageResponse)
@ -60,7 +81,7 @@ def ping() -> str:
return "pong"
@app.post("/predict")
@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
model_name: str = Form(alias="modelName"),
model_type: ModelType = Form(alias="modelType"),
@ -79,17 +100,16 @@ async def predict(
except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}")
model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
model = await load(await model_cache.get(model_name, model_type, **kwargs))
model.configure(**kwargs)
outputs = await run(model, inputs)
return ORJSONResponse(outputs)
async def run(model: InferenceModel, inputs: Any) -> Any:
app.state.last_called = time.time()
if app.state.thread_pool is None:
if thread_pool is None:
return model.predict(inputs)
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
return await asyncio.get_running_loop().run_in_executor(thread_pool, model.predict, inputs)
async def load(model: InferenceModel) -> InferenceModel:
@ -97,15 +117,15 @@ async def load(model: InferenceModel) -> InferenceModel:
return model
def _load() -> None:
with app.state.lock:
with lock:
model.load()
loop = asyncio.get_running_loop()
try:
if app.state.thread_pool is None:
if thread_pool is None:
model.load()
else:
await loop.run_in_executor(app.state.thread_pool, _load)
await loop.run_in_executor(thread_pool, _load)
return model
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warn(
@ -115,32 +135,23 @@ async def load(model: InferenceModel) -> InferenceModel:
)
)
model.clear_cache()
if app.state.thread_pool is None:
if thread_pool is None:
model.load()
else:
await loop.run_in_executor(app.state.thread_pool, _load)
await loop.run_in_executor(thread_pool, _load)
return model
async def idle_shutdown_task() -> None:
while True:
log.debug("Checking for inactivity...")
if app.state.last_called is not None and time.time() - app.state.last_called > settings.model_ttl:
if (
last_called is not None
and not active_requests
and not lock.locked()
and time.time() - last_called > settings.model_ttl
):
log.info("Shutting down due to inactivity.")
loop = asyncio.get_running_loop()
for task in asyncio.all_tasks(loop):
if task is not asyncio.current_task():
try:
task.cancel()
except asyncio.CancelledError:
pass
sys.stderr.close()
sys.stdout.close()
sys.stdout = sys.stderr = open(os.devnull, "w")
try:
await app.state.model_cache.cache.clear()
gc.collect()
loop.stop()
except asyncio.CancelledError:
pass
os.kill(os.getpid(), signal.SIGINT)
break
await asyncio.sleep(settings.model_ttl_poll_s)

View File

@ -1,6 +1,7 @@
#!/usr/bin/env sh
export LD_PRELOAD="/usr/lib/$(arch)-linux-gnu/libmimalloc.so.2"
export LD_BIND_NOW=1
: "${MACHINE_LEARNING_HOST:=0.0.0.0}"
: "${MACHINE_LEARNING_PORT:=3003}"
@ -8,8 +9,9 @@ export LD_PRELOAD="/usr/lib/$(arch)-linux-gnu/libmimalloc.so.2"
: "${MACHINE_LEARNING_WORKER_TIMEOUT:=120}"
gunicorn app.main:app \
-k uvicorn.workers.UvicornWorker \
-k app.config.CustomUvicornWorker \
-w $MACHINE_LEARNING_WORKERS \
-b $MACHINE_LEARNING_HOST:$MACHINE_LEARNING_PORT \
-t $MACHINE_LEARNING_WORKER_TIMEOUT \
--log-config-json log_conf.json
--log-config-json log_conf.json \
--graceful-timeout 0