From bb56bd3297303332b25bdc3b112cfc278ee593ba Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Fri, 2 Feb 2024 23:34:32 -0500 Subject: [PATCH] cleanup --- machine-learning/app/models/base.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index b0ac8b9b66..7f707abfee 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -9,7 +9,7 @@ from typing import Any import onnx import onnxruntime as ort from huggingface_hub import snapshot_download -from onnx.shape_inference import infer_shapes +from onnx.shape_inference import infer_shapes_path from onnx.tools.update_model_dims import update_inputs_outputs_dims from typing_extensions import Buffer import ann.ann @@ -117,8 +117,7 @@ class InferenceModel(ABC): model_path = onnx_path if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers): - static_path = model_path.parent / "static_1" / "model.onnx" - static_path.parent.mkdir(parents=True, exist_ok=True) + static_path = model_path.parent / "model_static_1.onnx" if not static_path.is_file(): self._convert_to_static(model_path, static_path) model_path = static_path @@ -138,29 +137,24 @@ class InferenceModel(ABC): return session def _convert_to_static(self, source_path: Path, target_path: Path) -> None: - inferred = infer_shapes(onnx.load(source_path)) - inputs = self._get_static_dims(inferred.graph.input) - outputs = self._get_static_dims(inferred.graph.output) + infer_shapes_path(source_path, strict_mode=True) + proto = onnx.load(source_path, load_external_data=False) + inputs = self._get_static_dims(proto.graph.input) + outputs = self._get_static_dims(proto.graph.output) - # check_model gets called in update_inputs_outputs_dims and doesn't work for large models + # check_model gets called in update_inputs_outputs_dims check_model = onnx.checker.check_model try: - + def check_model_stub(*args: Any, **kwargs: Any) -> None: pass onnx.checker.check_model = check_model_stub - updated_model = update_inputs_outputs_dims(inferred, inputs, outputs) + updated_model = update_inputs_outputs_dims(proto, inputs, outputs) finally: onnx.checker.check_model = check_model - onnx.save( - updated_model, - target_path, - save_as_external_data=True, - all_tensors_to_one_file=False, - size_threshold=1048576, - ) + onnx.save(updated_model, target_path) def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]: return {