mirror of
https://github.com/pocketbase/pocketbase.git
synced 2025-01-27 23:46:18 +02:00
158 lines
3.9 KiB
Go
158 lines
3.9 KiB
Go
package core
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/pocketbase/pocketbase/tools/hook"
|
|
"github.com/pocketbase/pocketbase/tools/types"
|
|
)
|
|
|
|
const (
|
|
MFAMethodPassword = "password"
|
|
MFAMethodOAuth2 = "oauth2"
|
|
MFAMethodOTP = "otp"
|
|
)
|
|
|
|
const CollectionNameMFAs = "_mfas"
|
|
|
|
var (
|
|
_ Model = (*MFA)(nil)
|
|
_ PreValidator = (*MFA)(nil)
|
|
_ RecordProxy = (*MFA)(nil)
|
|
)
|
|
|
|
// MFA defines a Record proxy for working with the mfas collection.
|
|
type MFA struct {
|
|
*Record
|
|
}
|
|
|
|
// NewMFA instantiates and returns a new blank *MFA model.
|
|
//
|
|
// Example usage:
|
|
//
|
|
// mfa := core.NewMFA(app)
|
|
// mfa.SetRecordRef(user.Id)
|
|
// mfa.SetCollectionRef(user.Collection().Id)
|
|
// mfa.SetMethod(core.MFAMethodPassword)
|
|
// app.Save(mfa)
|
|
func NewMFA(app App) *MFA {
|
|
m := &MFA{}
|
|
|
|
c, err := app.FindCachedCollectionByNameOrId(CollectionNameMFAs)
|
|
if err != nil {
|
|
// this is just to make tests easier since mfa is a system collection and it is expected to be always accessible
|
|
// (note: the loaded record is further checked on MFA.PreValidate())
|
|
c = NewBaseCollection("@__invalid__")
|
|
}
|
|
|
|
m.Record = NewRecord(c)
|
|
|
|
return m
|
|
}
|
|
|
|
// PreValidate implements the [PreValidator] interface and checks
|
|
// whether the proxy is properly loaded.
|
|
func (m *MFA) PreValidate(ctx context.Context, app App) error {
|
|
if m.Record == nil || m.Record.Collection().Name != CollectionNameMFAs {
|
|
return errors.New("missing or invalid mfa ProxyRecord")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ProxyRecord returns the proxied Record model.
|
|
func (m *MFA) ProxyRecord() *Record {
|
|
return m.Record
|
|
}
|
|
|
|
// SetProxyRecord loads the specified record model into the current proxy.
|
|
func (m *MFA) SetProxyRecord(record *Record) {
|
|
m.Record = record
|
|
}
|
|
|
|
// CollectionRef returns the "collectionRef" field value.
|
|
func (m *MFA) CollectionRef() string {
|
|
return m.GetString("collectionRef")
|
|
}
|
|
|
|
// SetCollectionRef updates the "collectionRef" record field value.
|
|
func (m *MFA) SetCollectionRef(collectionId string) {
|
|
m.Set("collectionRef", collectionId)
|
|
}
|
|
|
|
// RecordRef returns the "recordRef" record field value.
|
|
func (m *MFA) RecordRef() string {
|
|
return m.GetString("recordRef")
|
|
}
|
|
|
|
// SetRecordRef updates the "recordRef" record field value.
|
|
func (m *MFA) SetRecordRef(recordId string) {
|
|
m.Set("recordRef", recordId)
|
|
}
|
|
|
|
// Method returns the "method" record field value.
|
|
func (m *MFA) Method() string {
|
|
return m.GetString("method")
|
|
}
|
|
|
|
// SetMethod updates the "method" record field value.
|
|
func (m *MFA) SetMethod(method string) {
|
|
m.Set("method", method)
|
|
}
|
|
|
|
// Created returns the "created" record field value.
|
|
func (m *MFA) Created() types.DateTime {
|
|
return m.GetDateTime("created")
|
|
}
|
|
|
|
// Updated returns the "updated" record field value.
|
|
func (m *MFA) Updated() types.DateTime {
|
|
return m.GetDateTime("updated")
|
|
}
|
|
|
|
// HasExpired checks if the mfa is expired, aka. whether it has been
|
|
// more than maxElapsed time since its creation.
|
|
func (m *MFA) HasExpired(maxElapsed time.Duration) bool {
|
|
return time.Since(m.Created().Time()) > maxElapsed
|
|
}
|
|
|
|
func (app *BaseApp) registerMFAHooks() {
|
|
recordRefHooks[*MFA](app, CollectionNameMFAs, CollectionTypeAuth)
|
|
|
|
// run on every hour to cleanup expired mfa sessions
|
|
app.Cron().Add("__mfasCleanup__", "0 * * * *", func() {
|
|
if err := app.DeleteExpiredMFAs(); err != nil {
|
|
app.Logger().Warn("Failed to delete expired MFA sessions", "error", err)
|
|
}
|
|
})
|
|
|
|
// delete existing mfas on password change
|
|
app.OnRecordUpdate().Bind(&hook.Handler[*RecordEvent]{
|
|
Func: func(e *RecordEvent) error {
|
|
err := e.Next()
|
|
if err != nil || !e.Record.Collection().IsAuth() {
|
|
return err
|
|
}
|
|
|
|
old := e.Record.Original().GetString(FieldNamePassword + ":hash")
|
|
new := e.Record.GetString(FieldNamePassword + ":hash")
|
|
if old != new {
|
|
err = e.App.DeleteAllMFAsByRecord(e.Record)
|
|
if err != nil {
|
|
e.App.Logger().Warn(
|
|
"Failed to delete all previous mfas",
|
|
"error", err,
|
|
"recordId", e.Record.Id,
|
|
"collectionId", e.Record.Collection().Id,
|
|
)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
},
|
|
Priority: 99,
|
|
})
|
|
}
|