1
0
mirror of https://github.com/immich-app/immich.git synced 2024-11-28 09:33:27 +02:00

feat(server): optimize partial facial recognition (#6634)

* optimize partial facial recognition

* add tests

* use map

* bulk insert faces
This commit is contained in:
Mert 2024-01-25 01:27:39 -05:00 committed by GitHub
parent 852effa998
commit bd87eb309c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 77 additions and 46 deletions

View File

@ -788,11 +788,13 @@ describe(`${AssetController.name} (e2e)`, () => {
const personRepository = app.get<IPersonRepository>(IPersonRepository);
const person = await personRepository.create({ ownerId: asset1.ownerId, name: 'Test Person' });
await personRepository.createFace({
assetId: asset1.id,
personId: person.id,
embedding: Array.from({ length: 512 }, Math.random),
});
await personRepository.createFaces([
{
assetId: asset1.id,
personId: person.id,
embedding: Array.from({ length: 512 }, Math.random),
},
]);
const { status, body } = await request(server)
.put(`/asset/${asset1.id}`)
@ -1377,11 +1379,13 @@ describe(`${AssetController.name} (e2e)`, () => {
beforeEach(async () => {
const personRepository = app.get<IPersonRepository>(IPersonRepository);
const person = await personRepository.create({ ownerId: asset1.ownerId, name: 'Test Person' });
await personRepository.createFace({
assetId: asset1.id,
personId: person.id,
embedding: Array.from({ length: 512 }, Math.random),
});
await personRepository.createFaces([
{
assetId: asset1.id,
personId: person.id,
embedding: Array.from({ length: 512 }, Math.random),
},
]);
});
it('should not return asset with facesRecognizedAt unset', async () => {

View File

@ -38,11 +38,13 @@ describe(`${PersonController.name}`, () => {
name: 'visible_person',
thumbnailPath: '/thumbnail/face_asset',
});
await personRepository.createFace({
assetId: faceAsset.id,
personId: visiblePerson.id,
embedding: Array.from({ length: 512 }, Math.random),
});
await personRepository.createFaces([
{
assetId: faceAsset.id,
personId: visiblePerson.id,
embedding: Array.from({ length: 512 }, Math.random),
},
]);
hiddenPerson = await personRepository.create({
ownerId: loginResponse.userId,
@ -50,11 +52,13 @@ describe(`${PersonController.name}`, () => {
isHidden: true,
thumbnailPath: '/thumbnail/face_asset',
});
await personRepository.createFace({
assetId: faceAsset.id,
personId: hiddenPerson.id,
embedding: Array.from({ length: 512 }, Math.random),
});
await personRepository.createFaces([
{
assetId: faceAsset.id,
personId: hiddenPerson.id,
embedding: Array.from({ length: 512 }, Math.random),
},
]);
});
describe('GET /person', () => {

View File

@ -61,6 +61,7 @@ describe(JobService.name, () => {
{ name: JobName.QUEUE_GENERATE_THUMBNAILS, data: { force: false } },
{ name: JobName.CLEAN_OLD_AUDIT_LOGS },
{ name: JobName.USER_SYNC_USAGE },
{ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } },
]);
});
});
@ -318,7 +319,7 @@ describe(JobService.name, () => {
},
{
item: { name: JobName.FACE_DETECTION, data: { id: 'asset-1' } },
jobs: [JobName.QUEUE_FACIAL_RECOGNITION],
jobs: [],
},
{
item: { name: JobName.FACIAL_RECOGNITION, data: { id: 'asset-1' } },

View File

@ -174,6 +174,7 @@ export class JobService {
{ name: JobName.QUEUE_GENERATE_THUMBNAILS, data: { force: false } },
{ name: JobName.CLEAN_OLD_AUDIT_LOGS },
{ name: JobName.USER_SYNC_USAGE },
{ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } },
]);
}
@ -255,11 +256,6 @@ export class JobService {
}
break;
}
case JobName.FACE_DETECTION: {
await this.jobRepository.queue({ name: JobName.QUEUE_FACIAL_RECOGNITION, data: item.data });
break;
}
}
}
}

View File

@ -607,14 +607,23 @@ describe(PersonService.name, () => {
describe('handleQueueRecognizeFaces', () => {
it('should return if machine learning is disabled', async () => {
jobMock.getJobCounts.mockResolvedValue({ active: 1, waiting: 0, paused: 0, completed: 0, failed: 0, delayed: 0 });
configMock.load.mockResolvedValue([{ key: SystemConfigKey.MACHINE_LEARNING_ENABLED, value: false }]);
await expect(sut.handleQueueRecognizeFaces({})).resolves.toBe(true);
expect(jobMock.queue).not.toHaveBeenCalled();
expect(jobMock.queueAll).not.toHaveBeenCalled();
expect(configMock.load).toHaveBeenCalled();
});
it('should return if recognition jobs are already queued', async () => {
jobMock.getJobCounts.mockResolvedValue({ active: 1, waiting: 1, paused: 0, completed: 0, failed: 0, delayed: 0 });
await expect(sut.handleQueueRecognizeFaces({})).resolves.toBe(true);
expect(jobMock.queueAll).not.toHaveBeenCalled();
});
it('should queue missing assets', async () => {
jobMock.getJobCounts.mockResolvedValue({ active: 1, waiting: 0, paused: 0, completed: 0, failed: 0, delayed: 0 });
personMock.getAllFaces.mockResolvedValue({
items: [faceStub.face1],
hasNextPage: false,
@ -632,6 +641,7 @@ describe(PersonService.name, () => {
});
it('should queue all assets', async () => {
jobMock.getJobCounts.mockResolvedValue({ active: 1, waiting: 0, paused: 0, completed: 0, failed: 0, delayed: 0 });
personMock.getAll.mockResolvedValue({
items: [],
hasNextPage: false,
@ -653,6 +663,7 @@ describe(PersonService.name, () => {
});
it('should delete existing people and faces if forced', async () => {
jobMock.getJobCounts.mockResolvedValue({ active: 1, waiting: 0, paused: 0, completed: 0, failed: 0, delayed: 0 });
personMock.getAll.mockResolvedValue({
items: [faceStub.face1.person],
hasNextPage: false,
@ -727,7 +738,7 @@ describe(PersonService.name, () => {
modelName: 'buffalo_l',
},
);
expect(personMock.createFace).not.toHaveBeenCalled();
expect(personMock.createFaces).not.toHaveBeenCalled();
expect(jobMock.queue).not.toHaveBeenCalled();
expect(jobMock.queueAll).not.toHaveBeenCalled();
@ -738,13 +749,12 @@ describe(PersonService.name, () => {
expect(assetMock.upsertJobStatus.mock.calls[0][0].facesRecognizedAt?.getTime()).toBeGreaterThan(start);
});
it('should create a face with no person', async () => {
it('should create a face with no person and queue recognition job', async () => {
personMock.createFaces.mockResolvedValue([faceStub.face1.id]);
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
smartInfoMock.searchFaces.mockResolvedValue([{ face: faceStub.face1, distance: 0.7 }]);
assetMock.getByIds.mockResolvedValue([assetStub.image]);
await sut.handleDetectFaces({ id: assetStub.image.id });
expect(personMock.createFace).toHaveBeenCalledWith({
const face = {
assetId: 'asset-id',
embedding: [1, 2, 3, 4],
boundingBoxX1: 100,
@ -753,7 +763,14 @@ describe(PersonService.name, () => {
boundingBoxY2: 200,
imageHeight: 500,
imageWidth: 400,
});
};
await sut.handleDetectFaces({ id: assetStub.image.id });
expect(personMock.createFaces).toHaveBeenCalledWith([face]);
expect(jobMock.queueAll).toHaveBeenCalledWith([
{ name: JobName.FACIAL_RECOGNITION, data: { id: faceStub.face1.id } },
]);
expect(personMock.reassignFace).not.toHaveBeenCalled();
expect(personMock.reassignFaces).not.toHaveBeenCalled();
});
@ -767,7 +784,7 @@ describe(PersonService.name, () => {
expect(personMock.reassignFaces).not.toHaveBeenCalled();
expect(personMock.create).not.toHaveBeenCalled();
expect(personMock.createFace).not.toHaveBeenCalled();
expect(personMock.createFaces).not.toHaveBeenCalled();
});
it('should return true if face already has an assigned person', async () => {
@ -777,7 +794,7 @@ describe(PersonService.name, () => {
expect(personMock.reassignFaces).not.toHaveBeenCalled();
expect(personMock.create).not.toHaveBeenCalled();
expect(personMock.createFace).not.toHaveBeenCalled();
expect(personMock.createFaces).not.toHaveBeenCalled();
});
it('should match existing person', async () => {

View File

@ -332,8 +332,10 @@ export class PersonService {
this.logger.debug(`${faces.length} faces detected in ${asset.resizePath}`);
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` })));
for (const face of faces) {
const mappedFace = {
if (faces.length) {
await this.jobRepository.queue({ name: JobName.QUEUE_FACIAL_RECOGNITION, data: { force: false } });
const mappedFaces = faces.map((face) => ({
assetId: asset.id,
embedding: face.embedding,
imageHeight: face.imageHeight,
@ -342,9 +344,10 @@ export class PersonService {
boundingBoxX2: face.boundingBox.x2,
boundingBoxY1: face.boundingBox.y1,
boundingBoxY2: face.boundingBox.y2,
};
}));
await this.repository.createFace(mappedFace);
const faceIds = await this.repository.createFaces(mappedFaces);
await this.jobRepository.queueAll(faceIds.map((id) => ({ name: JobName.FACIAL_RECOGNITION, data: { id } })));
}
await this.assetRepository.upsertJobStatus({
@ -362,9 +365,15 @@ export class PersonService {
}
await this.jobRepository.waitForQueueCompletion(QueueName.THUMBNAIL_GENERATION, QueueName.FACE_DETECTION);
const { waiting } = await this.jobRepository.getJobCounts(QueueName.FACIAL_RECOGNITION);
if (force) {
await this.deleteAllPeople();
} else if (waiting) {
this.logger.debug(
`Skipping facial recognition queueing because ${waiting} job${waiting > 1 ? 's are' : ' is'} already queued`,
);
return true;
}
const facePagination = usePagination(JOBS_ASSET_PAGINATION_SIZE, (pagination) =>

View File

@ -38,7 +38,7 @@ export interface IPersonRepository {
getAssets(personId: string): Promise<AssetEntity[]>;
create(entity: Partial<PersonEntity>): Promise<PersonEntity>;
createFace(entity: Partial<AssetFaceEntity>): Promise<void>;
createFaces(entities: Partial<AssetFaceEntity>[]): Promise<string[]>;
delete(entities: PersonEntity[]): Promise<void>;
deleteAll(): Promise<void>;
deleteAllFaces(): Promise<void>;

View File

@ -215,11 +215,11 @@ export class PersonRepository implements IPersonRepository {
return this.personRepository.save(entity);
}
async createFace(entity: AssetFaceEntity): Promise<void> {
if (!entity.embedding) {
throw new Error('Embedding is required to create a face');
}
await this.assetFaceRepository.insert({ ...entity, embedding: () => asVector(entity.embedding, true) });
async createFaces(entities: AssetFaceEntity[]): Promise<string[]> {
const res = await this.assetFaceRepository.insert(
entities.map((entity) => ({ ...entity, embedding: () => asVector(entity.embedding, true) })),
);
return res.identifiers.map((row) => row.id);
}
async update(entity: Partial<PersonEntity>): Promise<PersonEntity> {

View File

@ -22,7 +22,7 @@ export const newPersonRepositoryMock = (): jest.Mocked<IPersonRepository> => {
getRandomFace: jest.fn(),
reassignFaces: jest.fn(),
createFace: jest.fn(),
createFaces: jest.fn(),
getFaces: jest.fn(),
reassignFace: jest.fn(),
getFaceById: jest.fn(),