You've already forked immich
mirror of
https://github.com/immich-app/immich.git
synced 2025-06-27 05:11:11 +02:00
feat(ml): configurable batch size for facial recognition (#13689)
* configurable batch size, default openvino to 1 * update docs * don't add a new dependency for two lines * fix typing
This commit is contained in:
@ -549,7 +549,7 @@ class TestFaceRecognition:
|
||||
face_recognizer = FaceRecognizer("buffalo_s", cache_dir=path)
|
||||
face_recognizer.load()
|
||||
|
||||
assert face_recognizer.batch is True
|
||||
assert face_recognizer.batch_size is None
|
||||
update_dims.assert_called_once_with(proto, {"input.1": ["batch", 3, 224, 224]}, {"output.1": ["batch", 800]})
|
||||
onnx.save.assert_called_once_with(update_dims.return_value, face_recognizer.model_path)
|
||||
|
||||
@ -572,7 +572,7 @@ class TestFaceRecognition:
|
||||
face_recognizer = FaceRecognizer("buffalo_s", cache_dir=path)
|
||||
face_recognizer.load()
|
||||
|
||||
assert face_recognizer.batch is True
|
||||
assert face_recognizer.batch_size is None
|
||||
update_dims.assert_not_called()
|
||||
onnx.load.assert_not_called()
|
||||
onnx.save.assert_not_called()
|
||||
@ -596,7 +596,33 @@ class TestFaceRecognition:
|
||||
face_recognizer = FaceRecognizer("buffalo_s", model_format=ModelFormat.ARMNN, cache_dir=path)
|
||||
face_recognizer.load()
|
||||
|
||||
assert face_recognizer.batch is False
|
||||
assert face_recognizer.batch_size == 1
|
||||
update_dims.assert_not_called()
|
||||
onnx.load.assert_not_called()
|
||||
onnx.save.assert_not_called()
|
||||
|
||||
def test_recognition_does_not_add_batch_axis_for_openvino(
|
||||
self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture
|
||||
) -> None:
|
||||
onnx = mocker.patch("app.models.facial_recognition.recognition.onnx", autospec=True)
|
||||
update_dims = mocker.patch(
|
||||
"app.models.facial_recognition.recognition.update_inputs_outputs_dims", autospec=True
|
||||
)
|
||||
mocker.patch("app.models.base.InferenceModel.download")
|
||||
mocker.patch("app.models.facial_recognition.recognition.ArcFaceONNX")
|
||||
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
|
||||
|
||||
inputs = [SimpleNamespace(name="input.1", shape=("batch", 3, 224, 224))]
|
||||
outputs = [SimpleNamespace(name="output.1", shape=("batch", 800))]
|
||||
ort_session.return_value.get_inputs.return_value = inputs
|
||||
ort_session.return_value.get_outputs.return_value = outputs
|
||||
|
||||
face_recognizer = FaceRecognizer(
|
||||
"buffalo_s", model_format=ModelFormat.ARMNN, cache_dir=path, providers=["OpenVINOExecutionProvider"]
|
||||
)
|
||||
face_recognizer.load()
|
||||
|
||||
assert face_recognizer.batch_size == 1
|
||||
update_dims.assert_not_called()
|
||||
onnx.load.assert_not_called()
|
||||
onnx.save.assert_not_called()
|
||||
|
Reference in New Issue
Block a user