1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-04-09 17:34:04 +02:00

updated mfa defaults and errors check

This commit is contained in:
Gani Georgiev 2024-11-13 20:14:27 +02:00
parent 396aa0f97c
commit cc833ad643

View File

@ -1,7 +1,6 @@
package apis
import (
"database/sql"
"errors"
"fmt"
"net/http"
@ -122,7 +121,8 @@ func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token str
})
}
// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule.
// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule
// (note: returns true even in case of an error as a safer default).
func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
rule := record.Collection().MFA.Rule
if rule == "" {
@ -131,7 +131,7 @@ func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
requestInfo, err := e.RequestInfo()
if err != nil {
return false, err
return true, err
}
var exists bool
@ -144,13 +144,13 @@ func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
resolver := core.NewRecordFieldResolver(e.App, record.Collection(), requestInfo, true)
expr, err := search.FilterData(rule).BuildExpr(resolver)
if err != nil {
return false, err
return true, err
}
resolver.UpdateQuery(query)
err = query.AndWhere(expr).Limit(1).Row(&exists)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, err
if err != nil {
return true, err
}
return exists, nil
@ -166,11 +166,10 @@ func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod s
}
ok, err := wantsMFA(e, authRecord)
if err != nil {
return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err))
}
if !ok {
if err != nil {
return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err))
}
return "", nil // no mfa needed for this auth record
}
@ -214,7 +213,7 @@ func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod s
}
if err != nil || mfa.HasExpired(authRecord.Collection().MFA.DurationTime()) {
deleteMFA()
return "", firstApiError(err, e.BadRequestError("Invalid or expired MFA session.", err))
return "", e.BadRequestError("Invalid or expired MFA session.", err)
}
if mfa.RecordRef() != authRecord.Id || mfa.CollectionRef() != authRecord.Collection().Id {