You've already forked immich
							
							
				mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-31 00:18:28 +02:00 
			
		
		
		
	chore(ml): load models on start up (#2487)
* chore(ml): load models on start up * Download correct model
This commit is contained in:
		| @@ -5,7 +5,7 @@ import uvicorn | ||||
|  | ||||
| from insightface.app import FaceAnalysis | ||||
| from transformers import pipeline | ||||
| from sentence_transformers import SentenceTransformer, util | ||||
| from sentence_transformers import SentenceTransformer | ||||
| from PIL import Image | ||||
| from fastapi import FastAPI | ||||
| from pydantic import BaseModel | ||||
| @@ -20,22 +20,32 @@ class ClipRequestBody(BaseModel): | ||||
|  | ||||
|  | ||||
| classification_model = os.getenv( | ||||
|     'MACHINE_LEARNING_CLASSIFICATION_MODEL', 'microsoft/resnet-50') | ||||
| object_model = os.getenv('MACHINE_LEARNING_OBJECT_MODEL', 'hustvl/yolos-tiny') | ||||
| clip_image_model = os.getenv( | ||||
|     'MACHINE_LEARNING_CLIP_IMAGE_MODEL', 'clip-ViT-B-32') | ||||
| clip_text_model = os.getenv( | ||||
|     'MACHINE_LEARNING_CLIP_TEXT_MODEL', 'clip-ViT-B-32') | ||||
|     "MACHINE_LEARNING_CLASSIFICATION_MODEL", "microsoft/resnet-50" | ||||
| ) | ||||
| object_model = os.getenv("MACHINE_LEARNING_OBJECT_MODEL", "hustvl/yolos-tiny") | ||||
| clip_image_model = os.getenv("MACHINE_LEARNING_CLIP_IMAGE_MODEL", "clip-ViT-B-32") | ||||
| clip_text_model = os.getenv("MACHINE_LEARNING_CLIP_TEXT_MODEL", "clip-ViT-B-32") | ||||
| facial_recognition_model = os.getenv( | ||||
|     'MACHINE_LEARNING_FACIAL_RECOGNITION_MODEL', 'buffalo_l') | ||||
|     "MACHINE_LEARNING_FACIAL_RECOGNITION_MODEL", "buffalo_l" | ||||
| ) | ||||
|  | ||||
| cache_folder = os.getenv('MACHINE_LEARNING_CACHE_FOLDER', '/cache') | ||||
| cache_folder = os.getenv("MACHINE_LEARNING_CACHE_FOLDER", "/cache") | ||||
|  | ||||
| _model_cache = {} | ||||
|  | ||||
| app = FastAPI() | ||||
|  | ||||
|  | ||||
| @app.on_event("startup") | ||||
| async def startup_event(): | ||||
|     # Get all models | ||||
|     _get_model(object_model, "object-detection") | ||||
|     _get_model(classification_model, "image-classification") | ||||
|     _get_model(clip_image_model) | ||||
|     _get_model(clip_text_model) | ||||
|     _get_model(facial_recognition_model, "facial-recognition") | ||||
|  | ||||
|  | ||||
| @app.get("/") | ||||
| async def root(): | ||||
|     return {"message": "Immich ML"} | ||||
| @@ -48,14 +58,14 @@ def ping(): | ||||
|  | ||||
| @app.post("/object-detection/detect-object", status_code=200) | ||||
| def object_detection(payload: MlRequestBody): | ||||
|     model = _get_model(object_model, 'object-detection') | ||||
|     model = _get_model(object_model, "object-detection") | ||||
|     assetPath = payload.thumbnailPath | ||||
|     return run_engine(model, assetPath) | ||||
|  | ||||
|  | ||||
| @app.post("/image-classifier/tag-image", status_code=200) | ||||
| def image_classification(payload: MlRequestBody): | ||||
|     model = _get_model(classification_model, 'image-classification') | ||||
|     model = _get_model(classification_model, "image-classification") | ||||
|     assetPath = payload.thumbnailPath | ||||
|     return run_engine(model, assetPath) | ||||
|  | ||||
| @@ -76,31 +86,32 @@ def clip_encode_text(payload: ClipRequestBody): | ||||
|  | ||||
| @app.post("/facial-recognition/detect-faces", status_code=200) | ||||
| def facial_recognition(payload: MlRequestBody): | ||||
|     model = _get_model(facial_recognition_model, 'facial-recognition') | ||||
|     model = _get_model(facial_recognition_model, "facial-recognition") | ||||
|     assetPath = payload.thumbnailPath | ||||
|     img = cv.imread(assetPath) | ||||
|     height, width, _ = img.shape | ||||
|     results = [] | ||||
|     faces = model.get(img) | ||||
|  | ||||
|     for face in faces: | ||||
|         if face.det_score < 0.7: | ||||
|             continue | ||||
|         x1, y1, x2, y2 = face.bbox | ||||
|         # min face size as percent of original image | ||||
|         # if (x2 - x1) / width < 0.03 or (y2 - y1) / height < 0.05: | ||||
|         #     continue | ||||
|         results.append({ | ||||
|             "imageWidth": width, | ||||
|             "imageHeight": height, | ||||
|             "boundingBox": { | ||||
|                 "x1": round(x1), | ||||
|                 "y1": round(y1), | ||||
|                 "x2": round(x2), | ||||
|                 "y2": round(y2), | ||||
|             }, | ||||
|             "score": face.det_score.item(), | ||||
|             "embedding": face.normed_embedding.tolist() | ||||
|         }) | ||||
|  | ||||
|         results.append( | ||||
|             { | ||||
|                 "imageWidth": width, | ||||
|                 "imageHeight": height, | ||||
|                 "boundingBox": { | ||||
|                     "x1": round(x1), | ||||
|                     "y1": round(y1), | ||||
|                     "x2": round(x2), | ||||
|                     "y2": round(y2), | ||||
|                 }, | ||||
|                 "score": face.det_score.item(), | ||||
|                 "embedding": face.normed_embedding.tolist(), | ||||
|             } | ||||
|         ) | ||||
|     return results | ||||
|  | ||||
|  | ||||
| @@ -109,11 +120,11 @@ def run_engine(engine, path): | ||||
|     predictions = engine(path) | ||||
|  | ||||
|     for index, pred in enumerate(predictions): | ||||
|         tags = pred['label'].split(', ') | ||||
|         if (pred['score'] > 0.9): | ||||
|         tags = pred["label"].split(", ") | ||||
|         if pred["score"] > 0.9: | ||||
|             result = [*result, *tags] | ||||
|  | ||||
|     if (len(result) > 1): | ||||
|     if len(result) > 1: | ||||
|         result = list(set(result)) | ||||
|  | ||||
|     return result | ||||
| @@ -121,25 +132,27 @@ def run_engine(engine, path): | ||||
|  | ||||
| def _get_model(model, task=None): | ||||
|     global _model_cache | ||||
|     key = '|'.join([model, str(task)]) | ||||
|     key = "|".join([model, str(task)]) | ||||
|     if key not in _model_cache: | ||||
|         if task: | ||||
|             if task == 'facial-recognition': | ||||
|             if task == "facial-recognition": | ||||
|                 face_model = FaceAnalysis( | ||||
|                     name=model, root=cache_folder, allowed_modules=["detection", "recognition"]) | ||||
|                     name=model, | ||||
|                     root=cache_folder, | ||||
|                     allowed_modules=["detection", "recognition"], | ||||
|                 ) | ||||
|                 face_model.prepare(ctx_id=0, det_size=(640, 640)) | ||||
|                 _model_cache[key] = face_model | ||||
|             else: | ||||
|                 _model_cache[key] = pipeline(model=model, task=task) | ||||
|         else: | ||||
|             _model_cache[key] = SentenceTransformer( | ||||
|                 model, cache_folder=cache_folder) | ||||
|             _model_cache[key] = SentenceTransformer(model, cache_folder=cache_folder) | ||||
|     return _model_cache[key] | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     host = os.getenv('MACHINE_LEARNING_HOST', '0.0.0.0') | ||||
|     port = int(os.getenv('MACHINE_LEARNING_PORT', 3003)) | ||||
|     is_dev = os.getenv('NODE_ENV') == 'development' | ||||
|     host = os.getenv("MACHINE_LEARNING_HOST", "0.0.0.0") | ||||
|     port = int(os.getenv("MACHINE_LEARNING_PORT", 3003)) | ||||
|     is_dev = os.getenv("NODE_ENV") == "development" | ||||
|  | ||||
|     uvicorn.run("main:app", host=host, port=port, reload=is_dev, workers=1) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user