mirror of
https://github.com/immich-app/immich.git
synced 2024-12-27 10:58:13 +02:00
fix(ml): batch axis not being added for recognition model (#12588)
* fix has_batch_axis * fix typing
This commit is contained in:
parent
fa095c3ca0
commit
22dc9bcebb
@ -13,7 +13,6 @@ from app.config import log
|
||||
from app.models.base import InferenceModel
|
||||
from app.models.transforms import decode_cv2
|
||||
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType
|
||||
from app.sessions import has_batch_axis
|
||||
|
||||
|
||||
class FaceRecognizer(InferenceModel):
|
||||
@ -27,7 +26,7 @@ class FaceRecognizer(InferenceModel):
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
session = self._make_session(self.model_path)
|
||||
if self.batch and not has_batch_axis(session):
|
||||
if self.batch and str(session.get_inputs()[0].shape[0]) != "batch":
|
||||
self._add_batch_axis(self.model_path)
|
||||
session = self._make_session(self.model_path)
|
||||
self.model = ArcFaceONNX(
|
||||
|
@ -1,5 +0,0 @@
|
||||
from app.schemas import ModelSession
|
||||
|
||||
|
||||
def has_batch_axis(session: ModelSession) -> bool:
|
||||
return not isinstance(session.get_inputs()[0].shape[0], int) or session.get_inputs()[0].shape[0] < 0
|
Loading…
Reference in New Issue
Block a user