From cc7ba3c21a863a6f165fc227f966661f44d84978 Mon Sep 17 00:00:00 2001 From: Fynn Petersen-Frey <10599762+fyfrey@users.noreply.github.com> Date: Mon, 1 Jan 2024 23:25:22 +0100 Subject: [PATCH] feat(server): search across own+partner assets (#5966) * feat(server): search across own+partner assets * generate sql * fix sql parameter --- server/src/domain/person/person.service.ts | 2 +- .../src/domain/repositories/asset.repository.ts | 2 +- .../repositories/smart-info.repository.ts | 2 +- server/src/domain/search/search.service.spec.ts | 12 +++++++++--- server/src/domain/search/search.service.ts | 17 +++++++++++++++-- .../src/infra/repositories/asset.repository.ts | 10 +++++++--- .../infra/repositories/smart-info.repository.ts | 16 ++++++++-------- server/src/infra/sql/smart.info.repository.sql | 4 ++-- 8 files changed, 44 insertions(+), 21 deletions(-) diff --git a/server/src/domain/person/person.service.ts b/server/src/domain/person/person.service.ts index 73fb37489b..928685da08 100644 --- a/server/src/domain/person/person.service.ts +++ b/server/src/domain/person/person.service.ts @@ -334,7 +334,7 @@ export class PersonService { for (const { embedding, ...rest } of faces) { const matches = await this.smartInfoRepository.searchFaces({ - ownerId: asset.ownerId, + userIds: [asset.ownerId], embedding, numResults: 1, maxDistance: machineLearning.facialRecognition.maxDistance, diff --git a/server/src/domain/repositories/asset.repository.ts b/server/src/domain/repositories/asset.repository.ts index 48f83de37b..296d4f40cf 100644 --- a/server/src/domain/repositories/asset.repository.ts +++ b/server/src/domain/repositories/asset.repository.ts @@ -199,5 +199,5 @@ export interface IAssetRepository { search(options: AssetSearchOptions): Promise; getAssetIdByCity(userId: string, options: AssetExploreFieldOptions): Promise>; getAssetIdByTag(userId: string, options: AssetExploreFieldOptions): Promise>; - searchMetadata(query: string, userId: string, options: MetadataSearchOptions): Promise; + searchMetadata(query: string, userIds: string[], options: MetadataSearchOptions): Promise; } diff --git a/server/src/domain/repositories/smart-info.repository.ts b/server/src/domain/repositories/smart-info.repository.ts index f26061f1b0..c35ec1d84c 100644 --- a/server/src/domain/repositories/smart-info.repository.ts +++ b/server/src/domain/repositories/smart-info.repository.ts @@ -5,7 +5,7 @@ export const ISmartInfoRepository = 'ISmartInfoRepository'; export type Embedding = number[]; export interface EmbeddingSearch { - ownerId: string; + userIds: string[]; embedding: Embedding; numResults: number; maxDistance?: number; diff --git a/server/src/domain/search/search.service.spec.ts b/server/src/domain/search/search.service.spec.ts index c6fe2abf66..9541d8f1d4 100644 --- a/server/src/domain/search/search.service.spec.ts +++ b/server/src/domain/search/search.service.spec.ts @@ -4,6 +4,7 @@ import { authStub, newAssetRepositoryMock, newMachineLearningRepositoryMock, + newPartnerRepositoryMock, newPersonRepositoryMock, newSmartInfoRepositoryMock, newSystemConfigRepositoryMock, @@ -13,6 +14,7 @@ import { mapAsset } from '../asset'; import { IAssetRepository, IMachineLearningRepository, + IPartnerRepository, IPersonRepository, ISmartInfoRepository, ISystemConfigRepository, @@ -29,6 +31,7 @@ describe(SearchService.name, () => { let machineMock: jest.Mocked; let personMock: jest.Mocked; let smartInfoMock: jest.Mocked; + let partnerMock: jest.Mocked; beforeEach(() => { assetMock = newAssetRepositoryMock(); @@ -36,7 +39,8 @@ describe(SearchService.name, () => { machineMock = newMachineLearningRepositoryMock(); personMock = newPersonRepositoryMock(); smartInfoMock = newSmartInfoRepositoryMock(); - sut = new SearchService(configMock, machineMock, personMock, smartInfoMock, assetMock); + partnerMock = newPartnerRepositoryMock(); + sut = new SearchService(configMock, machineMock, personMock, smartInfoMock, assetMock, partnerMock); }); it('should work', () => { @@ -87,6 +91,7 @@ describe(SearchService.name, () => { it('should search by metadata if `clip` option is false', async () => { const dto: SearchDto = { q: 'test query', clip: false }; assetMock.searchMetadata.mockResolvedValueOnce([assetStub.image]); + partnerMock.getAll.mockResolvedValueOnce([]); const expectedResponse = { albums: { total: 0, @@ -105,7 +110,7 @@ describe(SearchService.name, () => { const result = await sut.search(authStub.user1, dto); expect(result).toEqual(expectedResponse); - expect(assetMock.searchMetadata).toHaveBeenCalledWith(dto.q, authStub.user1.user.id, { numResults: 250 }); + expect(assetMock.searchMetadata).toHaveBeenCalledWith(dto.q, [authStub.user1.user.id], { numResults: 250 }); expect(smartInfoMock.searchCLIP).not.toHaveBeenCalled(); }); @@ -114,6 +119,7 @@ describe(SearchService.name, () => { const embedding = [1, 2, 3]; smartInfoMock.searchCLIP.mockResolvedValueOnce([assetStub.image]); machineMock.encodeText.mockResolvedValueOnce(embedding); + partnerMock.getAll.mockResolvedValueOnce([]); const expectedResponse = { albums: { total: 0, @@ -133,7 +139,7 @@ describe(SearchService.name, () => { expect(result).toEqual(expectedResponse); expect(smartInfoMock.searchCLIP).toHaveBeenCalledWith({ - ownerId: authStub.user1.user.id, + userIds: [authStub.user1.user.id], embedding, numResults: 100, }); diff --git a/server/src/domain/search/search.service.ts b/server/src/domain/search/search.service.ts index 0bceb43578..ef5a42fe46 100644 --- a/server/src/domain/search/search.service.ts +++ b/server/src/domain/search/search.service.ts @@ -7,6 +7,7 @@ import { PersonResponseDto } from '../person'; import { IAssetRepository, IMachineLearningRepository, + IPartnerRepository, IPersonRepository, ISmartInfoRepository, ISystemConfigRepository, @@ -28,6 +29,7 @@ export class SearchService { @Inject(IPersonRepository) private personRepository: IPersonRepository, @Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository, @Inject(IAssetRepository) private assetRepository: IAssetRepository, + @Inject(IPartnerRepository) private partnerRepository: IPartnerRepository, ) { this.configCore = SystemConfigCore.create(configRepository); } @@ -64,6 +66,7 @@ export class SearchService { throw new Error('CLIP is not enabled'); } const strategy = dto.clip ? SearchStrategy.CLIP : SearchStrategy.TEXT; + const userIds = await this.getUserIdsToSearch(auth); let assets: AssetEntity[] = []; @@ -74,10 +77,10 @@ export class SearchService { { text: query }, machineLearning.clip, ); - assets = await this.smartInfoRepository.searchCLIP({ ownerId: auth.user.id, embedding, numResults: 100 }); + assets = await this.smartInfoRepository.searchCLIP({ userIds: userIds, embedding, numResults: 100 }); break; case SearchStrategy.TEXT: - assets = await this.assetRepository.searchMetadata(query, auth.user.id, { numResults: 250 }); + assets = await this.assetRepository.searchMetadata(query, userIds, { numResults: 250 }); default: break; } @@ -97,4 +100,14 @@ export class SearchService { }, }; } + + private async getUserIdsToSearch(auth: AuthDto): Promise { + const userIds: string[] = [auth.user.id]; + const partners = await this.partnerRepository.getAll(auth.user.id); + const partnersIds = partners + .filter((partner) => partner.sharedBy && partner.inTimeline) + .map((partner) => partner.sharedById); + userIds.push(...partnersIds); + return userIds; + } } diff --git a/server/src/infra/repositories/asset.repository.ts b/server/src/infra/repositories/asset.repository.ts index ca3ca685a7..e56a2827bc 100644 --- a/server/src/infra/repositories/asset.repository.ts +++ b/server/src/infra/repositories/asset.repository.ts @@ -804,10 +804,14 @@ export class AssetRepository implements IAssetRepository { return builder; } - @GenerateSql({ params: [DummyValue.STRING, DummyValue.UUID, { numResults: 250 }] }) - async searchMetadata(query: string, ownerId: string, { numResults }: MetadataSearchOptions): Promise { + @GenerateSql({ params: [DummyValue.STRING, [DummyValue.UUID], { numResults: 250 }] }) + async searchMetadata( + query: string, + userIds: string[], + { numResults }: MetadataSearchOptions, + ): Promise { const rows = await this.getBuilder({ - userIds: [ownerId], + userIds: userIds, exifInfo: false, isArchived: false, }) diff --git a/server/src/infra/repositories/smart-info.repository.ts b/server/src/infra/repositories/smart-info.repository.ts index dc7d5a2db6..ae8bea2d09 100644 --- a/server/src/infra/repositories/smart-info.repository.ts +++ b/server/src/infra/repositories/smart-info.repository.ts @@ -41,9 +41,9 @@ export class SmartInfoRepository implements ISmartInfoRepository { } @GenerateSql({ - params: [{ ownerId: DummyValue.UUID, embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }], + params: [{ userIds: [DummyValue.UUID], embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }], }) - async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise { + async searchCLIP({ userIds, embedding, numResults }: EmbeddingSearch): Promise { if (!isValidInteger(numResults, { min: 1 })) { throw new Error(`Invalid value for 'numResults': ${numResults}`); } @@ -55,13 +55,13 @@ export class SmartInfoRepository implements ISmartInfoRepository { results = await manager .createQueryBuilder(AssetEntity, 'a') .innerJoin('a.smartSearch', 's') - .where('a.ownerId = :ownerId') + .where('a.ownerId IN (:...userIds )') .andWhere('a.isVisible = true') .andWhere('a.isArchived = false') .andWhere('a.fileCreatedAt < NOW()') .leftJoinAndSelect('a.exifInfo', 'e') .orderBy('s.embedding <=> :embedding') - .setParameters({ ownerId, embedding: asVector(embedding) }) + .setParameters({ userIds, embedding: asVector(embedding) }) .limit(numResults) .getMany(); }); @@ -72,14 +72,14 @@ export class SmartInfoRepository implements ISmartInfoRepository { @GenerateSql({ params: [ { - ownerId: DummyValue.UUID, + userIds: [DummyValue.UUID], embedding: Array.from({ length: 512 }, Math.random), numResults: 100, maxDistance: 0.6, }, ], }) - async searchFaces({ ownerId, embedding, numResults, maxDistance }: EmbeddingSearch): Promise { + async searchFaces({ userIds, embedding, numResults, maxDistance }: EmbeddingSearch): Promise { if (!isValidInteger(numResults, { min: 1 })) { throw new Error(`Invalid value for 'numResults': ${numResults}`); } @@ -91,9 +91,9 @@ export class SmartInfoRepository implements ISmartInfoRepository { .createQueryBuilder(AssetFaceEntity, 'faces') .select('1 + (faces.embedding <=> :embedding)', 'distance') .innerJoin('faces.asset', 'asset') - .where('asset.ownerId = :ownerId') + .where('asset.ownerId IN (:...userIds )') .orderBy('1 + (faces.embedding <=> :embedding)') - .setParameters({ ownerId, embedding: asVector(embedding) }) + .setParameters({ userIds, embedding: asVector(embedding) }) .limit(numResults); this.faceColumns.forEach((col) => cte.addSelect(`faces.${col}`, col)); diff --git a/server/src/infra/sql/smart.info.repository.sql b/server/src/infra/sql/smart.info.repository.sql index b2cb551a61..56a4738281 100644 --- a/server/src/infra/sql/smart.info.repository.sql +++ b/server/src/infra/sql/smart.info.repository.sql @@ -69,7 +69,7 @@ FROM LEFT JOIN "exif" "e" ON "e"."assetId" = "a"."id" WHERE ( - "a"."ownerId" = $1 + "a"."ownerId" IN ($1) AND "a"."isVisible" = true AND "a"."isArchived" = false AND "a"."fileCreatedAt" < NOW() @@ -103,7 +103,7 @@ WITH INNER JOIN "assets" "asset" ON "asset"."id" = "faces"."assetId" AND ("asset"."deletedAt" IS NULL) WHERE - "asset"."ownerId" = $2 + "asset"."ownerId" IN ($2) ORDER BY 1 + ("faces"."embedding" <= > $3) ASC LIMIT