mirror of
https://github.com/immich-app/immich.git
synced 2024-11-28 09:33:27 +02:00
fixed setting different clip, removed unused stubs (#2987)
This commit is contained in:
parent
b3e97a1a0c
commit
4d3ce0a65e
@ -27,13 +27,10 @@ app = FastAPI()
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
|
||||
same_clip = settings.clip_image_model == settings.clip_text_model
|
||||
app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION
|
||||
app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT
|
||||
models = [
|
||||
(settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
|
||||
(settings.clip_image_model, app.state.clip_vision_type),
|
||||
(settings.clip_text_model, app.state.clip_text_type),
|
||||
(settings.clip_image_model, ModelType.CLIP),
|
||||
(settings.clip_text_model, ModelType.CLIP),
|
||||
(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
|
||||
]
|
||||
|
||||
@ -87,9 +84,7 @@ async def image_classification(
|
||||
async def clip_encode_image(
|
||||
image: Image.Image = Depends(dep_pil_image),
|
||||
) -> list[float]:
|
||||
model = await app.state.model_cache.get(
|
||||
settings.clip_image_model, app.state.clip_vision_type
|
||||
)
|
||||
model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP)
|
||||
embedding = model.predict(image)
|
||||
return embedding
|
||||
|
||||
@ -100,9 +95,7 @@ async def clip_encode_image(
|
||||
status_code=200,
|
||||
)
|
||||
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
|
||||
model = await app.state.model_cache.get(
|
||||
settings.clip_text_model, app.state.clip_text_type
|
||||
)
|
||||
model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP)
|
||||
embedding = model.predict(payload.text)
|
||||
return embedding
|
||||
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .clip import CLIPSTTextEncoder, CLIPSTVisionEncoder
|
||||
from .clip import CLIPSTEncoder
|
||||
from .facial_recognition import FaceRecognizer
|
||||
from .image_classification import ImageClassifier
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
@ -25,13 +25,3 @@ class CLIPSTEncoder(InferenceModel):
|
||||
|
||||
def predict(self, image_or_text: Image | str) -> list[float]:
|
||||
return self.model.encode(image_or_text).tolist()
|
||||
|
||||
|
||||
# stubs to allow different behavior between the two in the future
|
||||
# and handle loading different image and text clip models
|
||||
class CLIPSTVisionEncoder(CLIPSTEncoder):
|
||||
_model_type = ModelType.CLIP_VISION
|
||||
|
||||
|
||||
class CLIPSTTextEncoder(CLIPSTEncoder):
|
||||
_model_type = ModelType.CLIP_TEXT
|
||||
|
@ -61,6 +61,4 @@ class FaceResponse(BaseModel):
|
||||
class ModelType(Enum):
|
||||
IMAGE_CLASSIFICATION = "image-classification"
|
||||
CLIP = "clip"
|
||||
CLIP_VISION = "clip-vision"
|
||||
CLIP_TEXT = "clip-text"
|
||||
FACIAL_RECOGNITION = "facial-recognition"
|
||||
|
Loading…
Reference in New Issue
Block a user