diff --git a/CHANGELOG.md b/CHANGELOG.md index 39e1c0c0..a1460117 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,27 @@ - Added LiveChat OAuth2 provider ([#1573](https://github.com/pocketbase/pocketbase/pull/1573); thanks @mariosant). +- Added new event hooks: + + ```go + OnRecordBeforeAuthWithPasswordRequest() + OnRecordAfterAuthWithPasswordRequest() + OnRecordBeforeAuthWithOAuth2Request() + OnRecordAfterAuthWithOAuth2Request() + OnRecordBeforeAuthRefreshRequest() + OnRecordAfterAuthRefreshRequest() + OnAdminBeforeAuthWithPasswordRequest() + OnAdminAfterAuthWithPasswordRequest() + OnAdminBeforeAuthRefreshRequest() + OnAdminAfterAuthRefreshRequest() + OnAdminBeforeRequestPasswordResetRequest() + OnAdminAfterRequestPasswordResetRequest() + OnAdminBeforeConfirmPasswordResetRequest() + OnAdminAfterConfirmPasswordResetRequest() + ``` + +- Refactored all `forms` Submit interceptors to use a Generic data type as their payload. + ## v0.11.2 diff --git a/README.md b/README.md index 3633f220..9ea58fd3 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ To build the minimal standalone executable, like the prebuilt ones in the releas 2. Navigate to `examples/base` 3. Run `GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build` (_https://go.dev/doc/install/source#environment_) -4. Start the generated executable by running `./base serve`. +4. Start the created executable by running `./base serve`. The supported build targets by the non-cgo driver at the moment are: ``` diff --git a/apis/admin.go b/apis/admin.go index 69119d39..d6301796 100644 --- a/apis/admin.go +++ b/apis/admin.go @@ -59,21 +59,57 @@ func (api *adminApi) authRefresh(c echo.Context) error { return NewNotFoundError("Missing auth admin context.", nil) } - return api.authResponse(c, admin) + event := &core.AdminAuthRefreshEvent{ + HttpContext: c, + Admin: admin, + } + + handlerErr := api.app.OnAdminBeforeAuthRefreshRequest().Trigger(event, func(e *core.AdminAuthRefreshEvent) error { + return api.authResponse(e.HttpContext, e.Admin) + }) + + if handlerErr == nil { + if err := api.app.OnAdminAfterAuthRefreshRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } + } + + return handlerErr } func (api *adminApi) authWithPassword(c echo.Context) error { form := forms.NewAdminLogin(api.app) - if readErr := c.Bind(form); readErr != nil { - return NewBadRequestError("An error occurred while loading the submitted data.", readErr) + if err := c.Bind(form); err != nil { + return NewBadRequestError("An error occurred while loading the submitted data.", err) } - admin, submitErr := form.Submit() - if submitErr != nil { - return NewBadRequestError("Failed to authenticate.", submitErr) + event := &core.AdminAuthWithPasswordEvent{ + HttpContext: c, + Password: form.Password, + Identity: form.Identity, } - return api.authResponse(c, admin) + _, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(admin *models.Admin) error { + event.Admin = admin + + return api.app.OnAdminBeforeAuthWithPasswordRequest().Trigger(event, func(e *core.AdminAuthWithPasswordEvent) error { + if err := next(e.Admin); err != nil { + return NewBadRequestError("Failed to authenticate.", err) + } + + return api.authResponse(e.HttpContext, e.Admin) + }) + } + }) + + if submitErr == nil { + if err := api.app.OnAdminAfterAuthWithPasswordRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } + } + + return submitErr } func (api *adminApi) requestPasswordReset(c echo.Context) error { @@ -86,15 +122,41 @@ func (api *adminApi) requestPasswordReset(c echo.Context) error { return NewBadRequestError("An error occurred while validating the form.", err) } - // run in background because we don't need to show the result - // (prevents admins enumeration) - routine.FireAndForget(func() { - if err := form.Submit(); err != nil && api.app.IsDebug() { - log.Println(err) + event := &core.AdminRequestPasswordResetEvent{ + HttpContext: c, + } + + submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(Admin *models.Admin) error { + event.Admin = Admin + + return api.app.OnAdminBeforeRequestPasswordResetRequest().Trigger(event, func(e *core.AdminRequestPasswordResetEvent) error { + // run in background because we don't need to show the result to the client + routine.FireAndForget(func() { + if err := next(e.Admin); err != nil && api.app.IsDebug() { + log.Println(err) + } + }) + + return e.HttpContext.NoContent(http.StatusNoContent) + }) } }) - return c.NoContent(http.StatusNoContent) + if submitErr == nil { + if err := api.app.OnAdminAfterRequestPasswordResetRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } + } else if api.app.IsDebug() { + log.Println(submitErr) + } + + // don't return the response error to prevent emails enumeration + if !c.Response().Committed { + c.NoContent(http.StatusNoContent) + } + + return nil } func (api *adminApi) confirmPasswordReset(c echo.Context) error { @@ -103,12 +165,31 @@ func (api *adminApi) confirmPasswordReset(c echo.Context) error { return NewBadRequestError("An error occurred while loading the submitted data.", readErr) } - _, submitErr := form.Submit() - if submitErr != nil { - return NewBadRequestError("Failed to set new password.", submitErr) + event := &core.AdminConfirmPasswordResetEvent{ + HttpContext: c, } - return c.NoContent(http.StatusNoContent) + _, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(admin *models.Admin) error { + event.Admin = admin + + return api.app.OnAdminBeforeConfirmPasswordResetRequest().Trigger(event, func(e *core.AdminConfirmPasswordResetEvent) error { + if err := next(e.Admin); err != nil { + return NewBadRequestError("Failed to set new password.", err) + } + + return e.HttpContext.NoContent(http.StatusNoContent) + }) + } + }) + + if submitErr == nil { + if err := api.app.OnAdminAfterConfirmPasswordResetRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } + } + + return submitErr } func (api *adminApi) list(c echo.Context) error { @@ -174,10 +255,12 @@ func (api *adminApi) create(c echo.Context) error { } // create the admin - submitErr := form.Submit(func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(m *models.Admin) error { + event.Admin = m + return api.app.OnAdminBeforeCreateRequest().Trigger(event, func(e *core.AdminCreateEvent) error { - if err := next(); err != nil { + if err := next(e.Admin); err != nil { return NewBadRequestError("Failed to create admin.", err) } @@ -187,7 +270,9 @@ func (api *adminApi) create(c echo.Context) error { }) if submitErr == nil { - api.app.OnAdminAfterCreateRequest().Trigger(event) + if err := api.app.OnAdminAfterCreateRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return submitErr @@ -217,10 +302,12 @@ func (api *adminApi) update(c echo.Context) error { } // update the admin - submitErr := form.Submit(func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(m *models.Admin) error { + event.Admin = m + return api.app.OnAdminBeforeUpdateRequest().Trigger(event, func(e *core.AdminUpdateEvent) error { - if err := next(); err != nil { + if err := next(e.Admin); err != nil { return NewBadRequestError("Failed to update admin.", err) } @@ -230,7 +317,9 @@ func (api *adminApi) update(c echo.Context) error { }) if submitErr == nil { - api.app.OnAdminAfterUpdateRequest().Trigger(event) + if err := api.app.OnAdminAfterUpdateRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return submitErr @@ -261,7 +350,9 @@ func (api *adminApi) delete(c echo.Context) error { }) if handlerErr == nil { - api.app.OnAdminAfterDeleteRequest().Trigger(event) + if err := api.app.OnAdminAfterDeleteRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return handlerErr diff --git a/apis/admin_test.go b/apis/admin_test.go index 0a94cb45..a750b1be 100644 --- a/apis/admin_test.go +++ b/apis/admin_test.go @@ -14,7 +14,7 @@ import ( "github.com/pocketbase/pocketbase/tools/types" ) -func TestAdminAuthWithEmail(t *testing.T) { +func TestAdminAuthWithPassword(t *testing.T) { scenarios := []tests.ApiScenario{ { Name: "empty data", @@ -39,6 +39,9 @@ func TestAdminAuthWithEmail(t *testing.T) { Body: strings.NewReader(`{"identity":"missing@example.com","password":"1234567890"}`), ExpectedStatus: 400, ExpectedContent: []string{`"data":{}`}, + ExpectedEvents: map[string]int{ + "OnAdminBeforeAuthWithPasswordRequest": 1, + }, }, { Name: "wrong password", @@ -47,6 +50,9 @@ func TestAdminAuthWithEmail(t *testing.T) { Body: strings.NewReader(`{"identity":"test@example.com","password":"invalid"}`), ExpectedStatus: 400, ExpectedContent: []string{`"data":{}`}, + ExpectedEvents: map[string]int{ + "OnAdminBeforeAuthWithPasswordRequest": 1, + }, }, { Name: "valid email/password (guest)", @@ -59,7 +65,9 @@ func TestAdminAuthWithEmail(t *testing.T) { `"token":`, }, ExpectedEvents: map[string]int{ - "OnAdminAuthRequest": 1, + "OnAdminBeforeAuthWithPasswordRequest": 1, + "OnAdminAfterAuthWithPasswordRequest": 1, + "OnAdminAuthRequest": 1, }, }, { @@ -76,7 +84,9 @@ func TestAdminAuthWithEmail(t *testing.T) { `"token":`, }, ExpectedEvents: map[string]int{ - "OnAdminAuthRequest": 1, + "OnAdminBeforeAuthWithPasswordRequest": 1, + "OnAdminAfterAuthWithPasswordRequest": 1, + "OnAdminAuthRequest": 1, }, }, } @@ -120,10 +130,12 @@ func TestAdminRequestPasswordReset(t *testing.T) { Delay: 100 * time.Millisecond, ExpectedStatus: 204, ExpectedEvents: map[string]int{ - "OnModelBeforeUpdate": 1, - "OnModelAfterUpdate": 1, - "OnMailerBeforeAdminResetPasswordSend": 1, - "OnMailerAfterAdminResetPasswordSend": 1, + "OnModelBeforeUpdate": 1, + "OnModelAfterUpdate": 1, + "OnMailerBeforeAdminResetPasswordSend": 1, + "OnMailerAfterAdminResetPasswordSend": 1, + "OnAdminBeforeRequestPasswordResetRequest": 1, + "OnAdminAfterRequestPasswordResetRequest": 1, }, }, { @@ -206,8 +218,10 @@ func TestAdminConfirmPasswordReset(t *testing.T) { }`), ExpectedStatus: 204, ExpectedEvents: map[string]int{ - "OnModelBeforeUpdate": 1, - "OnModelAfterUpdate": 1, + "OnModelBeforeUpdate": 1, + "OnModelAfterUpdate": 1, + "OnAdminBeforeConfirmPasswordResetRequest": 1, + "OnAdminAfterConfirmPasswordResetRequest": 1, }, }, } @@ -259,7 +273,9 @@ func TestAdminRefresh(t *testing.T) { `"token":`, }, ExpectedEvents: map[string]int{ - "OnAdminAuthRequest": 1, + "OnAdminAuthRequest": 1, + "OnAdminBeforeAuthRefreshRequest": 1, + "OnAdminAfterAuthRefreshRequest": 1, }, }, } diff --git a/apis/collection.go b/apis/collection.go index 146a041f..8bb0bf4a 100644 --- a/apis/collection.go +++ b/apis/collection.go @@ -1,6 +1,7 @@ package apis import ( + "log" "net/http" "github.com/labstack/echo/v5" @@ -85,10 +86,12 @@ func (api *collectionApi) create(c echo.Context) error { } // create the collection - submitErr := form.Submit(func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Collection]) forms.InterceptorNextFunc[*models.Collection] { + return func(m *models.Collection) error { + event.Collection = m + return api.app.OnCollectionBeforeCreateRequest().Trigger(event, func(e *core.CollectionCreateEvent) error { - if err := next(); err != nil { + if err := next(e.Collection); err != nil { return NewBadRequestError("Failed to create the collection.", err) } @@ -98,7 +101,9 @@ func (api *collectionApi) create(c echo.Context) error { }) if submitErr == nil { - api.app.OnCollectionAfterCreateRequest().Trigger(event) + if err := api.app.OnCollectionAfterCreateRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return submitErr @@ -123,10 +128,12 @@ func (api *collectionApi) update(c echo.Context) error { } // update the collection - submitErr := form.Submit(func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Collection]) forms.InterceptorNextFunc[*models.Collection] { + return func(m *models.Collection) error { + event.Collection = m + return api.app.OnCollectionBeforeUpdateRequest().Trigger(event, func(e *core.CollectionUpdateEvent) error { - if err := next(); err != nil { + if err := next(e.Collection); err != nil { return NewBadRequestError("Failed to update the collection.", err) } @@ -136,7 +143,9 @@ func (api *collectionApi) update(c echo.Context) error { }) if submitErr == nil { - api.app.OnCollectionAfterUpdateRequest().Trigger(event) + if err := api.app.OnCollectionAfterUpdateRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return submitErr @@ -162,7 +171,9 @@ func (api *collectionApi) delete(c echo.Context) error { }) if handlerErr == nil { - api.app.OnCollectionAfterDeleteRequest().Trigger(event) + if err := api.app.OnCollectionAfterDeleteRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return handlerErr @@ -182,12 +193,12 @@ func (api *collectionApi) bulkImport(c echo.Context) error { } // import collections - submitErr := form.Submit(func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { - return api.app.OnCollectionsBeforeImportRequest().Trigger(event, func(e *core.CollectionsImportEvent) error { - form.Collections = e.Collections // ensures that the form always has the latest changes + submitErr := form.Submit(func(next forms.InterceptorNextFunc[[]*models.Collection]) forms.InterceptorNextFunc[[]*models.Collection] { + return func(imports []*models.Collection) error { + event.Collections = imports - if err := next(); err != nil { + return api.app.OnCollectionsBeforeImportRequest().Trigger(event, func(e *core.CollectionsImportEvent) error { + if err := next(e.Collections); err != nil { return NewBadRequestError("Failed to import the submitted collections.", err) } @@ -197,7 +208,9 @@ func (api *collectionApi) bulkImport(c echo.Context) error { }) if submitErr == nil { - api.app.OnCollectionsAfterImportRequest().Trigger(event) + if err := api.app.OnCollectionsAfterImportRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return submitErr diff --git a/apis/record_auth.go b/apis/record_auth.go index 575c48eb..94cf1206 100644 --- a/apis/record_auth.go +++ b/apis/record_auth.go @@ -104,7 +104,22 @@ func (api *recordAuthApi) authRefresh(c echo.Context) error { return NewNotFoundError("Missing auth record context.", nil) } - return api.authResponse(c, record, nil) + event := &core.RecordAuthRefreshEvent{ + HttpContext: c, + Record: record, + } + + handlerErr := api.app.OnRecordBeforeAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshEvent) error { + return api.authResponse(e.HttpContext, e.Record, nil) + }) + + if handlerErr == nil { + if err := api.app.OnRecordAfterAuthRefreshRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } + } + + return handlerErr } type providerInfo struct { @@ -202,7 +217,7 @@ func (api *recordAuthApi) authWithOAuth2(c echo.Context) error { return NewBadRequestError("An error occurred while loading the submitted data.", readErr) } - record, authData, submitErr := form.Submit(func(createForm *forms.RecordUpsert, authRecord *models.Record, authUser *auth.AuthUser) error { + form.SetBeforeNewRecordCreateFunc(func(createForm *forms.RecordUpsert, authRecord *models.Record, authUser *auth.AuthUser) error { return createForm.DrySubmit(func(txDao *daos.Dao) error { requestData := RequestData(c) requestData.Data = form.CreateData @@ -237,11 +252,36 @@ func (api *recordAuthApi) authWithOAuth2(c echo.Context) error { return nil }) }) - if submitErr != nil { - return NewBadRequestError("Failed to authenticate.", submitErr) + + event := &core.RecordAuthWithOAuth2Event{ + HttpContext: c, } - return api.authResponse(c, record, authData) + _, _, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*forms.RecordOAuth2LoginData]) forms.InterceptorNextFunc[*forms.RecordOAuth2LoginData] { + return func(data *forms.RecordOAuth2LoginData) error { + event.Record = data.Record + event.OAuth2User = data.OAuth2User + + return api.app.OnRecordBeforeAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2Event) error { + data.Record = e.Record + data.OAuth2User = e.OAuth2User + + if err := next(data); err != nil { + return NewBadRequestError("Failed to authenticate.", err) + } + + return api.authResponse(e.HttpContext, e.Record, e.OAuth2User) + }) + } + }) + + if submitErr == nil { + if err := api.app.OnRecordAfterAuthWithOAuth2Request().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } + } + + return submitErr } func (api *recordAuthApi) authWithPassword(c echo.Context) error { @@ -255,12 +295,33 @@ func (api *recordAuthApi) authWithPassword(c echo.Context) error { return NewBadRequestError("An error occurred while loading the submitted data.", readErr) } - record, submitErr := form.Submit() - if submitErr != nil { - return NewBadRequestError("Failed to authenticate.", submitErr) + event := &core.RecordAuthWithPasswordEvent{ + HttpContext: c, + Password: form.Password, + Identity: form.Identity, } - return api.authResponse(c, record, nil) + _, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { + return func(record *models.Record) error { + event.Record = record + + return api.app.OnRecordBeforeAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordEvent) error { + if err := next(e.Record); err != nil { + return NewBadRequestError("Failed to authenticate.", err) + } + + return api.authResponse(e.HttpContext, e.Record, nil) + }) + } + }) + + if submitErr == nil { + if err := api.app.OnRecordAfterAuthWithPasswordRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } + } + + return submitErr } func (api *recordAuthApi) requestPasswordReset(c echo.Context) error { @@ -287,7 +348,7 @@ func (api *recordAuthApi) requestPasswordReset(c echo.Context) error { HttpContext: c, } - submitErr := form.Submit(func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { event.Record = record @@ -305,7 +366,9 @@ func (api *recordAuthApi) requestPasswordReset(c echo.Context) error { }) if submitErr == nil { - api.app.OnRecordAfterRequestPasswordResetRequest().Trigger(event) + if err := api.app.OnRecordAfterRequestPasswordResetRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } else if api.app.IsDebug() { log.Println(submitErr) } @@ -333,7 +396,7 @@ func (api *recordAuthApi) confirmPasswordReset(c echo.Context) error { HttpContext: c, } - _, submitErr := form.Submit(func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + _, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { event.Record = record @@ -348,7 +411,9 @@ func (api *recordAuthApi) confirmPasswordReset(c echo.Context) error { }) if submitErr == nil { - api.app.OnRecordAfterConfirmPasswordResetRequest().Trigger(event) + if err := api.app.OnRecordAfterConfirmPasswordResetRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return submitErr @@ -373,7 +438,7 @@ func (api *recordAuthApi) requestVerification(c echo.Context) error { HttpContext: c, } - submitErr := form.Submit(func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { event.Record = record @@ -391,7 +456,9 @@ func (api *recordAuthApi) requestVerification(c echo.Context) error { }) if submitErr == nil { - api.app.OnRecordAfterRequestVerificationRequest().Trigger(event) + if err := api.app.OnRecordAfterRequestVerificationRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } else if api.app.IsDebug() { log.Println(submitErr) } @@ -419,7 +486,7 @@ func (api *recordAuthApi) confirmVerification(c echo.Context) error { HttpContext: c, } - _, submitErr := form.Submit(func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + _, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { event.Record = record @@ -434,7 +501,9 @@ func (api *recordAuthApi) confirmVerification(c echo.Context) error { }) if submitErr == nil { - api.app.OnRecordAfterConfirmVerificationRequest().Trigger(event) + if err := api.app.OnRecordAfterConfirmVerificationRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return submitErr @@ -456,7 +525,7 @@ func (api *recordAuthApi) requestEmailChange(c echo.Context) error { Record: record, } - submitErr := form.Submit(func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { return api.app.OnRecordBeforeRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeEvent) error { if err := next(e.Record); err != nil { @@ -490,7 +559,7 @@ func (api *recordAuthApi) confirmEmailChange(c echo.Context) error { HttpContext: c, } - _, submitErr := form.Submit(func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + _, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { event.Record = record @@ -505,7 +574,9 @@ func (api *recordAuthApi) confirmEmailChange(c echo.Context) error { }) if submitErr == nil { - api.app.OnRecordAfterConfirmEmailChangeRequest().Trigger(event) + if err := api.app.OnRecordAfterConfirmEmailChangeRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return submitErr diff --git a/apis/record_auth_test.go b/apis/record_auth_test.go index ccda842c..2e9e4be9 100644 --- a/apis/record_auth_test.go +++ b/apis/record_auth_test.go @@ -100,6 +100,9 @@ func TestRecordAuthWithPassword(t *testing.T) { ExpectedContent: []string{ `"data":{}`, }, + ExpectedEvents: map[string]int{ + "OnRecordBeforeAuthWithPasswordRequest": 1, + }, }, { Name: "valid username and invalid password", @@ -113,6 +116,9 @@ func TestRecordAuthWithPassword(t *testing.T) { ExpectedContent: []string{ `"data":{}`, }, + ExpectedEvents: map[string]int{ + "OnRecordBeforeAuthWithPasswordRequest": 1, + }, }, { Name: "valid username and valid password in restricted collection", @@ -126,6 +132,9 @@ func TestRecordAuthWithPassword(t *testing.T) { ExpectedContent: []string{ `"data":{}`, }, + ExpectedEvents: map[string]int{ + "OnRecordBeforeAuthWithPasswordRequest": 1, + }, }, { Name: "valid username and valid password in allowed collection", @@ -143,7 +152,9 @@ func TestRecordAuthWithPassword(t *testing.T) { `"email":"test2@example.com"`, }, ExpectedEvents: map[string]int{ - "OnRecordAuthRequest": 1, + "OnRecordBeforeAuthWithPasswordRequest": 1, + "OnRecordAfterAuthWithPasswordRequest": 1, + "OnRecordAuthRequest": 1, }, }, @@ -160,6 +171,9 @@ func TestRecordAuthWithPassword(t *testing.T) { ExpectedContent: []string{ `"data":{}`, }, + ExpectedEvents: map[string]int{ + "OnRecordBeforeAuthWithPasswordRequest": 1, + }, }, { Name: "valid email and invalid password", @@ -173,6 +187,9 @@ func TestRecordAuthWithPassword(t *testing.T) { ExpectedContent: []string{ `"data":{}`, }, + ExpectedEvents: map[string]int{ + "OnRecordBeforeAuthWithPasswordRequest": 1, + }, }, { Name: "valid email and valid password in restricted collection", @@ -186,6 +203,9 @@ func TestRecordAuthWithPassword(t *testing.T) { ExpectedContent: []string{ `"data":{}`, }, + ExpectedEvents: map[string]int{ + "OnRecordBeforeAuthWithPasswordRequest": 1, + }, }, { Name: "valid email and valid password in allowed collection", @@ -203,7 +223,9 @@ func TestRecordAuthWithPassword(t *testing.T) { `"email":"test@example.com"`, }, ExpectedEvents: map[string]int{ - "OnRecordAuthRequest": 1, + "OnRecordBeforeAuthWithPasswordRequest": 1, + "OnRecordAfterAuthWithPasswordRequest": 1, + "OnRecordAuthRequest": 1, }, }, @@ -227,7 +249,9 @@ func TestRecordAuthWithPassword(t *testing.T) { `"email":"test@example.com"`, }, ExpectedEvents: map[string]int{ - "OnRecordAuthRequest": 1, + "OnRecordBeforeAuthWithPasswordRequest": 1, + "OnRecordAfterAuthWithPasswordRequest": 1, + "OnRecordAuthRequest": 1, }, }, { @@ -249,7 +273,9 @@ func TestRecordAuthWithPassword(t *testing.T) { `"email":"test@example.com"`, }, ExpectedEvents: map[string]int{ - "OnRecordAuthRequest": 1, + "OnRecordBeforeAuthWithPasswordRequest": 1, + "OnRecordAfterAuthWithPasswordRequest": 1, + "OnRecordAuthRequest": 1, }, }, } @@ -320,7 +346,9 @@ func TestRecordAuthRefresh(t *testing.T) { `"missing":`, }, ExpectedEvents: map[string]int{ - "OnRecordAuthRequest": 1, + "OnRecordBeforeAuthRefreshRequest": 1, + "OnRecordAuthRequest": 1, + "OnRecordAfterAuthRefreshRequest": 1, }, }, } diff --git a/apis/record_crud.go b/apis/record_crud.go index ae326332..4e5c75bc 100644 --- a/apis/record_crud.go +++ b/apis/record_crud.go @@ -224,10 +224,12 @@ func (api *recordApi) create(c echo.Context) error { } // create the record - submitErr := form.Submit(func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { + return func(m *models.Record) error { + event.Record = m + return api.app.OnRecordBeforeCreateRequest().Trigger(event, func(e *core.RecordCreateEvent) error { - if err := next(); err != nil { + if err := next(e.Record); err != nil { return NewBadRequestError("Failed to create record.", err) } @@ -241,7 +243,9 @@ func (api *recordApi) create(c echo.Context) error { }) if submitErr == nil { - api.app.OnRecordAfterCreateRequest().Trigger(event) + if err := api.app.OnRecordAfterCreateRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return submitErr @@ -308,10 +312,12 @@ func (api *recordApi) update(c echo.Context) error { } // update the record - submitErr := form.Submit(func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { + return func(m *models.Record) error { + event.Record = m + return api.app.OnRecordBeforeUpdateRequest().Trigger(event, func(e *core.RecordUpdateEvent) error { - if err := next(); err != nil { + if err := next(e.Record); err != nil { return NewBadRequestError("Failed to update record.", err) } @@ -325,7 +331,9 @@ func (api *recordApi) update(c echo.Context) error { }) if submitErr == nil { - api.app.OnRecordAfterUpdateRequest().Trigger(event) + if err := api.app.OnRecordAfterUpdateRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return submitErr @@ -382,7 +390,9 @@ func (api *recordApi) delete(c echo.Context) error { }) if handlerErr == nil { - api.app.OnRecordAfterDeleteRequest().Trigger(event) + if err := api.app.OnRecordAfterDeleteRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return handlerErr diff --git a/apis/settings.go b/apis/settings.go index 2874acd4..04477b78 100644 --- a/apis/settings.go +++ b/apis/settings.go @@ -2,12 +2,14 @@ package apis import ( "fmt" + "log" "net/http" validation "github.com/go-ozzo/ozzo-validation/v4" "github.com/labstack/echo/v5" "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/forms" + "github.com/pocketbase/pocketbase/models/settings" "github.com/pocketbase/pocketbase/tools/security" ) @@ -53,14 +55,15 @@ func (api *settingsApi) set(c echo.Context) error { event := &core.SettingsUpdateEvent{ HttpContext: c, OldSettings: api.app.Settings(), - NewSettings: form.Settings, } // update the settings - submitErr := form.Submit(func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + submitErr := form.Submit(func(next forms.InterceptorNextFunc[*settings.Settings]) forms.InterceptorNextFunc[*settings.Settings] { + return func(s *settings.Settings) error { + event.NewSettings = s + return api.app.OnSettingsBeforeUpdateRequest().Trigger(event, func(e *core.SettingsUpdateEvent) error { - if err := next(); err != nil { + if err := next(e.NewSettings); err != nil { return NewBadRequestError("An error occurred while submitting the form.", err) } @@ -75,7 +78,9 @@ func (api *settingsApi) set(c echo.Context) error { }) if submitErr == nil { - api.app.OnSettingsAfterUpdateRequest().Trigger(event) + if err := api.app.OnSettingsAfterUpdateRequest().Trigger(event); err != nil && api.app.IsDebug() { + log.Println(err) + } } return submitErr diff --git a/core/app.go b/core/app.go index 2f09e99e..6c99f09d 100644 --- a/core/app.go +++ b/core/app.go @@ -313,6 +313,50 @@ type App interface { // authenticated admin data and token. OnAdminAuthRequest() *hook.Hook[*AdminAuthEvent] + // OnAdminBeforeAuthWithPasswordRequest hook is triggered before each Admin + // auth with password API request (after request data load and before password validation). + // + // Could be used to implement for example a custom password validation + // or to locate a different Admin identity (by assigning [AdminAuthWithPasswordEvent.Admin]). + OnAdminBeforeAuthWithPasswordRequest() *hook.Hook[*AdminAuthWithPasswordEvent] + + // OnAdminAfterAuthWithPasswordRequest hook is triggered after each + // successful Admin auth with password API request. + OnAdminAfterAuthWithPasswordRequest() *hook.Hook[*AdminAuthWithPasswordEvent] + + // OnAdminBeforeAuthRefreshRequest hook is triggered before each Admin + // auth refresh API request (right before generating a new auth token). + // + // Could be used to additionally validate the request data or implement + // completely different auth refresh behavior (returning [hook.StopPropagation]). + OnAdminBeforeAuthRefreshRequest() *hook.Hook[*AdminAuthRefreshEvent] + + // OnAdminAfterAuthRefreshRequest hook is triggered after each + // successful auth refresh API request (right after generating a new auth token). + OnAdminAfterAuthRefreshRequest() *hook.Hook[*AdminAuthRefreshEvent] + + // OnAdminBeforeRequestPasswordResetRequest hook is triggered before each Admin + // request password reset API request (after request data load and before sending the reset email). + // + // Could be used to additionally validate the request data or implement + // completely different password reset behavior (returning [hook.StopPropagation]). + OnAdminBeforeRequestPasswordResetRequest() *hook.Hook[*AdminRequestPasswordResetEvent] + + // OnAdminAfterRequestPasswordResetRequest hook is triggered after each + // successful request password reset API request. + OnAdminAfterRequestPasswordResetRequest() *hook.Hook[*AdminRequestPasswordResetEvent] + + // OnAdminBeforeConfirmPasswordResetRequest hook is triggered before each Admin + // confirm password reset API request (after request data load and before persistence). + // + // Could be used to additionally validate the request data or implement + // completely different persistence behavior (returning [hook.StopPropagation]). + OnAdminBeforeConfirmPasswordResetRequest() *hook.Hook[*AdminConfirmPasswordResetEvent] + + // OnAdminAfterConfirmPasswordResetRequest hook is triggered after each + // successful confirm password reset API request. + OnAdminAfterConfirmPasswordResetRequest() *hook.Hook[*AdminConfirmPasswordResetEvent] + // --------------------------------------------------------------- // Record Auth API event hooks // --------------------------------------------------------------- @@ -324,6 +368,42 @@ type App interface { // record data and token. OnRecordAuthRequest() *hook.Hook[*RecordAuthEvent] + // OnRecordBeforeAuthWithPasswordRequest hook is triggered before each Record + // auth with password API request (after request data load and before password validation). + // + // Could be used to implement for example a custom password validation + // or to locate a different Record identity (by assigning [RecordAuthWithPasswordEvent.Record]). + OnRecordBeforeAuthWithPasswordRequest() *hook.Hook[*RecordAuthWithPasswordEvent] + + // OnRecordAfterAuthWithPasswordRequest hook is triggered after each + // successful Record auth with password API request. + OnRecordAfterAuthWithPasswordRequest() *hook.Hook[*RecordAuthWithPasswordEvent] + + // OnRecordBeforeAuthWithOAuth2Request hook is triggered before each Record + // OAuth2 sign-in/sign-up API request (after token exchange and before external provider linking). + // + // If the [RecordAuthWithOAuth2Event.Record] is nil, then the OAuth2 + // request will try to create a new auth Record. + // + // To assign or link a different existing record model you can + // overwrite/modify the [RecordAuthWithOAuth2Event.Record] field. + OnRecordBeforeAuthWithOAuth2Request() *hook.Hook[*RecordAuthWithOAuth2Event] + + // OnRecordAfterAuthWithOAuth2Request hook is triggered after each + // successful Record OAuth2 API request. + OnRecordAfterAuthWithOAuth2Request() *hook.Hook[*RecordAuthWithOAuth2Event] + + // OnRecordBeforeAuthRefreshRequest hook is triggered before each Record + // auth refresh API request (right before generating a new auth token). + // + // Could be used to additionally validate the request data or implement + // completely different auth refresh behavior (returning [hook.StopPropagation]). + OnRecordBeforeAuthRefreshRequest() *hook.Hook[*RecordAuthRefreshEvent] + + // OnRecordAfterAuthRefreshRequest hook is triggered after each + // successful auth refresh API request (right after generating a new auth token). + OnRecordAfterAuthRefreshRequest() *hook.Hook[*RecordAuthRefreshEvent] + // OnRecordBeforeRequestPasswordResetRequest hook is triggered before each Record // request password reset API request (after request data load and before sending the reset email). // diff --git a/core/base.go b/core/base.go index 5b6aaea0..2dd93a4c 100644 --- a/core/base.go +++ b/core/base.go @@ -91,18 +91,32 @@ type BaseApp struct { onFileDownloadRequest *hook.Hook[*FileDownloadEvent] // admin api event hooks - onAdminsListRequest *hook.Hook[*AdminsListEvent] - onAdminViewRequest *hook.Hook[*AdminViewEvent] - onAdminBeforeCreateRequest *hook.Hook[*AdminCreateEvent] - onAdminAfterCreateRequest *hook.Hook[*AdminCreateEvent] - onAdminBeforeUpdateRequest *hook.Hook[*AdminUpdateEvent] - onAdminAfterUpdateRequest *hook.Hook[*AdminUpdateEvent] - onAdminBeforeDeleteRequest *hook.Hook[*AdminDeleteEvent] - onAdminAfterDeleteRequest *hook.Hook[*AdminDeleteEvent] - onAdminAuthRequest *hook.Hook[*AdminAuthEvent] + onAdminsListRequest *hook.Hook[*AdminsListEvent] + onAdminViewRequest *hook.Hook[*AdminViewEvent] + onAdminBeforeCreateRequest *hook.Hook[*AdminCreateEvent] + onAdminAfterCreateRequest *hook.Hook[*AdminCreateEvent] + onAdminBeforeUpdateRequest *hook.Hook[*AdminUpdateEvent] + onAdminAfterUpdateRequest *hook.Hook[*AdminUpdateEvent] + onAdminBeforeDeleteRequest *hook.Hook[*AdminDeleteEvent] + onAdminAfterDeleteRequest *hook.Hook[*AdminDeleteEvent] + onAdminAuthRequest *hook.Hook[*AdminAuthEvent] + onAdminBeforeAuthWithPasswordRequest *hook.Hook[*AdminAuthWithPasswordEvent] + onAdminAfterAuthWithPasswordRequest *hook.Hook[*AdminAuthWithPasswordEvent] + onAdminBeforeAuthRefreshRequest *hook.Hook[*AdminAuthRefreshEvent] + onAdminAfterAuthRefreshRequest *hook.Hook[*AdminAuthRefreshEvent] + onAdminBeforeRequestPasswordResetRequest *hook.Hook[*AdminRequestPasswordResetEvent] + onAdminAfterRequestPasswordResetRequest *hook.Hook[*AdminRequestPasswordResetEvent] + onAdminBeforeConfirmPasswordResetRequest *hook.Hook[*AdminConfirmPasswordResetEvent] + onAdminAfterConfirmPasswordResetRequest *hook.Hook[*AdminConfirmPasswordResetEvent] // record auth API event hooks onRecordAuthRequest *hook.Hook[*RecordAuthEvent] + onRecordBeforeAuthWithPasswordRequest *hook.Hook[*RecordAuthWithPasswordEvent] + onRecordAfterAuthWithPasswordRequest *hook.Hook[*RecordAuthWithPasswordEvent] + onRecordBeforeAuthWithOAuth2Request *hook.Hook[*RecordAuthWithOAuth2Event] + onRecordAfterAuthWithOAuth2Request *hook.Hook[*RecordAuthWithOAuth2Event] + onRecordBeforeAuthRefreshRequest *hook.Hook[*RecordAuthRefreshEvent] + onRecordAfterAuthRefreshRequest *hook.Hook[*RecordAuthRefreshEvent] onRecordBeforeRequestPasswordResetRequest *hook.Hook[*RecordRequestPasswordResetEvent] onRecordAfterRequestPasswordResetRequest *hook.Hook[*RecordRequestPasswordResetEvent] onRecordBeforeConfirmPasswordResetRequest *hook.Hook[*RecordConfirmPasswordResetEvent] @@ -212,18 +226,32 @@ func NewBaseApp(config *BaseAppConfig) *BaseApp { onFileDownloadRequest: &hook.Hook[*FileDownloadEvent]{}, // admin API event hooks - onAdminsListRequest: &hook.Hook[*AdminsListEvent]{}, - onAdminViewRequest: &hook.Hook[*AdminViewEvent]{}, - onAdminBeforeCreateRequest: &hook.Hook[*AdminCreateEvent]{}, - onAdminAfterCreateRequest: &hook.Hook[*AdminCreateEvent]{}, - onAdminBeforeUpdateRequest: &hook.Hook[*AdminUpdateEvent]{}, - onAdminAfterUpdateRequest: &hook.Hook[*AdminUpdateEvent]{}, - onAdminBeforeDeleteRequest: &hook.Hook[*AdminDeleteEvent]{}, - onAdminAfterDeleteRequest: &hook.Hook[*AdminDeleteEvent]{}, - onAdminAuthRequest: &hook.Hook[*AdminAuthEvent]{}, + onAdminsListRequest: &hook.Hook[*AdminsListEvent]{}, + onAdminViewRequest: &hook.Hook[*AdminViewEvent]{}, + onAdminBeforeCreateRequest: &hook.Hook[*AdminCreateEvent]{}, + onAdminAfterCreateRequest: &hook.Hook[*AdminCreateEvent]{}, + onAdminBeforeUpdateRequest: &hook.Hook[*AdminUpdateEvent]{}, + onAdminAfterUpdateRequest: &hook.Hook[*AdminUpdateEvent]{}, + onAdminBeforeDeleteRequest: &hook.Hook[*AdminDeleteEvent]{}, + onAdminAfterDeleteRequest: &hook.Hook[*AdminDeleteEvent]{}, + onAdminAuthRequest: &hook.Hook[*AdminAuthEvent]{}, + onAdminBeforeAuthWithPasswordRequest: &hook.Hook[*AdminAuthWithPasswordEvent]{}, + onAdminAfterAuthWithPasswordRequest: &hook.Hook[*AdminAuthWithPasswordEvent]{}, + onAdminBeforeAuthRefreshRequest: &hook.Hook[*AdminAuthRefreshEvent]{}, + onAdminAfterAuthRefreshRequest: &hook.Hook[*AdminAuthRefreshEvent]{}, + onAdminBeforeRequestPasswordResetRequest: &hook.Hook[*AdminRequestPasswordResetEvent]{}, + onAdminAfterRequestPasswordResetRequest: &hook.Hook[*AdminRequestPasswordResetEvent]{}, + onAdminBeforeConfirmPasswordResetRequest: &hook.Hook[*AdminConfirmPasswordResetEvent]{}, + onAdminAfterConfirmPasswordResetRequest: &hook.Hook[*AdminConfirmPasswordResetEvent]{}, // record auth API event hooks onRecordAuthRequest: &hook.Hook[*RecordAuthEvent]{}, + onRecordBeforeAuthWithPasswordRequest: &hook.Hook[*RecordAuthWithPasswordEvent]{}, + onRecordAfterAuthWithPasswordRequest: &hook.Hook[*RecordAuthWithPasswordEvent]{}, + onRecordBeforeAuthWithOAuth2Request: &hook.Hook[*RecordAuthWithOAuth2Event]{}, + onRecordAfterAuthWithOAuth2Request: &hook.Hook[*RecordAuthWithOAuth2Event]{}, + onRecordBeforeAuthRefreshRequest: &hook.Hook[*RecordAuthRefreshEvent]{}, + onRecordAfterAuthRefreshRequest: &hook.Hook[*RecordAuthRefreshEvent]{}, onRecordBeforeRequestPasswordResetRequest: &hook.Hook[*RecordRequestPasswordResetEvent]{}, onRecordAfterRequestPasswordResetRequest: &hook.Hook[*RecordRequestPasswordResetEvent]{}, onRecordBeforeConfirmPasswordResetRequest: &hook.Hook[*RecordConfirmPasswordResetEvent]{}, @@ -665,6 +693,38 @@ func (app *BaseApp) OnAdminAuthRequest() *hook.Hook[*AdminAuthEvent] { return app.onAdminAuthRequest } +func (app *BaseApp) OnAdminBeforeAuthWithPasswordRequest() *hook.Hook[*AdminAuthWithPasswordEvent] { + return app.onAdminBeforeAuthWithPasswordRequest +} + +func (app *BaseApp) OnAdminAfterAuthWithPasswordRequest() *hook.Hook[*AdminAuthWithPasswordEvent] { + return app.onAdminAfterAuthWithPasswordRequest +} + +func (app *BaseApp) OnAdminBeforeAuthRefreshRequest() *hook.Hook[*AdminAuthRefreshEvent] { + return app.onAdminBeforeAuthRefreshRequest +} + +func (app *BaseApp) OnAdminAfterAuthRefreshRequest() *hook.Hook[*AdminAuthRefreshEvent] { + return app.onAdminAfterAuthRefreshRequest +} + +func (app *BaseApp) OnAdminBeforeRequestPasswordResetRequest() *hook.Hook[*AdminRequestPasswordResetEvent] { + return app.onAdminBeforeRequestPasswordResetRequest +} + +func (app *BaseApp) OnAdminAfterRequestPasswordResetRequest() *hook.Hook[*AdminRequestPasswordResetEvent] { + return app.onAdminAfterRequestPasswordResetRequest +} + +func (app *BaseApp) OnAdminBeforeConfirmPasswordResetRequest() *hook.Hook[*AdminConfirmPasswordResetEvent] { + return app.onAdminBeforeConfirmPasswordResetRequest +} + +func (app *BaseApp) OnAdminAfterConfirmPasswordResetRequest() *hook.Hook[*AdminConfirmPasswordResetEvent] { + return app.onAdminAfterConfirmPasswordResetRequest +} + // ------------------------------------------------------------------- // Record auth API event hooks // ------------------------------------------------------------------- @@ -673,6 +733,30 @@ func (app *BaseApp) OnRecordAuthRequest() *hook.Hook[*RecordAuthEvent] { return app.onRecordAuthRequest } +func (app *BaseApp) OnRecordBeforeAuthWithPasswordRequest() *hook.Hook[*RecordAuthWithPasswordEvent] { + return app.onRecordBeforeAuthWithPasswordRequest +} + +func (app *BaseApp) OnRecordAfterAuthWithPasswordRequest() *hook.Hook[*RecordAuthWithPasswordEvent] { + return app.onRecordAfterAuthWithPasswordRequest +} + +func (app *BaseApp) OnRecordBeforeAuthWithOAuth2Request() *hook.Hook[*RecordAuthWithOAuth2Event] { + return app.onRecordBeforeAuthWithOAuth2Request +} + +func (app *BaseApp) OnRecordAfterAuthWithOAuth2Request() *hook.Hook[*RecordAuthWithOAuth2Event] { + return app.onRecordAfterAuthWithOAuth2Request +} + +func (app *BaseApp) OnRecordBeforeAuthRefreshRequest() *hook.Hook[*RecordAuthRefreshEvent] { + return app.onRecordBeforeAuthRefreshRequest +} + +func (app *BaseApp) OnRecordAfterAuthRefreshRequest() *hook.Hook[*RecordAuthRefreshEvent] { + return app.onRecordAfterAuthRefreshRequest +} + func (app *BaseApp) OnRecordBeforeRequestPasswordResetRequest() *hook.Hook[*RecordRequestPasswordResetEvent] { return app.onRecordBeforeRequestPasswordResetRequest } diff --git a/core/events.go b/core/events.go index 62bf968d..e8d8321c 100644 --- a/core/events.go +++ b/core/events.go @@ -5,6 +5,7 @@ import ( "github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/models/schema" "github.com/pocketbase/pocketbase/models/settings" + "github.com/pocketbase/pocketbase/tools/auth" "github.com/pocketbase/pocketbase/tools/mailer" "github.com/pocketbase/pocketbase/tools/search" "github.com/pocketbase/pocketbase/tools/subscriptions" @@ -140,6 +141,24 @@ type RecordAuthEvent struct { Meta any } +type RecordAuthWithPasswordEvent struct { + HttpContext echo.Context + Record *models.Record + Identity string + Password string +} + +type RecordAuthWithOAuth2Event struct { + HttpContext echo.Context + Record *models.Record + OAuth2User *auth.AuthUser +} + +type RecordAuthRefreshEvent struct { + HttpContext echo.Context + Record *models.Record +} + type RecordRequestPasswordResetEvent struct { HttpContext echo.Context Record *models.Record @@ -218,6 +237,28 @@ type AdminAuthEvent struct { Token string } +type AdminAuthWithPasswordEvent struct { + HttpContext echo.Context + Admin *models.Admin + Identity string + Password string +} + +type AdminAuthRefreshEvent struct { + HttpContext echo.Context + Admin *models.Admin +} + +type AdminRequestPasswordResetEvent struct { + HttpContext echo.Context + Admin *models.Admin +} + +type AdminConfirmPasswordResetEvent struct { + HttpContext echo.Context + Admin *models.Admin +} + // ------------------------------------------------------------------- // Collection API events data // ------------------------------------------------------------------- diff --git a/forms/admin_login.go b/forms/admin_login.go index a88d1ad5..da4631a7 100644 --- a/forms/admin_login.go +++ b/forms/admin_login.go @@ -1,6 +1,7 @@ package forms import ( + "database/sql" "errors" validation "github.com/go-ozzo/ozzo-validation/v4" @@ -46,19 +47,34 @@ func (form *AdminLogin) Validate() error { // Submit validates and submits the admin form. // On success returns the authorized admin model. -func (form *AdminLogin) Submit() (*models.Admin, error) { +// +// You can optionally provide a list of InterceptorFunc to +// further modify the form behavior before persisting it. +func (form *AdminLogin) Submit(interceptors ...InterceptorFunc[*models.Admin]) (*models.Admin, error) { if err := form.Validate(); err != nil { return nil, err } - admin, err := form.dao.FindAdminByEmail(form.Identity) - if err != nil { - return nil, err + admin, fetchErr := form.dao.FindAdminByEmail(form.Identity) + + // ignore not found errors to allow custom fetch implementations + if fetchErr != nil && !errors.Is(fetchErr, sql.ErrNoRows) { + return nil, fetchErr } - if admin.ValidatePassword(form.Password) { - return admin, nil + interceptorsErr := runInterceptors(admin, func(m *models.Admin) error { + admin = m + + if admin == nil || !admin.ValidatePassword(form.Password) { + return errors.New("Invalid login credentials.") + } + + return nil + }, interceptors...) + + if interceptorsErr != nil { + return nil, interceptorsErr } - return nil, errors.New("Invalid login credentials.") + return admin, nil } diff --git a/forms/admin_login_test.go b/forms/admin_login_test.go index bd63e7c2..6ff59137 100644 --- a/forms/admin_login_test.go +++ b/forms/admin_login_test.go @@ -1,9 +1,11 @@ package forms_test import ( + "errors" "testing" "github.com/pocketbase/pocketbase/forms" + "github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tests" ) @@ -47,3 +49,48 @@ func TestAdminLoginValidateAndSubmit(t *testing.T) { } } } + +func TestAdminLoginInterceptors(t *testing.T) { + testApp, _ := tests.NewTestApp() + defer testApp.Cleanup() + + form := forms.NewAdminLogin(testApp) + form.Identity = "test@example.com" + form.Password = "123456" + var interceptorAdmin *models.Admin + testErr := errors.New("test_error") + + interceptor1Called := false + interceptor1 := func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(admin *models.Admin) error { + interceptor1Called = true + return next(admin) + } + } + + interceptor2Called := false + interceptor2 := func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(admin *models.Admin) error { + interceptorAdmin = admin + interceptor2Called = true + return testErr + } + } + + _, submitErr := form.Submit(interceptor1, interceptor2) + if submitErr != testErr { + t.Fatalf("Expected submitError %v, got %v", testErr, submitErr) + } + + if !interceptor1Called { + t.Fatalf("Expected interceptor1 to be called") + } + + if !interceptor2Called { + t.Fatalf("Expected interceptor2 to be called") + } + + if interceptorAdmin == nil || interceptorAdmin.Email != form.Identity { + t.Fatalf("Expected Admin model with email %s, got %v", form.Identity, interceptorAdmin) + } +} diff --git a/forms/admin_password_reset_confirm.go b/forms/admin_password_reset_confirm.go index 134abc3f..0c5ee11f 100644 --- a/forms/admin_password_reset_confirm.go +++ b/forms/admin_password_reset_confirm.go @@ -63,7 +63,10 @@ func (form *AdminPasswordResetConfirm) checkToken(value any) error { // Submit validates and submits the admin password reset confirmation form. // On success returns the updated admin model associated to `form.Token`. -func (form *AdminPasswordResetConfirm) Submit() (*models.Admin, error) { +// +// You can optionally provide a list of InterceptorFunc to further +// modify the form behavior before persisting it. +func (form *AdminPasswordResetConfirm) Submit(interceptors ...InterceptorFunc[*models.Admin]) (*models.Admin, error) { if err := form.Validate(); err != nil { return nil, err } @@ -80,8 +83,13 @@ func (form *AdminPasswordResetConfirm) Submit() (*models.Admin, error) { return nil, err } - if err := form.dao.SaveAdmin(admin); err != nil { - return nil, err + interceptorsErr := runInterceptors(admin, func(m *models.Admin) error { + admin = m + return form.dao.SaveAdmin(m) + }, interceptors...) + + if interceptorsErr != nil { + return nil, interceptorsErr } return admin, nil diff --git a/forms/admin_password_reset_confirm_test.go b/forms/admin_password_reset_confirm_test.go index fc825838..546b8e7c 100644 --- a/forms/admin_password_reset_confirm_test.go +++ b/forms/admin_password_reset_confirm_test.go @@ -1,9 +1,11 @@ package forms_test import ( + "errors" "testing" "github.com/pocketbase/pocketbase/forms" + "github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tools/security" ) @@ -54,7 +56,24 @@ func TestAdminPasswordResetConfirmValidateAndSubmit(t *testing.T) { form.Password = s.password form.PasswordConfirm = s.passwordConfirm - admin, err := form.Submit() + interceptorCalls := 0 + interceptor := func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(m *models.Admin) error { + interceptorCalls++ + return next(m) + } + } + + admin, err := form.Submit(interceptor) + + // check interceptor calls + expectInterceptorCalls := 1 + if s.expectError { + expectInterceptorCalls = 0 + } + if interceptorCalls != expectInterceptorCalls { + t.Errorf("[%d] Expected interceptor to be called %d, got %d", i, expectInterceptorCalls, interceptorCalls) + } hasErr := err != nil if hasErr != s.expectError { @@ -78,3 +97,54 @@ func TestAdminPasswordResetConfirmValidateAndSubmit(t *testing.T) { } } } + +func TestAdminPasswordResetConfirmInterceptors(t *testing.T) { + testApp, _ := tests.NewTestApp() + defer testApp.Cleanup() + + admin, err := testApp.Dao().FindAdminByEmail("test@example.com") + if err != nil { + t.Fatal(err) + } + + form := forms.NewAdminPasswordResetConfirm(testApp) + form.Token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MjIwODk4MTYwMH0.kwFEler6KSMKJNstuaSDvE1QnNdCta5qSnjaIQ0hhhc" + form.Password = "1234567891" + form.PasswordConfirm = "1234567891" + interceptorTokenKey := admin.TokenKey + testErr := errors.New("test_error") + + interceptor1Called := false + interceptor1 := func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(admin *models.Admin) error { + interceptor1Called = true + return next(admin) + } + } + + interceptor2Called := false + interceptor2 := func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(admin *models.Admin) error { + interceptorTokenKey = admin.TokenKey + interceptor2Called = true + return testErr + } + } + + _, submitErr := form.Submit(interceptor1, interceptor2) + if submitErr != testErr { + t.Fatalf("Expected submitError %v, got %v", testErr, submitErr) + } + + if !interceptor1Called { + t.Fatalf("Expected interceptor1 to be called") + } + + if !interceptor2Called { + t.Fatalf("Expected interceptor2 to be called") + } + + if interceptorTokenKey == admin.TokenKey { + t.Fatalf("Expected the form model to be filled before calling the interceptors") + } +} diff --git a/forms/admin_password_reset_request.go b/forms/admin_password_reset_request.go index 1abfd9d8..99edaf41 100644 --- a/forms/admin_password_reset_request.go +++ b/forms/admin_password_reset_request.go @@ -9,6 +9,7 @@ import ( "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/daos" "github.com/pocketbase/pocketbase/mails" + "github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tools/types" ) @@ -55,7 +56,10 @@ func (form *AdminPasswordResetRequest) Validate() error { // Submit validates and submits the form. // On success sends a password reset email to the `form.Email` admin. -func (form *AdminPasswordResetRequest) Submit() error { +// +// You can optionally provide a list of InterceptorFunc to further +// modify the form behavior before persisting it. +func (form *AdminPasswordResetRequest) Submit(interceptors ...InterceptorFunc[*models.Admin]) error { if err := form.Validate(); err != nil { return err } @@ -71,12 +75,14 @@ func (form *AdminPasswordResetRequest) Submit() error { return errors.New("You have already requested a password reset.") } - if err := mails.SendAdminPasswordReset(form.app, admin); err != nil { - return err - } - // update last sent timestamp admin.LastResetSentAt = types.NowDateTime() - return form.dao.SaveAdmin(admin) + return runInterceptors(admin, func(m *models.Admin) error { + if err := mails.SendAdminPasswordReset(form.app, m); err != nil { + return err + } + + return form.dao.SaveAdmin(m) + }, interceptors...) } diff --git a/forms/admin_password_reset_request_test.go b/forms/admin_password_reset_request_test.go index 0261c935..9bb4fd13 100644 --- a/forms/admin_password_reset_request_test.go +++ b/forms/admin_password_reset_request_test.go @@ -1,9 +1,11 @@ package forms_test import ( + "errors" "testing" "github.com/pocketbase/pocketbase/forms" + "github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tests" ) @@ -31,7 +33,24 @@ func TestAdminPasswordResetRequestValidateAndSubmit(t *testing.T) { adminBefore, _ := testApp.Dao().FindAdminByEmail(s.email) - err := form.Submit() + interceptorCalls := 0 + interceptor := func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(m *models.Admin) error { + interceptorCalls++ + return next(m) + } + } + + err := form.Submit(interceptor) + + // check interceptor calls + expectInterceptorCalls := 1 + if s.expectError { + expectInterceptorCalls = 0 + } + if interceptorCalls != expectInterceptorCalls { + t.Errorf("[%d] Expected interceptor to be called %d, got %d", i, expectInterceptorCalls, interceptorCalls) + } hasErr := err != nil if hasErr != s.expectError { @@ -53,3 +72,52 @@ func TestAdminPasswordResetRequestValidateAndSubmit(t *testing.T) { } } } + +func TestAdminPasswordResetRequestInterceptors(t *testing.T) { + testApp, _ := tests.NewTestApp() + defer testApp.Cleanup() + + admin, err := testApp.Dao().FindAdminByEmail("test@example.com") + if err != nil { + t.Fatal(err) + } + + form := forms.NewAdminPasswordResetRequest(testApp) + form.Email = admin.Email + interceptorLastResetSentAt := admin.LastResetSentAt + testErr := errors.New("test_error") + + interceptor1Called := false + interceptor1 := func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(admin *models.Admin) error { + interceptor1Called = true + return next(admin) + } + } + + interceptor2Called := false + interceptor2 := func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(admin *models.Admin) error { + interceptorLastResetSentAt = admin.LastResetSentAt + interceptor2Called = true + return testErr + } + } + + submitErr := form.Submit(interceptor1, interceptor2) + if submitErr != testErr { + t.Fatalf("Expected submitError %v, got %v", testErr, submitErr) + } + + if !interceptor1Called { + t.Fatalf("Expected interceptor1 to be called") + } + + if !interceptor2Called { + t.Fatalf("Expected interceptor2 to be called") + } + + if interceptorLastResetSentAt.String() == admin.LastResetSentAt.String() { + t.Fatalf("Expected the form model to be filled before calling the interceptors") + } +} diff --git a/forms/admin_upsert.go b/forms/admin_upsert.go index 1180b480..4afcb67e 100644 --- a/forms/admin_upsert.go +++ b/forms/admin_upsert.go @@ -99,7 +99,7 @@ func (form *AdminUpsert) checkUniqueEmail(value any) error { // // You can optionally provide a list of InterceptorFunc to further // modify the form behavior before persisting it. -func (form *AdminUpsert) Submit(interceptors ...InterceptorFunc) error { +func (form *AdminUpsert) Submit(interceptors ...InterceptorFunc[*models.Admin]) error { if err := form.Validate(); err != nil { return err } @@ -117,7 +117,7 @@ func (form *AdminUpsert) Submit(interceptors ...InterceptorFunc) error { form.admin.SetPassword(form.Password) } - return runInterceptors(func() error { - return form.dao.SaveAdmin(form.admin) + return runInterceptors(form.admin, func(admin *models.Admin) error { + return form.dao.SaveAdmin(admin) }, interceptors...) } diff --git a/forms/admin_upsert_test.go b/forms/admin_upsert_test.go index e92f029e..14657360 100644 --- a/forms/admin_upsert_test.go +++ b/forms/admin_upsert_test.go @@ -137,10 +137,10 @@ func TestAdminUpsertValidateAndSubmit(t *testing.T) { interceptorCalls := 0 - err := form.Submit(func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + err := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(m *models.Admin) error { interceptorCalls++ - return next() + return next(m) } }) @@ -196,16 +196,16 @@ func TestAdminUpsertSubmitInterceptors(t *testing.T) { interceptorAdminEmail := "" interceptor1Called := false - interceptor1 := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor1 := func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(m *models.Admin) error { interceptor1Called = true - return next() + return next(m) } } interceptor2Called := false - interceptor2 := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor2 := func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] { + return func(m *models.Admin) error { interceptorAdminEmail = admin.Email // to check if the record was filled interceptor2Called = true return testErr diff --git a/forms/base.go b/forms/base.go index 698feaa1..64c18883 100644 --- a/forms/base.go +++ b/forms/base.go @@ -4,8 +4,6 @@ package forms import ( "regexp" - - "github.com/pocketbase/pocketbase/models" ) // base ID value regex pattern @@ -13,32 +11,21 @@ var idRegex = regexp.MustCompile(`^[^\@\#\$\&\|\.\,\'\"\\\/\s]+$`) // InterceptorNextFunc is a interceptor handler function. // Usually used in combination with InterceptorFunc. -type InterceptorNextFunc = func() error +type InterceptorNextFunc[T any] func(t T) error // InterceptorFunc defines a single interceptor function that // will execute the provided next func handler. -type InterceptorFunc func(next InterceptorNextFunc) InterceptorNextFunc +type InterceptorFunc[T any] func(next InterceptorNextFunc[T]) InterceptorNextFunc[T] // runInterceptors executes the provided list of interceptors. -func runInterceptors(next InterceptorNextFunc, interceptors ...InterceptorFunc) error { +func runInterceptors[T any]( + data T, + next InterceptorNextFunc[T], + interceptors ...InterceptorFunc[T], +) error { for i := len(interceptors) - 1; i >= 0; i-- { next = interceptors[i](next) } - return next() -} - -// InterceptorWithRecordNextFunc is a Record interceptor handler function. -// Usually used in combination with InterceptorWithRecordFunc. -type InterceptorWithRecordNextFunc = func(record *models.Record) error - -// InterceptorWithRecordFunc defines a single Record interceptor function -// that will execute the provided next func handler. -type InterceptorWithRecordFunc func(next InterceptorWithRecordNextFunc) InterceptorWithRecordNextFunc - -// runInterceptorsWithRecord executes the provided list of Record interceptors. -func runInterceptorsWithRecord(record *models.Record, next InterceptorWithRecordNextFunc, interceptors ...InterceptorWithRecordFunc) error { - for i := len(interceptors) - 1; i >= 0; i-- { - next = interceptors[i](next) - } - return next(record) + + return next(data) } diff --git a/forms/collection_upsert.go b/forms/collection_upsert.go index 48ea4927..35c0bd38 100644 --- a/forms/collection_upsert.go +++ b/forms/collection_upsert.go @@ -345,7 +345,7 @@ func (form *CollectionUpsert) checkOptions(value any) error { // // You can optionally provide a list of InterceptorFunc to further // modify the form behavior before persisting it. -func (form *CollectionUpsert) Submit(interceptors ...InterceptorFunc) error { +func (form *CollectionUpsert) Submit(interceptors ...InterceptorFunc[*models.Collection]) error { if err := form.Validate(); err != nil { return err } @@ -377,7 +377,7 @@ func (form *CollectionUpsert) Submit(interceptors ...InterceptorFunc) error { form.collection.DeleteRule = form.DeleteRule form.collection.SetOptions(form.Options) - return runInterceptors(func() error { - return form.dao.SaveCollection(form.collection) + return runInterceptors(form.collection, func(collection *models.Collection) error { + return form.dao.SaveCollection(collection) }, interceptors...) } diff --git a/forms/collection_upsert_test.go b/forms/collection_upsert_test.go index adbe1c77..b79e8049 100644 --- a/forms/collection_upsert_test.go +++ b/forms/collection_upsert_test.go @@ -351,10 +351,10 @@ func TestCollectionUpsertValidateAndSubmit(t *testing.T) { } interceptorCalls := 0 - interceptor := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor := func(next forms.InterceptorNextFunc[*models.Collection]) forms.InterceptorNextFunc[*models.Collection] { + return func(c *models.Collection) error { interceptorCalls++ - return next() + return next(c) } } @@ -451,16 +451,16 @@ func TestCollectionUpsertSubmitInterceptors(t *testing.T) { interceptorCollectionName := "" interceptor1Called := false - interceptor1 := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor1 := func(next forms.InterceptorNextFunc[*models.Collection]) forms.InterceptorNextFunc[*models.Collection] { + return func(c *models.Collection) error { interceptor1Called = true - return next() + return next(c) } } interceptor2Called := false - interceptor2 := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor2 := func(next forms.InterceptorNextFunc[*models.Collection]) forms.InterceptorNextFunc[*models.Collection] { + return func(c *models.Collection) error { interceptorCollectionName = collection.Name // to check if the record was filled interceptor2Called = true return testErr diff --git a/forms/collections_import.go b/forms/collections_import.go index 083f4d14..b6618c66 100644 --- a/forms/collections_import.go +++ b/forms/collections_import.go @@ -56,15 +56,15 @@ func (form *CollectionsImport) Validate() error { // // You can optionally provide a list of InterceptorFunc to further // modify the form behavior before persisting it. -func (form *CollectionsImport) Submit(interceptors ...InterceptorFunc) error { +func (form *CollectionsImport) Submit(interceptors ...InterceptorFunc[[]*models.Collection]) error { if err := form.Validate(); err != nil { return err } - return runInterceptors(func() error { + return runInterceptors(form.Collections, func(collections []*models.Collection) error { return form.dao.RunInTransaction(func(txDao *daos.Dao) error { importErr := txDao.ImportCollections( - form.Collections, + collections, form.DeleteMissing, form.beforeRecordsSync, ) diff --git a/forms/collections_import_test.go b/forms/collections_import_test.go index 07611960..f96b88ca 100644 --- a/forms/collections_import_test.go +++ b/forms/collections_import_test.go @@ -404,16 +404,16 @@ func TestCollectionsImportSubmitInterceptors(t *testing.T) { testErr := errors.New("test_error") interceptor1Called := false - interceptor1 := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor1 := func(next forms.InterceptorNextFunc[[]*models.Collection]) forms.InterceptorNextFunc[[]*models.Collection] { + return func(imports []*models.Collection) error { interceptor1Called = true - return next() + return next(imports) } } interceptor2Called := false - interceptor2 := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor2 := func(next forms.InterceptorNextFunc[[]*models.Collection]) forms.InterceptorNextFunc[[]*models.Collection] { + return func(imports []*models.Collection) error { interceptor2Called = true return testErr } diff --git a/forms/record_email_change_confirm.go b/forms/record_email_change_confirm.go index 033252d8..37a7419b 100644 --- a/forms/record_email_change_confirm.go +++ b/forms/record_email_change_confirm.go @@ -114,9 +114,9 @@ func (form *RecordEmailChangeConfirm) parseToken(token string) (*models.Record, // Submit validates and submits the auth record email change confirmation form. // On success returns the updated auth record associated to `form.Token`. // -// You can optionally provide a list of InterceptorWithRecordFunc to +// You can optionally provide a list of InterceptorFunc to // further modify the form behavior before persisting it. -func (form *RecordEmailChangeConfirm) Submit(interceptors ...InterceptorWithRecordFunc) (*models.Record, error) { +func (form *RecordEmailChangeConfirm) Submit(interceptors ...InterceptorFunc[*models.Record]) (*models.Record, error) { if err := form.Validate(); err != nil { return nil, err } @@ -130,7 +130,8 @@ func (form *RecordEmailChangeConfirm) Submit(interceptors ...InterceptorWithReco authRecord.SetVerified(true) authRecord.RefreshTokenKey() // invalidate old tokens - interceptorsErr := runInterceptorsWithRecord(authRecord, func(m *models.Record) error { + interceptorsErr := runInterceptors(authRecord, func(m *models.Record) error { + authRecord = m return form.dao.SaveRecord(m) }, interceptors...) diff --git a/forms/record_email_change_confirm_test.go b/forms/record_email_change_confirm_test.go index 96152282..e64b89c0 100644 --- a/forms/record_email_change_confirm_test.go +++ b/forms/record_email_change_confirm_test.go @@ -85,7 +85,7 @@ func TestRecordEmailChangeConfirmValidateAndSubmit(t *testing.T) { } interceptorCalls := 0 - interceptor := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(r *models.Record) error { interceptorCalls++ return next(r) @@ -165,7 +165,7 @@ func TestRecordEmailChangeConfirmInterceptors(t *testing.T) { testErr := errors.New("test_error") interceptor1Called := false - interceptor1 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor1 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptor1Called = true return next(record) @@ -173,7 +173,7 @@ func TestRecordEmailChangeConfirmInterceptors(t *testing.T) { } interceptor2Called := false - interceptor2 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor2 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptorEmail = record.Email() interceptor2Called = true diff --git a/forms/record_email_change_request.go b/forms/record_email_change_request.go index 8e77932c..9f29a74e 100644 --- a/forms/record_email_change_request.go +++ b/forms/record_email_change_request.go @@ -62,14 +62,14 @@ func (form *RecordEmailChangeRequest) checkUniqueEmail(value any) error { // Submit validates and sends the change email request. // -// You can optionally provide a list of InterceptorWithRecordFunc to +// You can optionally provide a list of InterceptorFunc to // further modify the form behavior before persisting it. -func (form *RecordEmailChangeRequest) Submit(interceptors ...InterceptorWithRecordFunc) error { +func (form *RecordEmailChangeRequest) Submit(interceptors ...InterceptorFunc[*models.Record]) error { if err := form.Validate(); err != nil { return err } - return runInterceptorsWithRecord(form.record, func(m *models.Record) error { + return runInterceptors(form.record, func(m *models.Record) error { return mails.SendRecordChangeEmail(form.app, m, form.NewEmail) }, interceptors...) } diff --git a/forms/record_email_change_request_test.go b/forms/record_email_change_request_test.go index daec3ffd..4b19b503 100644 --- a/forms/record_email_change_request_test.go +++ b/forms/record_email_change_request_test.go @@ -60,7 +60,7 @@ func TestRecordEmailChangeRequestValidateAndSubmit(t *testing.T) { } interceptorCalls := 0 - interceptor := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(r *models.Record) error { interceptorCalls++ return next(r) @@ -119,7 +119,7 @@ func TestRecordEmailChangeRequestInterceptors(t *testing.T) { testErr := errors.New("test_error") interceptor1Called := false - interceptor1 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor1 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptor1Called = true return next(record) @@ -127,7 +127,7 @@ func TestRecordEmailChangeRequestInterceptors(t *testing.T) { } interceptor2Called := false - interceptor2 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor2 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptor2Called = true return testErr diff --git a/forms/record_oauth2_login.go b/forms/record_oauth2_login.go index 8eb6bc7f..988bb944 100644 --- a/forms/record_oauth2_login.go +++ b/forms/record_oauth2_login.go @@ -14,12 +14,25 @@ import ( "golang.org/x/oauth2" ) +// RecordOAuth2LoginData defines the OA +type RecordOAuth2LoginData struct { + ExternalAuth *models.ExternalAuth + Record *models.Record + OAuth2User *auth.AuthUser +} + +// BeforeOAuth2RecordCreateFunc defines a callback function that will +// be called before OAuth2 new Record creation. +type BeforeOAuth2RecordCreateFunc func(createForm *RecordUpsert, authRecord *models.Record, authUser *auth.AuthUser) error + // RecordOAuth2Login is an auth record OAuth2 login form. type RecordOAuth2Login struct { app core.App dao *daos.Dao collection *models.Collection + beforeOAuth2RecordCreateFunc BeforeOAuth2RecordCreateFunc + // Optional auth record that will be used if no external // auth relation is found (if it is from the same collection) loggedAuthRecord *models.Record @@ -62,6 +75,11 @@ func (form *RecordOAuth2Login) SetDao(dao *daos.Dao) { form.dao = dao } +// SetBeforeNewRecordCreateFunc sets a before OAuth2 record create callback handler. +func (form *RecordOAuth2Login) SetBeforeNewRecordCreateFunc(f BeforeOAuth2RecordCreateFunc) { + form.beforeOAuth2RecordCreateFunc = f +} + // Validate makes the form validatable by implementing [validation.Validatable] interface. func (form *RecordOAuth2Login) Validate() error { return validation.ValidateStruct(form, @@ -87,11 +105,14 @@ func (form *RecordOAuth2Login) checkProviderName(value any) error { // // If an auth record doesn't exist, it will make an attempt to create it // based on the fetched OAuth2 profile data via a local [RecordUpsert] form. -// You can intercept/modify the create form by setting the optional beforeCreateFuncs argument. +// You can intercept/modify the Record create form with [form.SetBeforeNewRecordCreateFunc()]. +// +// You can also optionally provide a list of InterceptorFunc to +// further modify the form behavior before persisting it. // // On success returns the authorized record model and the fetched provider's data. func (form *RecordOAuth2Login) Submit( - beforeCreateFuncs ...func(createForm *RecordUpsert, authRecord *models.Record, authUser *auth.AuthUser) error, + interceptors ...InterceptorFunc[*RecordOAuth2LoginData], ) (*models.Record, *auth.AuthUser, error) { if err := form.Validate(); err != nil { return nil, nil, err @@ -147,16 +168,37 @@ func (form *RecordOAuth2Login) Submit( authRecord, _ = form.dao.FindAuthRecordByEmail(form.collection.Id, authUser.Email) } - saveErr := form.dao.RunInTransaction(func(txDao *daos.Dao) error { - if authRecord == nil { - authRecord = models.NewRecord(form.collection) - authRecord.RefreshId() - authRecord.MarkAsNew() - createForm := NewRecordUpsert(form.app, authRecord) + interceptorData := &RecordOAuth2LoginData{ + ExternalAuth: rel, + Record: authRecord, + OAuth2User: authUser, + } + + interceptorsErr := runInterceptors(interceptorData, func(newData *RecordOAuth2LoginData) error { + return form.submit(newData) + }, interceptors...) + + if interceptorsErr != nil { + return nil, interceptorData.OAuth2User, interceptorsErr + } + + return interceptorData.Record, interceptorData.OAuth2User, nil +} + +func (form *RecordOAuth2Login) submit(data *RecordOAuth2LoginData) error { + return form.dao.RunInTransaction(func(txDao *daos.Dao) error { + if data.Record == nil { + data.Record = models.NewRecord(form.collection) + data.Record.RefreshId() + data.Record.MarkAsNew() + createForm := NewRecordUpsert(form.app, data.Record) createForm.SetFullManageAccess(true) createForm.SetDao(txDao) - if authUser.Username != "" && usernameRegex.MatchString(authUser.Username) { - createForm.Username = form.dao.SuggestUniqueAuthRecordUsername(form.collection.Id, authUser.Username) + if data.OAuth2User.Username != "" && usernameRegex.MatchString(data.OAuth2User.Username) { + createForm.Username = form.dao.SuggestUniqueAuthRecordUsername( + form.collection.Id, + data.OAuth2User.Username, + ) } // load custom data @@ -164,10 +206,10 @@ func (form *RecordOAuth2Login) Submit( // load the OAuth2 profile data as fallback if createForm.Email == "" { - createForm.Email = authUser.Email + createForm.Email = data.OAuth2User.Email } createForm.Verified = false - if createForm.Email == authUser.Email { + if createForm.Email == data.OAuth2User.Email { // mark as verified as long as it matches the OAuth2 data (even if the email is empty) createForm.Verified = true } @@ -176,11 +218,8 @@ func (form *RecordOAuth2Login) Submit( createForm.PasswordConfirm = createForm.Password } - for _, f := range beforeCreateFuncs { - if f == nil { - continue - } - if err := f(createForm, authRecord, authUser); err != nil { + if form.beforeOAuth2RecordCreateFunc != nil { + if err := form.beforeOAuth2RecordCreateFunc(createForm, data.Record, data.OAuth2User); err != nil { return err } } @@ -190,45 +229,39 @@ func (form *RecordOAuth2Login) Submit( return err } } else { - // update the existing auth record empty email if the authUser has one + // update the existing auth record empty email if the data.OAuth2User has one // (this is in case previously the auth record was created // with an OAuth2 provider that didn't return an email address) - if authRecord.Email() == "" && authUser.Email != "" { - authRecord.SetEmail(authUser.Email) - if err := txDao.SaveRecord(authRecord); err != nil { + if data.Record.Email() == "" && data.OAuth2User.Email != "" { + data.Record.SetEmail(data.OAuth2User.Email) + if err := txDao.SaveRecord(data.Record); err != nil { return err } } // update the existing auth record verified state - // (only if the auth record doesn't have an email or the auth record email match with the one in authUser) - if !authRecord.Verified() && (authRecord.Email() == "" || authRecord.Email() == authUser.Email) { - authRecord.SetVerified(true) - if err := txDao.SaveRecord(authRecord); err != nil { + // (only if the auth record doesn't have an email or the auth record email match with the one in data.OAuth2User) + if !data.Record.Verified() && (data.Record.Email() == "" || data.Record.Email() == data.OAuth2User.Email) { + data.Record.SetVerified(true) + if err := txDao.SaveRecord(data.Record); err != nil { return err } } } // create ExternalAuth relation if missing - if rel == nil { - rel = &models.ExternalAuth{ - CollectionId: authRecord.Collection().Id, - RecordId: authRecord.Id, + if data.ExternalAuth == nil { + data.ExternalAuth = &models.ExternalAuth{ + CollectionId: data.Record.Collection().Id, + RecordId: data.Record.Id, Provider: form.Provider, - ProviderId: authUser.Id, + ProviderId: data.OAuth2User.Id, } - if err := txDao.SaveExternalAuth(rel); err != nil { + if err := txDao.SaveExternalAuth(data.ExternalAuth); err != nil { return err } } return nil }) - - if saveErr != nil { - return nil, authUser, saveErr - } - - return authRecord, authUser, nil } diff --git a/forms/record_password_login.go b/forms/record_password_login.go index 2c01e9a8..85f8caae 100644 --- a/forms/record_password_login.go +++ b/forms/record_password_login.go @@ -1,6 +1,7 @@ package forms import ( + "database/sql" "errors" validation "github.com/go-ozzo/ozzo-validation/v4" @@ -48,30 +49,47 @@ func (form *RecordPasswordLogin) Validate() error { // Submit validates and submits the form. // On success returns the authorized record model. -func (form *RecordPasswordLogin) Submit() (*models.Record, error) { +// +// You can optionally provide a list of InterceptorFunc to +// further modify the form behavior before persisting it. +func (form *RecordPasswordLogin) Submit(interceptors ...InterceptorFunc[*models.Record]) (*models.Record, error) { if err := form.Validate(); err != nil { return nil, err } authOptions := form.collection.AuthOptions() - if !authOptions.AllowEmailAuth && !authOptions.AllowUsernameAuth { - return nil, errors.New("Password authentication is not allowed for the collection.") - } - - var record *models.Record + var authRecord *models.Record var fetchErr error - if authOptions.AllowEmailAuth && - (!authOptions.AllowUsernameAuth || is.EmailFormat.Validate(form.Identity) == nil) { - record, fetchErr = form.dao.FindAuthRecordByEmail(form.collection.Id, form.Identity) - } else { - record, fetchErr = form.dao.FindAuthRecordByUsername(form.collection.Id, form.Identity) + isEmail := is.EmailFormat.Validate(form.Identity) == nil + + if isEmail { + if authOptions.AllowEmailAuth { + authRecord, fetchErr = form.dao.FindAuthRecordByEmail(form.collection.Id, form.Identity) + } + } else if authOptions.AllowUsernameAuth { + authRecord, fetchErr = form.dao.FindAuthRecordByUsername(form.collection.Id, form.Identity) } - if fetchErr != nil || !record.ValidatePassword(form.Password) { - return nil, errors.New("Invalid login credentials.") + // ignore not found errors to allow custom fetch implementations + if fetchErr != nil && !errors.Is(fetchErr, sql.ErrNoRows) { + return nil, fetchErr } - return record, nil + interceptorsErr := runInterceptors(authRecord, func(m *models.Record) error { + authRecord = m + + if authRecord == nil || !authRecord.ValidatePassword(form.Password) { + return errors.New("Invalid login credentials.") + } + + return nil + }, interceptors...) + + if interceptorsErr != nil { + return nil, interceptorsErr + } + + return authRecord, nil } diff --git a/forms/record_password_login_test.go b/forms/record_password_login_test.go index c36dc72d..a7d5a173 100644 --- a/forms/record_password_login_test.go +++ b/forms/record_password_login_test.go @@ -1,13 +1,15 @@ package forms_test import ( + "errors" "testing" "github.com/pocketbase/pocketbase/forms" + "github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tests" ) -func TestRecordEmailLoginValidateAndSubmit(t *testing.T) { +func TestRecordPasswordLoginValidateAndSubmit(t *testing.T) { testApp, _ := tests.NewTestApp() defer testApp.Cleanup() @@ -128,3 +130,53 @@ func TestRecordEmailLoginValidateAndSubmit(t *testing.T) { } } } + +func TestRecordPasswordLoginInterceptors(t *testing.T) { + testApp, _ := tests.NewTestApp() + defer testApp.Cleanup() + + authCollection, err := testApp.Dao().FindCollectionByNameOrId("users") + if err != nil { + t.Fatal(err) + } + + form := forms.NewRecordPasswordLogin(testApp, authCollection) + form.Identity = "test@example.com" + form.Password = "123456" + var interceptorRecord *models.Record + testErr := errors.New("test_error") + + interceptor1Called := false + interceptor1 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { + return func(record *models.Record) error { + interceptor1Called = true + return next(record) + } + } + + interceptor2Called := false + interceptor2 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { + return func(record *models.Record) error { + interceptorRecord = record + interceptor2Called = true + return testErr + } + } + + _, submitErr := form.Submit(interceptor1, interceptor2) + if submitErr != testErr { + t.Fatalf("Expected submitError %v, got %v", testErr, submitErr) + } + + if !interceptor1Called { + t.Fatalf("Expected interceptor1 to be called") + } + + if !interceptor2Called { + t.Fatalf("Expected interceptor2 to be called") + } + + if interceptorRecord == nil || interceptorRecord.Email() != form.Identity { + t.Fatalf("Expected auth Record model with email %s, got %v", form.Identity, interceptorRecord) + } +} diff --git a/forms/record_password_reset_confirm.go b/forms/record_password_reset_confirm.go index 89722c5d..e79a21f0 100644 --- a/forms/record_password_reset_confirm.go +++ b/forms/record_password_reset_confirm.go @@ -72,9 +72,9 @@ func (form *RecordPasswordResetConfirm) checkToken(value any) error { // Submit validates and submits the form. // On success returns the updated auth record associated to `form.Token`. // -// You can optionally provide a list of InterceptorWithRecordFunc to -// further modify the form behavior before persisting it. -func (form *RecordPasswordResetConfirm) Submit(interceptors ...InterceptorWithRecordFunc) (*models.Record, error) { +// You can optionally provide a list of InterceptorFunc to further +// modify the form behavior before persisting it. +func (form *RecordPasswordResetConfirm) Submit(interceptors ...InterceptorFunc[*models.Record]) (*models.Record, error) { if err := form.Validate(); err != nil { return nil, err } @@ -91,7 +91,8 @@ func (form *RecordPasswordResetConfirm) Submit(interceptors ...InterceptorWithRe return nil, err } - interceptorsErr := runInterceptorsWithRecord(authRecord, func(m *models.Record) error { + interceptorsErr := runInterceptors(authRecord, func(m *models.Record) error { + authRecord = m return form.dao.SaveRecord(m) }, interceptors...) diff --git a/forms/record_password_reset_confirm_test.go b/forms/record_password_reset_confirm_test.go index 543a9c9c..2012e346 100644 --- a/forms/record_password_reset_confirm_test.go +++ b/forms/record_password_reset_confirm_test.go @@ -79,7 +79,7 @@ func TestRecordPasswordResetConfirmValidateAndSubmit(t *testing.T) { } interceptorCalls := 0 - interceptor := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(r *models.Record) error { interceptorCalls++ return next(r) @@ -157,7 +157,7 @@ func TestRecordPasswordResetConfirmInterceptors(t *testing.T) { testErr := errors.New("test_error") interceptor1Called := false - interceptor1 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor1 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptor1Called = true return next(record) @@ -165,7 +165,7 @@ func TestRecordPasswordResetConfirmInterceptors(t *testing.T) { } interceptor2Called := false - interceptor2 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor2 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptorTokenKey = record.TokenKey() interceptor2Called = true diff --git a/forms/record_password_reset_request.go b/forms/record_password_reset_request.go index 9b5c4d1a..2dc900af 100644 --- a/forms/record_password_reset_request.go +++ b/forms/record_password_reset_request.go @@ -60,9 +60,9 @@ func (form *RecordPasswordResetRequest) Validate() error { // Submit validates and submits the form. // On success, sends a password reset email to the `form.Email` auth record. // -// You can optionally provide a list of InterceptorWithRecordFunc to -// further modify the form behavior before persisting it. -func (form *RecordPasswordResetRequest) Submit(interceptors ...InterceptorWithRecordFunc) error { +// You can optionally provide a list of InterceptorFunc to further +// modify the form behavior before persisting it. +func (form *RecordPasswordResetRequest) Submit(interceptors ...InterceptorFunc[*models.Record]) error { if err := form.Validate(); err != nil { return err } @@ -81,7 +81,7 @@ func (form *RecordPasswordResetRequest) Submit(interceptors ...InterceptorWithRe // update last sent timestamp authRecord.Set(schema.FieldNameLastResetSentAt, types.NowDateTime()) - return runInterceptorsWithRecord(authRecord, func(m *models.Record) error { + return runInterceptors(authRecord, func(m *models.Record) error { if err := mails.SendRecordPasswordReset(form.app, m); err != nil { return err } diff --git a/forms/record_password_reset_request_test.go b/forms/record_password_reset_request_test.go index ff0db1fa..a81ec47c 100644 --- a/forms/record_password_reset_request_test.go +++ b/forms/record_password_reset_request_test.go @@ -67,7 +67,7 @@ func TestRecordPasswordResetRequestSubmit(t *testing.T) { } interceptorCalls := 0 - interceptor := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(r *models.Record) error { interceptorCalls++ return next(r) @@ -135,7 +135,7 @@ func TestRecordPasswordResetRequestInterceptors(t *testing.T) { testErr := errors.New("test_error") interceptor1Called := false - interceptor1 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor1 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptor1Called = true return next(record) @@ -143,7 +143,7 @@ func TestRecordPasswordResetRequestInterceptors(t *testing.T) { } interceptor2Called := false - interceptor2 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor2 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptorLastResetSentAt = record.LastResetSentAt() interceptor2Called = true diff --git a/forms/record_upsert.go b/forms/record_upsert.go index 4a8e8020..7334b7a5 100644 --- a/forms/record_upsert.go +++ b/forms/record_upsert.go @@ -718,12 +718,14 @@ func (form *RecordUpsert) DrySubmit(callback func(txDao *daos.Dao) error) error // // You can optionally provide a list of InterceptorFunc to further // modify the form behavior before persisting it. -func (form *RecordUpsert) Submit(interceptors ...InterceptorFunc) error { +func (form *RecordUpsert) Submit(interceptors ...InterceptorFunc[*models.Record]) error { if err := form.ValidateAndFill(); err != nil { return err } - return runInterceptors(func() error { + return runInterceptors(form.record, func(record *models.Record) error { + form.record = record + if !form.record.HasId() { form.record.RefreshId() form.record.MarkAsNew() diff --git a/forms/record_upsert_test.go b/forms/record_upsert_test.go index 80f08def..b4ef6522 100644 --- a/forms/record_upsert_test.go +++ b/forms/record_upsert_test.go @@ -428,10 +428,10 @@ func TestRecordUpsertSubmitFailure(t *testing.T) { form.LoadRequest(req, "") interceptorCalls := 0 - interceptor := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { + return func(r *models.Record) error { interceptorCalls++ - return next() + return next(r) } } @@ -505,10 +505,10 @@ func TestRecordUpsertSubmitSuccess(t *testing.T) { form.LoadRequest(req, "") interceptorCalls := 0 - interceptor := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { + return func(r *models.Record) error { interceptorCalls++ - return next() + return next(r) } } @@ -566,16 +566,16 @@ func TestRecordUpsertSubmitInterceptors(t *testing.T) { interceptorRecordTitle := "" interceptor1Called := false - interceptor1 := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor1 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { + return func(r *models.Record) error { interceptor1Called = true - return next() + return next(r) } } interceptor2Called := false - interceptor2 := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor2 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { + return func(r *models.Record) error { interceptorRecordTitle = record.GetString("title") // to check if the record was filled interceptor2Called = true return testErr diff --git a/forms/record_verification_confirm.go b/forms/record_verification_confirm.go index a4f15f46..2d0f7ad5 100644 --- a/forms/record_verification_confirm.go +++ b/forms/record_verification_confirm.go @@ -77,9 +77,9 @@ func (form *RecordVerificationConfirm) checkToken(value any) error { // Submit validates and submits the form. // On success returns the verified auth record associated to `form.Token`. // -// You can optionally provide a list of InterceptorWithRecordFunc to -// further modify the form behavior before persisting it. -func (form *RecordVerificationConfirm) Submit(interceptors ...InterceptorWithRecordFunc) (*models.Record, error) { +// You can optionally provide a list of InterceptorFunc to further +// modify the form behavior before persisting it. +func (form *RecordVerificationConfirm) Submit(interceptors ...InterceptorFunc[*models.Record]) (*models.Record, error) { if err := form.Validate(); err != nil { return nil, err } @@ -98,7 +98,9 @@ func (form *RecordVerificationConfirm) Submit(interceptors ...InterceptorWithRec record.SetVerified(true) } - interceptorsErr := runInterceptorsWithRecord(record, func(m *models.Record) error { + interceptorsErr := runInterceptors(record, func(m *models.Record) error { + record = m + if wasVerified { return nil // already verified } diff --git a/forms/record_verification_confirm_test.go b/forms/record_verification_confirm_test.go index d927f5b7..487b9b35 100644 --- a/forms/record_verification_confirm_test.go +++ b/forms/record_verification_confirm_test.go @@ -57,7 +57,7 @@ func TestRecordVerificationConfirmValidateAndSubmit(t *testing.T) { } interceptorCalls := 0 - interceptor := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(r *models.Record) error { interceptorCalls++ return next(r) @@ -117,7 +117,7 @@ func TestRecordVerificationConfirmInterceptors(t *testing.T) { testErr := errors.New("test_error") interceptor1Called := false - interceptor1 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor1 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptor1Called = true return next(record) @@ -125,7 +125,7 @@ func TestRecordVerificationConfirmInterceptors(t *testing.T) { } interceptor2Called := false - interceptor2 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor2 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptorVerified = record.Verified() interceptor2Called = true diff --git a/forms/record_verification_request.go b/forms/record_verification_request.go index 3868d610..7c1686bb 100644 --- a/forms/record_verification_request.go +++ b/forms/record_verification_request.go @@ -60,9 +60,9 @@ func (form *RecordVerificationRequest) Validate() error { // Submit validates and sends a verification request email // to the `form.Email` auth record. // -// You can optionally provide a list of InterceptorWithRecordFunc to -// further modify the form behavior before persisting it. -func (form *RecordVerificationRequest) Submit(interceptors ...InterceptorWithRecordFunc) error { +// You can optionally provide a list of InterceptorFunc to further +// modify the form behavior before persisting it. +func (form *RecordVerificationRequest) Submit(interceptors ...InterceptorFunc[*models.Record]) error { if err := form.Validate(); err != nil { return err } @@ -87,7 +87,7 @@ func (form *RecordVerificationRequest) Submit(interceptors ...InterceptorWithRec record.SetLastVerificationSentAt(types.NowDateTime()) } - return runInterceptorsWithRecord(record, func(m *models.Record) error { + return runInterceptors(record, func(m *models.Record) error { if m.Verified() { return nil // already verified } diff --git a/forms/record_verification_request_test.go b/forms/record_verification_request_test.go index 82a6dc25..797fd8ad 100644 --- a/forms/record_verification_request_test.go +++ b/forms/record_verification_request_test.go @@ -85,7 +85,7 @@ func TestRecordVerificationRequestSubmit(t *testing.T) { } interceptorCalls := 0 - interceptor := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(r *models.Record) error { interceptorCalls++ return next(r) @@ -153,7 +153,7 @@ func TestRecordVerificationRequestInterceptors(t *testing.T) { testErr := errors.New("test_error") interceptor1Called := false - interceptor1 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor1 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptor1Called = true return next(record) @@ -161,7 +161,7 @@ func TestRecordVerificationRequestInterceptors(t *testing.T) { } interceptor2Called := false - interceptor2 := func(next forms.InterceptorWithRecordNextFunc) forms.InterceptorWithRecordNextFunc { + interceptor2 := func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] { return func(record *models.Record) error { interceptorLastVerificationSentAt = record.LastVerificationSentAt() interceptor2Called = true diff --git a/forms/settings_upsert.go b/forms/settings_upsert.go index 6864d1c8..7ade4583 100644 --- a/forms/settings_upsert.go +++ b/forms/settings_upsert.go @@ -50,12 +50,14 @@ func (form *SettingsUpsert) Validate() error { // // You can optionally provide a list of InterceptorFunc to further // modify the form behavior before persisting it. -func (form *SettingsUpsert) Submit(interceptors ...InterceptorFunc) error { +func (form *SettingsUpsert) Submit(interceptors ...InterceptorFunc[*settings.Settings]) error { if err := form.Validate(); err != nil { return err } - return runInterceptors(func() error { + return runInterceptors(form.Settings, func(s *settings.Settings) error { + form.Settings = s + encryptionKey := os.Getenv(form.app.EncryptionEnv()) if err := form.dao.SaveSettings(form.Settings, encryptionKey); err != nil { return err diff --git a/forms/settings_upsert_test.go b/forms/settings_upsert_test.go index 494545c0..53041ffa 100644 --- a/forms/settings_upsert_test.go +++ b/forms/settings_upsert_test.go @@ -8,6 +8,7 @@ import ( validation "github.com/go-ozzo/ozzo-validation/v4" "github.com/pocketbase/pocketbase/forms" + "github.com/pocketbase/pocketbase/models/settings" "github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tools/security" ) @@ -78,10 +79,10 @@ func TestSettingsUpsertValidateAndSubmit(t *testing.T) { } interceptorCalls := 0 - interceptor := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor := func(next forms.InterceptorNextFunc[*settings.Settings]) forms.InterceptorNextFunc[*settings.Settings] { + return func(s *settings.Settings) error { interceptorCalls++ - return next() + return next(s) } } @@ -135,16 +136,16 @@ func TestSettingsUpsertSubmitInterceptors(t *testing.T) { testErr := errors.New("test_error") interceptor1Called := false - interceptor1 := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor1 := func(next forms.InterceptorNextFunc[*settings.Settings]) forms.InterceptorNextFunc[*settings.Settings] { + return func(s *settings.Settings) error { interceptor1Called = true - return next() + return next(s) } } interceptor2Called := false - interceptor2 := func(next forms.InterceptorNextFunc) forms.InterceptorNextFunc { - return func() error { + interceptor2 := func(next forms.InterceptorNextFunc[*settings.Settings]) forms.InterceptorNextFunc[*settings.Settings] { + return func(s *settings.Settings) error { interceptor2Called = true return testErr } diff --git a/forms/validators/model.go b/forms/validators/model.go index 928207d5..337972a7 100644 --- a/forms/validators/model.go +++ b/forms/validators/model.go @@ -29,7 +29,7 @@ func UniqueId(dao *daos.Dao, tableName string) validation.RuleFunc { Limit(1). Row(&foundId) - if !errors.Is(err, sql.ErrNoRows) || foundId != "" { + if (err != nil && !errors.Is(err, sql.ErrNoRows)) || foundId != "" { return validation.NewError("validation_invalid_id", "The model id is invalid or already exists.") } diff --git a/tests/app.go b/tests/app.go index 1b4d33f5..daaca1a9 100644 --- a/tests/app.go +++ b/tests/app.go @@ -182,6 +182,30 @@ func NewTestApp(optTestDataDir ...string) (*TestApp, error) { return t.registerEventCall("OnRecordAuthRequest") }) + t.OnRecordBeforeAuthWithPasswordRequest().Add(func(e *core.RecordAuthWithPasswordEvent) error { + return t.registerEventCall("OnRecordBeforeAuthWithPasswordRequest") + }) + + t.OnRecordAfterAuthWithPasswordRequest().Add(func(e *core.RecordAuthWithPasswordEvent) error { + return t.registerEventCall("OnRecordAfterAuthWithPasswordRequest") + }) + + t.OnRecordBeforeAuthWithOAuth2Request().Add(func(e *core.RecordAuthWithOAuth2Event) error { + return t.registerEventCall("OnRecordBeforeAuthWithOAuth2Request") + }) + + t.OnRecordAfterAuthWithOAuth2Request().Add(func(e *core.RecordAuthWithOAuth2Event) error { + return t.registerEventCall("OnRecordAfterAuthWithOAuth2Request") + }) + + t.OnRecordBeforeAuthRefreshRequest().Add(func(e *core.RecordAuthRefreshEvent) error { + return t.registerEventCall("OnRecordBeforeAuthRefreshRequest") + }) + + t.OnRecordAfterAuthRefreshRequest().Add(func(e *core.RecordAuthRefreshEvent) error { + return t.registerEventCall("OnRecordAfterAuthRefreshRequest") + }) + t.OnRecordBeforeRequestPasswordResetRequest().Add(func(e *core.RecordRequestPasswordResetEvent) error { return t.registerEventCall("OnRecordBeforeRequestPasswordResetRequest") }) @@ -386,6 +410,38 @@ func NewTestApp(optTestDataDir ...string) (*TestApp, error) { return t.registerEventCall("OnAdminAuthRequest") }) + t.OnAdminBeforeAuthWithPasswordRequest().Add(func(e *core.AdminAuthWithPasswordEvent) error { + return t.registerEventCall("OnAdminBeforeAuthWithPasswordRequest") + }) + + t.OnAdminAfterAuthWithPasswordRequest().Add(func(e *core.AdminAuthWithPasswordEvent) error { + return t.registerEventCall("OnAdminAfterAuthWithPasswordRequest") + }) + + t.OnAdminBeforeAuthRefreshRequest().Add(func(e *core.AdminAuthRefreshEvent) error { + return t.registerEventCall("OnAdminBeforeAuthRefreshRequest") + }) + + t.OnAdminAfterAuthRefreshRequest().Add(func(e *core.AdminAuthRefreshEvent) error { + return t.registerEventCall("OnAdminAfterAuthRefreshRequest") + }) + + t.OnAdminBeforeRequestPasswordResetRequest().Add(func(e *core.AdminRequestPasswordResetEvent) error { + return t.registerEventCall("OnAdminBeforeRequestPasswordResetRequest") + }) + + t.OnAdminAfterRequestPasswordResetRequest().Add(func(e *core.AdminRequestPasswordResetEvent) error { + return t.registerEventCall("OnAdminAfterRequestPasswordResetRequest") + }) + + t.OnAdminBeforeConfirmPasswordResetRequest().Add(func(e *core.AdminConfirmPasswordResetEvent) error { + return t.registerEventCall("OnAdminBeforeConfirmPasswordResetRequest") + }) + + t.OnAdminAfterConfirmPasswordResetRequest().Add(func(e *core.AdminConfirmPasswordResetEvent) error { + return t.registerEventCall("OnAdminAfterConfirmPasswordResetRequest") + }) + t.OnFileDownloadRequest().Add(func(e *core.FileDownloadEvent) error { return t.registerEventCall("OnFileDownloadRequest") })