mirror of
https://github.com/immich-app/immich.git
synced 2025-04-27 13:42:33 +02:00
chore(ml): removed vit-b check and st warning (#4422)
This commit is contained in:
parent
b8d6cc1e09
commit
d8ecefaea5
@ -16,13 +16,6 @@ from ..config import log
|
|||||||
from ..schemas import ModelType
|
from ..schemas import ModelType
|
||||||
from .base import InferenceModel
|
from .base import InferenceModel
|
||||||
|
|
||||||
_ST_TO_JINA_MODEL_NAME = {
|
|
||||||
"clip-ViT-B-16": "ViT-B-16::openai",
|
|
||||||
"clip-ViT-B-32": "ViT-B-32::openai",
|
|
||||||
"clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32",
|
|
||||||
"clip-ViT-L-14": "ViT-L-14::openai",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPEncoder(InferenceModel):
|
class CLIPEncoder(InferenceModel):
|
||||||
_model_type = ModelType.CLIP
|
_model_type = ModelType.CLIP
|
||||||
@ -36,11 +29,10 @@ class CLIPEncoder(InferenceModel):
|
|||||||
) -> None:
|
) -> None:
|
||||||
if mode is not None and mode not in ("text", "vision"):
|
if mode is not None and mode not in ("text", "vision"):
|
||||||
raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
|
raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
|
||||||
if "vit-b" not in model_name.lower():
|
if model_name not in _MODELS:
|
||||||
raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'")
|
raise ValueError(f"Unknown model name {model_name}.")
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
jina_model_name = self._get_jina_model_name(model_name)
|
super().__init__(model_name, cache_dir, **model_kwargs)
|
||||||
super().__init__(jina_model_name, cache_dir, **model_kwargs)
|
|
||||||
|
|
||||||
def _download(self) -> None:
|
def _download(self) -> None:
|
||||||
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
|
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
|
||||||
@ -104,20 +96,6 @@ class CLIPEncoder(InferenceModel):
|
|||||||
|
|
||||||
return outputs[0][0].tolist()
|
return outputs[0][0].tolist()
|
||||||
|
|
||||||
def _get_jina_model_name(self, model_name: str) -> str:
|
|
||||||
if model_name in _MODELS:
|
|
||||||
return model_name
|
|
||||||
elif model_name in _ST_TO_JINA_MODEL_NAME:
|
|
||||||
log.warn(
|
|
||||||
(
|
|
||||||
f"Sentence-Transformer models like '{model_name}' are not supported."
|
|
||||||
f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return _ST_TO_JINA_MODEL_NAME[model_name]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown model name {model_name}.")
|
|
||||||
|
|
||||||
def _download_model(self, model_name: str, model_md5: str) -> bool:
|
def _download_model(self, model_name: str, model_md5: str) -> bool:
|
||||||
# downloading logic is adapted from clip-server's CLIPOnnxModel class
|
# downloading logic is adapted from clip-server's CLIPOnnxModel class
|
||||||
download_model(
|
download_model(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user