from __future__ import annotations from pathlib import Path from typing import Any, NamedTuple import numpy as np from numpy.typing import NDArray from ann.ann import Ann from ..config import log, settings class AnnSession: """ Wrapper for ANN to be drop-in replacement for ONNX session. """ def __init__(self, model_path: Path): tuning_file = Path(settings.cache_folder) / "gpu-tuning.ann" with tuning_file.open(mode="a"): # make sure tuning file exists (without clearing contents) # once filled, the tuning file reduces the cost/time of the first # inference after model load by 10s of seconds pass self.ann = Ann(tuning_level=3, tuning_file=tuning_file.as_posix()) log.info("Loading ANN model %s ...", model_path) cache_file = model_path.with_suffix(".anncache") save = False if not cache_file.is_file(): save = True with cache_file.open(mode="a"): # create empty model cache file pass self.model = self.ann.load( model_path.as_posix(), save_cached_network=save, cached_network_path=cache_file.as_posix(), ) log.info("Loaded ANN model with ID %d", self.model) def __del__(self) -> None: self.ann.unload(self.model) log.info("Unloaded ANN model %d", self.model) self.ann.destroy() def get_inputs(self) -> list[AnnNode]: shapes = self.ann.input_shapes[self.model] return [AnnNode(None, s) for s in shapes] def get_outputs(self) -> list[AnnNode]: shapes = self.ann.output_shapes[self.model] return [AnnNode(None, s) for s in shapes] def run( self, output_names: list[str] | None, input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]], run_options: Any = None, ) -> list[NDArray[np.float32]]: inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()] return self.ann.execute(self.model, inputs) class AnnNode(NamedTuple): name: str | None shape: tuple[int, ...]