1
0
mirror of https://github.com/immich-app/immich.git synced 2025-07-17 15:47:54 +02:00

convert to static

This commit is contained in:
mertalev
2024-02-02 22:15:14 -05:00
parent b768eef44d
commit b374052de2
2 changed files with 46 additions and 6 deletions

View File

@ -5,13 +5,16 @@ from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from typing import Any
import onnx
from onnx.tools.update_model_dims import update_inputs_outputs_dims
from onnx.shape_inference import infer_shapes
import onnxruntime as ort
from huggingface_hub import snapshot_download
from typing_extensions import Buffer
import ann.ann
from app.models.constants import SUPPORTED_PROVIDERS
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
from ..config import get_cache_dir, get_hf_model_name, log, settings
from ..schemas import ModelRuntime, ModelType
@ -114,6 +117,13 @@ class InferenceModel(ABC):
)
model_path = onnx_path
if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers):
static_path = model_path.parent / "static_1" / "model.onnx"
static_path.parent.mkdir(parents=True, exist_ok=True)
if not static_path.is_file():
self._convert_to_static(model_path, static_path)
model_path = static_path
match model_path.suffix:
case ".armnn":
session = AnnSession(model_path)
@ -128,6 +138,37 @@ class InferenceModel(ABC):
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session
def _convert_to_static(self, source_path: Path, target_path: Path) -> None:
inferred = infer_shapes(onnx.load(source_path))
inputs = self._get_static_dims(inferred.graph.input)
outputs = self._get_static_dims(inferred.graph.output)
check_model = onnx.checker.check_model
try:
onnx.checker.check_model = lambda _: None
updated_model = update_inputs_outputs_dims(inferred, inputs, outputs)
finally:
onnx.checker.check_model = check_model
onnx.save(
updated_model,
target_path,
save_as_external_data=True,
all_tensors_to_one_file=False,
size_threshold=1048576,
)
def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]:
return {
field.name: [
d.dim_value if d.HasField("dim_value") else dim_size
for shape in field.type.ListFields()
if (dim := shape[1].shape.dim)
for d in dim
]
for field in graph_io
}
@property
def model_type(self) -> ModelType:
return self._model_type

View File

@ -51,11 +51,10 @@ _INSIGHTFACE_MODELS = {
}
SUPPORTED_PROVIDERS = [
"CUDAExecutionProvider",
"OpenVINOExecutionProvider",
"CPUExecutionProvider",
]
SUPPORTED_PROVIDERS = {"CUDAExecutionProvider", "OpenVINOExecutionProvider", "CPUExecutionProvider"}
STATIC_INPUT_PROVIDERS = {"OpenVINOExecutionProvider"}
def is_openclip(model_name: str) -> bool: