You've already forked immich
mirror of
https://github.com/immich-app/immich.git
synced 2025-08-07 23:03:36 +02:00
chore(ml): installable package (#17153)
* app -> immich_ml * fix test ci * omit file name * add new line * add new line
This commit is contained in:
58
machine-learning/immich_ml/sessions/ann/__init__.py
Normal file
58
machine-learning/immich_ml/sessions/ann/__init__.py
Normal file
@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.config import log, settings
|
||||
from immich_ml.schemas import SessionNode
|
||||
|
||||
from .loader import Ann
|
||||
|
||||
|
||||
class AnnSession:
|
||||
"""
|
||||
Wrapper for ANN to be drop-in replacement for ONNX session.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None:
|
||||
self.model_path = model_path
|
||||
self.cache_dir = cache_dir
|
||||
self.ann = Ann(tuning_level=settings.ann_tuning_level, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
|
||||
|
||||
log.info("Loading ANN model %s ...", model_path)
|
||||
self.model = self.ann.load(
|
||||
model_path.as_posix(),
|
||||
cached_network_path=model_path.with_suffix(".anncache").as_posix(),
|
||||
fp16=settings.ann_fp16_turbo,
|
||||
)
|
||||
log.info("Loaded ANN model with ID %d", self.model)
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.ann.unload(self.model)
|
||||
log.info("Unloaded ANN model %d", self.model)
|
||||
self.ann.destroy()
|
||||
|
||||
def get_inputs(self) -> list[SessionNode]:
|
||||
shapes = self.ann.input_shapes[self.model]
|
||||
return [AnnNode(None, s) for s in shapes]
|
||||
|
||||
def get_outputs(self) -> list[SessionNode]:
|
||||
shapes = self.ann.output_shapes[self.model]
|
||||
return [AnnNode(None, s) for s in shapes]
|
||||
|
||||
def run(
|
||||
self,
|
||||
output_names: list[str] | None,
|
||||
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
|
||||
run_options: Any = None,
|
||||
) -> list[NDArray[np.float32]]:
|
||||
inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
|
||||
return self.ann.execute(self.model, inputs)
|
||||
|
||||
|
||||
class AnnNode(NamedTuple):
|
||||
name: str | None
|
||||
shape: tuple[int, ...]
|
Reference in New Issue
Block a user