diff --git a/machine-learning/Dockerfile b/machine-learning/Dockerfile index 3f35c95df8..d6b45d21dd 100644 --- a/machine-learning/Dockerfile +++ b/machine-learning/Dockerfile @@ -1,14 +1,15 @@ FROM python:3.10 as builder ENV PYTHONDONTWRITEBYTECODE=1 \ - PYTHONUNBUFFERED=1 \ - PIP_NO_CACHE_DIR=true + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=true RUN python -m venv /opt/venv RUN /opt/venv/bin/pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html -RUN /opt/venv/bin/pip install transformers tqdm numpy scikit-learn scipy nltk sentencepiece flask Pillow gunicorn +RUN /opt/venv/bin/pip install transformers tqdm numpy scikit-learn scipy nltk sentencepiece fastapi Pillow uvicorn[standard] RUN /opt/venv/bin/pip install --no-deps sentence-transformers + FROM python:3.10-slim ENV NODE_ENV=production @@ -16,12 +17,12 @@ ENV NODE_ENV=production COPY --from=builder /opt/venv /opt/venv ENV TRANSFORMERS_CACHE=/cache \ - PYTHONDONTWRITEBYTECODE=1 \ - PYTHONUNBUFFERED=1 \ - PATH="/opt/venv/bin:$PATH" + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PATH="/opt/venv/bin:$PATH" WORKDIR /usr/src/app COPY . . - -CMD ["gunicorn", "src.main:server"] +ENV PYTHONPATH=`pwd` +CMD ["python", "main.py"] \ No newline at end of file diff --git a/machine-learning/gunicorn.conf.py b/machine-learning/gunicorn.conf.py deleted file mode 100644 index 0db0e8ee7d..0000000000 --- a/machine-learning/gunicorn.conf.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -Gunicorn configuration options. -https://docs.gunicorn.org/en/stable/settings.html -""" -import os - - -# Set the bind address based on the env -port = os.getenv("MACHINE_LEARNING_PORT") or "3003" -listen_ip = os.getenv("MACHINE_LEARNING_IP") or "0.0.0.0" -bind = [f"{listen_ip}:{port}"] - -# Preload the Flask app / models etc. before starting the server -preload_app = True - -# Logging settings - log to stdout and set log level -accesslog = "-" -loglevel = os.getenv("MACHINE_LEARNING_LOG_LEVEL") or "info" - -# Worker settings -# ---------------------- -# It is important these are chosen carefully as per -# https://pythonspeed.com/articles/gunicorn-in-docker/ -# Otherwise we get workers failing to respond to heartbeat checks, -# especially as requests take a long time to complete. -workers = 2 -threads = 4 -worker_tmp_dir = "/dev/shm" -timeout = 60 diff --git a/machine-learning/src/main.py b/machine-learning/src/main.py index cd6726f4d0..cd6debbfbf 100644 --- a/machine-learning/src/main.py +++ b/machine-learning/src/main.py @@ -1,58 +1,77 @@ -import os -from flask import Flask, request from transformers import pipeline from sentence_transformers import SentenceTransformer, util from PIL import Image +from fastapi import FastAPI +import uvicorn +import os +from pydantic import BaseModel + + +class MlRequestBody(BaseModel): + thumbnailPath: str + + +class ClipRequestBody(BaseModel): + text: str + is_dev = os.getenv('NODE_ENV') == 'development' server_port = os.getenv('MACHINE_LEARNING_PORT', 3003) server_host = os.getenv('MACHINE_LEARNING_HOST', '0.0.0.0') -classification_model = os.getenv('MACHINE_LEARNING_CLASSIFICATION_MODEL', 'microsoft/resnet-50') +app = FastAPI() + +""" +Model Initialization +""" +classification_model = os.getenv( + 'MACHINE_LEARNING_CLASSIFICATION_MODEL', 'microsoft/resnet-50') object_model = os.getenv('MACHINE_LEARNING_OBJECT_MODEL', 'hustvl/yolos-tiny') -clip_image_model = os.getenv('MACHINE_LEARNING_CLIP_IMAGE_MODEL', 'clip-ViT-B-32') -clip_text_model = os.getenv('MACHINE_LEARNING_CLIP_TEXT_MODEL', 'clip-ViT-B-32') +clip_image_model = os.getenv( + 'MACHINE_LEARNING_CLIP_IMAGE_MODEL', 'clip-ViT-B-32') +clip_text_model = os.getenv( + 'MACHINE_LEARNING_CLIP_TEXT_MODEL', 'clip-ViT-B-32') _model_cache = {} -def _get_model(model, task=None): - global _model_cache - key = '|'.join([model, str(task)]) - if key not in _model_cache: - if task: - _model_cache[key] = pipeline(model=model, task=task) - else: - _model_cache[key] = SentenceTransformer(model) - return _model_cache[key] -server = Flask(__name__) -@server.route("/ping") +@app.get("/") +async def root(): + return {"message": "Immich ML"} + + +@app.get("/ping") def ping(): return "pong" -@server.route("/object-detection/detect-object", methods=['POST']) -def object_detection(): + +@app.post("/object-detection/detect-object", status_code=200) +def object_detection(payload: MlRequestBody): model = _get_model(object_model, 'object-detection') - assetPath = request.json['thumbnailPath'] - return run_engine(model, assetPath), 200 + assetPath = payload.thumbnailPath + return run_engine(model, assetPath) -@server.route("/image-classifier/tag-image", methods=['POST']) -def image_classification(): + +@app.post("/image-classifier/tag-image", status_code=200) +def image_classification(payload: MlRequestBody): model = _get_model(classification_model, 'image-classification') - assetPath = request.json['thumbnailPath'] - return run_engine(model, assetPath), 200 + assetPath = payload.thumbnailPath + return run_engine(model, assetPath) -@server.route("/sentence-transformer/encode-image", methods=['POST']) -def clip_encode_image(): + +@app.post("/sentence-transformer/encode-image", status_code=200) +def clip_encode_image(payload: MlRequestBody): model = _get_model(clip_image_model) - assetPath = request.json['thumbnailPath'] - return model.encode(Image.open(assetPath)).tolist(), 200 + assetPath = payload.thumbnailPath + return model.encode(Image.open(assetPath)).tolist() -@server.route("/sentence-transformer/encode-text", methods=['POST']) -def clip_encode_text(): + +@app.post("/sentence-transformer/encode-text", status_code=200) +def clip_encode_text(payload: ClipRequestBody): model = _get_model(clip_text_model) - text = request.json['text'] - return model.encode(text).tolist(), 200 + text = payload.text + return model.encode(text).tolist() + def run_engine(engine, path): result = [] @@ -69,5 +88,17 @@ def run_engine(engine, path): return result +def _get_model(model, task=None): + global _model_cache + key = '|'.join([model, str(task)]) + if key not in _model_cache: + if task: + _model_cache[key] = pipeline(model=model, task=task) + else: + _model_cache[key] = SentenceTransformer(model) + return _model_cache[key] + + if __name__ == "__main__": - server.run(debug=is_dev, host=server_host, port=server_port) + uvicorn.run("main:app", host=server_host, + port=int(server_port), reload=is_dev, workers=1)