diff --git a/packages/server/src/utils/routeUtils.test.ts b/packages/server/src/utils/routeUtils.test.ts index bbd48add6..b2573f495 100644 --- a/packages/server/src/utils/routeUtils.test.ts +++ b/packages/server/src/utils/routeUtils.test.ts @@ -1,5 +1,6 @@ import { isValidOrigin, parseSubPath, splitItemPath } from './routeUtils'; import { ItemAddressingType } from '../db'; +import { RouteType } from './types'; describe('routeUtils', function() { @@ -41,7 +42,7 @@ describe('routeUtils', function() { } }); - it('should check the request origin', async function() { + it('should check the request origin for API URLs', async function() { const testCases: any[] = [ [ 'https://example.com', // Request origin @@ -79,7 +80,37 @@ describe('routeUtils', function() { for (const testCase of testCases) { const [requestOrigin, configBaseUrl, expected] = testCase; - expect(isValidOrigin(requestOrigin, configBaseUrl)).toBe(expected); + expect(isValidOrigin(requestOrigin, configBaseUrl, RouteType.Api)).toBe(expected); + } + }); + + it('should check the request origin for User Content URLs', async function() { + const testCases: any[] = [ + [ + 'https://usercontent.local', // Request origin + 'https://usercontent.local', // Config base URL + true, + ], + [ + 'http://usercontent.local', + 'https://usercontent.local', + true, + ], + [ + 'https://abcd.usercontent.local', + 'https://usercontent.local', + true, + ], + [ + 'https://bad.local', + 'https://usercontent.local', + false, + ], + ]; + + for (const testCase of testCases) { + const [requestOrigin, configBaseUrl, expected] = testCase; + expect(isValidOrigin(requestOrigin, configBaseUrl, RouteType.UserContent)).toBe(expected); } }); diff --git a/packages/server/src/utils/routeUtils.ts b/packages/server/src/utils/routeUtils.ts index 32472e1fa..14af2c48c 100644 --- a/packages/server/src/utils/routeUtils.ts +++ b/packages/server/src/utils/routeUtils.ts @@ -158,8 +158,12 @@ export function isValidOrigin(requestOrigin: string, endPointBaseUrl: string, ro const host2 = (new URL(endPointBaseUrl)).host; if (routeType === RouteType.UserContent) { + // At this point we only check if eg usercontent.com has been accessed + // with origin usercontent.com, or something.usercontent.com. We don't + // check that the user ID is valid or is event present. This will be + // done by the /share end point, which will also check that the share + // owner ID matches the origin URL. if (host1 === host2) return true; - const hostNoPrefix = host1.split('.').slice(1).join('.'); return hostNoPrefix === host2; } else { diff --git a/packages/server/src/utils/setupAppContext.ts b/packages/server/src/utils/setupAppContext.ts index e1685eb37..2d2441f54 100644 --- a/packages/server/src/utils/setupAppContext.ts +++ b/packages/server/src/utils/setupAppContext.ts @@ -23,7 +23,7 @@ async function setupServices(env: Env, models: Models, config: Config): Promise< return output; } -export default async function(appContext: AppContext, env: Env, dbConnection: DbConnection, appLogger: ()=> LoggerWrapper) { +export default async function(appContext: AppContext, env: Env, dbConnection: DbConnection, appLogger: ()=> LoggerWrapper): Promise { appContext.env = env; appContext.db = dbConnection; appContext.models = newModelFactory(appContext.db, config()); @@ -32,4 +32,6 @@ export default async function(appContext: AppContext, env: Env, dbConnection: Db appContext.routes = { ...routes }; if (env === Env.Prod) delete appContext.routes['api/debug']; + + return appContext; } diff --git a/packages/server/src/utils/testing/testUtils.ts b/packages/server/src/utils/testing/testUtils.ts index 0e374013c..1ef2e1b14 100644 --- a/packages/server/src/utils/testing/testUtils.ts +++ b/packages/server/src/utils/testing/testUtils.ts @@ -177,24 +177,24 @@ export async function koaAppContext(options: AppContextTestOptions = null): Prom // Set type to "any" because the Koa context has many properties and we // don't need to mock all of them. - const appContext: any = {}; - - await setupAppContext(appContext, Env.Dev, db_, () => appLogger); - - appContext.env = Env.Dev; - appContext.db = db_; - appContext.models = models(); - appContext.appLogger = () => appLogger; - appContext.path = req.url; - appContext.owner = owner; - appContext.cookies = new FakeCookies(); - appContext.request = new FakeRequest(req); - appContext.response = new FakeResponse(); - appContext.headers = { ...reqOptions.headers }; - appContext.req = req; - appContext.query = req.query; - appContext.method = req.method; - appContext.redirect = () => {}; + const appContext: any = { + ...await setupAppContext({} as any, Env.Dev, db_, () => appLogger), + env: Env.Dev, + db: db_, + models: models(), + appLogger: () => appLogger, + path: req.url, + owner: owner, + cookies: new FakeCookies(), + request: new FakeRequest(req), + response: new FakeResponse(), + headers: { ...reqOptions.headers }, + req: req, + query: req.query, + method: req.method, + redirect: () => {}, + URL: { origin: config().baseUrl }, + }; if (options.sessionId) { appContext.cookies.set('sessionId', options.sessionId);