1
0
mirror of https://github.com/immich-app/immich.git synced 2025-01-25 17:15:28 +02:00
Mert 95cfe22866
feat(ml)!: cuda and openvino acceleration (#5619)
* cuda and openvino ep, refactor, update dockerfile

* updated workflow

* typing fixes

* added tests

* updated ml test gh action

* updated README

* updated docker-compose

* added compute to hwaccel.yml

* updated gh matrix

updated gh matrix

updated gh matrix

updated gh matrix

updated gh matrix

give up

* remove cuda/arm64 build

* add hwaccel image tags to docker-compose

* remove unnecessary quotes

* add suffix to git tag

* fixed kwargs in base model

* armnn ld_library_path

* update pyproject.toml

* add armnn workflow

* formatting

* consolidate hwaccel files, update docker compose

* update hw transcoding docs

* add ml hwaccel docs

* update dev and prod docker-compose

* added armnn prerequisite docs

* support 3.10

* updated docker-compose comments

* formatting

* test coverage

* don't set arena extend strategy for openvino

* working openvino

* formatting

* fix dockerfile

* added type annotation

* add wsl configuration for openvino

* updated lock file

* copy python3

* comment out extends section

* fix platforms

* simplify workflow suffix tagging

* simplify aio transcoding doc

* update docs and workflow for `hwaccel.yml` change

* revert docs
2024-01-21 18:22:39 -05:00

161 lines
4.7 KiB
Python

import asyncio
import gc
import os
import signal
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Iterator
from zipfile import BadZipFile
import orjson
from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel
from .config import log, settings
from .models.cache import ModelCache
from .schemas import (
MessageResponse,
ModelType,
TextResponse,
)
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
thread_pool: ThreadPoolExecutor | None = None
lock = threading.Lock()
active_requests = 0
last_called: float | None = None
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
global thread_pool
log.info(
(
"Created in-memory cache with unloading "
f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
)
)
try:
if settings.request_threads > 0:
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task())
yield
finally:
log.handlers.clear()
for model in model_cache.cache._cache.values():
del model
if thread_pool is not None:
thread_pool.shutdown()
gc.collect()
def update_state() -> Iterator[None]:
global active_requests, last_called
active_requests += 1
last_called = time.time()
try:
yield
finally:
active_requests -= 1
app = FastAPI(lifespan=lifespan)
@app.get("/", response_model=MessageResponse)
async def root() -> dict[str, str]:
return {"message": "Immich ML"}
@app.get("/ping", response_model=TextResponse)
def ping() -> str:
return "pong"
@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
model_name: str = Form(alias="modelName"),
model_type: ModelType = Form(alias="modelType"),
options: str = Form(default="{}"),
text: str | None = Form(default=None),
image: UploadFile | None = None,
) -> Any:
if image is not None:
inputs: str | bytes = await image.read()
elif text is not None:
inputs = text
else:
raise HTTPException(400, "Either image or text must be provided")
try:
kwargs = orjson.loads(options)
except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}")
model = await load(await model_cache.get(model_name, model_type, **kwargs))
model.configure(**kwargs)
outputs = await run(model.predict, inputs)
return ORJSONResponse(outputs)
async def run(func: Callable[..., Any], inputs: Any) -> Any:
if thread_pool is None:
return func(inputs)
return await asyncio.get_running_loop().run_in_executor(thread_pool, func, inputs)
async def load(model: InferenceModel) -> InferenceModel:
if model.loaded:
return model
def _load() -> None:
with lock:
model.load()
loop = asyncio.get_running_loop()
try:
if thread_pool is None:
model.load()
else:
await loop.run_in_executor(thread_pool, _load)
return model
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warning(
(
f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
"Clearing cache and retrying."
)
)
model.clear_cache()
if thread_pool is None:
model.load()
else:
await loop.run_in_executor(thread_pool, _load)
return model
async def idle_shutdown_task() -> None:
while True:
log.debug("Checking for inactivity...")
if (
last_called is not None
and not active_requests
and not lock.locked()
and time.time() - last_called > settings.model_ttl
):
log.info("Shutting down due to inactivity.")
os.kill(os.getpid(), signal.SIGINT)
break
await asyncio.sleep(settings.model_ttl_poll_s)