You've already forked immich
							
							
				mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-31 00:18:28 +02:00 
			
		
		
		
	fix(ml): limit load retries (#10494)
This commit is contained in:
		| @@ -192,23 +192,18 @@ async def load(model: InferenceModel) -> InferenceModel: | ||||
|         return model | ||||
|  | ||||
|     def _load(model: InferenceModel) -> InferenceModel: | ||||
|         if model.load_attempts > 1: | ||||
|             raise HTTPException(500, f"Failed to load model '{model.model_name}'") | ||||
|         with lock: | ||||
|             model.load() | ||||
|         return model | ||||
|  | ||||
|     try: | ||||
|         await run(_load, model) | ||||
|         return model | ||||
|         return await run(_load, model) | ||||
|     except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile): | ||||
|         log.warning( | ||||
|             ( | ||||
|                 f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'." | ||||
|                 "Clearing cache and retrying." | ||||
|             ) | ||||
|         ) | ||||
|         log.warning(f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'. Clearing cache.") | ||||
|         model.clear_cache() | ||||
|         await run(_load, model) | ||||
|         return model | ||||
|         return await run(_load, model) | ||||
|  | ||||
|  | ||||
| async def idle_shutdown_task() -> None: | ||||
|   | ||||
| @@ -31,6 +31,7 @@ class InferenceModel(ABC): | ||||
|         **model_kwargs: Any, | ||||
|     ) -> None: | ||||
|         self.loaded = False | ||||
|         self.load_attempts = 0 | ||||
|         self.model_name = clean_name(model_name) | ||||
|         self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default | ||||
|         self.providers = providers if providers is not None else self.providers_default | ||||
| @@ -48,9 +49,11 @@ class InferenceModel(ABC): | ||||
|     def load(self) -> None: | ||||
|         if self.loaded: | ||||
|             return | ||||
|         self.load_attempts += 1 | ||||
|  | ||||
|         self.download() | ||||
|         log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory") | ||||
|         attempt = f"Attempt #{self.load_attempts + 1} to load" if self.load_attempts else "Loading" | ||||
|         log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory") | ||||
|         self.session = self._load() | ||||
|         self.loaded = True | ||||
|  | ||||
|   | ||||
| @@ -11,6 +11,7 @@ import cv2 | ||||
| import numpy as np | ||||
| import onnxruntime as ort | ||||
| import pytest | ||||
| from fastapi import HTTPException | ||||
| from fastapi.testclient import TestClient | ||||
| from PIL import Image | ||||
| from pytest import MonkeyPatch | ||||
| @@ -627,6 +628,7 @@ class TestLoad: | ||||
|     async def test_load(self) -> None: | ||||
|         mock_model = mock.Mock(spec=InferenceModel) | ||||
|         mock_model.loaded = False | ||||
|         mock_model.load_attempts = 0 | ||||
|  | ||||
|         res = await load(mock_model) | ||||
|  | ||||
| @@ -650,6 +652,7 @@ class TestLoad: | ||||
|         mock_model.model_task = ModelTask.SEARCH | ||||
|         mock_model.load.side_effect = [OSError, None] | ||||
|         mock_model.loaded = False | ||||
|         mock_model.load_attempts = 0 | ||||
|  | ||||
|         res = await load(mock_model) | ||||
|  | ||||
| @@ -657,6 +660,20 @@ class TestLoad: | ||||
|         mock_model.clear_cache.assert_called_once() | ||||
|         assert mock_model.load.call_count == 2 | ||||
|  | ||||
|     async def test_load_clears_cache_and_raises_if_os_error_and_already_retried(self) -> None: | ||||
|         mock_model = mock.Mock(spec=InferenceModel) | ||||
|         mock_model.model_name = "test_model_name" | ||||
|         mock_model.model_type = ModelType.VISUAL | ||||
|         mock_model.model_task = ModelTask.SEARCH | ||||
|         mock_model.loaded = False | ||||
|         mock_model.load_attempts = 2 | ||||
|  | ||||
|         with pytest.raises(HTTPException): | ||||
|             await load(mock_model) | ||||
|  | ||||
|         mock_model.clear_cache.assert_not_called() | ||||
|         mock_model.load.assert_not_called() | ||||
|  | ||||
|  | ||||
| @pytest.mark.skipif( | ||||
|     not settings.test_full, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user