1
0
mirror of https://github.com/immich-app/immich.git synced 2025-01-13 15:35:15 +02:00

fixed tests (#5017)

This commit is contained in:
Mert 2023-11-13 14:37:39 -05:00 committed by GitHub
parent 464cf903f4
commit 291159e7fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -75,9 +75,9 @@ class TestCLIP:
embedding = clip_encoder.predict(pil_image) embedding = clip_encoder.predict(pil_image)
assert clip_encoder.mode == "vision" assert clip_encoder.mode == "vision"
assert isinstance(embedding, list) assert isinstance(embedding, np.ndarray)
assert len(embedding) == clip_model_cfg["embed_dim"] assert embedding.shape[0] == clip_model_cfg["embed_dim"]
assert all([isinstance(num, float) for num in embedding]) assert embedding.dtype == np.float32
clip_encoder.vision_model.run.assert_called_once() clip_encoder.vision_model.run.assert_called_once()
def test_basic_text( def test_basic_text(
@ -97,9 +97,9 @@ class TestCLIP:
embedding = clip_encoder.predict("test search query") embedding = clip_encoder.predict("test search query")
assert clip_encoder.mode == "text" assert clip_encoder.mode == "text"
assert isinstance(embedding, list) assert isinstance(embedding, np.ndarray)
assert len(embedding) == clip_model_cfg["embed_dim"] assert embedding.shape[0] == clip_model_cfg["embed_dim"]
assert all([isinstance(num, float) for num in embedding]) assert embedding.dtype == np.float32
clip_encoder.text_model.run.assert_called_once() clip_encoder.text_model.run.assert_called_once()
@ -133,9 +133,9 @@ class TestFaceRecognition:
for face in faces: for face in faces:
assert face["imageHeight"] == 800 assert face["imageHeight"] == 800
assert face["imageWidth"] == 600 assert face["imageWidth"] == 600
assert isinstance(face["embedding"], list) assert isinstance(face["embedding"], np.ndarray)
assert len(face["embedding"]) == 512 assert face["embedding"].shape[0] == 512
assert all([isinstance(num, float) for num in face["embedding"]]) assert face["embedding"].dtype == np.float32
det_model.detect.assert_called_once() det_model.detect.assert_called_once()
assert rec_model.get_feat.call_count == num_faces assert rec_model.get_feat.call_count == num_faces