1
0
mirror of https://github.com/immich-app/immich.git synced 2024-11-24 08:52:28 +02:00

fix(ml): limit load retries (#10494)

This commit is contained in:
Mert 2024-06-20 14:13:18 -04:00 committed by GitHub
parent 79a8ab71ef
commit a42af06889
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 11 deletions

View File

@ -192,23 +192,18 @@ async def load(model: InferenceModel) -> InferenceModel:
return model return model
def _load(model: InferenceModel) -> InferenceModel: def _load(model: InferenceModel) -> InferenceModel:
if model.load_attempts > 1:
raise HTTPException(500, f"Failed to load model '{model.model_name}'")
with lock: with lock:
model.load() model.load()
return model return model
try: try:
await run(_load, model) return await run(_load, model)
return model
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile): except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warning( log.warning(f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'. Clearing cache.")
(
f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
"Clearing cache and retrying."
)
)
model.clear_cache() model.clear_cache()
await run(_load, model) return await run(_load, model)
return model
async def idle_shutdown_task() -> None: async def idle_shutdown_task() -> None:

View File

@ -31,6 +31,7 @@ class InferenceModel(ABC):
**model_kwargs: Any, **model_kwargs: Any,
) -> None: ) -> None:
self.loaded = False self.loaded = False
self.load_attempts = 0
self.model_name = clean_name(model_name) self.model_name = clean_name(model_name)
self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default
self.providers = providers if providers is not None else self.providers_default self.providers = providers if providers is not None else self.providers_default
@ -48,9 +49,11 @@ class InferenceModel(ABC):
def load(self) -> None: def load(self) -> None:
if self.loaded: if self.loaded:
return return
self.load_attempts += 1
self.download() self.download()
log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory") attempt = f"Attempt #{self.load_attempts + 1} to load" if self.load_attempts else "Loading"
log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
self.session = self._load() self.session = self._load()
self.loaded = True self.loaded = True

View File

@ -11,6 +11,7 @@ import cv2
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
import pytest import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from PIL import Image from PIL import Image
from pytest import MonkeyPatch from pytest import MonkeyPatch
@ -627,6 +628,7 @@ class TestLoad:
async def test_load(self) -> None: async def test_load(self) -> None:
mock_model = mock.Mock(spec=InferenceModel) mock_model = mock.Mock(spec=InferenceModel)
mock_model.loaded = False mock_model.loaded = False
mock_model.load_attempts = 0
res = await load(mock_model) res = await load(mock_model)
@ -650,6 +652,7 @@ class TestLoad:
mock_model.model_task = ModelTask.SEARCH mock_model.model_task = ModelTask.SEARCH
mock_model.load.side_effect = [OSError, None] mock_model.load.side_effect = [OSError, None]
mock_model.loaded = False mock_model.loaded = False
mock_model.load_attempts = 0
res = await load(mock_model) res = await load(mock_model)
@ -657,6 +660,20 @@ class TestLoad:
mock_model.clear_cache.assert_called_once() mock_model.clear_cache.assert_called_once()
assert mock_model.load.call_count == 2 assert mock_model.load.call_count == 2
async def test_load_clears_cache_and_raises_if_os_error_and_already_retried(self) -> None:
mock_model = mock.Mock(spec=InferenceModel)
mock_model.model_name = "test_model_name"
mock_model.model_type = ModelType.VISUAL
mock_model.model_task = ModelTask.SEARCH
mock_model.loaded = False
mock_model.load_attempts = 2
with pytest.raises(HTTPException):
await load(mock_model)
mock_model.clear_cache.assert_not_called()
mock_model.load.assert_not_called()
@pytest.mark.skipif( @pytest.mark.skipif(
not settings.test_full, not settings.test_full,