mirror of
https://github.com/immich-app/immich.git
synced 2025-01-12 15:32:36 +02:00
cleanup
This commit is contained in:
parent
2c2cf59f09
commit
bb56bd3297
@ -9,7 +9,7 @@ from typing import Any
|
|||||||
import onnx
|
import onnx
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from onnx.shape_inference import infer_shapes
|
from onnx.shape_inference import infer_shapes_path
|
||||||
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
||||||
from typing_extensions import Buffer
|
from typing_extensions import Buffer
|
||||||
import ann.ann
|
import ann.ann
|
||||||
@ -117,8 +117,7 @@ class InferenceModel(ABC):
|
|||||||
model_path = onnx_path
|
model_path = onnx_path
|
||||||
|
|
||||||
if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers):
|
if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers):
|
||||||
static_path = model_path.parent / "static_1" / "model.onnx"
|
static_path = model_path.parent / "model_static_1.onnx"
|
||||||
static_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
if not static_path.is_file():
|
if not static_path.is_file():
|
||||||
self._convert_to_static(model_path, static_path)
|
self._convert_to_static(model_path, static_path)
|
||||||
model_path = static_path
|
model_path = static_path
|
||||||
@ -138,29 +137,24 @@ class InferenceModel(ABC):
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
def _convert_to_static(self, source_path: Path, target_path: Path) -> None:
|
def _convert_to_static(self, source_path: Path, target_path: Path) -> None:
|
||||||
inferred = infer_shapes(onnx.load(source_path))
|
infer_shapes_path(source_path, strict_mode=True)
|
||||||
inputs = self._get_static_dims(inferred.graph.input)
|
proto = onnx.load(source_path, load_external_data=False)
|
||||||
outputs = self._get_static_dims(inferred.graph.output)
|
inputs = self._get_static_dims(proto.graph.input)
|
||||||
|
outputs = self._get_static_dims(proto.graph.output)
|
||||||
|
|
||||||
# check_model gets called in update_inputs_outputs_dims and doesn't work for large models
|
# check_model gets called in update_inputs_outputs_dims
|
||||||
check_model = onnx.checker.check_model
|
check_model = onnx.checker.check_model
|
||||||
try:
|
try:
|
||||||
|
|
||||||
def check_model_stub(*args: Any, **kwargs: Any) -> None:
|
def check_model_stub(*args: Any, **kwargs: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
onnx.checker.check_model = check_model_stub
|
onnx.checker.check_model = check_model_stub
|
||||||
updated_model = update_inputs_outputs_dims(inferred, inputs, outputs)
|
updated_model = update_inputs_outputs_dims(proto, inputs, outputs)
|
||||||
finally:
|
finally:
|
||||||
onnx.checker.check_model = check_model
|
onnx.checker.check_model = check_model
|
||||||
|
|
||||||
onnx.save(
|
onnx.save(updated_model, target_path)
|
||||||
updated_model,
|
|
||||||
target_path,
|
|
||||||
save_as_external_data=True,
|
|
||||||
all_tensors_to_one_file=False,
|
|
||||||
size_threshold=1048576,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]:
|
def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]:
|
||||||
return {
|
return {
|
||||||
|
Loading…
Reference in New Issue
Block a user