mirror of
https://github.com/immich-app/immich.git
synced 2024-12-23 02:06:15 +02:00
20be42cec0
* chore(deps): update machine-learning * fix typing, use new lifespan syntax * wrap in try / finally * move log --------- Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
69 lines
2.1 KiB
Python
69 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any, NamedTuple
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
from ann.ann import Ann
|
|
|
|
from ..config import log, settings
|
|
|
|
|
|
class AnnSession:
|
|
"""
|
|
Wrapper for ANN to be drop-in replacement for ONNX session.
|
|
"""
|
|
|
|
def __init__(self, model_path: Path):
|
|
tuning_file = Path(settings.cache_folder) / "gpu-tuning.ann"
|
|
with tuning_file.open(mode="a"):
|
|
# make sure tuning file exists (without clearing contents)
|
|
# once filled, the tuning file reduces the cost/time of the first
|
|
# inference after model load by 10s of seconds
|
|
pass
|
|
self.ann = Ann(tuning_level=3, tuning_file=tuning_file.as_posix())
|
|
log.info("Loading ANN model %s ...", model_path)
|
|
cache_file = model_path.with_suffix(".anncache")
|
|
save = False
|
|
if not cache_file.is_file():
|
|
save = True
|
|
with cache_file.open(mode="a"):
|
|
# create empty model cache file
|
|
pass
|
|
|
|
self.model = self.ann.load(
|
|
model_path.as_posix(),
|
|
save_cached_network=save,
|
|
cached_network_path=cache_file.as_posix(),
|
|
)
|
|
log.info("Loaded ANN model with ID %d", self.model)
|
|
|
|
def __del__(self) -> None:
|
|
self.ann.unload(self.model)
|
|
log.info("Unloaded ANN model %d", self.model)
|
|
self.ann.destroy()
|
|
|
|
def get_inputs(self) -> list[AnnNode]:
|
|
shapes = self.ann.input_shapes[self.model]
|
|
return [AnnNode(None, s) for s in shapes]
|
|
|
|
def get_outputs(self) -> list[AnnNode]:
|
|
shapes = self.ann.output_shapes[self.model]
|
|
return [AnnNode(None, s) for s in shapes]
|
|
|
|
def run(
|
|
self,
|
|
output_names: list[str] | None,
|
|
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
|
|
run_options: Any = None,
|
|
) -> list[NDArray[np.float32]]:
|
|
inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
|
|
return self.ann.execute(self.model, inputs)
|
|
|
|
|
|
class AnnNode(NamedTuple):
|
|
name: str | None
|
|
shape: tuple[int, ...]
|