1
0
mirror of https://github.com/immich-app/immich.git synced 2025-01-13 15:35:15 +02:00

feat(server): search across own+partner assets (#5966)

* feat(server): search across own+partner assets

* generate sql

* fix sql parameter
This commit is contained in:
Fynn Petersen-Frey 2024-01-01 23:25:22 +01:00 committed by GitHub
parent 2688e05033
commit cc7ba3c21a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 44 additions and 21 deletions

View File

@ -334,7 +334,7 @@ export class PersonService {
for (const { embedding, ...rest } of faces) { for (const { embedding, ...rest } of faces) {
const matches = await this.smartInfoRepository.searchFaces({ const matches = await this.smartInfoRepository.searchFaces({
ownerId: asset.ownerId, userIds: [asset.ownerId],
embedding, embedding,
numResults: 1, numResults: 1,
maxDistance: machineLearning.facialRecognition.maxDistance, maxDistance: machineLearning.facialRecognition.maxDistance,

View File

@ -199,5 +199,5 @@ export interface IAssetRepository {
search(options: AssetSearchOptions): Promise<AssetEntity[]>; search(options: AssetSearchOptions): Promise<AssetEntity[]>;
getAssetIdByCity(userId: string, options: AssetExploreFieldOptions): Promise<SearchExploreItem<string>>; getAssetIdByCity(userId: string, options: AssetExploreFieldOptions): Promise<SearchExploreItem<string>>;
getAssetIdByTag(userId: string, options: AssetExploreFieldOptions): Promise<SearchExploreItem<string>>; getAssetIdByTag(userId: string, options: AssetExploreFieldOptions): Promise<SearchExploreItem<string>>;
searchMetadata(query: string, userId: string, options: MetadataSearchOptions): Promise<AssetEntity[]>; searchMetadata(query: string, userIds: string[], options: MetadataSearchOptions): Promise<AssetEntity[]>;
} }

View File

@ -5,7 +5,7 @@ export const ISmartInfoRepository = 'ISmartInfoRepository';
export type Embedding = number[]; export type Embedding = number[];
export interface EmbeddingSearch { export interface EmbeddingSearch {
ownerId: string; userIds: string[];
embedding: Embedding; embedding: Embedding;
numResults: number; numResults: number;
maxDistance?: number; maxDistance?: number;

View File

@ -4,6 +4,7 @@ import {
authStub, authStub,
newAssetRepositoryMock, newAssetRepositoryMock,
newMachineLearningRepositoryMock, newMachineLearningRepositoryMock,
newPartnerRepositoryMock,
newPersonRepositoryMock, newPersonRepositoryMock,
newSmartInfoRepositoryMock, newSmartInfoRepositoryMock,
newSystemConfigRepositoryMock, newSystemConfigRepositoryMock,
@ -13,6 +14,7 @@ import { mapAsset } from '../asset';
import { import {
IAssetRepository, IAssetRepository,
IMachineLearningRepository, IMachineLearningRepository,
IPartnerRepository,
IPersonRepository, IPersonRepository,
ISmartInfoRepository, ISmartInfoRepository,
ISystemConfigRepository, ISystemConfigRepository,
@ -29,6 +31,7 @@ describe(SearchService.name, () => {
let machineMock: jest.Mocked<IMachineLearningRepository>; let machineMock: jest.Mocked<IMachineLearningRepository>;
let personMock: jest.Mocked<IPersonRepository>; let personMock: jest.Mocked<IPersonRepository>;
let smartInfoMock: jest.Mocked<ISmartInfoRepository>; let smartInfoMock: jest.Mocked<ISmartInfoRepository>;
let partnerMock: jest.Mocked<IPartnerRepository>;
beforeEach(() => { beforeEach(() => {
assetMock = newAssetRepositoryMock(); assetMock = newAssetRepositoryMock();
@ -36,7 +39,8 @@ describe(SearchService.name, () => {
machineMock = newMachineLearningRepositoryMock(); machineMock = newMachineLearningRepositoryMock();
personMock = newPersonRepositoryMock(); personMock = newPersonRepositoryMock();
smartInfoMock = newSmartInfoRepositoryMock(); smartInfoMock = newSmartInfoRepositoryMock();
sut = new SearchService(configMock, machineMock, personMock, smartInfoMock, assetMock); partnerMock = newPartnerRepositoryMock();
sut = new SearchService(configMock, machineMock, personMock, smartInfoMock, assetMock, partnerMock);
}); });
it('should work', () => { it('should work', () => {
@ -87,6 +91,7 @@ describe(SearchService.name, () => {
it('should search by metadata if `clip` option is false', async () => { it('should search by metadata if `clip` option is false', async () => {
const dto: SearchDto = { q: 'test query', clip: false }; const dto: SearchDto = { q: 'test query', clip: false };
assetMock.searchMetadata.mockResolvedValueOnce([assetStub.image]); assetMock.searchMetadata.mockResolvedValueOnce([assetStub.image]);
partnerMock.getAll.mockResolvedValueOnce([]);
const expectedResponse = { const expectedResponse = {
albums: { albums: {
total: 0, total: 0,
@ -105,7 +110,7 @@ describe(SearchService.name, () => {
const result = await sut.search(authStub.user1, dto); const result = await sut.search(authStub.user1, dto);
expect(result).toEqual(expectedResponse); expect(result).toEqual(expectedResponse);
expect(assetMock.searchMetadata).toHaveBeenCalledWith(dto.q, authStub.user1.user.id, { numResults: 250 }); expect(assetMock.searchMetadata).toHaveBeenCalledWith(dto.q, [authStub.user1.user.id], { numResults: 250 });
expect(smartInfoMock.searchCLIP).not.toHaveBeenCalled(); expect(smartInfoMock.searchCLIP).not.toHaveBeenCalled();
}); });
@ -114,6 +119,7 @@ describe(SearchService.name, () => {
const embedding = [1, 2, 3]; const embedding = [1, 2, 3];
smartInfoMock.searchCLIP.mockResolvedValueOnce([assetStub.image]); smartInfoMock.searchCLIP.mockResolvedValueOnce([assetStub.image]);
machineMock.encodeText.mockResolvedValueOnce(embedding); machineMock.encodeText.mockResolvedValueOnce(embedding);
partnerMock.getAll.mockResolvedValueOnce([]);
const expectedResponse = { const expectedResponse = {
albums: { albums: {
total: 0, total: 0,
@ -133,7 +139,7 @@ describe(SearchService.name, () => {
expect(result).toEqual(expectedResponse); expect(result).toEqual(expectedResponse);
expect(smartInfoMock.searchCLIP).toHaveBeenCalledWith({ expect(smartInfoMock.searchCLIP).toHaveBeenCalledWith({
ownerId: authStub.user1.user.id, userIds: [authStub.user1.user.id],
embedding, embedding,
numResults: 100, numResults: 100,
}); });

View File

@ -7,6 +7,7 @@ import { PersonResponseDto } from '../person';
import { import {
IAssetRepository, IAssetRepository,
IMachineLearningRepository, IMachineLearningRepository,
IPartnerRepository,
IPersonRepository, IPersonRepository,
ISmartInfoRepository, ISmartInfoRepository,
ISystemConfigRepository, ISystemConfigRepository,
@ -28,6 +29,7 @@ export class SearchService {
@Inject(IPersonRepository) private personRepository: IPersonRepository, @Inject(IPersonRepository) private personRepository: IPersonRepository,
@Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository, @Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository,
@Inject(IAssetRepository) private assetRepository: IAssetRepository, @Inject(IAssetRepository) private assetRepository: IAssetRepository,
@Inject(IPartnerRepository) private partnerRepository: IPartnerRepository,
) { ) {
this.configCore = SystemConfigCore.create(configRepository); this.configCore = SystemConfigCore.create(configRepository);
} }
@ -64,6 +66,7 @@ export class SearchService {
throw new Error('CLIP is not enabled'); throw new Error('CLIP is not enabled');
} }
const strategy = dto.clip ? SearchStrategy.CLIP : SearchStrategy.TEXT; const strategy = dto.clip ? SearchStrategy.CLIP : SearchStrategy.TEXT;
const userIds = await this.getUserIdsToSearch(auth);
let assets: AssetEntity[] = []; let assets: AssetEntity[] = [];
@ -74,10 +77,10 @@ export class SearchService {
{ text: query }, { text: query },
machineLearning.clip, machineLearning.clip,
); );
assets = await this.smartInfoRepository.searchCLIP({ ownerId: auth.user.id, embedding, numResults: 100 }); assets = await this.smartInfoRepository.searchCLIP({ userIds: userIds, embedding, numResults: 100 });
break; break;
case SearchStrategy.TEXT: case SearchStrategy.TEXT:
assets = await this.assetRepository.searchMetadata(query, auth.user.id, { numResults: 250 }); assets = await this.assetRepository.searchMetadata(query, userIds, { numResults: 250 });
default: default:
break; break;
} }
@ -97,4 +100,14 @@ export class SearchService {
}, },
}; };
} }
private async getUserIdsToSearch(auth: AuthDto): Promise<string[]> {
const userIds: string[] = [auth.user.id];
const partners = await this.partnerRepository.getAll(auth.user.id);
const partnersIds = partners
.filter((partner) => partner.sharedBy && partner.inTimeline)
.map((partner) => partner.sharedById);
userIds.push(...partnersIds);
return userIds;
}
} }

View File

@ -804,10 +804,14 @@ export class AssetRepository implements IAssetRepository {
return builder; return builder;
} }
@GenerateSql({ params: [DummyValue.STRING, DummyValue.UUID, { numResults: 250 }] }) @GenerateSql({ params: [DummyValue.STRING, [DummyValue.UUID], { numResults: 250 }] })
async searchMetadata(query: string, ownerId: string, { numResults }: MetadataSearchOptions): Promise<AssetEntity[]> { async searchMetadata(
query: string,
userIds: string[],
{ numResults }: MetadataSearchOptions,
): Promise<AssetEntity[]> {
const rows = await this.getBuilder({ const rows = await this.getBuilder({
userIds: [ownerId], userIds: userIds,
exifInfo: false, exifInfo: false,
isArchived: false, isArchived: false,
}) })

View File

@ -41,9 +41,9 @@ export class SmartInfoRepository implements ISmartInfoRepository {
} }
@GenerateSql({ @GenerateSql({
params: [{ ownerId: DummyValue.UUID, embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }], params: [{ userIds: [DummyValue.UUID], embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }],
}) })
async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> { async searchCLIP({ userIds, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
if (!isValidInteger(numResults, { min: 1 })) { if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`); throw new Error(`Invalid value for 'numResults': ${numResults}`);
} }
@ -55,13 +55,13 @@ export class SmartInfoRepository implements ISmartInfoRepository {
results = await manager results = await manager
.createQueryBuilder(AssetEntity, 'a') .createQueryBuilder(AssetEntity, 'a')
.innerJoin('a.smartSearch', 's') .innerJoin('a.smartSearch', 's')
.where('a.ownerId = :ownerId') .where('a.ownerId IN (:...userIds )')
.andWhere('a.isVisible = true') .andWhere('a.isVisible = true')
.andWhere('a.isArchived = false') .andWhere('a.isArchived = false')
.andWhere('a.fileCreatedAt < NOW()') .andWhere('a.fileCreatedAt < NOW()')
.leftJoinAndSelect('a.exifInfo', 'e') .leftJoinAndSelect('a.exifInfo', 'e')
.orderBy('s.embedding <=> :embedding') .orderBy('s.embedding <=> :embedding')
.setParameters({ ownerId, embedding: asVector(embedding) }) .setParameters({ userIds, embedding: asVector(embedding) })
.limit(numResults) .limit(numResults)
.getMany(); .getMany();
}); });
@ -72,14 +72,14 @@ export class SmartInfoRepository implements ISmartInfoRepository {
@GenerateSql({ @GenerateSql({
params: [ params: [
{ {
ownerId: DummyValue.UUID, userIds: [DummyValue.UUID],
embedding: Array.from({ length: 512 }, Math.random), embedding: Array.from({ length: 512 }, Math.random),
numResults: 100, numResults: 100,
maxDistance: 0.6, maxDistance: 0.6,
}, },
], ],
}) })
async searchFaces({ ownerId, embedding, numResults, maxDistance }: EmbeddingSearch): Promise<AssetFaceEntity[]> { async searchFaces({ userIds, embedding, numResults, maxDistance }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
if (!isValidInteger(numResults, { min: 1 })) { if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`); throw new Error(`Invalid value for 'numResults': ${numResults}`);
} }
@ -91,9 +91,9 @@ export class SmartInfoRepository implements ISmartInfoRepository {
.createQueryBuilder(AssetFaceEntity, 'faces') .createQueryBuilder(AssetFaceEntity, 'faces')
.select('1 + (faces.embedding <=> :embedding)', 'distance') .select('1 + (faces.embedding <=> :embedding)', 'distance')
.innerJoin('faces.asset', 'asset') .innerJoin('faces.asset', 'asset')
.where('asset.ownerId = :ownerId') .where('asset.ownerId IN (:...userIds )')
.orderBy('1 + (faces.embedding <=> :embedding)') .orderBy('1 + (faces.embedding <=> :embedding)')
.setParameters({ ownerId, embedding: asVector(embedding) }) .setParameters({ userIds, embedding: asVector(embedding) })
.limit(numResults); .limit(numResults);
this.faceColumns.forEach((col) => cte.addSelect(`faces.${col}`, col)); this.faceColumns.forEach((col) => cte.addSelect(`faces.${col}`, col));

View File

@ -69,7 +69,7 @@ FROM
LEFT JOIN "exif" "e" ON "e"."assetId" = "a"."id" LEFT JOIN "exif" "e" ON "e"."assetId" = "a"."id"
WHERE WHERE
( (
"a"."ownerId" = $1 "a"."ownerId" IN ($1)
AND "a"."isVisible" = true AND "a"."isVisible" = true
AND "a"."isArchived" = false AND "a"."isArchived" = false
AND "a"."fileCreatedAt" < NOW() AND "a"."fileCreatedAt" < NOW()
@ -103,7 +103,7 @@ WITH
INNER JOIN "assets" "asset" ON "asset"."id" = "faces"."assetId" INNER JOIN "assets" "asset" ON "asset"."id" = "faces"."assetId"
AND ("asset"."deletedAt" IS NULL) AND ("asset"."deletedAt" IS NULL)
WHERE WHERE
"asset"."ownerId" = $2 "asset"."ownerId" IN ($2)
ORDER BY ORDER BY
1 + ("faces"."embedding" <= > $3) ASC 1 + ("faces"."embedding" <= > $3) ASC
LIMIT LIMIT