From a42af06889f73a2b291aebab5918915ce706f959 Mon Sep 17 00:00:00 2001 From: Mert <101130780+mertalev@users.noreply.github.com> Date: Thu, 20 Jun 2024 14:13:18 -0400 Subject: [PATCH] fix(ml): limit load retries (#10494) --- machine-learning/app/main.py | 15 +++++---------- machine-learning/app/models/base.py | 5 ++++- machine-learning/app/test_main.py | 17 +++++++++++++++++ 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index 3c607015d9..ac493a059a 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -192,23 +192,18 @@ async def load(model: InferenceModel) -> InferenceModel: return model def _load(model: InferenceModel) -> InferenceModel: + if model.load_attempts > 1: + raise HTTPException(500, f"Failed to load model '{model.model_name}'") with lock: model.load() return model try: - await run(_load, model) - return model + return await run(_load, model) except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile): - log.warning( - ( - f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'." - "Clearing cache and retrying." - ) - ) + log.warning(f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'. Clearing cache.") model.clear_cache() - await run(_load, model) - return model + return await run(_load, model) async def idle_shutdown_task() -> None: diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index f64a873010..4ad6fd6eb7 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -31,6 +31,7 @@ class InferenceModel(ABC): **model_kwargs: Any, ) -> None: self.loaded = False + self.load_attempts = 0 self.model_name = clean_name(model_name) self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default self.providers = providers if providers is not None else self.providers_default @@ -48,9 +49,11 @@ class InferenceModel(ABC): def load(self) -> None: if self.loaded: return + self.load_attempts += 1 self.download() - log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory") + attempt = f"Attempt #{self.load_attempts + 1} to load" if self.load_attempts else "Loading" + log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory") self.session = self._load() self.loaded = True diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index d9d1455bd1..2068c7a4c6 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -11,6 +11,7 @@ import cv2 import numpy as np import onnxruntime as ort import pytest +from fastapi import HTTPException from fastapi.testclient import TestClient from PIL import Image from pytest import MonkeyPatch @@ -627,6 +628,7 @@ class TestLoad: async def test_load(self) -> None: mock_model = mock.Mock(spec=InferenceModel) mock_model.loaded = False + mock_model.load_attempts = 0 res = await load(mock_model) @@ -650,6 +652,7 @@ class TestLoad: mock_model.model_task = ModelTask.SEARCH mock_model.load.side_effect = [OSError, None] mock_model.loaded = False + mock_model.load_attempts = 0 res = await load(mock_model) @@ -657,6 +660,20 @@ class TestLoad: mock_model.clear_cache.assert_called_once() assert mock_model.load.call_count == 2 + async def test_load_clears_cache_and_raises_if_os_error_and_already_retried(self) -> None: + mock_model = mock.Mock(spec=InferenceModel) + mock_model.model_name = "test_model_name" + mock_model.model_type = ModelType.VISUAL + mock_model.model_task = ModelTask.SEARCH + mock_model.loaded = False + mock_model.load_attempts = 2 + + with pytest.raises(HTTPException): + await load(mock_model) + + mock_model.clear_cache.assert_not_called() + mock_model.load.assert_not_called() + @pytest.mark.skipif( not settings.test_full,