1
0
mirror of https://github.com/immich-app/immich.git synced 2025-07-17 08:47:43 +02:00

convert to static

This commit is contained in:
mertalev
2024-02-02 22:15:14 -05:00
parent b768eef44d
commit b374052de2
2 changed files with 46 additions and 6 deletions

View File

@ -5,13 +5,16 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from typing import Any from typing import Any
import onnx
from onnx.tools.update_model_dims import update_inputs_outputs_dims
from onnx.shape_inference import infer_shapes
import onnxruntime as ort import onnxruntime as ort
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from typing_extensions import Buffer from typing_extensions import Buffer
import ann.ann import ann.ann
from app.models.constants import SUPPORTED_PROVIDERS from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
from ..config import get_cache_dir, get_hf_model_name, log, settings from ..config import get_cache_dir, get_hf_model_name, log, settings
from ..schemas import ModelRuntime, ModelType from ..schemas import ModelRuntime, ModelType
@ -114,6 +117,13 @@ class InferenceModel(ABC):
) )
model_path = onnx_path model_path = onnx_path
if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers):
static_path = model_path.parent / "static_1" / "model.onnx"
static_path.parent.mkdir(parents=True, exist_ok=True)
if not static_path.is_file():
self._convert_to_static(model_path, static_path)
model_path = static_path
match model_path.suffix: match model_path.suffix:
case ".armnn": case ".armnn":
session = AnnSession(model_path) session = AnnSession(model_path)
@ -128,6 +138,37 @@ class InferenceModel(ABC):
raise ValueError(f"Unsupported model file type: {model_path.suffix}") raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session return session
def _convert_to_static(self, source_path: Path, target_path: Path) -> None:
inferred = infer_shapes(onnx.load(source_path))
inputs = self._get_static_dims(inferred.graph.input)
outputs = self._get_static_dims(inferred.graph.output)
check_model = onnx.checker.check_model
try:
onnx.checker.check_model = lambda _: None
updated_model = update_inputs_outputs_dims(inferred, inputs, outputs)
finally:
onnx.checker.check_model = check_model
onnx.save(
updated_model,
target_path,
save_as_external_data=True,
all_tensors_to_one_file=False,
size_threshold=1048576,
)
def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]:
return {
field.name: [
d.dim_value if d.HasField("dim_value") else dim_size
for shape in field.type.ListFields()
if (dim := shape[1].shape.dim)
for d in dim
]
for field in graph_io
}
@property @property
def model_type(self) -> ModelType: def model_type(self) -> ModelType:
return self._model_type return self._model_type

View File

@ -51,11 +51,10 @@ _INSIGHTFACE_MODELS = {
} }
SUPPORTED_PROVIDERS = [ SUPPORTED_PROVIDERS = {"CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"}
"CUDAExecutionProvider",
"OpenVINOExecutionProvider",
"CPUExecutionProvider", STATIC_INPUT_PROVIDERS = {"OpenVINOExecutionProvider"}
]
def is_openclip(model_name: str) -> bool: def is_openclip(model_name: str) -> bool: