1
0
mirror of https://github.com/immich-app/immich.git synced 2025-01-12 15:32:36 +02:00

feat(server): link via profile.sub (#1055)

This commit is contained in:
Jason Rasmussen 2022-12-03 22:59:24 -05:00 committed by GitHub
parent 424b11cf50
commit 99854e90be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 64 additions and 3 deletions

View File

@ -39,6 +39,7 @@ describe('AuthService', () => {
userRepositoryMock = { userRepositoryMock = {
get: jest.fn(), get: jest.fn(),
getAdmin: jest.fn(), getAdmin: jest.fn(),
getByOAuthId: jest.fn(),
getByEmail: jest.fn(), getByEmail: jest.fn(),
getList: jest.fn(), getList: jest.fn(),
create: jest.fn(), create: jest.fn(),

View File

@ -20,12 +20,14 @@ const mockConfig = (config: Partial<OAuthConfig>) => {
}; };
const email = 'user@immich.com'; const email = 'user@immich.com';
const sub = 'my-auth-user-sub';
const user = { const user = {
id: 'user', id: 'user',
email, email,
firstName: 'user', firstName: 'user',
lastName: 'imimch', lastName: 'imimch',
oauthId: '',
} as UserEntity; } as UserEntity;
const loginResponse = { const loginResponse = {
@ -53,13 +55,14 @@ describe('OAuthService', () => {
authorizationUrl: jest.fn().mockReturnValue('http://authorization-url'), authorizationUrl: jest.fn().mockReturnValue('http://authorization-url'),
callbackParams: jest.fn().mockReturnValue({ state: 'state' }), callbackParams: jest.fn().mockReturnValue({ state: 'state' }),
callback: jest.fn().mockReturnValue({ access_token: 'access-token' }), callback: jest.fn().mockReturnValue({ access_token: 'access-token' }),
userinfo: jest.fn().mockResolvedValue({ email }), userinfo: jest.fn().mockResolvedValue({ sub, email }),
}), }),
} as any); } as any);
userRepositoryMock = { userRepositoryMock = {
get: jest.fn(), get: jest.fn(),
getAdmin: jest.fn(), getAdmin: jest.fn(),
getByOAuthId: jest.fn(),
getByEmail: jest.fn(), getByEmail: jest.fn(),
getList: jest.fn(), getList: jest.fn(),
create: jest.fn(), create: jest.fn(),
@ -132,6 +135,26 @@ describe('OAuthService', () => {
expect(userRepositoryMock.getByEmail).toHaveBeenCalledTimes(1); expect(userRepositoryMock.getByEmail).toHaveBeenCalledTimes(1);
}); });
it('should link an existing user', async () => {
configServiceMock.get.mockImplementation(
mockConfig({
OAUTH_ENABLED: true,
OAUTH_AUTO_REGISTER: false,
}),
);
sut = new OAuthService(immichJwtServiceMock, configServiceMock, userRepositoryMock);
jest.spyOn(sut['logger'], 'debug').mockImplementation(() => null);
jest.spyOn(sut['logger'], 'warn').mockImplementation(() => null);
userRepositoryMock.getByEmail.mockResolvedValue(user);
userRepositoryMock.update.mockResolvedValue(user);
immichJwtServiceMock.createLoginResponse.mockResolvedValue(loginResponse);
await expect(sut.callback({ url: 'http://immich/auth/login?code=abc123' })).resolves.toEqual(loginResponse);
expect(userRepositoryMock.getByEmail).toHaveBeenCalledTimes(1);
expect(userRepositoryMock.update).toHaveBeenCalledWith(user.id, { oauthId: sub });
});
it('should allow auto registering by default', async () => { it('should allow auto registering by default', async () => {
configServiceMock.get.mockImplementation(mockConfig({ OAUTH_ENABLED: true })); configServiceMock.get.mockImplementation(mockConfig({ OAUTH_ENABLED: true }));
sut = new OAuthService(immichJwtServiceMock, configServiceMock, userRepositoryMock); sut = new OAuthService(immichJwtServiceMock, configServiceMock, userRepositoryMock);

View File

@ -63,8 +63,17 @@ export class OAuthService {
const profile = await client.userinfo<OAuthProfile>(tokens.access_token || ''); const profile = await client.userinfo<OAuthProfile>(tokens.access_token || '');
this.logger.debug(`Logging in with OAuth: ${JSON.stringify(profile)}`); this.logger.debug(`Logging in with OAuth: ${JSON.stringify(profile)}`);
let user = await this.userRepository.getByEmail(profile.email); let user = await this.userRepository.getByOAuthId(profile.sub);
// link existing user
if (!user) {
const emailUser = await this.userRepository.getByEmail(profile.email);
if (emailUser) {
user = await this.userRepository.update(emailUser.id, { oauthId: profile.sub });
}
}
// register new user
if (!user) { if (!user) {
if (!this.autoRegister) { if (!this.autoRegister) {
this.logger.warn( this.logger.warn(
@ -73,11 +82,12 @@ export class OAuthService {
throw new BadRequestException(`User does not exist and auto registering is disabled.`); throw new BadRequestException(`User does not exist and auto registering is disabled.`);
} }
this.logger.log(`Registering new user: ${profile.email}`); this.logger.log(`Registering new user: ${profile.email}/${profile.sub}`);
user = await this.userRepository.create({ user = await this.userRepository.create({
firstName: profile.given_name || '', firstName: profile.given_name || '',
lastName: profile.family_name || '', lastName: profile.family_name || '',
email: profile.email, email: profile.email,
oauthId: profile.sub,
}); });
} }

View File

@ -8,6 +8,7 @@ export interface IUserRepository {
get(id: string, withDeleted?: boolean): Promise<UserEntity | null>; get(id: string, withDeleted?: boolean): Promise<UserEntity | null>;
getAdmin(): Promise<UserEntity | null>; getAdmin(): Promise<UserEntity | null>;
getByEmail(email: string, withPassword?: boolean): Promise<UserEntity | null>; getByEmail(email: string, withPassword?: boolean): Promise<UserEntity | null>;
getByOAuthId(oauthId: string): Promise<UserEntity | null>;
getList(filter?: { excludeId?: string }): Promise<UserEntity[]>; getList(filter?: { excludeId?: string }): Promise<UserEntity[]>;
create(user: Partial<UserEntity>): Promise<UserEntity>; create(user: Partial<UserEntity>): Promise<UserEntity>;
update(id: string, user: Partial<UserEntity>): Promise<UserEntity>; update(id: string, user: Partial<UserEntity>): Promise<UserEntity>;
@ -41,6 +42,10 @@ export class UserRepository implements IUserRepository {
return builder.getOne(); return builder.getOne();
} }
public async getByOAuthId(oauthId: string): Promise<UserEntity | null> {
return this.userRepository.findOne({ where: { oauthId } });
}
public async getList({ excludeId }: { excludeId?: string } = {}): Promise<UserEntity[]> { public async getList({ excludeId }: { excludeId?: string } = {}): Promise<UserEntity[]> {
if (!excludeId) { if (!excludeId) {
return this.userRepository.find(); // TODO: this should also be ordered the same as below return this.userRepository.find(); // TODO: this should also be ordered the same as below

View File

@ -27,6 +27,7 @@ describe('UserService', () => {
firstName: 'admin_first_name', firstName: 'admin_first_name',
lastName: 'admin_last_name', lastName: 'admin_last_name',
isAdmin: true, isAdmin: true,
oauthId: '',
shouldChangePassword: false, shouldChangePassword: false,
profileImagePath: '', profileImagePath: '',
createdAt: '2021-01-01', createdAt: '2021-01-01',
@ -40,6 +41,7 @@ describe('UserService', () => {
firstName: 'immich_first_name', firstName: 'immich_first_name',
lastName: 'immich_last_name', lastName: 'immich_last_name',
isAdmin: false, isAdmin: false,
oauthId: '',
shouldChangePassword: false, shouldChangePassword: false,
profileImagePath: '', profileImagePath: '',
createdAt: '2021-01-01', createdAt: '2021-01-01',
@ -53,6 +55,7 @@ describe('UserService', () => {
firstName: 'updated_immich_first_name', firstName: 'updated_immich_first_name',
lastName: 'updated_immich_last_name', lastName: 'updated_immich_last_name',
isAdmin: false, isAdmin: false,
oauthId: '',
shouldChangePassword: true, shouldChangePassword: true,
profileImagePath: '', profileImagePath: '',
createdAt: '2021-01-01', createdAt: '2021-01-01',

View File

@ -52,6 +52,7 @@ describe('ImmichJwtService', () => {
email: 'test@immich.com', email: 'test@immich.com',
password: 'changeme', password: 'changeme',
salt: '123', salt: '123',
oauthId: '',
profileImagePath: '', profileImagePath: '',
shouldChangePassword: false, shouldChangePassword: false,
createdAt: 'today', createdAt: 'today',

View File

@ -20,6 +20,7 @@ export function newUserRepositoryMock(): jest.Mocked<IUserRepository> {
get: jest.fn(), get: jest.fn(),
getAdmin: jest.fn(), getAdmin: jest.fn(),
getByEmail: jest.fn(), getByEmail: jest.fn(),
getByOAuthId: jest.fn(),
getList: jest.fn(), getList: jest.fn(),
create: jest.fn(), create: jest.fn(),
update: jest.fn(), update: jest.fn(),

View File

@ -23,6 +23,9 @@ export class UserEntity {
@Column({ default: '', select: false }) @Column({ default: '', select: false })
salt?: string; salt?: string;
@Column({ default: '', select: false })
oauthId!: string;
@Column({ default: '' }) @Column({ default: '' })
profileImagePath!: string; profileImagePath!: string;

View File

@ -0,0 +1,14 @@
import { MigrationInterface, QueryRunner } from "typeorm";
export class OAuthId1670104716264 implements MigrationInterface {
name = 'OAuthId1670104716264'
public async up(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(`ALTER TABLE "users" ADD "oauthId" character varying NOT NULL DEFAULT ''`);
}
public async down(queryRunner: QueryRunner): Promise<void> {
await queryRunner.query(`ALTER TABLE "users" DROP COLUMN "oauthId"`);
}
}