1
0
mirror of https://github.com/immich-app/immich.git synced 2025-08-08 23:07:06 +02:00

refactor(ml): model sessions (#10559)

This commit is contained in:
Mert
2024-06-25 12:00:24 -04:00
committed by GitHub
parent 6538ad8de7
commit 6356c28f64
11 changed files with 529 additions and 375 deletions

View File

@ -0,0 +1,5 @@
from app.schemas import ModelSession
def has_batch_axis(session: ModelSession) -> bool:
return not isinstance(session.get_inputs()[0].shape[0], int) or session.get_inputs()[0].shape[0] < 0

View File

@ -0,0 +1,57 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, NamedTuple
import numpy as np
from numpy.typing import NDArray
from ann.ann import Ann
from app.schemas import SessionNode
from ..config import log, settings
class AnnSession:
"""
Wrapper for ANN to be drop-in replacement for ONNX session.
"""
def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None:
self.model_path = model_path
self.cache_dir = cache_dir
self.ann = Ann(tuning_level=3, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
log.info("Loading ANN model %s ...", model_path)
self.model = self.ann.load(
model_path.as_posix(),
cached_network_path=model_path.with_suffix(".anncache").as_posix(),
)
log.info("Loaded ANN model with ID %d", self.model)
def __del__(self) -> None:
self.ann.unload(self.model)
log.info("Unloaded ANN model %d", self.model)
self.ann.destroy()
def get_inputs(self) -> list[SessionNode]:
shapes = self.ann.input_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def get_outputs(self) -> list[SessionNode]:
shapes = self.ann.output_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None,
) -> list[NDArray[np.float32]]:
inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
return self.ann.execute(self.model, inputs)
class AnnNode(NamedTuple):
name: str | None
shape: tuple[int, ...]

View File

@ -0,0 +1,129 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import onnxruntime as ort
from numpy.typing import NDArray
from app.models.constants import SUPPORTED_PROVIDERS
from app.schemas import SessionNode
from ..config import log, settings
class OrtSession:
def __init__(
self,
model_path: Path | str,
providers: list[str] | None = None,
provider_options: list[dict[str, Any]] | None = None,
sess_options: ort.SessionOptions | None = None,
):
self.model_path = Path(model_path)
self.providers = providers if providers is not None else self._providers_default
self.provider_options = provider_options if provider_options is not None else self._provider_options_default
self.sess_options = sess_options if sess_options is not None else self._sess_options_default
self.session = ort.InferenceSession(
self.model_path.as_posix(),
providers=self.providers,
provider_options=self.provider_options,
sess_options=self.sess_options,
)
def get_inputs(self) -> list[SessionNode]:
inputs: list[SessionNode] = self.session.get_inputs()
return inputs
def get_outputs(self) -> list[SessionNode]:
outputs: list[SessionNode] = self.session.get_outputs()
return outputs
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None,
) -> list[NDArray[np.float32]]:
outputs: list[NDArray[np.float32]] = self.session.run(output_names, input_feed, run_options)
return outputs
@property
def providers(self) -> list[str]:
return self._providers
@providers.setter
def providers(self, providers: list[str]) -> None:
log.info(f"Setting execution providers to {providers}, in descending order of preference")
self._providers = providers
@property
def _providers_default(self) -> list[str]:
available_providers = set(ort.get_available_providers())
log.debug(f"Available ORT providers: {available_providers}")
if (openvino := "OpenVINOExecutionProvider") in available_providers:
device_ids: list[str] = ort.capi._pybind_state.get_available_openvino_device_ids()
log.debug(f"Available OpenVINO devices: {device_ids}")
gpu_devices = [device_id for device_id in device_ids if device_id.startswith("GPU")]
if not gpu_devices:
log.warning("No GPU device found in OpenVINO. Falling back to CPU.")
available_providers.remove(openvino)
return [provider for provider in SUPPORTED_PROVIDERS if provider in available_providers]
@property
def provider_options(self) -> list[dict[str, Any]]:
return self._provider_options
@provider_options.setter
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
log.debug(f"Setting execution provider options to {provider_options}")
self._provider_options = provider_options
@property
def _provider_options_default(self) -> list[dict[str, Any]]:
options = []
for provider in self.providers:
match provider:
case "CPUExecutionProvider" | "CUDAExecutionProvider":
option = {"arena_extend_strategy": "kSameAsRequested"}
case "OpenVINOExecutionProvider":
option = {"device_type": "GPU_FP32", "cache_dir": (self.model_path.parent / "openvino").as_posix()}
case _:
option = {}
options.append(option)
return options
@property
def sess_options(self) -> ort.SessionOptions:
return self._sess_options
@sess_options.setter
def sess_options(self, sess_options: ort.SessionOptions) -> None:
log.debug(f"Setting execution_mode to {sess_options.execution_mode.name}")
log.debug(f"Setting inter_op_num_threads to {sess_options.inter_op_num_threads}")
log.debug(f"Setting intra_op_num_threads to {sess_options.intra_op_num_threads}")
self._sess_options = sess_options
@property
def _sess_options_default(self) -> ort.SessionOptions:
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = False
# avoid thread contention between models
if settings.model_inter_op_threads > 0:
sess_options.inter_op_num_threads = settings.model_inter_op_threads
# these defaults work well for CPU, but bottleneck GPU
elif settings.model_inter_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
sess_options.inter_op_num_threads = 1
if settings.model_intra_op_threads > 0:
sess_options.intra_op_num_threads = settings.model_intra_op_threads
elif settings.model_intra_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
sess_options.intra_op_num_threads = 2
if sess_options.inter_op_num_threads > 1:
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
return sess_options