1
0
mirror of https://github.com/immich-app/immich.git synced 2025-01-02 12:48:35 +02:00

feat(ml): ARMNN acceleration (#5667)

* feat(ml): ARMNN acceleration for CLIP

* wrap ANN as ONNX-Session

* strict typing

* normalize ARMNN CLIP embedding

* mutex to handle concurrent execution

* make inputs contiguous

* fine-grained locking; concurrent network execution

---------

Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
This commit is contained in:
Fynn Petersen-Frey 2024-01-11 18:26:46 +01:00 committed by GitHub
parent 29747437f6
commit 753292956e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 956 additions and 44 deletions

11
docker/mlaccel-armnn.yml Normal file
View File

@ -0,0 +1,11 @@
version: "3.8"
# ML acceleration on supported Mali ARM GPUs using ARM-NN
services:
mlaccel:
devices:
- /dev/mali0:/dev/mali0
volumes:
- /lib/firmware/mali_csffw.bin:/lib/firmware/mali_csffw.bin:ro # Mali firmware for your chipset (not always required depending on the driver)
- /usr/lib/libmali.so:/usr/lib/libmali.so:ro # Mali driver for your chipset (always required)

View File

@ -13,17 +13,40 @@ ENV VIRTUAL_ENV="/opt/venv" PATH="/opt/venv/bin:${PATH}"
COPY poetry.lock pyproject.toml ./
RUN poetry install --sync --no-interaction --no-ansi --no-root --only main
FROM python:3.11-slim-bookworm@sha256:8f64a67710f3d981cf3008d6f9f1dbe61accd7927f165f4e37ea3f8b883ccc3f
ARG TARGETPLATFORM
ENV ARMNN_PATH=/opt/armnn
COPY ann /opt/ann
RUN if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
mkdir /opt/armnn && \
curl -SL "https://github.com/ARM-software/armnn/releases/download/v23.11/ArmNN-linux-aarch64.tar.gz" | tar -zx -C /opt/armnn && \
cd /opt/ann && \
sh build.sh; \
else \
mkdir /opt/armnn; \
fi
FROM python:3.11-slim-bookworm@sha256:8f64a67710f3d981cf3008d6f9f1dbe61accd7927f165f4e37ea3f8b883ccc3f
ARG TARGETPLATFORM
RUN apt-get update && apt-get install -y --no-install-recommends tini libmimalloc2.0 && rm -rf /var/lib/apt/lists/*
RUN if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
apt-get update && apt-get install -y --no-install-recommends ocl-icd-libopencl1 mesa-opencl-icd && \
rm -rf /var/lib/apt/lists/* && \
mkdir --parents /etc/OpenCL/vendors && \
echo "/usr/lib/libmali.so" > /etc/OpenCL/vendors/mali.icd && \
mkdir /opt/armnn; \
fi
WORKDIR /usr/src/app
ENV NODE_ENV=production \
TRANSFORMERS_CACHE=/cache \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PATH="/opt/venv/bin:$PATH" \
PYTHONPATH=/usr/src
PYTHONPATH=/usr/src \
LD_LIBRARY_PATH=/opt/armnn
# prevent core dumps
RUN echo "hard core 0" >> /etc/security/limits.conf && \
@ -31,7 +54,10 @@ RUN echo "hard core 0" >> /etc/security/limits.conf && \
echo 'ulimit -S -c 0 > /dev/null 2>&1' >> /etc/profile
COPY --from=builder /opt/venv /opt/venv
COPY --from=builder /opt/armnn/libarmnn.so.?? /opt/armnn/libarmnnOnnxParser.so.?? /opt/armnn/libarmnnDeserializer.so.?? /opt/armnn/libarmnnTfLiteParser.so.?? /opt/armnn/libprotobuf.so.?.??.?.? /opt/ann/libann.s[o] /opt/ann/build.sh /opt/armnn
COPY ann/ann.py /usr/src/ann/ann.py
COPY start.sh log_conf.json ./
COPY app .
ENTRYPOINT ["tini", "--"]
CMD ["./start.sh"]

View File

@ -0,0 +1 @@
from .ann import Ann, is_available

View File

@ -0,0 +1,281 @@
#include <fstream>
#include <mutex>
#include <atomic>
#include "armnn/IRuntime.hpp"
#include "armnn/INetwork.hpp"
#include "armnn/Types.hpp"
#include "armnnDeserializer/IDeserializer.hpp"
#include "armnnTfLiteParser/ITfLiteParser.hpp"
#include "armnnOnnxParser/IOnnxParser.hpp"
using namespace armnn;
struct IOInfos
{
std::vector<BindingPointInfo> inputInfos;
std::vector<BindingPointInfo> outputInfos;
};
// from https://rigtorp.se/spinlock/
struct SpinLock
{
std::atomic<bool> lock_ = {false};
void lock()
{
for (;;)
{
if (!lock_.exchange(true, std::memory_order_acquire))
{
break;
}
while (lock_.load(std::memory_order_relaxed))
;
}
}
void unlock() { lock_.store(false, std::memory_order_release); }
};
class Ann
{
public:
int load(const char *modelPath,
bool fastMath,
bool fp16,
bool saveCachedNetwork,
const char *cachedNetworkPath)
{
INetworkPtr network = loadModel(modelPath);
IOptimizedNetworkPtr optNet = OptimizeNetwork(network.get(), fastMath, fp16, saveCachedNetwork, cachedNetworkPath);
const IOInfos infos = getIOInfos(optNet.get());
NetworkId netId;
mutex.lock();
Status status = runtime->LoadNetwork(netId, std::move(optNet));
mutex.unlock();
if (status != Status::Success)
{
return -1;
}
spinLock.lock();
ioInfos[netId] = infos;
mutexes.emplace(netId, std::make_unique<std::mutex>());
spinLock.unlock();
return netId;
}
void execute(NetworkId netId, const void **inputData, void **outputData)
{
spinLock.lock();
const IOInfos *infos = &ioInfos[netId];
auto m = mutexes[netId].get();
spinLock.unlock();
InputTensors inputTensors;
inputTensors.reserve(infos->inputInfos.size());
size_t i = 0;
for (const BindingPointInfo &info : infos->inputInfos)
inputTensors.emplace_back(info.first, ConstTensor(info.second, inputData[i++]));
OutputTensors outputTensors;
outputTensors.reserve(infos->outputInfos.size());
i = 0;
for (const BindingPointInfo &info : infos->outputInfos)
outputTensors.emplace_back(info.first, Tensor(info.second, outputData[i++]));
m->lock();
runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
m->unlock();
}
void unload(NetworkId netId)
{
mutex.lock();
runtime->UnloadNetwork(netId);
mutex.unlock();
}
int tensors(NetworkId netId, bool isInput = false)
{
spinLock.lock();
const IOInfos *infos = &ioInfos[netId];
spinLock.unlock();
return (int)(isInput ? infos->inputInfos.size() : infos->outputInfos.size());
}
unsigned long shape(NetworkId netId, bool isInput = false, int index = 0)
{
spinLock.lock();
const IOInfos *infos = &ioInfos[netId];
spinLock.unlock();
const TensorShape shape = (isInput ? infos->inputInfos : infos->outputInfos)[index].second.GetShape();
unsigned long s = 0;
for (unsigned int d = 0; d < shape.GetNumDimensions(); d++)
s |= ((unsigned long)shape[d]) << (d * 16); // stores up to 4 16-bit values in a 64-bit value
return s;
}
Ann(int tuningLevel, const char *tuningFile)
{
IRuntime::CreationOptions runtimeOptions;
BackendOptions backendOptions{"GpuAcc",
{
{"TuningLevel", tuningLevel},
{"MemoryOptimizerStrategy", "ConstantMemoryStrategy"}, // SingleAxisPriorityList or ConstantMemoryStrategy
}};
if (tuningFile)
backendOptions.AddOption({"TuningFile", tuningFile});
runtimeOptions.m_BackendOptions.emplace_back(backendOptions);
runtime = IRuntime::CreateRaw(runtimeOptions);
};
~Ann()
{
IRuntime::Destroy(runtime);
};
private:
INetworkPtr loadModel(const char *modelPath)
{
const auto path = std::string(modelPath);
if (path.rfind(".tflite") == path.length() - 7) // endsWith()
{
auto parser = armnnTfLiteParser::ITfLiteParser::CreateRaw();
return parser->CreateNetworkFromBinaryFile(modelPath);
}
else if (path.rfind(".onnx") == path.length() - 5) // endsWith()
{
auto parser = armnnOnnxParser::IOnnxParser::CreateRaw();
return parser->CreateNetworkFromBinaryFile(modelPath);
}
else
{
std::ifstream ifs(path, std::ifstream::in | std::ifstream::binary);
auto parser = armnnDeserializer::IDeserializer::CreateRaw();
return parser->CreateNetworkFromBinary(ifs);
}
}
static BindingPointInfo getInputTensorInfo(LayerBindingId inputBindingId, TensorInfo info)
{
const auto newInfo = TensorInfo{info.GetShape(), info.GetDataType(),
info.GetQuantizationScale(),
info.GetQuantizationOffset(),
true};
return {inputBindingId, newInfo};
}
IOptimizedNetworkPtr OptimizeNetwork(INetwork *network, bool fastMath, bool fp16, bool saveCachedNetwork, const char *cachedNetworkPath)
{
const bool allowExpandedDims = false;
const ShapeInferenceMethod shapeInferenceMethod = ShapeInferenceMethod::ValidateOnly;
OptimizerOptionsOpaque options;
options.SetReduceFp32ToFp16(fp16);
options.SetShapeInferenceMethod(shapeInferenceMethod);
options.SetAllowExpandedDims(allowExpandedDims);
BackendOptions gpuAcc("GpuAcc", {{"FastMathEnabled", fastMath}});
if (cachedNetworkPath)
{
gpuAcc.AddOption({"SaveCachedNetwork", saveCachedNetwork});
gpuAcc.AddOption({"CachedNetworkFilePath", cachedNetworkPath});
}
options.AddModelOption(gpuAcc);
// No point in using ARMNN for CPU, use ONNX (quantized) instead.
// BackendOptions cpuAcc("CpuAcc",
// {
// {"FastMathEnabled", fastMath},
// {"NumberOfThreads", 0},
// });
// options.AddModelOption(cpuAcc);
BackendOptions allowExDimOpt("AllowExpandedDims",
{{"AllowExpandedDims", allowExpandedDims}});
options.AddModelOption(allowExDimOpt);
BackendOptions shapeInferOpt("ShapeInferenceMethod",
{{"InferAndValidate", shapeInferenceMethod == ShapeInferenceMethod::InferAndValidate}});
options.AddModelOption(shapeInferOpt);
std::vector<BackendId> backends = {
BackendId("GpuAcc"),
// BackendId("CpuAcc"),
// BackendId("CpuRef"),
};
return Optimize(*network, backends, runtime->GetDeviceSpec(), options);
}
IOInfos getIOInfos(IOptimizedNetwork *optNet)
{
struct InfoStrategy : IStrategy
{
void ExecuteStrategy(const IConnectableLayer *layer,
const BaseDescriptor &descriptor,
const std::vector<ConstTensor> &constants,
const char *name,
const LayerBindingId id = 0) override
{
IgnoreUnused(descriptor, constants, id);
const LayerType lt = layer->GetType();
if (lt == LayerType::Input)
ioInfos.inputInfos.push_back(getInputTensorInfo(id, layer->GetOutputSlot(0).GetTensorInfo()));
else if (lt == LayerType::Output)
ioInfos.outputInfos.push_back({id, layer->GetInputSlot(0).GetTensorInfo()});
}
IOInfos ioInfos;
};
InfoStrategy infoStrategy;
optNet->ExecuteStrategy(infoStrategy);
return infoStrategy.ioInfos;
}
IRuntime *runtime;
std::map<NetworkId, IOInfos> ioInfos;
std::map<NetworkId, std::unique_ptr<std::mutex>> mutexes; // mutex per network to not execute the same the same network concurrently
std::mutex mutex; // global mutex for load/unload calls to the runtime
SpinLock spinLock; // fast spin lock to guard access to the ioInfos and mutexes maps
};
extern "C" void *init(int logLevel, int tuningLevel, const char *tuningFile)
{
LogSeverity level = static_cast<LogSeverity>(logLevel);
ConfigureLogging(true, true, level);
Ann *ann = new Ann(tuningLevel, tuningFile);
return ann;
}
extern "C" void destroy(void *ann)
{
delete ((Ann *)ann);
}
extern "C" int load(void *ann,
const char *path,
bool fastMath,
bool fp16,
bool saveCachedNetwork,
const char *cachedNetworkPath)
{
return ((Ann *)ann)->load(path, fastMath, fp16, saveCachedNetwork, cachedNetworkPath);
}
extern "C" void unload(void *ann, NetworkId netId)
{
((Ann *)ann)->unload(netId);
}
extern "C" void execute(void *ann, NetworkId netId, const void **inputData, void **outputData)
{
((Ann *)ann)->execute(netId, inputData, outputData);
}
extern "C" unsigned long shape(void *ann, NetworkId netId, bool isInput, int index)
{
return ((Ann *)ann)->shape(netId, isInput, index);
}
extern "C" int tensors(void *ann, NetworkId netId, bool isInput)
{
return ((Ann *)ann)->tensors(netId, isInput);
}

162
machine-learning/ann/ann.py Normal file
View File

@ -0,0 +1,162 @@
from __future__ import annotations
from ctypes import CDLL, Array, c_bool, c_char_p, c_int, c_ulong, c_void_p
from os.path import exists
from typing import Any, Generic, Protocol, Type, TypeVar
import numpy as np
from numpy.typing import NDArray
from app.config import log
try:
CDLL("libmali.so") # fail if libmali.so is not mounted into container
libann = CDLL("libann.so")
libann.init.argtypes = c_int, c_int, c_char_p
libann.init.restype = c_void_p
libann.load.argtypes = c_void_p, c_char_p, c_bool, c_bool, c_bool, c_char_p
libann.load.restype = c_int
libann.execute.argtypes = c_void_p, c_int, Array[c_void_p], Array[c_void_p]
libann.unload.argtypes = c_void_p, c_int
libann.destroy.argtypes = (c_void_p,)
libann.shape.argtypes = c_void_p, c_int, c_bool, c_int
libann.shape.restype = c_ulong
libann.tensors.argtypes = c_void_p, c_int, c_bool
libann.tensors.restype = c_int
is_available = True
except OSError as e:
log.debug("Could not load ANN shared libraries, using ONNX: %s", e)
is_available = False
T = TypeVar("T", covariant=True)
class Newable(Protocol[T]):
def new(self) -> None:
...
class _Singleton(type, Newable[T]):
_instances: dict[_Singleton[T], Newable[T]] = {}
def __call__(cls, *args: Any, **kwargs: Any) -> Newable[T]:
if cls not in cls._instances:
obj: Newable[T] = super(_Singleton, cls).__call__(*args, **kwargs)
cls._instances[cls] = obj
else:
obj = cls._instances[cls]
obj.new()
return obj
class Ann(metaclass=_Singleton):
def __init__(self, log_level: int = 3, tuning_level: int = 1, tuning_file: str | None = None) -> None:
if not is_available:
raise RuntimeError("libann is not available!")
if tuning_file and not exists(tuning_file):
raise ValueError("tuning_file must point to an existing (possibly empty) file!")
if tuning_level == 0 and tuning_file is None:
raise ValueError("tuning_level == 0 reads existing tuning information and requires a tuning_file")
if tuning_level < 0 or tuning_level > 3:
raise ValueError("tuning_level must be 0 (load from tuning_file), 1, 2 or 3.")
if log_level < 0 or log_level > 5:
raise ValueError("log_level must be 0 (trace), 1 (debug), 2 (info), 3 (warning), 4 (error) or 5 (fatal)")
self.log_level = log_level
self.tuning_level = tuning_level
self.tuning_file = tuning_file
self.output_shapes: dict[int, tuple[tuple[int], ...]] = {}
self.input_shapes: dict[int, tuple[tuple[int], ...]] = {}
self.ann: int | None = None
self.new()
def new(self) -> None:
if self.ann is None:
self.ann = libann.init(
self.log_level,
self.tuning_level,
self.tuning_file.encode() if self.tuning_file is not None else None,
)
self.ref_count = 0
self.ref_count += 1
def destroy(self) -> None:
self.ref_count -= 1
if self.ref_count <= 0 and self.ann is not None:
libann.destroy(self.ann)
self.ann = None
def __del__(self) -> None:
if self.ann is not None:
libann.destroy(self.ann)
self.ann = None
def load(
self,
model_path: str,
fast_math: bool = True,
fp16: bool = False,
save_cached_network: bool = False,
cached_network_path: str | None = None,
) -> int:
if not model_path.endswith((".armnn", ".tflite", ".onnx")):
raise ValueError("model_path must be a file with extension .armnn, .tflite or .onnx")
if not exists(model_path):
raise ValueError("model_path must point to an existing file!")
if cached_network_path is not None and not exists(cached_network_path):
raise ValueError("cached_network_path must point to an existing (possibly empty) file!")
if save_cached_network and cached_network_path is None:
raise ValueError("save_cached_network is True, cached_network_path must be specified!")
net_id: int = libann.load(
self.ann,
model_path.encode(),
fast_math,
fp16,
save_cached_network,
cached_network_path.encode() if cached_network_path is not None else None,
)
self.input_shapes[net_id] = tuple(
self.shape(net_id, input=True, index=i) for i in range(self.tensors(net_id, input=True))
)
self.output_shapes[net_id] = tuple(
self.shape(net_id, input=False, index=i) for i in range(self.tensors(net_id, input=False))
)
return net_id
def unload(self, network_id: int) -> None:
libann.unload(self.ann, network_id)
del self.output_shapes[network_id]
def execute(self, network_id: int, input_tensors: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
if not isinstance(input_tensors, list):
raise ValueError("input_tensors needs to be a list!")
net_input_shapes = self.input_shapes[network_id]
if len(input_tensors) != len(net_input_shapes):
raise ValueError(f"input_tensors lengths {len(input_tensors)} != network inputs {len(net_input_shapes)}")
for net_input_shape, input_tensor in zip(net_input_shapes, input_tensors):
if net_input_shape != input_tensor.shape:
raise ValueError(f"input_tensor shape {input_tensor.shape} != network input shape {net_input_shape}")
if not input_tensor.flags.c_contiguous:
raise ValueError("input_tensors must be c_contiguous numpy ndarrays")
output_tensors: list[NDArray[np.float32]] = [
np.ndarray(s, dtype=np.float32) for s in self.output_shapes[network_id]
]
input_type = c_void_p * len(input_tensors)
inputs = input_type(*[t.ctypes.data_as(c_void_p) for t in input_tensors])
output_type = c_void_p * len(output_tensors)
outputs = output_type(*[t.ctypes.data_as(c_void_p) for t in output_tensors])
libann.execute(self.ann, network_id, inputs, outputs)
return output_tensors
def shape(self, network_id: int, input: bool = False, index: int = 0) -> tuple[int]:
s = libann.shape(self.ann, network_id, input, index)
a = []
while s != 0:
a.append(s & 0xFFFF)
s >>= 16
return tuple(a)
def tensors(self, network_id: int, input: bool = False) -> int:
tensors: int = libann.tensors(self.ann, network_id, input)
return tensors

View File

@ -0,0 +1 @@
g++ -shared -O3 -o libann.so -fuse-ld=gold -std=c++17 -I$ARMNN_PATH/include -larmnn -larmnnDeserializer -larmnnTfLiteParser -larmnnOnnxParser -L$ARMNN_PATH ann.cpp

View File

@ -0,0 +1,2 @@
armnn*
output/

View File

@ -0,0 +1,4 @@
#!/bin/sh
cd armnn-23.11/
g++ -o ../armnnconverter -O1 -DARMNN_ONNX_PARSER -DARMNN_SERIALIZER -DARMNN_TF_LITE_PARSER -fuse-ld=gold -std=c++17 -Iinclude -Isrc/armnnUtils -Ithird-party -larmnn -larmnnDeserializer -larmnnTfLiteParser -larmnnOnnxParser -larmnnSerializer -L../armnn src/armnnConverter/ArmnnConverter.cpp

View File

@ -0,0 +1,8 @@
#!/bin/sh
# binaries
mkdir armnn
curl -SL "https://github.com/ARM-software/armnn/releases/download/v23.11/ArmNN-linux-x86_64.tar.gz" | tar -zx -C armnn
# source to build ArmnnConverter
curl -SL "https://github.com/ARM-software/armnn/archive/refs/tags/v23.11.tar.gz" | tar -zx

View File

@ -0,0 +1,201 @@
name: annexport
channels:
- pytorch
- nvidia
- conda-forge
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_kmp_llvm
- aiohttp=3.9.1=py310h2372a71_0
- aiosignal=1.3.1=pyhd8ed1ab_0
- arpack=3.8.0=nompi_h0baa96a_101
- async-timeout=4.0.3=pyhd8ed1ab_0
- attrs=23.1.0=pyh71513ae_1
- aws-c-auth=0.7.3=h28f7589_1
- aws-c-cal=0.6.1=hc309b26_1
- aws-c-common=0.9.0=hd590300_0
- aws-c-compression=0.2.17=h4d4d85c_2
- aws-c-event-stream=0.3.1=h2e3709c_4
- aws-c-http=0.7.11=h00aa349_4
- aws-c-io=0.13.32=he9a53bd_1
- aws-c-mqtt=0.9.3=hb447be9_1
- aws-c-s3=0.3.14=hf3aad02_1
- aws-c-sdkutils=0.1.12=h4d4d85c_1
- aws-checksums=0.1.17=h4d4d85c_1
- aws-crt-cpp=0.21.0=hb942446_5
- aws-sdk-cpp=1.10.57=h85b1a90_19
- blas=2.120=openblas
- blas-devel=3.9.0=20_linux64_openblas
- brotli-python=1.0.9=py310hd8f1fbe_9
- bzip2=1.0.8=hd590300_5
- c-ares=1.23.0=hd590300_0
- ca-certificates=2023.11.17=hbcca054_0
- certifi=2023.11.17=pyhd8ed1ab_0
- charset-normalizer=3.3.2=pyhd8ed1ab_0
- click=8.1.7=unix_pyh707e725_0
- colorama=0.4.6=pyhd8ed1ab_0
- coloredlogs=15.0.1=pyhd8ed1ab_3
- cuda-cudart=11.7.99=0
- cuda-cupti=11.7.101=0
- cuda-libraries=11.7.1=0
- cuda-nvrtc=11.7.99=0
- cuda-nvtx=11.7.91=0
- cuda-runtime=11.7.1=0
- dataclasses=0.8=pyhc8e2a94_3
- datasets=2.14.7=pyhd8ed1ab_0
- dill=0.3.7=pyhd8ed1ab_0
- filelock=3.13.1=pyhd8ed1ab_0
- flatbuffers=23.5.26=h59595ed_1
- freetype=2.12.1=h267a509_2
- frozenlist=1.4.0=py310h2372a71_1
- fsspec=2023.10.0=pyhca7485f_0
- ftfy=6.1.3=pyhd8ed1ab_0
- gflags=2.2.2=he1b5a44_1004
- glog=0.6.0=h6f12383_0
- glpk=5.0=h445213a_0
- gmp=6.3.0=h59595ed_0
- gmpy2=2.1.2=py310h3ec546c_1
- huggingface_hub=0.17.3=pyhd8ed1ab_0
- humanfriendly=10.0=pyhd8ed1ab_6
- icu=73.2=h59595ed_0
- idna=3.6=pyhd8ed1ab_0
- importlib-metadata=7.0.0=pyha770c72_0
- importlib_metadata=7.0.0=hd8ed1ab_0
- joblib=1.3.2=pyhd8ed1ab_0
- keyutils=1.6.1=h166bdaf_0
- krb5=1.21.2=h659d440_0
- lcms2=2.15=h7f713cb_2
- ld_impl_linux-64=2.40=h41732ed_0
- lerc=4.0.0=h27087fc_0
- libabseil=20230125.3=cxx17_h59595ed_0
- libarrow=12.0.1=hb87d912_8_cpu
- libblas=3.9.0=20_linux64_openblas
- libbrotlicommon=1.0.9=h166bdaf_9
- libbrotlidec=1.0.9=h166bdaf_9
- libbrotlienc=1.0.9=h166bdaf_9
- libcblas=3.9.0=20_linux64_openblas
- libcrc32c=1.1.2=h9c3ff4c_0
- libcublas=11.10.3.66=0
- libcufft=10.7.2.124=h4fbf590_0
- libcufile=1.8.1.2=0
- libcurand=10.3.4.101=0
- libcurl=8.5.0=hca28451_0
- libcusolver=11.4.0.1=0
- libcusparse=11.7.4.91=0
- libdeflate=1.19=hd590300_0
- libedit=3.1.20191231=he28a2e2_2
- libev=4.33=hd590300_2
- libevent=2.1.12=hf998b51_1
- libffi=3.4.2=h7f98852_5
- libgcc-ng=13.2.0=h807b86a_3
- libgfortran-ng=13.2.0=h69a702a_3
- libgfortran5=13.2.0=ha4646dd_3
- libgoogle-cloud=2.12.0=hac9eb74_1
- libgrpc=1.54.3=hb20ce57_0
- libhwloc=2.9.3=default_h554bfaf_1009
- libiconv=1.17=hd590300_1
- libjpeg-turbo=2.1.5.1=hd590300_1
- liblapack=3.9.0=20_linux64_openblas
- liblapacke=3.9.0=20_linux64_openblas
- libnghttp2=1.58.0=h47da74e_1
- libnpp=11.7.4.75=0
- libnsl=2.0.1=hd590300_0
- libnuma=2.0.16=h0b41bf4_1
- libnvjpeg=11.8.0.2=0
- libopenblas=0.3.25=pthreads_h413a1c8_0
- libpng=1.6.39=h753d276_0
- libprotobuf=3.21.12=hfc55251_2
- libsentencepiece=0.1.99=h180e1df_0
- libsqlite=3.44.2=h2797004_0
- libssh2=1.11.0=h0841786_0
- libstdcxx-ng=13.2.0=h7e041cc_3
- libthrift=0.18.1=h8fd135c_2
- libtiff=4.6.0=h29866fb_1
- libutf8proc=2.8.0=h166bdaf_0
- libuuid=2.38.1=h0b41bf4_0
- libwebp-base=1.3.2=hd590300_0
- libxcb=1.15=h0b41bf4_0
- libxml2=2.11.6=h232c23b_0
- libzlib=1.2.13=hd590300_5
- llvm-openmp=17.0.6=h4dfa4b3_0
- lz4-c=1.9.4=hcb278e6_0
- mkl=2022.2.1=h84fe81f_16997
- mkl-devel=2022.2.1=ha770c72_16998
- mkl-include=2022.2.1=h84fe81f_16997
- mpc=1.3.1=hfe3b2da_0
- mpfr=4.2.1=h9458935_0
- mpmath=1.3.0=pyhd8ed1ab_0
- multidict=6.0.4=py310h2372a71_1
- multiprocess=0.70.15=py310h2372a71_1
- ncurses=6.4=h59595ed_2
- numpy=1.26.2=py310hb13e2d6_0
- onnx=1.14.0=py310ha3deec4_1
- onnx2torch=1.5.13=pyhd8ed1ab_0
- onnxruntime=1.16.3=py310hd4b7fbc_1_cpu
- open-clip-torch=2.23.0=pyhd8ed1ab_1
- openblas=0.3.25=pthreads_h7a3da1a_0
- openjpeg=2.5.0=h488ebb8_3
- openssl=3.2.0=hd590300_1
- orc=1.9.0=h2f23424_1
- packaging=23.2=pyhd8ed1ab_0
- pandas=2.1.4=py310hcc13569_0
- pillow=10.0.1=py310h29da1c1_1
- pip=23.3.1=pyhd8ed1ab_0
- protobuf=4.21.12=py310heca2aa9_0
- pthread-stubs=0.4=h36c2ea0_1001
- pyarrow=12.0.1=py310h0576679_8_cpu
- pyarrow-hotfix=0.6=pyhd8ed1ab_0
- pysocks=1.7.1=pyha2e5f31_6
- python=3.10.13=hd12c33a_0_cpython
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python-flatbuffers=23.5.26=pyhd8ed1ab_0
- python-tzdata=2023.3=pyhd8ed1ab_0
- python-xxhash=3.4.1=py310h2372a71_0
- python_abi=3.10=4_cp310
- pytorch=1.13.1=cpu_py310hd11e9c7_1
- pytorch-cuda=11.7=h778d358_5
- pytorch-mutex=1.0=cuda
- pytz=2023.3.post1=pyhd8ed1ab_0
- pyyaml=6.0.1=py310h2372a71_1
- rdma-core=28.9=h59595ed_1
- re2=2023.03.02=h8c504da_0
- readline=8.2=h8228510_1
- regex=2023.10.3=py310h2372a71_0
- requests=2.31.0=pyhd8ed1ab_0
- s2n=1.3.49=h06160fa_0
- sacremoses=0.0.53=pyhd8ed1ab_0
- safetensors=0.3.3=py310hcb5633a_1
- sentencepiece=0.1.99=hff52083_0
- sentencepiece-python=0.1.99=py310hebdb9f0_0
- sentencepiece-spm=0.1.99=h180e1df_0
- setuptools=68.2.2=pyhd8ed1ab_0
- six=1.16.0=pyh6c4a22f_0
- sleef=3.5.1=h9b69904_2
- snappy=1.1.10=h9fff704_0
- sympy=1.12=pypyh9d50eac_103
- tbb=2021.11.0=h00ab1b0_0
- texttable=1.7.0=pyhd8ed1ab_0
- timm=0.9.12=pyhd8ed1ab_0
- tk=8.6.13=noxft_h4845f30_101
- tokenizers=0.14.1=py310h320607d_2
- torchvision=0.14.1=cpu_py310hd3d2ac3_1
- tqdm=4.66.1=pyhd8ed1ab_0
- transformers=4.35.2=pyhd8ed1ab_0
- typing-extensions=4.9.0=hd8ed1ab_0
- typing_extensions=4.9.0=pyha770c72_0
- tzdata=2023c=h71feb2d_0
- ucx=1.14.1=h64cca9d_5
- urllib3=2.1.0=pyhd8ed1ab_0
- wcwidth=0.2.12=pyhd8ed1ab_0
- wheel=0.42.0=pyhd8ed1ab_0
- xorg-libxau=1.0.11=hd590300_0
- xorg-libxdmcp=1.1.3=h7f98852_0
- xxhash=0.8.2=hd590300_0
- xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7f98852_2
- yarl=1.9.3=py310h2372a71_0
- zipp=3.17.0=pyhd8ed1ab_0
- zlib=1.2.13=hd590300_5
- zstd=1.5.5=hfc55251_0
- pip:
- git+https://github.com/fyfrey/TinyNeuralNetwork.git

View File

@ -0,0 +1,157 @@
import logging
import os
import platform
import subprocess
from abc import abstractmethod
import onnx
import open_clip
import torch
from onnx2torch import convert
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
from tinynn.converter import TFLiteConverter
class ExportBase(torch.nn.Module):
input_shape: tuple[int, ...]
def __init__(self, device: torch.device, name: str):
super().__init__()
self.device = device
self.name = name
self.optimize = 5
self.nchw_transpose = False
@abstractmethod
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]:
pass
def dummy_input(self) -> torch.FloatTensor:
return torch.rand((1, 3, 224, 224), device=self.device)
class ArcFace(ExportBase):
input_shape = (1, 3, 112, 112)
def __init__(self, onnx_model_path: str, device: torch.device):
name, _ = os.path.splitext(os.path.basename(onnx_model_path))
super().__init__(device, name)
onnx_model = onnx.load_model(onnx_model_path)
make_input_shape_fixed(onnx_model.graph, onnx_model.graph.input[0].name, self.input_shape)
fix_output_shapes(onnx_model)
self.model = convert(onnx_model).to(device)
if self.device.type == "cuda":
self.model = self.model.half()
def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
embedding: torch.FloatTensor = self.model(
input_tensor.half() if self.device.type == "cuda" else input_tensor
).float()
assert isinstance(embedding, torch.FloatTensor)
return embedding
def dummy_input(self) -> torch.FloatTensor:
return torch.rand(self.input_shape, device=self.device)
class RetinaFace(ExportBase):
input_shape = (1, 3, 640, 640)
def __init__(self, onnx_model_path: str, device: torch.device):
name, _ = os.path.splitext(os.path.basename(onnx_model_path))
super().__init__(device, name)
self.optimize = 3
self.model = convert(onnx_model_path).eval().to(device)
if self.device.type == "cuda":
self.model = self.model.half()
def forward(self, input_tensor: torch.Tensor) -> tuple[torch.FloatTensor]:
out: torch.Tensor = self.model(input_tensor.half() if self.device.type == "cuda" else input_tensor)
return tuple(o.float() for o in out)
def dummy_input(self) -> torch.FloatTensor:
return torch.rand(self.input_shape, device=self.device)
class ClipVision(ExportBase):
input_shape = (1, 3, 224, 224)
def __init__(self, model_name: str, weights: str, device: torch.device):
super().__init__(device, model_name + "__" + weights)
self.model = open_clip.create_model(
model_name,
weights,
precision="fp16" if device.type == "cuda" else "fp32",
jit=False,
require_pretrained=True,
device=device,
)
def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
embedding: torch.Tensor = self.model.encode_image(
input_tensor.half() if self.device.type == "cuda" else input_tensor,
normalize=True,
).float()
return embedding
def export(model: ExportBase) -> None:
model.eval()
for param in model.parameters():
param.requires_grad = False
dummy_input = model.dummy_input()
model(dummy_input)
jit = torch.jit.trace(model, dummy_input) # type: ignore[no-untyped-call,attr-defined]
tflite_model_path = f"output/{model.name}.tflite"
os.makedirs("output", exist_ok=True)
converter = TFLiteConverter(
jit,
dummy_input,
tflite_model_path,
optimize=model.optimize,
nchw_transpose=model.nchw_transpose,
)
# segfaults on ARM, must run on x86_64 / AMD64
converter.convert()
armnn_model_path = f"output/{model.name}.armnn"
os.environ["LD_LIBRARY_PATH"] = "armnn"
subprocess.run(
[
"./armnnconverter",
"-f",
"tflite-binary",
"-m",
tflite_model_path,
"-i",
"input_tensor",
"-o",
"output_tensor",
"-p",
armnn_model_path,
]
)
def main() -> None:
if platform.machine() not in ("x86_64", "AMD64"):
raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != "cuda":
logging.warning(
"No CUDA available, cannot create fp16 model! proceeding to create a fp32 model (use only for testing)"
)
models = [
ClipVision("ViT-B-32", "openai", device),
ArcFace("buffalo_l_rec.onnx", device),
RetinaFace("buffalo_l_det.onnx", device),
]
for model in models:
export(model)
if __name__ == "__main__":
with torch.no_grad():
main()

View File

@ -26,6 +26,7 @@ class Settings(BaseSettings):
request_threads: int = os.cpu_count() or 4
model_inter_op_threads: int = 1
model_intra_op_threads: int = 2
ann: bool = True
class Config:
env_prefix = "MACHINE_LEARNING_"

View File

@ -0,0 +1,68 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, NamedTuple
from numpy import ascontiguousarray
from ann.ann import Ann
from app.schemas import ndarray_f32, ndarray_i32
from ..config import log, settings
class AnnSession:
"""
Wrapper for ANN to be drop-in replacement for ONNX session.
"""
def __init__(self, model_path: Path):
tuning_file = Path(settings.cache_folder) / "gpu-tuning.ann"
with tuning_file.open(mode="a"):
# make sure tuning file exists (without clearing contents)
# once filled, the tuning file reduces the cost/time of the first
# inference after model load by 10s of seconds
pass
self.ann = Ann(tuning_level=3, tuning_file=tuning_file.as_posix())
log.info("Loading ANN model %s ...", model_path)
cache_file = model_path.with_suffix(".anncache")
save = False
if not cache_file.is_file():
save = True
with cache_file.open(mode="a"):
# create empty model cache file
pass
self.model = self.ann.load(
model_path.as_posix(),
save_cached_network=save,
cached_network_path=cache_file.as_posix(),
)
log.info("Loaded ANN model with ID %d", self.model)
def __del__(self) -> None:
self.ann.unload(self.model)
log.info("Unloaded ANN model %d", self.model)
self.ann.destroy()
def get_inputs(self) -> list[AnnNode]:
shapes = self.ann.input_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def get_outputs(self) -> list[AnnNode]:
shapes = self.ann.output_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, ndarray_f32] | dict[str, ndarray_i32],
run_options: Any = None,
) -> list[ndarray_f32]:
inputs: list[ndarray_f32] = [ascontiguousarray(v) for v in input_feed.values()]
return self.ann.execute(self.model, inputs)
class AnnNode(NamedTuple):
name: str | None
shape: tuple[int, ...]

View File

@ -10,8 +10,11 @@ import onnxruntime as ort
from huggingface_hub import snapshot_download
from typing_extensions import Buffer
import ann.ann
from ..config import get_cache_dir, get_hf_model_name, log, settings
from ..schemas import ModelType
from .ann import AnnSession
class InferenceModel(ABC):
@ -138,6 +141,21 @@ class InferenceModel(ABC):
self.cache_dir.unlink()
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
armnn_path = model_path.with_suffix(".armnn")
if settings.ann and ann.ann.is_available and armnn_path.is_file():
session = AnnSession(armnn_path)
elif model_path.is_file():
session = ort.InferenceSession(
model_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
else:
raise ValueError(f"the file model_path='{model_path}' does not exist")
return session
# HF deep copies configs, so we need to make session options picklable
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]

View File

@ -6,7 +6,6 @@ from pathlib import Path
from typing import Any, Literal
import numpy as np
import onnxruntime as ort
from PIL import Image
from tokenizers import Encoding, Tokenizer
@ -33,24 +32,12 @@ class BaseCLIPEncoder(InferenceModel):
def _load(self) -> None:
if self.mode == "text" or self.mode is None:
log.debug(f"Loading clip text model '{self.model_name}'")
self.text_model = ort.InferenceSession(
self.textual_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
self.text_model = self._make_session(self.textual_path)
log.debug(f"Loaded clip text model '{self.model_name}'")
if self.mode == "vision" or self.mode is None:
log.debug(f"Loading clip vision model '{self.model_name}'")
self.vision_model = ort.InferenceSession(
self.visual_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
self.vision_model = self._make_session(self.visual_path)
log.debug(f"Loaded clip vision model '{self.model_name}'")
def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32:
@ -61,12 +48,10 @@ class BaseCLIPEncoder(InferenceModel):
case Image.Image():
if self.mode == "text":
raise TypeError("Cannot encode image as text-only model")
outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0]
case str():
if self.mode == "vision":
raise TypeError("Cannot encode text as vision-only model")
outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0]
case _:
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")

View File

@ -3,7 +3,6 @@ from typing import Any
import cv2
import numpy as np
import onnxruntime as ort
from insightface.model_zoo import ArcFaceONNX, RetinaFace
from insightface.utils.face_align import norm_crop
@ -27,23 +26,8 @@ class FaceRecognizer(InferenceModel):
super().__init__(clean_name(model_name), cache_dir, **model_kwargs)
def _load(self) -> None:
self.det_model = RetinaFace(
session=ort.InferenceSession(
self.det_file.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
),
)
self.rec_model = ArcFaceONNX(
self.rec_file.as_posix(),
session=ort.InferenceSession(
self.rec_file.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
),
)
self.det_model = RetinaFace(session=self._make_session(self.det_file))
self.rec_model = ArcFaceONNX(self.rec_file.as_posix(), session=self._make_session(self.rec_file))
self.det_model.prepare(
ctx_id=0,

View File

@ -13,7 +13,7 @@ from PIL import Image
from pytest_mock import MockerFixture
from .config import settings
from .models.base import PicklableSessionOptions
from .models.base import InferenceModel, PicklableSessionOptions
from .models.cache import ModelCache
from .models.clip import OpenCLIPEncoder
from .models.facial_recognition import FaceRecognizer
@ -36,9 +36,10 @@ class TestCLIP:
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mocked = mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
mocked.run.return_value = [[self.embedding]]
mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True)
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
mocked.return_value.run.return_value = [[self.embedding]]
clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision")
embedding = clip_encoder.predict(pil_image)
@ -47,7 +48,7 @@ class TestCLIP:
assert isinstance(embedding, np.ndarray)
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
assert embedding.dtype == np.float32
clip_encoder.vision_model.run.assert_called_once()
mocked.run.assert_called_once()
def test_basic_text(
self,
@ -60,9 +61,10 @@ class TestCLIP:
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mocked = mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
mocked.run.return_value = [[self.embedding]]
mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True)
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
mocked.return_value.run.return_value = [[self.embedding]]
clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text")
embedding = clip_encoder.predict("test search query")
@ -71,7 +73,7 @@ class TestCLIP:
assert isinstance(embedding, np.ndarray)
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
assert embedding.dtype == np.float32
clip_encoder.text_model.run.assert_called_once()
mocked.run.assert_called_once()
class TestFaceRecognition: