import logging import os import platform import subprocess from abc import abstractmethod import onnx import open_clip import torch from onnx2torch import convert from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed from tinynn.converter import TFLiteConverter class ExportBase(torch.nn.Module): input_shape: tuple[int, ...] def __init__(self, device: torch.device, name: str): super().__init__() self.device = device self.name = name self.optimize = 5 self.nchw_transpose = False @abstractmethod def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]: pass def dummy_input(self) -> torch.FloatTensor: return torch.rand((1, 3, 224, 224), device=self.device) class ArcFace(ExportBase): input_shape = (1, 3, 112, 112) def __init__(self, onnx_model_path: str, device: torch.device): name, _ = os.path.splitext(os.path.basename(onnx_model_path)) super().__init__(device, name) onnx_model = onnx.load_model(onnx_model_path) make_input_shape_fixed(onnx_model.graph, onnx_model.graph.input[0].name, self.input_shape) fix_output_shapes(onnx_model) self.model = convert(onnx_model).to(device) if self.device.type == "cuda": self.model = self.model.half() def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor: embedding: torch.FloatTensor = self.model( input_tensor.half() if self.device.type == "cuda" else input_tensor ).float() assert isinstance(embedding, torch.FloatTensor) return embedding def dummy_input(self) -> torch.FloatTensor: return torch.rand(self.input_shape, device=self.device) class RetinaFace(ExportBase): input_shape = (1, 3, 640, 640) def __init__(self, onnx_model_path: str, device: torch.device): name, _ = os.path.splitext(os.path.basename(onnx_model_path)) super().__init__(device, name) self.optimize = 3 self.model = convert(onnx_model_path).eval().to(device) if self.device.type == "cuda": self.model = self.model.half() def forward(self, input_tensor: torch.Tensor) -> tuple[torch.FloatTensor]: out: torch.Tensor = self.model(input_tensor.half() if self.device.type == "cuda" else input_tensor) return tuple(o.float() for o in out) def dummy_input(self) -> torch.FloatTensor: return torch.rand(self.input_shape, device=self.device) class ClipVision(ExportBase): input_shape = (1, 3, 224, 224) def __init__(self, model_name: str, weights: str, device: torch.device): super().__init__(device, model_name + "__" + weights) self.model = open_clip.create_model( model_name, weights, precision="fp16" if device.type == "cuda" else "fp32", jit=False, require_pretrained=True, device=device, ) def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor: embedding: torch.Tensor = self.model.encode_image( input_tensor.half() if self.device.type == "cuda" else input_tensor, normalize=True, ).float() return embedding def export(model: ExportBase) -> None: model.eval() for param in model.parameters(): param.requires_grad = False dummy_input = model.dummy_input() model(dummy_input) jit = torch.jit.trace(model, dummy_input) # type: ignore[no-untyped-call,attr-defined] tflite_model_path = f"output/{model.name}.tflite" os.makedirs("output", exist_ok=True) converter = TFLiteConverter( jit, dummy_input, tflite_model_path, optimize=model.optimize, nchw_transpose=model.nchw_transpose, ) # segfaults on ARM, must run on x86_64 / AMD64 converter.convert() armnn_model_path = f"output/{model.name}.armnn" os.environ["LD_LIBRARY_PATH"] = "armnn" subprocess.run( [ "./armnnconverter", "-f", "tflite-binary", "-m", tflite_model_path, "-i", "input_tensor", "-o", "output_tensor", "-p", armnn_model_path, ] ) def main() -> None: if platform.machine() not in ("x86_64", "AMD64"): raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device.type != "cuda": logging.warning( "No CUDA available, cannot create fp16 model! proceeding to create a fp32 model (use only for testing)" ) models = [ ClipVision("ViT-B-32", "openai", device), ArcFace("buffalo_l_rec.onnx", device), RetinaFace("buffalo_l_det.onnx", device), ] for model in models: export(model) if __name__ == "__main__": with torch.no_grad(): main()