You've already forked immich
mirror of
https://github.com/immich-app/immich.git
synced 2025-08-07 23:03:36 +02:00
feat(ml): better multilingual search with nllb models (#13567)
This commit is contained in:
@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer
|
||||
|
||||
from immich_ml.config import log
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.constants import WEBLATE_TO_FLORES200
|
||||
from immich_ml.models.transforms import clean_text, serialize_np_array
|
||||
from immich_ml.schemas import ModelSession, ModelTask, ModelType
|
||||
|
||||
@ -18,8 +19,9 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||
depends = []
|
||||
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||
|
||||
def _predict(self, inputs: str, **kwargs: Any) -> str:
|
||||
res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0]
|
||||
def _predict(self, inputs: str, language: str | None = None, **kwargs: Any) -> str:
|
||||
tokens = self.tokenize(inputs, language=language)
|
||||
res: NDArray[np.float32] = self.session.run(None, tokens)[0][0]
|
||||
return serialize_np_array(res)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
@ -28,6 +30,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
|
||||
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
|
||||
self.is_nllb = self.model_name.startswith("nllb")
|
||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
return session
|
||||
@ -37,7 +40,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@ -92,14 +95,23 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
||||
|
||||
return tokenizer
|
||||
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
if self.is_nllb and language is not None:
|
||||
flores_code = WEBLATE_TO_FLORES200.get(language)
|
||||
if flores_code is None:
|
||||
no_country = language.split("-")[0]
|
||||
flores_code = WEBLATE_TO_FLORES200.get(no_country)
|
||||
if flores_code is None:
|
||||
log.warning(f"Language '{language}' not found, defaulting to 'en'")
|
||||
flores_code = "eng_Latn"
|
||||
text = f"{flores_code}{text}"
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||
|
||||
|
||||
class MClipTextualEncoder(OpenClipTextualEncoder):
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {
|
||||
|
Reference in New Issue
Block a user