mirror of
https://github.com/pocketbase/pocketbase.git
synced 2025-02-14 17:00:06 +02:00
762 lines
19 KiB
Go
762 lines
19 KiB
Go
package apis_test
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pocketbase/pocketbase/apis"
|
|
"github.com/pocketbase/pocketbase/core"
|
|
"github.com/pocketbase/pocketbase/tests"
|
|
"github.com/pocketbase/pocketbase/tools/router"
|
|
"github.com/pocketbase/pocketbase/tools/types"
|
|
)
|
|
|
|
func TestEnrichRecords(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// mock test data
|
|
// ---
|
|
app, _ := tests.NewTestApp()
|
|
defer app.Cleanup()
|
|
|
|
freshRecords := func(records []*core.Record) []*core.Record {
|
|
result := make([]*core.Record, len(records))
|
|
for i, r := range records {
|
|
result[i] = r.Fresh()
|
|
}
|
|
return result
|
|
}
|
|
|
|
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
usersRecords, err := app.FindRecordsByIds("users", []string{"4q1xlclmfloku33", "bgs820n361vj1qd"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
nologinRecords, err := app.FindRecordsByIds("nologin", []string{"dc49k6jgejn40h3", "oos036e9xvqeexy"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
demo1Records, err := app.FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
demo5Records, err := app.FindRecordsByIds("demo5", []string{"la4y2w4o98acwuj", "qjeql998mtp1azp"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// temp update the view rule to ensure that request context is set to "expand"
|
|
demo4, err := app.FindCollectionByNameOrId("demo4")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
demo4.ViewRule = types.Pointer("@request.context = 'expand'")
|
|
if err := app.Save(demo4); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// ---
|
|
|
|
scenarios := []struct {
|
|
name string
|
|
auth *core.Record
|
|
records []*core.Record
|
|
queryExpand string
|
|
defaultExpands []string
|
|
expected []string
|
|
notExpected []string
|
|
}{
|
|
// email visibility checks
|
|
{
|
|
name: "[emailVisibility] guest",
|
|
auth: nil,
|
|
records: freshRecords(usersRecords),
|
|
queryExpand: "",
|
|
defaultExpands: nil,
|
|
expected: []string{
|
|
`"customField":"123"`,
|
|
`"test3@example.com"`, // emailVisibility=true
|
|
},
|
|
notExpected: []string{
|
|
`"test@example.com"`,
|
|
},
|
|
},
|
|
{
|
|
name: "[emailVisibility] owner",
|
|
auth: user,
|
|
records: freshRecords(usersRecords),
|
|
queryExpand: "",
|
|
defaultExpands: nil,
|
|
expected: []string{
|
|
`"customField":"123"`,
|
|
`"test3@example.com"`, // emailVisibility=true
|
|
`"test@example.com"`, // owner
|
|
},
|
|
},
|
|
{
|
|
name: "[emailVisibility] manager",
|
|
auth: user,
|
|
records: freshRecords(nologinRecords),
|
|
queryExpand: "",
|
|
defaultExpands: nil,
|
|
expected: []string{
|
|
`"customField":"123"`,
|
|
`"test3@example.com"`,
|
|
`"test@example.com"`,
|
|
},
|
|
},
|
|
{
|
|
name: "[emailVisibility] superuser",
|
|
auth: superuser,
|
|
records: freshRecords(nologinRecords),
|
|
queryExpand: "",
|
|
defaultExpands: nil,
|
|
expected: []string{
|
|
`"customField":"123"`,
|
|
`"test3@example.com"`,
|
|
`"test@example.com"`,
|
|
},
|
|
},
|
|
{
|
|
name: "[emailVisibility + expand] recursive auth rule checks (regular user)",
|
|
auth: user,
|
|
records: freshRecords(demo1Records),
|
|
queryExpand: "",
|
|
defaultExpands: []string{"rel_many"},
|
|
expected: []string{
|
|
`"customField":"123"`,
|
|
`"expand":{"rel_many"`,
|
|
`"expand":{}`,
|
|
`"test@example.com"`,
|
|
},
|
|
notExpected: []string{
|
|
`"id":"bgs820n361vj1qd"`,
|
|
`"id":"oap640cot4yru2s"`,
|
|
},
|
|
},
|
|
{
|
|
name: "[emailVisibility + expand] recursive auth rule checks (superuser)",
|
|
auth: superuser,
|
|
records: freshRecords(demo1Records),
|
|
queryExpand: "",
|
|
defaultExpands: []string{"rel_many"},
|
|
expected: []string{
|
|
`"customField":"123"`,
|
|
`"test@example.com"`,
|
|
`"expand":{"rel_many"`,
|
|
`"id":"bgs820n361vj1qd"`,
|
|
`"id":"4q1xlclmfloku33"`,
|
|
`"id":"oap640cot4yru2s"`,
|
|
},
|
|
notExpected: []string{
|
|
`"expand":{}`,
|
|
},
|
|
},
|
|
|
|
// expand checks
|
|
{
|
|
name: "[expand] guest (query)",
|
|
auth: nil,
|
|
records: freshRecords(usersRecords),
|
|
queryExpand: "rel",
|
|
defaultExpands: nil,
|
|
expected: []string{
|
|
`"customField":"123"`,
|
|
`"expand":{"rel"`,
|
|
`"id":"llvuca81nly1qls"`,
|
|
`"id":"0yxhwia2amd8gec"`,
|
|
},
|
|
notExpected: []string{
|
|
`"expand":{}`,
|
|
},
|
|
},
|
|
{
|
|
name: "[expand] guest (default expands)",
|
|
auth: nil,
|
|
records: freshRecords(usersRecords),
|
|
queryExpand: "",
|
|
defaultExpands: []string{"rel"},
|
|
expected: []string{
|
|
`"customField":"123"`,
|
|
`"expand":{"rel"`,
|
|
`"id":"llvuca81nly1qls"`,
|
|
`"id":"0yxhwia2amd8gec"`,
|
|
},
|
|
},
|
|
{
|
|
name: "[expand] @request.context=expand check",
|
|
auth: nil,
|
|
records: freshRecords(demo5Records),
|
|
queryExpand: "rel_one",
|
|
defaultExpands: []string{"rel_many"},
|
|
expected: []string{
|
|
`"customField":"123"`,
|
|
`"expand":{}`,
|
|
`"expand":{"`,
|
|
`"rel_many":[{`,
|
|
`"rel_one":{`,
|
|
`"id":"i9naidtvr6qsgb4"`,
|
|
`"id":"qzaqccwrmva4o1n"`,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, s := range scenarios {
|
|
t.Run(s.name, func(t *testing.T) {
|
|
app, _ := tests.NewTestApp()
|
|
defer app.Cleanup()
|
|
|
|
app.OnRecordEnrich().BindFunc(func(e *core.RecordEnrichEvent) error {
|
|
e.Record.WithCustomData(true)
|
|
e.Record.Set("customField", "123")
|
|
return e.Next()
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/?expand="+s.queryExpand, nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
requestEvent := new(core.RequestEvent)
|
|
requestEvent.App = app
|
|
requestEvent.Request = req
|
|
requestEvent.Response = rec
|
|
requestEvent.Auth = s.auth
|
|
|
|
err := apis.EnrichRecords(requestEvent, s.records, s.defaultExpands...)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
raw, err := json.Marshal(s.records)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
rawStr := string(raw)
|
|
|
|
for _, str := range s.expected {
|
|
if !strings.Contains(rawStr, str) {
|
|
t.Fatalf("Expected\n%q\nin\n%v", str, rawStr)
|
|
}
|
|
}
|
|
|
|
for _, str := range s.notExpected {
|
|
if strings.Contains(rawStr, str) {
|
|
t.Fatalf("Didn't expected\n%q\nin\n%v", str, rawStr)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRecordAuthResponseAuthRuleCheck(t *testing.T) {
|
|
app, _ := tests.NewTestApp()
|
|
defer app.Cleanup()
|
|
|
|
event := new(core.RequestEvent)
|
|
event.App = app
|
|
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
event.Response = httptest.NewRecorder()
|
|
|
|
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
scenarios := []struct {
|
|
name string
|
|
rule *string
|
|
expectError bool
|
|
}{
|
|
{
|
|
"admin only rule",
|
|
nil,
|
|
true,
|
|
},
|
|
{
|
|
"empty rule",
|
|
types.Pointer(""),
|
|
false,
|
|
},
|
|
{
|
|
"false rule",
|
|
types.Pointer("1=2"),
|
|
true,
|
|
},
|
|
{
|
|
"true rule",
|
|
types.Pointer("1=1"),
|
|
false,
|
|
},
|
|
}
|
|
|
|
for _, s := range scenarios {
|
|
t.Run(s.name, func(t *testing.T) {
|
|
user.Collection().AuthRule = s.rule
|
|
|
|
err := apis.RecordAuthResponse(event, user, "", nil)
|
|
|
|
hasErr := err != nil
|
|
if s.expectError != hasErr {
|
|
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
|
}
|
|
|
|
// in all cases login alert shouldn't be send because of the empty auth method
|
|
if app.TestMailer.TotalSend() != 0 {
|
|
t.Fatalf("Expected no emails send, got %d:\n%v", app.TestMailer.TotalSend(), app.TestMailer.LastMessage().HTML)
|
|
}
|
|
|
|
if !hasErr {
|
|
return
|
|
}
|
|
|
|
apiErr, ok := err.(*router.ApiError)
|
|
|
|
if !ok || apiErr == nil {
|
|
t.Fatalf("Expected ApiError, got %v", apiErr)
|
|
}
|
|
|
|
if apiErr.Status != http.StatusForbidden {
|
|
t.Fatalf("Expected ApiError.Status %d, got %d", http.StatusForbidden, apiErr.Status)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRecordAuthResponseAuthAlertCheck(t *testing.T) {
|
|
const testFingerprint = "d0f88d6c87767262ba8e93d6acccd784"
|
|
|
|
scenarios := []struct {
|
|
name string
|
|
devices []string // mock existing device fingerprints
|
|
expectDevices []string
|
|
enabled bool
|
|
expectEmail bool
|
|
}{
|
|
{
|
|
name: "first login",
|
|
devices: nil,
|
|
expectDevices: []string{testFingerprint},
|
|
enabled: true,
|
|
expectEmail: false,
|
|
},
|
|
{
|
|
name: "existing device",
|
|
devices: []string{"1", testFingerprint},
|
|
expectDevices: []string{"1", testFingerprint},
|
|
enabled: true,
|
|
expectEmail: false,
|
|
},
|
|
{
|
|
name: "new device (< 5)",
|
|
devices: []string{"1", "2"},
|
|
expectDevices: []string{"1", "2", testFingerprint},
|
|
enabled: true,
|
|
expectEmail: true,
|
|
},
|
|
{
|
|
name: "new device (>= 5)",
|
|
devices: []string{"1", "2", "3", "4", "5"},
|
|
expectDevices: []string{"2", "3", "4", "5", testFingerprint},
|
|
enabled: true,
|
|
expectEmail: true,
|
|
},
|
|
{
|
|
name: "with disabled auth alert collection flag",
|
|
devices: []string{"1", "2"},
|
|
expectDevices: []string{"1", "2"},
|
|
enabled: false,
|
|
expectEmail: false,
|
|
},
|
|
}
|
|
|
|
for _, s := range scenarios {
|
|
t.Run(s.name, func(t *testing.T) {
|
|
app, _ := tests.NewTestApp()
|
|
defer app.Cleanup()
|
|
|
|
event := new(core.RequestEvent)
|
|
event.App = app
|
|
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
event.Response = httptest.NewRecorder()
|
|
|
|
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
user.Collection().MFA.Enabled = false
|
|
user.Collection().AuthRule = types.Pointer("")
|
|
user.Collection().AuthAlert.Enabled = s.enabled
|
|
|
|
// ensure that there are no other auth origins
|
|
err = app.DeleteAllAuthOriginsByRecord(user)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
mockCreated := types.NowDateTime().Add(-time.Duration(len(s.devices)+1) * time.Second)
|
|
// insert the mock devices
|
|
for _, fingerprint := range s.devices {
|
|
mockCreated = mockCreated.Add(1 * time.Second)
|
|
d := core.NewAuthOrigin(app)
|
|
d.SetCollectionRef(user.Collection().Id)
|
|
d.SetRecordRef(user.Id)
|
|
d.SetFingerprint(fingerprint)
|
|
d.SetRaw("created", mockCreated)
|
|
d.SetRaw("updated", mockCreated)
|
|
if err = app.Save(d); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
err = apis.RecordAuthResponse(event, user, "example", nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to resolve auth response: %v", err)
|
|
}
|
|
|
|
var expectTotalSend int
|
|
if s.expectEmail {
|
|
expectTotalSend = 1
|
|
}
|
|
if total := app.TestMailer.TotalSend(); total != expectTotalSend {
|
|
t.Fatalf("Expected %d sent emails, got %d", expectTotalSend, total)
|
|
}
|
|
|
|
devices, err := app.FindAllAuthOriginsByRecord(user)
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve auth origins: %v", err)
|
|
}
|
|
|
|
if len(devices) != len(s.expectDevices) {
|
|
t.Fatalf("Expected %d devices, got %d", len(s.expectDevices), len(devices))
|
|
}
|
|
|
|
for _, fingerprint := range s.expectDevices {
|
|
var exists bool
|
|
fingerprints := make([]string, 0, len(devices))
|
|
for _, d := range devices {
|
|
if d.Fingerprint() == fingerprint {
|
|
exists = true
|
|
break
|
|
}
|
|
fingerprints = append(fingerprints, d.Fingerprint())
|
|
}
|
|
if !exists {
|
|
t.Fatalf("Missing device with fingerprint %q:\n%v", fingerprint, fingerprints)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRecordAuthResponseMFACheck(t *testing.T) {
|
|
app, _ := tests.NewTestApp()
|
|
defer app.Cleanup()
|
|
|
|
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
user2, err := app.FindAuthRecordByEmail("users", "test2@example.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
event := new(core.RequestEvent)
|
|
event.App = app
|
|
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
|
event.Response = rec
|
|
|
|
resetMFAs := func(authRecord *core.Record) {
|
|
// ensure that mfa is enabled
|
|
user.Collection().MFA.Enabled = true
|
|
user.Collection().MFA.Duration = 5
|
|
user.Collection().MFA.Rule = ""
|
|
|
|
mfas, err := app.FindAllMFAsByRecord(authRecord)
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve mfas: %v", err)
|
|
}
|
|
for _, mfa := range mfas {
|
|
if err := app.Delete(mfa); err != nil {
|
|
t.Fatalf("Failed to delete mfa %q: %v", mfa.Id, err)
|
|
}
|
|
}
|
|
|
|
// reset response
|
|
rec = httptest.NewRecorder()
|
|
event.Response = rec
|
|
}
|
|
|
|
totalMFAs := func(authRecord *core.Record) int {
|
|
mfas, err := app.FindAllMFAsByRecord(authRecord)
|
|
if err != nil {
|
|
t.Fatalf("Failed to retrieve mfas: %v", err)
|
|
}
|
|
return len(mfas)
|
|
}
|
|
|
|
t.Run("no collection MFA enabled", func(t *testing.T) {
|
|
resetMFAs(user)
|
|
|
|
user.Collection().MFA.Enabled = false
|
|
|
|
err = apis.RecordAuthResponse(event, user, "example", nil)
|
|
if err != nil {
|
|
t.Fatalf("Expected nil, got error: %v", err)
|
|
}
|
|
|
|
body := rec.Body.String()
|
|
if strings.Contains(body, "mfaId") {
|
|
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
|
}
|
|
if !strings.Contains(body, "token") {
|
|
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
|
}
|
|
|
|
if total := totalMFAs(user); total != 0 {
|
|
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
|
}
|
|
})
|
|
|
|
t.Run("no explicit auth method", func(t *testing.T) {
|
|
resetMFAs(user)
|
|
|
|
err = apis.RecordAuthResponse(event, user, "", nil)
|
|
if err != nil {
|
|
t.Fatalf("Expected nil, got error: %v", err)
|
|
}
|
|
|
|
body := rec.Body.String()
|
|
if strings.Contains(body, "mfaId") {
|
|
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
|
}
|
|
if !strings.Contains(body, "token") {
|
|
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
|
}
|
|
|
|
if total := totalMFAs(user); total != 0 {
|
|
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
|
}
|
|
})
|
|
|
|
t.Run("no mfa wanted (mfa rule check failure)", func(t *testing.T) {
|
|
resetMFAs(user)
|
|
user.Collection().MFA.Rule = "1=2"
|
|
|
|
err = apis.RecordAuthResponse(event, user, "example", nil)
|
|
if err != nil {
|
|
t.Fatalf("Expected nil, got error: %v", err)
|
|
}
|
|
|
|
body := rec.Body.String()
|
|
if strings.Contains(body, "mfaId") {
|
|
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
|
|
}
|
|
if !strings.Contains(body, "token") {
|
|
t.Fatalf("Expected auth token in the response body, got\n%v", body)
|
|
}
|
|
|
|
if total := totalMFAs(user); total != 0 {
|
|
t.Fatalf("Expected no mfa records to be created, got %d", total)
|
|
}
|
|
})
|
|
|
|
t.Run("mfa wanted (mfa rule check success)", func(t *testing.T) {
|
|
resetMFAs(user)
|
|
user.Collection().MFA.Rule = "1=1"
|
|
|
|
err = apis.RecordAuthResponse(event, user, "example", nil)
|
|
if !errors.Is(err, apis.ErrMFA) {
|
|
t.Fatalf("Expected ErrMFA, got: %v", err)
|
|
}
|
|
|
|
body := rec.Body.String()
|
|
if !strings.Contains(body, "mfaId") {
|
|
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
|
|
}
|
|
|
|
if total := totalMFAs(user); total != 1 {
|
|
t.Fatalf("Expected a single mfa record to be created, got %d", total)
|
|
}
|
|
})
|
|
|
|
t.Run("mfa first-time", func(t *testing.T) {
|
|
resetMFAs(user)
|
|
|
|
err = apis.RecordAuthResponse(event, user, "example", nil)
|
|
if !errors.Is(err, apis.ErrMFA) {
|
|
t.Fatalf("Expected ErrMFA, got: %v", err)
|
|
}
|
|
|
|
body := rec.Body.String()
|
|
if !strings.Contains(body, "mfaId") {
|
|
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
|
|
}
|
|
|
|
if total := totalMFAs(user); total != 1 {
|
|
t.Fatalf("Expected a single mfa record to be created, got %d", total)
|
|
}
|
|
})
|
|
|
|
t.Run("mfa second-time with the same auth method", func(t *testing.T) {
|
|
resetMFAs(user)
|
|
|
|
// create a dummy mfa record
|
|
mfa := core.NewMFA(app)
|
|
mfa.SetCollectionRef(user.Collection().Id)
|
|
mfa.SetRecordRef(user.Id)
|
|
mfa.SetMethod("example")
|
|
if err = app.Save(mfa); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
|
|
|
err = apis.RecordAuthResponse(event, user, "example", nil)
|
|
if err == nil {
|
|
t.Fatal("Expected error, got nil")
|
|
}
|
|
|
|
if total := totalMFAs(user); total != 1 {
|
|
t.Fatalf("Expected only 1 mfa record (the existing one), got %d", total)
|
|
}
|
|
})
|
|
|
|
t.Run("mfa second-time with the different auth method (query param)", func(t *testing.T) {
|
|
resetMFAs(user)
|
|
|
|
// create a dummy mfa record
|
|
mfa := core.NewMFA(app)
|
|
mfa.SetCollectionRef(user.Collection().Id)
|
|
mfa.SetRecordRef(user.Id)
|
|
mfa.SetMethod("example1")
|
|
if err = app.Save(mfa); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
|
|
|
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
|
if err != nil {
|
|
t.Fatalf("Expected nil, got error: %v", err)
|
|
}
|
|
|
|
if total := totalMFAs(user); total != 0 {
|
|
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
|
|
}
|
|
})
|
|
|
|
t.Run("mfa second-time with the different auth method (body param)", func(t *testing.T) {
|
|
resetMFAs(user)
|
|
|
|
// create a dummy mfa record
|
|
mfa := core.NewMFA(app)
|
|
mfa.SetCollectionRef(user.Collection().Id)
|
|
mfa.SetRecordRef(user.Id)
|
|
mfa.SetMethod("example1")
|
|
if err = app.Save(mfa); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
event.Request = httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"mfaId":"`+mfa.Id+`"}`))
|
|
event.Request.Header.Add("content-type", "application/json")
|
|
|
|
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
|
if err != nil {
|
|
t.Fatalf("Expected nil, got error: %v", err)
|
|
}
|
|
|
|
if total := totalMFAs(user); total != 0 {
|
|
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
|
|
}
|
|
})
|
|
|
|
t.Run("missing mfa", func(t *testing.T) {
|
|
resetMFAs(user)
|
|
|
|
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId=missing", nil)
|
|
|
|
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
|
if err == nil {
|
|
t.Fatal("Expected error, got nil")
|
|
}
|
|
|
|
if total := totalMFAs(user); total != 0 {
|
|
t.Fatalf("Expected 0 mfa records, got %d", total)
|
|
}
|
|
})
|
|
|
|
t.Run("expired mfa", func(t *testing.T) {
|
|
resetMFAs(user)
|
|
|
|
// create a dummy expired mfa record
|
|
mfa := core.NewMFA(app)
|
|
mfa.SetCollectionRef(user.Collection().Id)
|
|
mfa.SetRecordRef(user.Id)
|
|
mfa.SetMethod("example1")
|
|
mfa.SetRaw("created", types.NowDateTime().Add(-1*time.Hour))
|
|
mfa.SetRaw("updated", types.NowDateTime().Add(-1*time.Hour))
|
|
if err = app.Save(mfa); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
|
|
|
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
|
if err == nil {
|
|
t.Fatal("Expected error, got nil")
|
|
}
|
|
|
|
if totalMFAs(user) != 0 {
|
|
t.Fatal("Expected the expired mfa record to be deleted")
|
|
}
|
|
})
|
|
|
|
t.Run("mfa for different auth record", func(t *testing.T) {
|
|
resetMFAs(user)
|
|
|
|
// create a dummy expired mfa record
|
|
mfa := core.NewMFA(app)
|
|
mfa.SetCollectionRef(user2.Collection().Id)
|
|
mfa.SetRecordRef(user2.Id)
|
|
mfa.SetMethod("example1")
|
|
if err = app.Save(mfa); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
|
|
|
|
err = apis.RecordAuthResponse(event, user, "example2", nil)
|
|
if err == nil {
|
|
t.Fatal("Expected error, got nil")
|
|
}
|
|
|
|
if total := totalMFAs(user); total != 0 {
|
|
t.Fatalf("Expected no user mfas, got %d", total)
|
|
}
|
|
|
|
if total := totalMFAs(user2); total != 1 {
|
|
t.Fatalf("Expected only 1 user2 mfa, got %d", total)
|
|
}
|
|
})
|
|
}
|