diff --git a/server/apps/immich/src/api-v1/auth/auth.service.spec.ts b/server/apps/immich/src/api-v1/auth/auth.service.spec.ts index 22882d67fc..0fbd9c1ae1 100644 --- a/server/apps/immich/src/api-v1/auth/auth.service.spec.ts +++ b/server/apps/immich/src/api-v1/auth/auth.service.spec.ts @@ -39,6 +39,7 @@ describe('AuthService', () => { userRepositoryMock = { get: jest.fn(), getAdmin: jest.fn(), + getByOAuthId: jest.fn(), getByEmail: jest.fn(), getList: jest.fn(), create: jest.fn(), diff --git a/server/apps/immich/src/api-v1/oauth/oauth.service.spec.ts b/server/apps/immich/src/api-v1/oauth/oauth.service.spec.ts index 9934701d1d..d62d442084 100644 --- a/server/apps/immich/src/api-v1/oauth/oauth.service.spec.ts +++ b/server/apps/immich/src/api-v1/oauth/oauth.service.spec.ts @@ -20,12 +20,14 @@ const mockConfig = (config: Partial) => { }; const email = 'user@immich.com'; +const sub = 'my-auth-user-sub'; const user = { id: 'user', email, firstName: 'user', lastName: 'imimch', + oauthId: '', } as UserEntity; const loginResponse = { @@ -53,13 +55,14 @@ describe('OAuthService', () => { authorizationUrl: jest.fn().mockReturnValue('http://authorization-url'), callbackParams: jest.fn().mockReturnValue({ state: 'state' }), callback: jest.fn().mockReturnValue({ access_token: 'access-token' }), - userinfo: jest.fn().mockResolvedValue({ email }), + userinfo: jest.fn().mockResolvedValue({ sub, email }), }), } as any); userRepositoryMock = { get: jest.fn(), getAdmin: jest.fn(), + getByOAuthId: jest.fn(), getByEmail: jest.fn(), getList: jest.fn(), create: jest.fn(), @@ -132,6 +135,26 @@ describe('OAuthService', () => { expect(userRepositoryMock.getByEmail).toHaveBeenCalledTimes(1); }); + it('should link an existing user', async () => { + configServiceMock.get.mockImplementation( + mockConfig({ + OAUTH_ENABLED: true, + OAUTH_AUTO_REGISTER: false, + }), + ); + sut = new OAuthService(immichJwtServiceMock, configServiceMock, userRepositoryMock); + jest.spyOn(sut['logger'], 'debug').mockImplementation(() => null); + jest.spyOn(sut['logger'], 'warn').mockImplementation(() => null); + userRepositoryMock.getByEmail.mockResolvedValue(user); + userRepositoryMock.update.mockResolvedValue(user); + immichJwtServiceMock.createLoginResponse.mockResolvedValue(loginResponse); + + await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' })).resolves.toEqual(loginResponse); + + expect(userRepositoryMock.getByEmail).toHaveBeenCalledTimes(1); + expect(userRepositoryMock.update).toHaveBeenCalledWith(user.id, { oauthId: sub }); + }); + it('should allow auto registering by default', async () => { configServiceMock.get.mockImplementation(mockConfig({ OAUTH_ENABLED: true })); sut = new OAuthService(immichJwtServiceMock, configServiceMock, userRepositoryMock); diff --git a/server/apps/immich/src/api-v1/oauth/oauth.service.ts b/server/apps/immich/src/api-v1/oauth/oauth.service.ts index 74c642a11f..349b7d3cf4 100644 --- a/server/apps/immich/src/api-v1/oauth/oauth.service.ts +++ b/server/apps/immich/src/api-v1/oauth/oauth.service.ts @@ -63,8 +63,17 @@ export class OAuthService { const profile = await client.userinfo(tokens.access_token || ''); this.logger.debug(`Logging in with OAuth: ${JSON.stringify(profile)}`); - let user = await this.userRepository.getByEmail(profile.email); + let user = await this.userRepository.getByOAuthId(profile.sub); + // link existing user + if (!user) { + const emailUser = await this.userRepository.getByEmail(profile.email); + if (emailUser) { + user = await this.userRepository.update(emailUser.id, { oauthId: profile.sub }); + } + } + + // register new user if (!user) { if (!this.autoRegister) { this.logger.warn( @@ -73,11 +82,12 @@ export class OAuthService { throw new BadRequestException(`User does not exist and auto registering is disabled.`); } - this.logger.log(`Registering new user: ${profile.email}`); + this.logger.log(`Registering new user: ${profile.email}/${profile.sub}`); user = await this.userRepository.create({ firstName: profile.given_name || '', lastName: profile.family_name || '', email: profile.email, + oauthId: profile.sub, }); } diff --git a/server/apps/immich/src/api-v1/user/user-repository.ts b/server/apps/immich/src/api-v1/user/user-repository.ts index e9bccf1e79..574feed292 100644 --- a/server/apps/immich/src/api-v1/user/user-repository.ts +++ b/server/apps/immich/src/api-v1/user/user-repository.ts @@ -8,6 +8,7 @@ export interface IUserRepository { get(id: string, withDeleted?: boolean): Promise; getAdmin(): Promise; getByEmail(email: string, withPassword?: boolean): Promise; + getByOAuthId(oauthId: string): Promise; getList(filter?: { excludeId?: string }): Promise; create(user: Partial): Promise; update(id: string, user: Partial): Promise; @@ -41,6 +42,10 @@ export class UserRepository implements IUserRepository { return builder.getOne(); } + public async getByOAuthId(oauthId: string): Promise { + return this.userRepository.findOne({ where: { oauthId } }); + } + public async getList({ excludeId }: { excludeId?: string } = {}): Promise { if (!excludeId) { return this.userRepository.find(); // TODO: this should also be ordered the same as below diff --git a/server/apps/immich/src/api-v1/user/user.service.spec.ts b/server/apps/immich/src/api-v1/user/user.service.spec.ts index 8539e88f46..49d82938ab 100644 --- a/server/apps/immich/src/api-v1/user/user.service.spec.ts +++ b/server/apps/immich/src/api-v1/user/user.service.spec.ts @@ -27,6 +27,7 @@ describe('UserService', () => { firstName: 'admin_first_name', lastName: 'admin_last_name', isAdmin: true, + oauthId: '', shouldChangePassword: false, profileImagePath: '', createdAt: '2021-01-01', @@ -40,6 +41,7 @@ describe('UserService', () => { firstName: 'immich_first_name', lastName: 'immich_last_name', isAdmin: false, + oauthId: '', shouldChangePassword: false, profileImagePath: '', createdAt: '2021-01-01', @@ -53,6 +55,7 @@ describe('UserService', () => { firstName: 'updated_immich_first_name', lastName: 'updated_immich_last_name', isAdmin: false, + oauthId: '', shouldChangePassword: true, profileImagePath: '', createdAt: '2021-01-01', diff --git a/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.spec.ts b/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.spec.ts index e0ea9e0555..9936f9de35 100644 --- a/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.spec.ts +++ b/server/apps/immich/src/modules/immich-jwt/immich-jwt.service.spec.ts @@ -52,6 +52,7 @@ describe('ImmichJwtService', () => { email: 'test@immich.com', password: 'changeme', salt: '123', + oauthId: '', profileImagePath: '', shouldChangePassword: false, createdAt: 'today', diff --git a/server/apps/immich/test/test-utils.ts b/server/apps/immich/test/test-utils.ts index b9ea5f5f6a..fee852820e 100644 --- a/server/apps/immich/test/test-utils.ts +++ b/server/apps/immich/test/test-utils.ts @@ -20,6 +20,7 @@ export function newUserRepositoryMock(): jest.Mocked { get: jest.fn(), getAdmin: jest.fn(), getByEmail: jest.fn(), + getByOAuthId: jest.fn(), getList: jest.fn(), create: jest.fn(), update: jest.fn(), diff --git a/server/libs/database/src/entities/user.entity.ts b/server/libs/database/src/entities/user.entity.ts index d0f0b50c7d..c114101c64 100644 --- a/server/libs/database/src/entities/user.entity.ts +++ b/server/libs/database/src/entities/user.entity.ts @@ -23,6 +23,9 @@ export class UserEntity { @Column({ default: '', select: false }) salt?: string; + @Column({ default: '', select: false }) + oauthId!: string; + @Column({ default: '' }) profileImagePath!: string; diff --git a/server/libs/database/src/migrations/1670104716264-OAuthId.ts b/server/libs/database/src/migrations/1670104716264-OAuthId.ts new file mode 100644 index 0000000000..46b99a79d5 --- /dev/null +++ b/server/libs/database/src/migrations/1670104716264-OAuthId.ts @@ -0,0 +1,14 @@ +import { MigrationInterface, QueryRunner } from "typeorm"; + +export class OAuthId1670104716264 implements MigrationInterface { + name = 'OAuthId1670104716264' + + public async up(queryRunner: QueryRunner): Promise { + await queryRunner.query(`ALTER TABLE "users" ADD "oauthId" character varying NOT NULL DEFAULT ''`); + } + + public async down(queryRunner: QueryRunner): Promise { + await queryRunner.query(`ALTER TABLE "users" DROP COLUMN "oauthId"`); + } + +}