1
0
mirror of https://github.com/immich-app/immich.git synced 2025-06-16 03:40:33 +02:00

feat(ml)!: switch image classification and CLIP models to ONNX (#3809)

This commit is contained in:
Mert
2023-08-25 00:28:51 -04:00
committed by GitHub
parent 8211afb726
commit 165b91b068
14 changed files with 1617 additions and 507 deletions

View File

@ -1,14 +1,17 @@
from __future__ import annotations
import os
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from typing import Any
from zipfile import BadZipFile
import onnxruntime as ort
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
from ..config import get_cache_dir
from ..config import get_cache_dir, settings
from ..schemas import ModelType
@ -16,12 +19,31 @@ class InferenceModel(ABC):
_model_type: ModelType
def __init__(
self, model_name: str, cache_dir: Path | str | None = None, eager: bool = True, **model_kwargs: Any
self,
model_name: str,
cache_dir: Path | str | None = None,
eager: bool = True,
inter_op_num_threads: int = settings.model_inter_op_threads,
intra_op_num_threads: int = settings.model_intra_op_threads,
**model_kwargs: Any,
) -> None:
self.model_name = model_name
self._loaded = False
self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
loader = self.load if eager else self.download
self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
# don't pre-allocate more memory than needed
self.provider_options = model_kwargs.pop(
"provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
)
self.sess_options = PicklableSessionOptions()
# avoid thread contention between models
if inter_op_num_threads > 1:
self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
self.sess_options.inter_op_num_threads = inter_op_num_threads
self.sess_options.intra_op_num_threads = intra_op_num_threads
try:
loader(**model_kwargs)
except (OSError, InvalidProtobuf, BadZipFile):
@ -30,6 +52,7 @@ class InferenceModel(ABC):
def download(self, **model_kwargs: Any) -> None:
if not self.cached:
print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...")
self._download(**model_kwargs)
def load(self, **model_kwargs: Any) -> None:
@ -39,6 +62,7 @@ class InferenceModel(ABC):
def predict(self, inputs: Any) -> Any:
if not self._loaded:
print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
self.load()
return self._predict(inputs)
@ -89,3 +113,14 @@ class InferenceModel(ABC):
else:
self.cache_dir.unlink()
self.cache_dir.mkdir(parents=True, exist_ok=True)
# HF deep copies configs, so we need to make session options picklable
class PicklableSessionOptions(ort.SessionOptions):
def __getstate__(self) -> bytes:
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
def __setstate__(self, state: Any) -> None:
self.__init__() # type: ignore
for attr, val in pickle.loads(state):
setattr(self, attr, val)