From 47982641b222e2724de779222ae3d7ef040445de Mon Sep 17 00:00:00 2001 From: Mert <101130780+mertalev@users.noreply.github.com> Date: Tue, 27 Jun 2023 17:01:24 -0400 Subject: [PATCH] fix(ml): clear model cache on load error (#2951) * clear model cache on load error * updated caught exceptions --- machine-learning/app/models/base.py | 31 ++++++++++++++++--- machine-learning/app/models/clip.py | 9 ++---- .../app/models/facial_recognition.py | 11 ++++--- .../app/models/image_classification.py | 6 ++-- 4 files changed, 38 insertions(+), 19 deletions(-) diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index 0ef3173ce8..a62d7730c3 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -2,8 +2,11 @@ from __future__ import annotations from abc import ABC, abstractmethod from pathlib import Path +from shutil import rmtree from typing import Any +from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf + from ..config import get_cache_dir from ..schemas import ModelType @@ -12,10 +15,8 @@ class InferenceModel(ABC): _model_type: ModelType def __init__( - self, - model_name: str, - cache_dir: Path | None = None, - ): + self, model_name: str, cache_dir: Path | None = None, **model_kwargs + ) -> None: self.model_name = model_name self._cache_dir = ( cache_dir @@ -23,6 +24,16 @@ class InferenceModel(ABC): else get_cache_dir(model_name, self.model_type) ) + try: + self.load(**model_kwargs) + except (OSError, InvalidProtobuf): + self.clear_cache() + self.load(**model_kwargs) + + @abstractmethod + def load(self, **model_kwargs: Any) -> None: + ... + @abstractmethod def predict(self, inputs: Any) -> Any: ... @@ -36,7 +47,7 @@ class InferenceModel(ABC): return self._cache_dir @cache_dir.setter - def cache_dir(self, cache_dir: Path): + def cache_dir(self, cache_dir: Path) -> None: self._cache_dir = cache_dir @classmethod @@ -50,3 +61,13 @@ class InferenceModel(ABC): raise ValueError(f"Unsupported model type: {model_type}") return subclasses[model_type](model_name, **model_kwargs) + + def clear_cache(self) -> None: + if not self.cache_dir.exists(): + return + elif not rmtree.avoids_symlink_attacks: + raise RuntimeError( + "Attempted to clear cache, but rmtree is not safe on this platform." + ) + + rmtree(self.cache_dir) diff --git a/machine-learning/app/models/clip.py b/machine-learning/app/models/clip.py index 9e55b28d57..ac9d800cf4 100644 --- a/machine-learning/app/models/clip.py +++ b/machine-learning/app/models/clip.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Any from PIL.Image import Image from sentence_transformers import SentenceTransformer @@ -10,13 +11,7 @@ from .base import InferenceModel class CLIPSTEncoder(InferenceModel): _model_type = ModelType.CLIP - def __init__( - self, - model_name: str, - cache_dir: Path | None = None, - **model_kwargs, - ): - super().__init__(model_name, cache_dir) + def load(self, **model_kwargs: Any) -> None: self.model = SentenceTransformer( self.model_name, cache_folder=self.cache_dir.as_posix(), diff --git a/machine-learning/app/models/facial_recognition.py b/machine-learning/app/models/facial_recognition.py index ff993c172d..99349409f0 100644 --- a/machine-learning/app/models/facial_recognition.py +++ b/machine-learning/app/models/facial_recognition.py @@ -18,21 +18,22 @@ class FaceRecognizer(InferenceModel): min_score: float = settings.min_face_score, cache_dir: Path | None = None, **model_kwargs, - ): - super().__init__(model_name, cache_dir) + ) -> None: self.min_score = min_score - model = FaceAnalysis( + super().__init__(model_name, cache_dir, **model_kwargs) + + def load(self, **model_kwargs: Any) -> None: + self.model = FaceAnalysis( name=self.model_name, root=self.cache_dir.as_posix(), allowed_modules=["detection", "recognition"], **model_kwargs, ) - model.prepare( + self.model.prepare( ctx_id=0, det_thresh=self.min_score, det_size=(640, 640), ) - self.model = model def predict(self, image: cv2.Mat) -> list[dict[str, Any]]: height, width, _ = image.shape diff --git a/machine-learning/app/models/image_classification.py b/machine-learning/app/models/image_classification.py index adb55181d8..9f7e4cfb6c 100644 --- a/machine-learning/app/models/image_classification.py +++ b/machine-learning/app/models/image_classification.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Any from PIL.Image import Image from transformers.pipelines import pipeline @@ -17,10 +18,11 @@ class ImageClassifier(InferenceModel): min_score: float = settings.min_tag_score, cache_dir: Path | None = None, **model_kwargs, - ): - super().__init__(model_name, cache_dir) + ) -> None: self.min_score = min_score + super().__init__(model_name, cache_dir, **model_kwargs) + def load(self, **model_kwargs: Any) -> None: self.model = pipeline( self.model_type.value, self.model_name,