package apis_test import ( "encoding/json" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tools/router" "github.com/pocketbase/pocketbase/tools/types" ) func TestEnrichRecords(t *testing.T) { t.Parallel() // mock test data // --- app, _ := tests.NewTestApp() defer app.Cleanup() freshRecords := func(records []*core.Record) []*core.Record { result := make([]*core.Record, len(records)) for i, r := range records { result[i] = r.Fresh() } return result } user, err := app.FindAuthRecordByEmail("users", "test@example.com") if err != nil { t.Fatal(err) } superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com") if err != nil { t.Fatal(err) } usersRecords, err := app.FindRecordsByIds("users", []string{"4q1xlclmfloku33", "bgs820n361vj1qd"}) if err != nil { t.Fatal(err) } nologinRecords, err := app.FindRecordsByIds("nologin", []string{"dc49k6jgejn40h3", "oos036e9xvqeexy"}) if err != nil { t.Fatal(err) } demo1Records, err := app.FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"}) if err != nil { t.Fatal(err) } demo5Records, err := app.FindRecordsByIds("demo5", []string{"la4y2w4o98acwuj", "qjeql998mtp1azp"}) if err != nil { t.Fatal(err) } // temp update the view rule to ensure that request context is set to "expand" demo4, err := app.FindCollectionByNameOrId("demo4") if err != nil { t.Fatal(err) } demo4.ViewRule = types.Pointer("@request.context = 'expand'") if err := app.Save(demo4); err != nil { t.Fatal(err) } // --- scenarios := []struct { name string auth *core.Record records []*core.Record queryExpand string defaultExpands []string expected []string notExpected []string }{ // email visibility checks { name: "[emailVisibility] guest", auth: nil, records: freshRecords(usersRecords), queryExpand: "", defaultExpands: nil, expected: []string{ `"customField":"123"`, `"test3@example.com"`, // emailVisibility=true }, notExpected: []string{ `"test@example.com"`, }, }, { name: "[emailVisibility] owner", auth: user, records: freshRecords(usersRecords), queryExpand: "", defaultExpands: nil, expected: []string{ `"customField":"123"`, `"test3@example.com"`, // emailVisibility=true `"test@example.com"`, // owner }, }, { name: "[emailVisibility] manager", auth: user, records: freshRecords(nologinRecords), queryExpand: "", defaultExpands: nil, expected: []string{ `"customField":"123"`, `"test3@example.com"`, `"test@example.com"`, }, }, { name: "[emailVisibility] superuser", auth: superuser, records: freshRecords(nologinRecords), queryExpand: "", defaultExpands: nil, expected: []string{ `"customField":"123"`, `"test3@example.com"`, `"test@example.com"`, }, }, { name: "[emailVisibility + expand] recursive auth rule checks (regular user)", auth: user, records: freshRecords(demo1Records), queryExpand: "", defaultExpands: []string{"rel_many"}, expected: []string{ `"customField":"123"`, `"expand":{"rel_many"`, `"expand":{}`, `"test@example.com"`, }, notExpected: []string{ `"id":"bgs820n361vj1qd"`, `"id":"oap640cot4yru2s"`, }, }, { name: "[emailVisibility + expand] recursive auth rule checks (superuser)", auth: superuser, records: freshRecords(demo1Records), queryExpand: "", defaultExpands: []string{"rel_many"}, expected: []string{ `"customField":"123"`, `"test@example.com"`, `"expand":{"rel_many"`, `"id":"bgs820n361vj1qd"`, `"id":"4q1xlclmfloku33"`, `"id":"oap640cot4yru2s"`, }, notExpected: []string{ `"expand":{}`, }, }, // expand checks { name: "[expand] guest (query)", auth: nil, records: freshRecords(usersRecords), queryExpand: "rel", defaultExpands: nil, expected: []string{ `"customField":"123"`, `"expand":{"rel"`, `"id":"llvuca81nly1qls"`, `"id":"0yxhwia2amd8gec"`, }, notExpected: []string{ `"expand":{}`, }, }, { name: "[expand] guest (default expands)", auth: nil, records: freshRecords(usersRecords), queryExpand: "", defaultExpands: []string{"rel"}, expected: []string{ `"customField":"123"`, `"expand":{"rel"`, `"id":"llvuca81nly1qls"`, `"id":"0yxhwia2amd8gec"`, }, }, { name: "[expand] @request.context=expand check", auth: nil, records: freshRecords(demo5Records), queryExpand: "rel_one", defaultExpands: []string{"rel_many"}, expected: []string{ `"customField":"123"`, `"expand":{}`, `"expand":{"`, `"rel_many":[{`, `"rel_one":{`, `"id":"i9naidtvr6qsgb4"`, `"id":"qzaqccwrmva4o1n"`, }, }, } for _, s := range scenarios { t.Run(s.name, func(t *testing.T) { app, _ := tests.NewTestApp() defer app.Cleanup() app.OnRecordEnrich().BindFunc(func(e *core.RecordEnrichEvent) error { e.Record.WithCustomData(true) e.Record.Set("customField", "123") return e.Next() }) req := httptest.NewRequest(http.MethodGet, "/?expand="+s.queryExpand, nil) rec := httptest.NewRecorder() requestEvent := new(core.RequestEvent) requestEvent.App = app requestEvent.Request = req requestEvent.Response = rec requestEvent.Auth = s.auth err := apis.EnrichRecords(requestEvent, s.records, s.defaultExpands...) if err != nil { t.Fatal(err) } raw, err := json.Marshal(s.records) if err != nil { t.Fatal(err) } rawStr := string(raw) for _, str := range s.expected { if !strings.Contains(rawStr, str) { t.Fatalf("Expected\n%q\nin\n%v", str, rawStr) } } for _, str := range s.notExpected { if strings.Contains(rawStr, str) { t.Fatalf("Didn't expected\n%q\nin\n%v", str, rawStr) } } }) } } func TestRecordAuthResponseAuthRuleCheck(t *testing.T) { app, _ := tests.NewTestApp() defer app.Cleanup() event := new(core.RequestEvent) event.App = app event.Request = httptest.NewRequest(http.MethodGet, "/", nil) event.Response = httptest.NewRecorder() user, err := app.FindAuthRecordByEmail("users", "test@example.com") if err != nil { t.Fatal(err) } scenarios := []struct { name string rule *string expectError bool }{ { "admin only rule", nil, true, }, { "empty rule", types.Pointer(""), false, }, { "false rule", types.Pointer("1=2"), true, }, { "true rule", types.Pointer("1=1"), false, }, } for _, s := range scenarios { t.Run(s.name, func(t *testing.T) { user.Collection().AuthRule = s.rule err := apis.RecordAuthResponse(event, user, "", nil) hasErr := err != nil if s.expectError != hasErr { t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err) } // in all cases login alert shouldn't be send because of the empty auth method if app.TestMailer.TotalSend() != 0 { t.Fatalf("Expected no emails send, got %d:\n%v", app.TestMailer.TotalSend(), app.TestMailer.LastMessage().HTML) } if !hasErr { return } apiErr, ok := err.(*router.ApiError) if !ok || apiErr == nil { t.Fatalf("Expected ApiError, got %v", apiErr) } if apiErr.Status != http.StatusForbidden { t.Fatalf("Expected ApiError.Status %d, got %d", http.StatusForbidden, apiErr.Status) } }) } } func TestRecordAuthResponseAuthAlertCheck(t *testing.T) { const testFingerprint = "d0f88d6c87767262ba8e93d6acccd784" scenarios := []struct { name string devices []string // mock existing device fingerprints expectDevices []string enabled bool expectEmail bool }{ { name: "first login", devices: nil, expectDevices: []string{testFingerprint}, enabled: true, expectEmail: false, }, { name: "existing device", devices: []string{"1", testFingerprint}, expectDevices: []string{"1", testFingerprint}, enabled: true, expectEmail: false, }, { name: "new device (< 5)", devices: []string{"1", "2"}, expectDevices: []string{"1", "2", testFingerprint}, enabled: true, expectEmail: true, }, { name: "new device (>= 5)", devices: []string{"1", "2", "3", "4", "5"}, expectDevices: []string{"2", "3", "4", "5", testFingerprint}, enabled: true, expectEmail: true, }, { name: "with disabled auth alert collection flag", devices: []string{"1", "2"}, expectDevices: []string{"1", "2"}, enabled: false, expectEmail: false, }, } for _, s := range scenarios { t.Run(s.name, func(t *testing.T) { app, _ := tests.NewTestApp() defer app.Cleanup() event := new(core.RequestEvent) event.App = app event.Request = httptest.NewRequest(http.MethodGet, "/", nil) event.Response = httptest.NewRecorder() user, err := app.FindAuthRecordByEmail("users", "test@example.com") if err != nil { t.Fatal(err) } user.Collection().MFA.Enabled = false user.Collection().AuthRule = types.Pointer("") user.Collection().AuthAlert.Enabled = s.enabled // ensure that there are no other auth origins err = app.DeleteAllAuthOriginsByRecord(user) if err != nil { t.Fatal(err) } mockCreated := types.NowDateTime().Add(-time.Duration(len(s.devices)+1) * time.Second) // insert the mock devices for _, fingerprint := range s.devices { mockCreated = mockCreated.Add(1 * time.Second) d := core.NewAuthOrigin(app) d.SetCollectionRef(user.Collection().Id) d.SetRecordRef(user.Id) d.SetFingerprint(fingerprint) d.SetRaw("created", mockCreated) d.SetRaw("updated", mockCreated) if err = app.Save(d); err != nil { t.Fatal(err) } } err = apis.RecordAuthResponse(event, user, "example", nil) if err != nil { t.Fatalf("Failed to resolve auth response: %v", err) } var expectTotalSend int if s.expectEmail { expectTotalSend = 1 } if total := app.TestMailer.TotalSend(); total != expectTotalSend { t.Fatalf("Expected %d sent emails, got %d", expectTotalSend, total) } devices, err := app.FindAllAuthOriginsByRecord(user) if err != nil { t.Fatalf("Failed to retrieve auth origins: %v", err) } if len(devices) != len(s.expectDevices) { t.Fatalf("Expected %d devices, got %d", len(s.expectDevices), len(devices)) } for _, fingerprint := range s.expectDevices { var exists bool fingerprints := make([]string, 0, len(devices)) for _, d := range devices { if d.Fingerprint() == fingerprint { exists = true break } fingerprints = append(fingerprints, d.Fingerprint()) } if !exists { t.Fatalf("Missing device with fingerprint %q:\n%v", fingerprint, fingerprints) } } }) } } func TestRecordAuthResponseMFACheck(t *testing.T) { app, _ := tests.NewTestApp() defer app.Cleanup() user, err := app.FindAuthRecordByEmail("users", "test@example.com") if err != nil { t.Fatal(err) } user2, err := app.FindAuthRecordByEmail("users", "test2@example.com") if err != nil { t.Fatal(err) } rec := httptest.NewRecorder() event := new(core.RequestEvent) event.App = app event.Request = httptest.NewRequest(http.MethodGet, "/", nil) event.Response = rec resetMFAs := func(authRecord *core.Record) { // ensure that mfa is enabled user.Collection().MFA.Enabled = true user.Collection().MFA.Duration = 5 user.Collection().MFA.Rule = "" mfas, err := app.FindAllMFAsByRecord(authRecord) if err != nil { t.Fatalf("Failed to retrieve mfas: %v", err) } for _, mfa := range mfas { if err := app.Delete(mfa); err != nil { t.Fatalf("Failed to delete mfa %q: %v", mfa.Id, err) } } // reset response rec = httptest.NewRecorder() event.Response = rec } totalMFAs := func(authRecord *core.Record) int { mfas, err := app.FindAllMFAsByRecord(authRecord) if err != nil { t.Fatalf("Failed to retrieve mfas: %v", err) } return len(mfas) } t.Run("no collection MFA enabled", func(t *testing.T) { resetMFAs(user) user.Collection().MFA.Enabled = false err = apis.RecordAuthResponse(event, user, "example", nil) if err != nil { t.Fatalf("Expected nil, got error: %v", err) } body := rec.Body.String() if strings.Contains(body, "mfaId") { t.Fatalf("Expected no mfaId in the response body, got\n%v", body) } if !strings.Contains(body, "token") { t.Fatalf("Expected auth token in the response body, got\n%v", body) } if total := totalMFAs(user); total != 0 { t.Fatalf("Expected no mfa records to be created, got %d", total) } }) t.Run("no explicit auth method", func(t *testing.T) { resetMFAs(user) err = apis.RecordAuthResponse(event, user, "", nil) if err != nil { t.Fatalf("Expected nil, got error: %v", err) } body := rec.Body.String() if strings.Contains(body, "mfaId") { t.Fatalf("Expected no mfaId in the response body, got\n%v", body) } if !strings.Contains(body, "token") { t.Fatalf("Expected auth token in the response body, got\n%v", body) } if total := totalMFAs(user); total != 0 { t.Fatalf("Expected no mfa records to be created, got %d", total) } }) t.Run("no mfa wanted (mfa rule check failure)", func(t *testing.T) { resetMFAs(user) user.Collection().MFA.Rule = "1=2" err = apis.RecordAuthResponse(event, user, "example", nil) if err != nil { t.Fatalf("Expected nil, got error: %v", err) } body := rec.Body.String() if strings.Contains(body, "mfaId") { t.Fatalf("Expected no mfaId in the response body, got\n%v", body) } if !strings.Contains(body, "token") { t.Fatalf("Expected auth token in the response body, got\n%v", body) } if total := totalMFAs(user); total != 0 { t.Fatalf("Expected no mfa records to be created, got %d", total) } }) t.Run("mfa wanted (mfa rule check success)", func(t *testing.T) { resetMFAs(user) user.Collection().MFA.Rule = "1=1" err = apis.RecordAuthResponse(event, user, "example", nil) if err != nil { t.Fatalf("Expected nil, got error: %v", err) } body := rec.Body.String() if !strings.Contains(body, "mfaId") { t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body) } if total := totalMFAs(user); total != 1 { t.Fatalf("Expected a single mfa record to be created, got %d", total) } }) t.Run("mfa first-time", func(t *testing.T) { resetMFAs(user) err = apis.RecordAuthResponse(event, user, "example", nil) if err != nil { t.Fatalf("Expected nil, got error: %v", err) } body := rec.Body.String() if !strings.Contains(body, "mfaId") { t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body) } if total := totalMFAs(user); total != 1 { t.Fatalf("Expected a single mfa record to be created, got %d", total) } }) t.Run("mfa second-time with the same auth method", func(t *testing.T) { resetMFAs(user) // create a dummy mfa record mfa := core.NewMFA(app) mfa.SetCollectionRef(user.Collection().Id) mfa.SetRecordRef(user.Id) mfa.SetMethod("example") if err = app.Save(mfa); err != nil { t.Fatal(err) } event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil) err = apis.RecordAuthResponse(event, user, "example", nil) if err == nil { t.Fatal("Expected error, got nil") } if total := totalMFAs(user); total != 1 { t.Fatalf("Expected only 1 mfa record (the existing one), got %d", total) } }) t.Run("mfa second-time with the different auth method (query param)", func(t *testing.T) { resetMFAs(user) // create a dummy mfa record mfa := core.NewMFA(app) mfa.SetCollectionRef(user.Collection().Id) mfa.SetRecordRef(user.Id) mfa.SetMethod("example1") if err = app.Save(mfa); err != nil { t.Fatal(err) } event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil) err = apis.RecordAuthResponse(event, user, "example2", nil) if err != nil { t.Fatalf("Expected nil, got error: %v", err) } if total := totalMFAs(user); total != 0 { t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total) } }) t.Run("mfa second-time with the different auth method (body param)", func(t *testing.T) { resetMFAs(user) // create a dummy mfa record mfa := core.NewMFA(app) mfa.SetCollectionRef(user.Collection().Id) mfa.SetRecordRef(user.Id) mfa.SetMethod("example1") if err = app.Save(mfa); err != nil { t.Fatal(err) } event.Request = httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"mfaId":"`+mfa.Id+`"}`)) event.Request.Header.Add("content-type", "application/json") err = apis.RecordAuthResponse(event, user, "example2", nil) if err != nil { t.Fatalf("Expected nil, got error: %v", err) } if total := totalMFAs(user); total != 0 { t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total) } }) t.Run("missing mfa", func(t *testing.T) { resetMFAs(user) event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId=missing", nil) err = apis.RecordAuthResponse(event, user, "example2", nil) if err == nil { t.Fatal("Expected error, got nil") } if total := totalMFAs(user); total != 0 { t.Fatalf("Expected 0 mfa records, got %d", total) } }) t.Run("expired mfa", func(t *testing.T) { resetMFAs(user) // create a dummy expired mfa record mfa := core.NewMFA(app) mfa.SetCollectionRef(user.Collection().Id) mfa.SetRecordRef(user.Id) mfa.SetMethod("example1") mfa.SetRaw("created", types.NowDateTime().Add(-1*time.Hour)) mfa.SetRaw("updated", types.NowDateTime().Add(-1*time.Hour)) if err = app.Save(mfa); err != nil { t.Fatal(err) } event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil) err = apis.RecordAuthResponse(event, user, "example2", nil) if err == nil { t.Fatal("Expected error, got nil") } if totalMFAs(user) != 0 { t.Fatal("Expected the expired mfa record to be deleted") } }) t.Run("mfa for different auth record", func(t *testing.T) { resetMFAs(user) // create a dummy expired mfa record mfa := core.NewMFA(app) mfa.SetCollectionRef(user2.Collection().Id) mfa.SetRecordRef(user2.Id) mfa.SetMethod("example1") if err = app.Save(mfa); err != nil { t.Fatal(err) } event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil) err = apis.RecordAuthResponse(event, user, "example2", nil) if err == nil { t.Fatal("Expected error, got nil") } if total := totalMFAs(user); total != 0 { t.Fatalf("Expected no user mfas, got %d", total) } if total := totalMFAs(user2); total != 1 { t.Fatalf("Expected only 1 user2 mfa, got %d", total) } }) }