1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-01-27 23:46:18 +02:00
pocketbase/core/mfa_model.go
2024-09-29 21:09:46 +03:00

158 lines
3.9 KiB
Go

package core
import (
"context"
"errors"
"time"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/types"
)
const (
MFAMethodPassword = "password"
MFAMethodOAuth2 = "oauth2"
MFAMethodOTP = "otp"
)
const CollectionNameMFAs = "_mfas"
var (
_ Model = (*MFA)(nil)
_ PreValidator = (*MFA)(nil)
_ RecordProxy = (*MFA)(nil)
)
// MFA defines a Record proxy for working with the mfas collection.
type MFA struct {
*Record
}
// NewMFA instantiates and returns a new blank *MFA model.
//
// Example usage:
//
// mfa := core.NewMFA(app)
// mfa.SetRecordRef(user.Id)
// mfa.SetCollectionRef(user.Collection().Id)
// mfa.SetMethod(core.MFAMethodPassword)
// app.Save(mfa)
func NewMFA(app App) *MFA {
m := &MFA{}
c, err := app.FindCachedCollectionByNameOrId(CollectionNameMFAs)
if err != nil {
// this is just to make tests easier since mfa is a system collection and it is expected to be always accessible
// (note: the loaded record is further checked on MFA.PreValidate())
c = NewBaseCollection("@__invalid__")
}
m.Record = NewRecord(c)
return m
}
// PreValidate implements the [PreValidator] interface and checks
// whether the proxy is properly loaded.
func (m *MFA) PreValidate(ctx context.Context, app App) error {
if m.Record == nil || m.Record.Collection().Name != CollectionNameMFAs {
return errors.New("missing or invalid mfa ProxyRecord")
}
return nil
}
// ProxyRecord returns the proxied Record model.
func (m *MFA) ProxyRecord() *Record {
return m.Record
}
// SetProxyRecord loads the specified record model into the current proxy.
func (m *MFA) SetProxyRecord(record *Record) {
m.Record = record
}
// CollectionRef returns the "collectionRef" field value.
func (m *MFA) CollectionRef() string {
return m.GetString("collectionRef")
}
// SetCollectionRef updates the "collectionRef" record field value.
func (m *MFA) SetCollectionRef(collectionId string) {
m.Set("collectionRef", collectionId)
}
// RecordRef returns the "recordRef" record field value.
func (m *MFA) RecordRef() string {
return m.GetString("recordRef")
}
// SetRecordRef updates the "recordRef" record field value.
func (m *MFA) SetRecordRef(recordId string) {
m.Set("recordRef", recordId)
}
// Method returns the "method" record field value.
func (m *MFA) Method() string {
return m.GetString("method")
}
// SetMethod updates the "method" record field value.
func (m *MFA) SetMethod(method string) {
m.Set("method", method)
}
// Created returns the "created" record field value.
func (m *MFA) Created() types.DateTime {
return m.GetDateTime("created")
}
// Updated returns the "updated" record field value.
func (m *MFA) Updated() types.DateTime {
return m.GetDateTime("updated")
}
// HasExpired checks if the mfa is expired, aka. whether it has been
// more than maxElapsed time since its creation.
func (m *MFA) HasExpired(maxElapsed time.Duration) bool {
return time.Since(m.Created().Time()) > maxElapsed
}
func (app *BaseApp) registerMFAHooks() {
recordRefHooks[*MFA](app, CollectionNameMFAs, CollectionTypeAuth)
// run on every hour to cleanup expired mfa sessions
app.Cron().Add("__mfasCleanup__", "0 * * * *", func() {
if err := app.DeleteExpiredMFAs(); err != nil {
app.Logger().Warn("Failed to delete expired MFA sessions", "error", err)
}
})
// delete existing mfas on password change
app.OnRecordUpdate().Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
err := e.Next()
if err != nil || !e.Record.Collection().IsAuth() {
return err
}
old := e.Record.Original().GetString(FieldNamePassword + ":hash")
new := e.Record.GetString(FieldNamePassword + ":hash")
if old != new {
err = e.App.DeleteAllMFAsByRecord(e.Record)
if err != nil {
e.App.Logger().Warn(
"Failed to delete all previous mfas",
"error", err,
"recordId", e.Record.Id,
"collectionId", e.Record.Collection().Id,
)
}
}
return nil
},
Priority: 99,
})
}