mirror of
https://github.com/immich-app/immich.git
synced 2024-12-26 10:50:29 +02:00
feat(server): search across own+partner assets (#5966)
* feat(server): search across own+partner assets * generate sql * fix sql parameter
This commit is contained in:
parent
2688e05033
commit
cc7ba3c21a
@ -334,7 +334,7 @@ export class PersonService {
|
|||||||
|
|
||||||
for (const { embedding, ...rest } of faces) {
|
for (const { embedding, ...rest } of faces) {
|
||||||
const matches = await this.smartInfoRepository.searchFaces({
|
const matches = await this.smartInfoRepository.searchFaces({
|
||||||
ownerId: asset.ownerId,
|
userIds: [asset.ownerId],
|
||||||
embedding,
|
embedding,
|
||||||
numResults: 1,
|
numResults: 1,
|
||||||
maxDistance: machineLearning.facialRecognition.maxDistance,
|
maxDistance: machineLearning.facialRecognition.maxDistance,
|
||||||
|
@ -199,5 +199,5 @@ export interface IAssetRepository {
|
|||||||
search(options: AssetSearchOptions): Promise<AssetEntity[]>;
|
search(options: AssetSearchOptions): Promise<AssetEntity[]>;
|
||||||
getAssetIdByCity(userId: string, options: AssetExploreFieldOptions): Promise<SearchExploreItem<string>>;
|
getAssetIdByCity(userId: string, options: AssetExploreFieldOptions): Promise<SearchExploreItem<string>>;
|
||||||
getAssetIdByTag(userId: string, options: AssetExploreFieldOptions): Promise<SearchExploreItem<string>>;
|
getAssetIdByTag(userId: string, options: AssetExploreFieldOptions): Promise<SearchExploreItem<string>>;
|
||||||
searchMetadata(query: string, userId: string, options: MetadataSearchOptions): Promise<AssetEntity[]>;
|
searchMetadata(query: string, userIds: string[], options: MetadataSearchOptions): Promise<AssetEntity[]>;
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@ export const ISmartInfoRepository = 'ISmartInfoRepository';
|
|||||||
export type Embedding = number[];
|
export type Embedding = number[];
|
||||||
|
|
||||||
export interface EmbeddingSearch {
|
export interface EmbeddingSearch {
|
||||||
ownerId: string;
|
userIds: string[];
|
||||||
embedding: Embedding;
|
embedding: Embedding;
|
||||||
numResults: number;
|
numResults: number;
|
||||||
maxDistance?: number;
|
maxDistance?: number;
|
||||||
|
@ -4,6 +4,7 @@ import {
|
|||||||
authStub,
|
authStub,
|
||||||
newAssetRepositoryMock,
|
newAssetRepositoryMock,
|
||||||
newMachineLearningRepositoryMock,
|
newMachineLearningRepositoryMock,
|
||||||
|
newPartnerRepositoryMock,
|
||||||
newPersonRepositoryMock,
|
newPersonRepositoryMock,
|
||||||
newSmartInfoRepositoryMock,
|
newSmartInfoRepositoryMock,
|
||||||
newSystemConfigRepositoryMock,
|
newSystemConfigRepositoryMock,
|
||||||
@ -13,6 +14,7 @@ import { mapAsset } from '../asset';
|
|||||||
import {
|
import {
|
||||||
IAssetRepository,
|
IAssetRepository,
|
||||||
IMachineLearningRepository,
|
IMachineLearningRepository,
|
||||||
|
IPartnerRepository,
|
||||||
IPersonRepository,
|
IPersonRepository,
|
||||||
ISmartInfoRepository,
|
ISmartInfoRepository,
|
||||||
ISystemConfigRepository,
|
ISystemConfigRepository,
|
||||||
@ -29,6 +31,7 @@ describe(SearchService.name, () => {
|
|||||||
let machineMock: jest.Mocked<IMachineLearningRepository>;
|
let machineMock: jest.Mocked<IMachineLearningRepository>;
|
||||||
let personMock: jest.Mocked<IPersonRepository>;
|
let personMock: jest.Mocked<IPersonRepository>;
|
||||||
let smartInfoMock: jest.Mocked<ISmartInfoRepository>;
|
let smartInfoMock: jest.Mocked<ISmartInfoRepository>;
|
||||||
|
let partnerMock: jest.Mocked<IPartnerRepository>;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
assetMock = newAssetRepositoryMock();
|
assetMock = newAssetRepositoryMock();
|
||||||
@ -36,7 +39,8 @@ describe(SearchService.name, () => {
|
|||||||
machineMock = newMachineLearningRepositoryMock();
|
machineMock = newMachineLearningRepositoryMock();
|
||||||
personMock = newPersonRepositoryMock();
|
personMock = newPersonRepositoryMock();
|
||||||
smartInfoMock = newSmartInfoRepositoryMock();
|
smartInfoMock = newSmartInfoRepositoryMock();
|
||||||
sut = new SearchService(configMock, machineMock, personMock, smartInfoMock, assetMock);
|
partnerMock = newPartnerRepositoryMock();
|
||||||
|
sut = new SearchService(configMock, machineMock, personMock, smartInfoMock, assetMock, partnerMock);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should work', () => {
|
it('should work', () => {
|
||||||
@ -87,6 +91,7 @@ describe(SearchService.name, () => {
|
|||||||
it('should search by metadata if `clip` option is false', async () => {
|
it('should search by metadata if `clip` option is false', async () => {
|
||||||
const dto: SearchDto = { q: 'test query', clip: false };
|
const dto: SearchDto = { q: 'test query', clip: false };
|
||||||
assetMock.searchMetadata.mockResolvedValueOnce([assetStub.image]);
|
assetMock.searchMetadata.mockResolvedValueOnce([assetStub.image]);
|
||||||
|
partnerMock.getAll.mockResolvedValueOnce([]);
|
||||||
const expectedResponse = {
|
const expectedResponse = {
|
||||||
albums: {
|
albums: {
|
||||||
total: 0,
|
total: 0,
|
||||||
@ -105,7 +110,7 @@ describe(SearchService.name, () => {
|
|||||||
const result = await sut.search(authStub.user1, dto);
|
const result = await sut.search(authStub.user1, dto);
|
||||||
|
|
||||||
expect(result).toEqual(expectedResponse);
|
expect(result).toEqual(expectedResponse);
|
||||||
expect(assetMock.searchMetadata).toHaveBeenCalledWith(dto.q, authStub.user1.user.id, { numResults: 250 });
|
expect(assetMock.searchMetadata).toHaveBeenCalledWith(dto.q, [authStub.user1.user.id], { numResults: 250 });
|
||||||
expect(smartInfoMock.searchCLIP).not.toHaveBeenCalled();
|
expect(smartInfoMock.searchCLIP).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -114,6 +119,7 @@ describe(SearchService.name, () => {
|
|||||||
const embedding = [1, 2, 3];
|
const embedding = [1, 2, 3];
|
||||||
smartInfoMock.searchCLIP.mockResolvedValueOnce([assetStub.image]);
|
smartInfoMock.searchCLIP.mockResolvedValueOnce([assetStub.image]);
|
||||||
machineMock.encodeText.mockResolvedValueOnce(embedding);
|
machineMock.encodeText.mockResolvedValueOnce(embedding);
|
||||||
|
partnerMock.getAll.mockResolvedValueOnce([]);
|
||||||
const expectedResponse = {
|
const expectedResponse = {
|
||||||
albums: {
|
albums: {
|
||||||
total: 0,
|
total: 0,
|
||||||
@ -133,7 +139,7 @@ describe(SearchService.name, () => {
|
|||||||
|
|
||||||
expect(result).toEqual(expectedResponse);
|
expect(result).toEqual(expectedResponse);
|
||||||
expect(smartInfoMock.searchCLIP).toHaveBeenCalledWith({
|
expect(smartInfoMock.searchCLIP).toHaveBeenCalledWith({
|
||||||
ownerId: authStub.user1.user.id,
|
userIds: [authStub.user1.user.id],
|
||||||
embedding,
|
embedding,
|
||||||
numResults: 100,
|
numResults: 100,
|
||||||
});
|
});
|
||||||
|
@ -7,6 +7,7 @@ import { PersonResponseDto } from '../person';
|
|||||||
import {
|
import {
|
||||||
IAssetRepository,
|
IAssetRepository,
|
||||||
IMachineLearningRepository,
|
IMachineLearningRepository,
|
||||||
|
IPartnerRepository,
|
||||||
IPersonRepository,
|
IPersonRepository,
|
||||||
ISmartInfoRepository,
|
ISmartInfoRepository,
|
||||||
ISystemConfigRepository,
|
ISystemConfigRepository,
|
||||||
@ -28,6 +29,7 @@ export class SearchService {
|
|||||||
@Inject(IPersonRepository) private personRepository: IPersonRepository,
|
@Inject(IPersonRepository) private personRepository: IPersonRepository,
|
||||||
@Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository,
|
@Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository,
|
||||||
@Inject(IAssetRepository) private assetRepository: IAssetRepository,
|
@Inject(IAssetRepository) private assetRepository: IAssetRepository,
|
||||||
|
@Inject(IPartnerRepository) private partnerRepository: IPartnerRepository,
|
||||||
) {
|
) {
|
||||||
this.configCore = SystemConfigCore.create(configRepository);
|
this.configCore = SystemConfigCore.create(configRepository);
|
||||||
}
|
}
|
||||||
@ -64,6 +66,7 @@ export class SearchService {
|
|||||||
throw new Error('CLIP is not enabled');
|
throw new Error('CLIP is not enabled');
|
||||||
}
|
}
|
||||||
const strategy = dto.clip ? SearchStrategy.CLIP : SearchStrategy.TEXT;
|
const strategy = dto.clip ? SearchStrategy.CLIP : SearchStrategy.TEXT;
|
||||||
|
const userIds = await this.getUserIdsToSearch(auth);
|
||||||
|
|
||||||
let assets: AssetEntity[] = [];
|
let assets: AssetEntity[] = [];
|
||||||
|
|
||||||
@ -74,10 +77,10 @@ export class SearchService {
|
|||||||
{ text: query },
|
{ text: query },
|
||||||
machineLearning.clip,
|
machineLearning.clip,
|
||||||
);
|
);
|
||||||
assets = await this.smartInfoRepository.searchCLIP({ ownerId: auth.user.id, embedding, numResults: 100 });
|
assets = await this.smartInfoRepository.searchCLIP({ userIds: userIds, embedding, numResults: 100 });
|
||||||
break;
|
break;
|
||||||
case SearchStrategy.TEXT:
|
case SearchStrategy.TEXT:
|
||||||
assets = await this.assetRepository.searchMetadata(query, auth.user.id, { numResults: 250 });
|
assets = await this.assetRepository.searchMetadata(query, userIds, { numResults: 250 });
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -97,4 +100,14 @@ export class SearchService {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async getUserIdsToSearch(auth: AuthDto): Promise<string[]> {
|
||||||
|
const userIds: string[] = [auth.user.id];
|
||||||
|
const partners = await this.partnerRepository.getAll(auth.user.id);
|
||||||
|
const partnersIds = partners
|
||||||
|
.filter((partner) => partner.sharedBy && partner.inTimeline)
|
||||||
|
.map((partner) => partner.sharedById);
|
||||||
|
userIds.push(...partnersIds);
|
||||||
|
return userIds;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -804,10 +804,14 @@ export class AssetRepository implements IAssetRepository {
|
|||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
||||||
@GenerateSql({ params: [DummyValue.STRING, DummyValue.UUID, { numResults: 250 }] })
|
@GenerateSql({ params: [DummyValue.STRING, [DummyValue.UUID], { numResults: 250 }] })
|
||||||
async searchMetadata(query: string, ownerId: string, { numResults }: MetadataSearchOptions): Promise<AssetEntity[]> {
|
async searchMetadata(
|
||||||
|
query: string,
|
||||||
|
userIds: string[],
|
||||||
|
{ numResults }: MetadataSearchOptions,
|
||||||
|
): Promise<AssetEntity[]> {
|
||||||
const rows = await this.getBuilder({
|
const rows = await this.getBuilder({
|
||||||
userIds: [ownerId],
|
userIds: userIds,
|
||||||
exifInfo: false,
|
exifInfo: false,
|
||||||
isArchived: false,
|
isArchived: false,
|
||||||
})
|
})
|
||||||
|
@ -41,9 +41,9 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@GenerateSql({
|
@GenerateSql({
|
||||||
params: [{ ownerId: DummyValue.UUID, embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }],
|
params: [{ userIds: [DummyValue.UUID], embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }],
|
||||||
})
|
})
|
||||||
async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
|
async searchCLIP({ userIds, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
|
||||||
if (!isValidInteger(numResults, { min: 1 })) {
|
if (!isValidInteger(numResults, { min: 1 })) {
|
||||||
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
||||||
}
|
}
|
||||||
@ -55,13 +55,13 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
|||||||
results = await manager
|
results = await manager
|
||||||
.createQueryBuilder(AssetEntity, 'a')
|
.createQueryBuilder(AssetEntity, 'a')
|
||||||
.innerJoin('a.smartSearch', 's')
|
.innerJoin('a.smartSearch', 's')
|
||||||
.where('a.ownerId = :ownerId')
|
.where('a.ownerId IN (:...userIds )')
|
||||||
.andWhere('a.isVisible = true')
|
.andWhere('a.isVisible = true')
|
||||||
.andWhere('a.isArchived = false')
|
.andWhere('a.isArchived = false')
|
||||||
.andWhere('a.fileCreatedAt < NOW()')
|
.andWhere('a.fileCreatedAt < NOW()')
|
||||||
.leftJoinAndSelect('a.exifInfo', 'e')
|
.leftJoinAndSelect('a.exifInfo', 'e')
|
||||||
.orderBy('s.embedding <=> :embedding')
|
.orderBy('s.embedding <=> :embedding')
|
||||||
.setParameters({ ownerId, embedding: asVector(embedding) })
|
.setParameters({ userIds, embedding: asVector(embedding) })
|
||||||
.limit(numResults)
|
.limit(numResults)
|
||||||
.getMany();
|
.getMany();
|
||||||
});
|
});
|
||||||
@ -72,14 +72,14 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
|||||||
@GenerateSql({
|
@GenerateSql({
|
||||||
params: [
|
params: [
|
||||||
{
|
{
|
||||||
ownerId: DummyValue.UUID,
|
userIds: [DummyValue.UUID],
|
||||||
embedding: Array.from({ length: 512 }, Math.random),
|
embedding: Array.from({ length: 512 }, Math.random),
|
||||||
numResults: 100,
|
numResults: 100,
|
||||||
maxDistance: 0.6,
|
maxDistance: 0.6,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
})
|
})
|
||||||
async searchFaces({ ownerId, embedding, numResults, maxDistance }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
|
async searchFaces({ userIds, embedding, numResults, maxDistance }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
|
||||||
if (!isValidInteger(numResults, { min: 1 })) {
|
if (!isValidInteger(numResults, { min: 1 })) {
|
||||||
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
||||||
}
|
}
|
||||||
@ -91,9 +91,9 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
|||||||
.createQueryBuilder(AssetFaceEntity, 'faces')
|
.createQueryBuilder(AssetFaceEntity, 'faces')
|
||||||
.select('1 + (faces.embedding <=> :embedding)', 'distance')
|
.select('1 + (faces.embedding <=> :embedding)', 'distance')
|
||||||
.innerJoin('faces.asset', 'asset')
|
.innerJoin('faces.asset', 'asset')
|
||||||
.where('asset.ownerId = :ownerId')
|
.where('asset.ownerId IN (:...userIds )')
|
||||||
.orderBy('1 + (faces.embedding <=> :embedding)')
|
.orderBy('1 + (faces.embedding <=> :embedding)')
|
||||||
.setParameters({ ownerId, embedding: asVector(embedding) })
|
.setParameters({ userIds, embedding: asVector(embedding) })
|
||||||
.limit(numResults);
|
.limit(numResults);
|
||||||
|
|
||||||
this.faceColumns.forEach((col) => cte.addSelect(`faces.${col}`, col));
|
this.faceColumns.forEach((col) => cte.addSelect(`faces.${col}`, col));
|
||||||
|
@ -69,7 +69,7 @@ FROM
|
|||||||
LEFT JOIN "exif" "e" ON "e"."assetId" = "a"."id"
|
LEFT JOIN "exif" "e" ON "e"."assetId" = "a"."id"
|
||||||
WHERE
|
WHERE
|
||||||
(
|
(
|
||||||
"a"."ownerId" = $1
|
"a"."ownerId" IN ($1)
|
||||||
AND "a"."isVisible" = true
|
AND "a"."isVisible" = true
|
||||||
AND "a"."isArchived" = false
|
AND "a"."isArchived" = false
|
||||||
AND "a"."fileCreatedAt" < NOW()
|
AND "a"."fileCreatedAt" < NOW()
|
||||||
@ -103,7 +103,7 @@ WITH
|
|||||||
INNER JOIN "assets" "asset" ON "asset"."id" = "faces"."assetId"
|
INNER JOIN "assets" "asset" ON "asset"."id" = "faces"."assetId"
|
||||||
AND ("asset"."deletedAt" IS NULL)
|
AND ("asset"."deletedAt" IS NULL)
|
||||||
WHERE
|
WHERE
|
||||||
"asset"."ownerId" = $2
|
"asset"."ownerId" IN ($2)
|
||||||
ORDER BY
|
ORDER BY
|
||||||
1 + ("faces"."embedding" <= > $3) ASC
|
1 + ("faces"."embedding" <= > $3) ASC
|
||||||
LIMIT
|
LIMIT
|
||||||
|
Loading…
Reference in New Issue
Block a user