1
0
mirror of https://github.com/immich-app/immich.git synced 2025-08-08 23:07:06 +02:00

feat: Add additional env variables for Machine Learning (#15326)

* Add additional variables to preload part ML models

* Add additional variables to preload part ML models

* Add additional variables to preload part ML models

* Add additional variables to preload part ML models

* Add additional variables to preload part ML models

* Add additional variables to preload part ML models

* Add additional variables to preload part ML models

* Add additional variables to preload part ML models

* Add additional variables to preload part ML models

* Update config.py

* Add additional variables to preload part ML models

* Add additional variables to preload part ML models

* Apply formatting

* minor update

* formatting

* root validator

* minor update

* minor update

* minor update

* change to support explicit models

* minor update

* minor change

* minor change

* minor change

* minor update

* add logs, resolve errors

* minor change

* add new enviornment variables

* minor revisons

* remove comments
This commit is contained in:
Tempest
2025-01-14 16:06:01 -06:00
committed by GitHub
parent 5d2e421800
commit c5476a99b1
4 changed files with 87 additions and 34 deletions

View File

@ -14,9 +14,41 @@ from uvicorn import Server
from uvicorn.workers import UvicornWorker
class ClipSettings(BaseModel):
textual: str | None = None
visual: str | None = None
class FacialRecognitionSettings(BaseModel):
recognition: str | None = None
detection: str | None = None
class PreloadModelData(BaseModel):
clip: str | None = None
facial_recognition: str | None = None
clip: ClipSettings = ClipSettings()
facial_recognition: FacialRecognitionSettings = FacialRecognitionSettings()
clip_model_fallback: str | None = os.getenv("MACHINE_LEARNING_PRELOAD__CLIP", None)
facial_recognition_model_fallback: str | None = os.getenv("MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION", None)
def update_from_fallbacks(self) -> None:
if self.clip_model_fallback:
self.clip.textual = self.clip_model_fallback
self.clip.visual = self.clip_model_fallback
log.warning(
"Deprecated env variable: MACHINE_LEARNING_PRELOAD__CLIP. "
"Use MACHINE_LEARNING_PRELOAD__CLIP__TEXTUAL and "
"MACHINE_LEARNING_PRELOAD__CLIP__VISUAL instead."
)
if self.facial_recognition_model_fallback:
self.facial_recognition.recognition = self.facial_recognition_model_fallback
self.facial_recognition.detection = self.facial_recognition_model_fallback
log.warning(
"Deprecated environment variable: MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION. "
"Use MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION and "
"MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION instead."
)
class MaxBatchSize(BaseModel):

View File

@ -76,18 +76,29 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
async def preload_models(preload: PreloadModelData) -> None:
log.info(f"Preloading models: {preload}")
if preload.clip is not None:
model = await model_cache.get(preload.clip, ModelType.TEXTUAL, ModelTask.SEARCH)
if preload.clip.textual is not None:
model = await model_cache.get(preload.clip.textual, ModelType.TEXTUAL, ModelTask.SEARCH)
await load(model)
model = await model_cache.get(preload.clip, ModelType.VISUAL, ModelTask.SEARCH)
if preload.clip.visual is not None:
model = await model_cache.get(preload.clip.visual, ModelType.VISUAL, ModelTask.SEARCH)
await load(model)
if preload.facial_recognition is not None:
model = await model_cache.get(preload.facial_recognition, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
if preload.facial_recognition.detection is not None:
model = await model_cache.get(
preload.facial_recognition.detection,
ModelType.DETECTION,
ModelTask.FACIAL_RECOGNITION,
)
await load(model)
model = await model_cache.get(preload.facial_recognition, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
if preload.facial_recognition.recognition is not None:
model = await model_cache.get(
preload.facial_recognition.recognition,
ModelType.RECOGNITION,
ModelTask.FACIAL_RECOGNITION,
)
await load(model)

View File

@ -700,11 +700,13 @@ class TestCache:
await model_cache.get("test_model_name", ModelType.TEXTUAL, ModelTask.SEARCH)
async def test_preloads_clip_models(self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock) -> None:
os.environ["MACHINE_LEARNING_PRELOAD__CLIP"] = "ViT-B-32__openai"
os.environ["MACHINE_LEARNING_PRELOAD__CLIP__TEXTUAL"] = "ViT-B-32__openai"
os.environ["MACHINE_LEARNING_PRELOAD__CLIP__VISUAL"] = "ViT-B-32__openai"
settings = Settings()
assert settings.preload is not None
assert settings.preload.clip == "ViT-B-32__openai"
assert settings.preload.clip.textual == "ViT-B-32__openai"
assert settings.preload.clip.visual == "ViT-B-32__openai"
model_cache = ModelCache()
monkeypatch.setattr("app.main.model_cache", model_cache)
@ -721,11 +723,13 @@ class TestCache:
async def test_preloads_facial_recognition_models(
self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock
) -> None:
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"] = "buffalo_s"
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION"] = "buffalo_s"
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION"] = "buffalo_s"
settings = Settings()
assert settings.preload is not None
assert settings.preload.facial_recognition == "buffalo_s"
assert settings.preload.facial_recognition.detection == "buffalo_s"
assert settings.preload.facial_recognition.recognition == "buffalo_s"
model_cache = ModelCache()
monkeypatch.setattr("app.main.model_cache", model_cache)
@ -740,13 +744,17 @@ class TestCache:
)
async def test_preloads_all_models(self, monkeypatch: MonkeyPatch, mock_get_model: mock.Mock) -> None:
os.environ["MACHINE_LEARNING_PRELOAD__CLIP"] = "ViT-B-32__openai"
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"] = "buffalo_s"
os.environ["MACHINE_LEARNING_PRELOAD__CLIP__TEXTUAL"] = "ViT-B-32__openai"
os.environ["MACHINE_LEARNING_PRELOAD__CLIP__VISUAL"] = "ViT-B-32__openai"
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION"] = "buffalo_s"
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION"] = "buffalo_s"
settings = Settings()
assert settings.preload is not None
assert settings.preload.clip == "ViT-B-32__openai"
assert settings.preload.facial_recognition == "buffalo_s"
assert settings.preload.clip.visual == "ViT-B-32__openai"
assert settings.preload.clip.textual == "ViT-B-32__openai"
assert settings.preload.facial_recognition.recognition == "buffalo_s"
assert settings.preload.facial_recognition.detection == "buffalo_s"
model_cache = ModelCache()
monkeypatch.setattr("app.main.model_cache", model_cache)