1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-07-17 11:07:42 +02:00

merge v0.23.0-rc changes

This commit is contained in:
Gani Georgiev
2024-09-29 19:23:19 +03:00
parent ad92992324
commit 844f18cac3
753 changed files with 85141 additions and 63396 deletions

2
.github/SECURITY.md vendored
View File

@ -2,4 +2,4 @@
If you discover a security vulnerability within PocketBase, please send an e-mail to **support at pocketbase.io**.
All reports will be promptly addressed, and you'll be credited accordingly.
All reports will be promptly addressed and you'll be credited in the fix release notes.

View File

@ -7,7 +7,14 @@ on:
jobs:
goreleaser:
runs-on: ubuntu-latest
env:
flags: ""
steps:
# re-enable auto-snapshot from goreleaser-action@v3
# (https://github.com/goreleaser/goreleaser-action-v4-auto-snapshot-example)
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
with:
@ -16,12 +23,12 @@ jobs:
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: 20.11.0
node-version: 20.17.0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '>=1.22.5'
go-version: '>=1.23.0'
# This step usually is not needed because the /ui/dist is pregenerated locally
# but its here to ensure that each release embeds the latest admin ui artifacts.
@ -36,19 +43,14 @@ jobs:
# - name: Generate jsvm types
# run: go run ./plugins/jsvm/internal/types/types.go
# The prebuilt golangci-lint doesn't support go 1.18+ yet
# https://github.com/golangci/golangci-lint/issues/2649
# - name: Run linter
# uses: golangci/golangci-lint-action@v3
- name: Run tests
run: go test ./...
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v3
uses: goreleaser/goreleaser-action@v6
with:
distribution: goreleaser
version: latest
args: release --clean
version: '~> v2'
args: release --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,3 +1,5 @@
version: 2
project_name: pocketbase
dist: .builds
@ -58,7 +60,7 @@ checksum:
name_template: 'checksums.txt'
snapshot:
name_template: '{{ incpatch .Version }}-next'
version_template: '{{ incpatch .Version }}-next'
changelog:
sort: asc

View File

@ -1,3 +1,8 @@
## v0.23.0-RC (WIP)
...
## v0.22.21
- Lock the logs database during backup to prevent `database disk image is malformed` errors in case there is a log write running in the background ([#5541](https://github.com/pocketbase/pocketbase/discussions/5541)).

View File

@ -10,7 +10,7 @@
<a href="https://pkg.go.dev/github.com/pocketbase/pocketbase" target="_blank" rel="noopener"><img src="https://godoc.org/github.com/pocketbase/pocketbase?status.svg" alt="Go package documentation" /></a>
</p>
[PocketBase](https://pocketbase.io) is an open source Go backend, consisting of:
[PocketBase](https://pocketbase.io) is an open source Go backend that includes:
- embedded database (_SQLite_) with **realtime subscriptions**
- built-in **files and users management**
@ -46,7 +46,7 @@ your own custom app specific business logic and still have a single portable exe
Here is a minimal example:
0. [Install Go 1.21+](https://go.dev/doc/install) (_if you haven't already_)
0. [Install Go 1.23+](https://go.dev/doc/install) (_if you haven't already_)
1. Create a new project directory with the following `main.go` file inside it:
```go
@ -56,29 +56,20 @@ Here is a minimal example:
"log"
"net/http"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
)
func main() {
app := pocketbase.New()
app.OnBeforeServe().Add(func(e *core.ServeEvent) error {
// add new "GET /hello" route to the app router (echo)
e.Router.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/hello",
Handler: func(c echo.Context) error {
return c.String(200, "Hello world!")
},
Middlewares: []echo.MiddlewareFunc{
apis.ActivityLogger(app),
},
app.OnServe().BindFunc(func(se *core.ServeEvent) error {
// registers new "GET /hello" route
se.Router.Get("/hello", func(re *core.RequestEvent) error {
return re.String(200, "Hello world!")
})
return nil
return se.Next()
})
if err := app.Start(); err != nil {
@ -145,7 +136,7 @@ Check also the [Testing guide](http://pocketbase.io/docs/testing) to learn how t
If you discover a security vulnerability within PocketBase, please send an e-mail to **support at pocketbase.io**.
All reports will be promptly addressed, and you'll be credited accordingly.
All reports will be promptly addressed and you'll be credited in the fix release notes.
## Contributing

View File

@ -1,353 +0,0 @@
package apis
import (
"net/http"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/search"
)
// bindAdminApi registers the admin api endpoints and the corresponding handlers.
func bindAdminApi(app core.App, rg *echo.Group) {
api := adminApi{app: app}
subGroup := rg.Group("/admins", ActivityLogger(app))
subGroup.POST("/auth-with-password", api.authWithPassword)
subGroup.POST("/request-password-reset", api.requestPasswordReset)
subGroup.POST("/confirm-password-reset", api.confirmPasswordReset)
subGroup.POST("/auth-refresh", api.authRefresh, RequireAdminAuth())
subGroup.GET("", api.list, RequireAdminAuth())
subGroup.POST("", api.create, RequireAdminAuthOnlyIfAny(app))
subGroup.GET("/:id", api.view, RequireAdminAuth())
subGroup.PATCH("/:id", api.update, RequireAdminAuth())
subGroup.DELETE("/:id", api.delete, RequireAdminAuth())
}
type adminApi struct {
app core.App
}
func (api *adminApi) authResponse(c echo.Context, admin *models.Admin, finalizers ...func(token string) error) error {
token, tokenErr := tokens.NewAdminAuthToken(api.app, admin)
if tokenErr != nil {
return NewBadRequestError("Failed to create auth token.", tokenErr)
}
for _, f := range finalizers {
if err := f(token); err != nil {
return err
}
}
event := new(core.AdminAuthEvent)
event.HttpContext = c
event.Admin = admin
event.Token = token
return api.app.OnAdminAuthRequest().Trigger(event, func(e *core.AdminAuthEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(200, map[string]any{
"token": e.Token,
"admin": e.Admin,
})
})
}
func (api *adminApi) authRefresh(c echo.Context) error {
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
if admin == nil {
return NewNotFoundError("Missing auth admin context.", nil)
}
event := new(core.AdminAuthRefreshEvent)
event.HttpContext = c
event.Admin = admin
return api.app.OnAdminBeforeAuthRefreshRequest().Trigger(event, func(e *core.AdminAuthRefreshEvent) error {
return api.app.OnAdminAfterAuthRefreshRequest().Trigger(event, func(e *core.AdminAuthRefreshEvent) error {
return api.authResponse(e.HttpContext, e.Admin)
})
})
}
func (api *adminApi) authWithPassword(c echo.Context) error {
form := forms.NewAdminLogin(api.app)
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
event := new(core.AdminAuthWithPasswordEvent)
event.HttpContext = c
event.Password = form.Password
event.Identity = form.Identity
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
return func(admin *models.Admin) error {
event.Admin = admin
return api.app.OnAdminBeforeAuthWithPasswordRequest().Trigger(event, func(e *core.AdminAuthWithPasswordEvent) error {
if err := next(e.Admin); err != nil {
return NewBadRequestError("Failed to authenticate.", err)
}
return api.app.OnAdminAfterAuthWithPasswordRequest().Trigger(event, func(e *core.AdminAuthWithPasswordEvent) error {
return api.authResponse(e.HttpContext, e.Admin)
})
})
}
})
return submitErr
}
func (api *adminApi) requestPasswordReset(c echo.Context) error {
form := forms.NewAdminPasswordResetRequest(api.app)
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
if err := form.Validate(); err != nil {
return NewBadRequestError("An error occurred while validating the form.", err)
}
event := new(core.AdminRequestPasswordResetEvent)
event.HttpContext = c
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
return func(Admin *models.Admin) error {
event.Admin = Admin
return api.app.OnAdminBeforeRequestPasswordResetRequest().Trigger(event, func(e *core.AdminRequestPasswordResetEvent) error {
// run in background because we don't need to show the result to the client
routine.FireAndForget(func() {
if err := next(e.Admin); err != nil {
api.app.Logger().Error("Failed to send admin password reset request.", "error", err)
}
})
return api.app.OnAdminAfterRequestPasswordResetRequest().Trigger(event, func(e *core.AdminRequestPasswordResetEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
// eagerly write 204 response and skip submit errors
// as a measure against admins enumeration
if !c.Response().Committed {
c.NoContent(http.StatusNoContent)
}
return submitErr
}
func (api *adminApi) confirmPasswordReset(c echo.Context) error {
form := forms.NewAdminPasswordResetConfirm(api.app)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.AdminConfirmPasswordResetEvent)
event.HttpContext = c
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
return func(admin *models.Admin) error {
event.Admin = admin
return api.app.OnAdminBeforeConfirmPasswordResetRequest().Trigger(event, func(e *core.AdminConfirmPasswordResetEvent) error {
if err := next(e.Admin); err != nil {
return NewBadRequestError("Failed to set new password.", err)
}
return api.app.OnAdminAfterConfirmPasswordResetRequest().Trigger(event, func(e *core.AdminConfirmPasswordResetEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
return submitErr
}
func (api *adminApi) list(c echo.Context) error {
fieldResolver := search.NewSimpleFieldResolver(
"id", "created", "updated", "name", "email",
)
admins := []*models.Admin{}
result, err := search.NewProvider(fieldResolver).
Query(api.app.Dao().AdminQuery()).
ParseAndExec(c.QueryParams().Encode(), &admins)
if err != nil {
return NewBadRequestError("", err)
}
event := new(core.AdminsListEvent)
event.HttpContext = c
event.Admins = admins
event.Result = result
return api.app.OnAdminsListRequest().Trigger(event, func(e *core.AdminsListEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Result)
})
}
func (api *adminApi) view(c echo.Context) error {
id := c.PathParam("id")
if id == "" {
return NewNotFoundError("", nil)
}
admin, err := api.app.Dao().FindAdminById(id)
if err != nil || admin == nil {
return NewNotFoundError("", err)
}
event := new(core.AdminViewEvent)
event.HttpContext = c
event.Admin = admin
return api.app.OnAdminViewRequest().Trigger(event, func(e *core.AdminViewEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Admin)
})
}
func (api *adminApi) create(c echo.Context) error {
admin := &models.Admin{}
form := forms.NewAdminUpsert(api.app, admin)
// load request
if err := c.Bind(form); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.AdminCreateEvent)
event.HttpContext = c
event.Admin = admin
// create the admin
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
return func(m *models.Admin) error {
event.Admin = m
return api.app.OnAdminBeforeCreateRequest().Trigger(event, func(e *core.AdminCreateEvent) error {
if err := next(e.Admin); err != nil {
return NewBadRequestError("Failed to create admin.", err)
}
return api.app.OnAdminAfterCreateRequest().Trigger(event, func(e *core.AdminCreateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Admin)
})
})
}
})
return submitErr
}
func (api *adminApi) update(c echo.Context) error {
id := c.PathParam("id")
if id == "" {
return NewNotFoundError("", nil)
}
admin, err := api.app.Dao().FindAdminById(id)
if err != nil || admin == nil {
return NewNotFoundError("", err)
}
form := forms.NewAdminUpsert(api.app, admin)
// load request
if err := c.Bind(form); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.AdminUpdateEvent)
event.HttpContext = c
event.Admin = admin
// update the admin
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Admin]) forms.InterceptorNextFunc[*models.Admin] {
return func(m *models.Admin) error {
event.Admin = m
return api.app.OnAdminBeforeUpdateRequest().Trigger(event, func(e *core.AdminUpdateEvent) error {
if err := next(e.Admin); err != nil {
return NewBadRequestError("Failed to update admin.", err)
}
return api.app.OnAdminAfterUpdateRequest().Trigger(event, func(e *core.AdminUpdateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Admin)
})
})
}
})
return submitErr
}
func (api *adminApi) delete(c echo.Context) error {
id := c.PathParam("id")
if id == "" {
return NewNotFoundError("", nil)
}
admin, err := api.app.Dao().FindAdminById(id)
if err != nil || admin == nil {
return NewNotFoundError("", err)
}
event := new(core.AdminDeleteEvent)
event.HttpContext = c
event.Admin = admin
return api.app.OnAdminBeforeDeleteRequest().Trigger(event, func(e *core.AdminDeleteEvent) error {
if err := api.app.Dao().DeleteAdmin(e.Admin); err != nil {
return NewBadRequestError("Failed to delete admin.", err)
}
return api.app.OnAdminAfterDeleteRequest().Trigger(event, func(e *core.AdminDeleteEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}

View File

@ -1,925 +0,0 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
"time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestAdminAuthWithPassword(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"identity":{"code":"validation_required","message":"Cannot be blank."},"password":{"code":"validation_required","message":"Cannot be blank."}}`},
},
{
Name: "invalid data",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "wrong email",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{"identity":"missing@example.com","password":"1234567890"}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthWithPasswordRequest": 1,
},
},
{
Name: "wrong password",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"invalid"}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthWithPasswordRequest": 1,
},
},
{
Name: "valid email/password (guest)",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"admin":{"id":"sywbhecnh46rhm0"`,
`"token":`,
},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthWithPasswordRequest": 1,
"OnAdminAfterAuthWithPasswordRequest": 1,
"OnAdminAuthRequest": 1,
},
},
{
Name: "valid email/password (already authorized)",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4MTYwMH0.han3_sG65zLddpcX2ic78qgy7FKecuPfOpFa8Dvi5Bg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"admin":{"id":"sywbhecnh46rhm0"`,
`"token":`,
},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthWithPasswordRequest": 1,
"OnAdminAfterAuthWithPasswordRequest": 1,
"OnAdminAuthRequest": 1,
},
},
{
Name: "OnAdminAfterAuthWithPasswordRequest error response",
Method: http.MethodPost,
Url: "/api/admins/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4MTYwMH0.han3_sG65zLddpcX2ic78qgy7FKecuPfOpFa8Dvi5Bg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterAuthWithPasswordRequest().Add(func(e *core.AdminAuthWithPasswordEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthWithPasswordRequest": 1,
"OnAdminAfterAuthWithPasswordRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminRequestPasswordReset(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
Url: "/api/admins/request-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
},
{
Name: "invalid data",
Method: http.MethodPost,
Url: "/api/admins/request-password-reset",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "missing admin",
Method: http.MethodPost,
Url: "/api/admins/request-password-reset",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
},
{
Name: "existing admin",
Method: http.MethodPost,
Url: "/api/admins/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnMailerBeforeAdminResetPasswordSend": 1,
"OnMailerAfterAdminResetPasswordSend": 1,
"OnAdminBeforeRequestPasswordResetRequest": 1,
"OnAdminAfterRequestPasswordResetRequest": 1,
},
},
{
Name: "existing admin (after already sent)",
Method: http.MethodPost,
Url: "/api/admins/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
// simulate recent password request
admin, err := app.Dao().FindAdminByEmail("test@example.com")
if err != nil {
t.Fatal(err)
}
admin.LastResetSentAt = types.NowDateTime()
dao := daos.New(app.Dao().DB()) // new dao to ignore hooks
if err := dao.Save(admin); err != nil {
t.Fatal(err)
}
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminConfirmPasswordReset(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"password":{"code":"validation_required","message":"Cannot be blank."},"passwordConfirm":{"code":"validation_required","message":"Cannot be blank."},"token":{"code":"validation_required","message":"Cannot be blank."}}`},
},
{
Name: "invalid data",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(`{"password`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "expired token",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MTY0MDk5MTY2MX0.GLwCOsgWTTEKXTK-AyGW838de1OeZGIjfHH0FoRLqZg",
"password":"1234567890",
"passwordConfirm":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"token":{"code":"validation_invalid_token","message":"Invalid or expired token."}}}`},
},
{
Name: "valid token + invalid password",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MjIwODk4MTYwMH0.kwFEler6KSMKJNstuaSDvE1QnNdCta5qSnjaIQ0hhhc",
"password":"123456",
"passwordConfirm":"123456"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"password":{"code":"validation_length_out_of_range"`},
},
{
Name: "valid token + valid password",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MjIwODk4MTYwMH0.kwFEler6KSMKJNstuaSDvE1QnNdCta5qSnjaIQ0hhhc",
"password":"1234567891",
"passwordConfirm":"1234567891"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnAdminBeforeConfirmPasswordResetRequest": 1,
"OnAdminAfterConfirmPasswordResetRequest": 1,
},
},
{
Name: "OnAdminAfterConfirmPasswordResetRequest error response",
Method: http.MethodPost,
Url: "/api/admins/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsImV4cCI6MjIwODk4MTYwMH0.kwFEler6KSMKJNstuaSDvE1QnNdCta5qSnjaIQ0hhhc",
"password":"1234567891",
"passwordConfirm":"1234567891"
}`),
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterConfirmPasswordResetRequest().Add(func(e *core.AdminConfirmPasswordResetEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnAdminBeforeConfirmPasswordResetRequest": 1,
"OnAdminAfterConfirmPasswordResetRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminRefresh(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
Url: "/api/admins/auth-refresh",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as user",
Method: http.MethodPost,
Url: "/api/admins/auth-refresh",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin (expired token)",
Method: http.MethodPost,
Url: "/api/admins/auth-refresh",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MTY0MDk5MTY2MX0.I7w8iktkleQvC7_UIRpD7rNzcU4OnF7i7SFIUu6lD_4",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin (valid token)",
Method: http.MethodPost,
Url: "/api/admins/auth-refresh",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"admin":{"id":"sywbhecnh46rhm0"`,
`"token":`,
},
ExpectedEvents: map[string]int{
"OnAdminAuthRequest": 1,
"OnAdminBeforeAuthRefreshRequest": 1,
"OnAdminAfterAuthRefreshRequest": 1,
},
},
{
Name: "OnAdminAfterAuthRefreshRequest error response",
Method: http.MethodPost,
Url: "/api/admins/auth-refresh",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterAuthRefreshRequest().Add(func(e *core.AdminAuthRefreshEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnAdminBeforeAuthRefreshRequest": 1,
"OnAdminAfterAuthRefreshRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminsList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodGet,
Url: "/api/admins",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as user",
Method: http.MethodGet,
Url: "/api/admins",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin",
Method: http.MethodGet,
Url: "/api/admins",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":3`,
`"items":[{`,
`"id":"sywbhecnh46rhm0"`,
`"id":"sbmbsdb40jyxf7h"`,
`"id":"9q2trqumvlyr3bd"`,
},
ExpectedEvents: map[string]int{
"OnAdminsListRequest": 1,
},
},
{
Name: "authorized as admin + paging and sorting",
Method: http.MethodGet,
Url: "/api/admins?page=2&perPage=1&sort=-created",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":2`,
`"perPage":1`,
`"totalItems":3`,
`"items":[{`,
`"id":"sbmbsdb40jyxf7h"`,
},
NotExpectedContent: []string{
`"tokenKey"`,
`"passwordHash"`,
},
ExpectedEvents: map[string]int{
"OnAdminsListRequest": 1,
},
},
{
Name: "authorized as admin + invalid filter",
Method: http.MethodGet,
Url: "/api/admins?filter=invalidfield~'test2'",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + valid filter",
Method: http.MethodGet,
Url: "/api/admins?filter=email~'test3'",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"items":[{`,
`"id":"9q2trqumvlyr3bd"`,
},
NotExpectedContent: []string{
`"tokenKey"`,
`"passwordHash"`,
},
ExpectedEvents: map[string]int{
"OnAdminsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodGet,
Url: "/api/admins/sbmbsdb40jyxf7h",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as user",
Method: http.MethodGet,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + nonexisting admin id",
Method: http.MethodGet,
Url: "/api/admins/nonexisting",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + existing admin id",
Method: http.MethodGet,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"sbmbsdb40jyxf7h"`,
},
NotExpectedContent: []string{
`"tokenKey"`,
`"passwordHash"`,
},
ExpectedEvents: map[string]int{
"OnAdminViewRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodDelete,
Url: "/api/admins/sbmbsdb40jyxf7h",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as user",
Method: http.MethodDelete,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + missing admin id",
Method: http.MethodDelete,
Url: "/api/admins/missing",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + existing admin id",
Method: http.MethodDelete,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"OnModelBeforeDelete": 1,
"OnModelAfterDelete": 1,
"OnAdminBeforeDeleteRequest": 1,
"OnAdminAfterDeleteRequest": 1,
},
},
{
Name: "authorized as admin - try to delete the only remaining admin",
Method: http.MethodDelete,
Url: "/api/admins/sywbhecnh46rhm0",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
// delete all admins except the authorized one
adminModel := &models.Admin{}
_, err := app.Dao().DB().Delete(adminModel.TableName(), dbx.Not(dbx.HashExp{
"id": "sywbhecnh46rhm0",
})).Execute()
if err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnAdminBeforeDeleteRequest": 1,
},
},
{
Name: "OnAdminAfterDeleteRequest error response",
Method: http.MethodDelete,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterDeleteRequest().Add(func(e *core.AdminDeleteEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnModelBeforeDelete": 1,
"OnModelAfterDelete": 1,
"OnAdminBeforeDeleteRequest": 1,
"OnAdminAfterDeleteRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminCreate(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized (while having at least 1 existing admin)",
Method: http.MethodPost,
Url: "/api/admins",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "unauthorized (while having 0 existing admins)",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(`{"email":"testnew@example.com","password":"1234567890","passwordConfirm":"1234567890","avatar":3}`),
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
// delete all admins
_, err := app.Dao().DB().NewQuery("DELETE FROM {{_admins}}").Execute()
if err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":`,
`"email":"testnew@example.com"`,
`"avatar":3`,
},
ExpectedEvents: map[string]int{
"OnModelBeforeCreate": 1,
"OnModelAfterCreate": 1,
"OnAdminBeforeCreateRequest": 1,
"OnAdminAfterCreateRequest": 1,
},
},
{
Name: "authorized as user",
Method: http.MethodPost,
Url: "/api/admins",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + empty data",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(``),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."},"password":{"code":"validation_required","message":"Cannot be blank."}}`},
},
{
Name: "authorized as admin + invalid data format",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(`{`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + invalid data",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(`{
"email":"test@example.com",
"password":"1234",
"passwordConfirm":"4321",
"avatar":99
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"avatar":{"code":"validation_max_less_equal_than_required"`,
`"email":{"code":"validation_admin_email_exists"`,
`"password":{"code":"validation_length_out_of_range"`,
`"passwordConfirm":{"code":"validation_values_mismatch"`,
},
},
{
Name: "authorized as admin + valid data",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(`{
"email":"testnew@example.com",
"password":"1234567890",
"passwordConfirm":"1234567890",
"avatar":3
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":`,
`"email":"testnew@example.com"`,
`"avatar":3`,
},
NotExpectedContent: []string{
`"password"`,
`"passwordConfirm"`,
`"tokenKey"`,
`"passwordHash"`,
},
ExpectedEvents: map[string]int{
"OnModelBeforeCreate": 1,
"OnModelAfterCreate": 1,
"OnAdminBeforeCreateRequest": 1,
"OnAdminAfterCreateRequest": 1,
},
},
{
Name: "OnAdminAfterCreateRequest error response",
Method: http.MethodPost,
Url: "/api/admins",
Body: strings.NewReader(`{
"email":"testnew@example.com",
"password":"1234567890",
"passwordConfirm":"1234567890",
"avatar":3
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterCreateRequest().Add(func(e *core.AdminCreateEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnModelBeforeCreate": 1,
"OnModelAfterCreate": 1,
"OnAdminBeforeCreateRequest": 1,
"OnAdminAfterCreateRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestAdminUpdate(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as user",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + missing admin",
Method: http.MethodPatch,
Url: "/api/admins/missing",
Body: strings.NewReader(``),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + empty data",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
Body: strings.NewReader(``),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"sbmbsdb40jyxf7h"`,
`"email":"test2@example.com"`,
`"avatar":2`,
},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnAdminBeforeUpdateRequest": 1,
"OnAdminAfterUpdateRequest": 1,
},
},
{
Name: "authorized as admin + invalid formatted data",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
Body: strings.NewReader(`{`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "authorized as admin + invalid data",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
Body: strings.NewReader(`{
"email":"test@example.com",
"password":"1234",
"passwordConfirm":"4321",
"avatar":99
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"avatar":{"code":"validation_max_less_equal_than_required"`,
`"email":{"code":"validation_admin_email_exists"`,
`"password":{"code":"validation_length_out_of_range"`,
`"passwordConfirm":{"code":"validation_values_mismatch"`,
},
},
{
Name: "authorized as admin + valid data",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
Body: strings.NewReader(`{
"email":"testnew@example.com",
"password":"1234567891",
"passwordConfirm":"1234567891",
"avatar":5
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"sbmbsdb40jyxf7h"`,
`"email":"testnew@example.com"`,
`"avatar":5`,
},
NotExpectedContent: []string{
`"password"`,
`"passwordConfirm"`,
`"tokenKey"`,
`"passwordHash"`,
},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnAdminBeforeUpdateRequest": 1,
"OnAdminAfterUpdateRequest": 1,
},
},
{
Name: "OnAdminAfterUpdateRequest error response",
Method: http.MethodPatch,
Url: "/api/admins/sbmbsdb40jyxf7h",
Body: strings.NewReader(`{
"email":"testnew@example.com",
"password":"1234567891",
"passwordConfirm":"1234567891",
"avatar":5
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnAdminAfterUpdateRequest().Add(func(e *core.AdminUpdateEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnAdminBeforeUpdateRequest": 1,
"OnAdminAfterUpdateRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -1,132 +0,0 @@
package apis
import (
"net/http"
"strings"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/tools/inflector"
)
// ApiError defines the struct for a basic api error response.
type ApiError struct {
Code int `json:"code"`
Message string `json:"message"`
Data map[string]any `json:"data"`
// stores unformatted error data (could be an internal error, text, etc.)
rawData any
}
// Error makes it compatible with the `error` interface.
func (e *ApiError) Error() string {
return e.Message
}
// RawData returns the unformatted error data (could be an internal error, text, etc.)
func (e *ApiError) RawData() any {
return e.rawData
}
// NewNotFoundError creates and returns 404 `ApiError`.
func NewNotFoundError(message string, data any) *ApiError {
if message == "" {
message = "The requested resource wasn't found."
}
return NewApiError(http.StatusNotFound, message, data)
}
// NewBadRequestError creates and returns 400 `ApiError`.
func NewBadRequestError(message string, data any) *ApiError {
if message == "" {
message = "Something went wrong while processing your request."
}
return NewApiError(http.StatusBadRequest, message, data)
}
// NewForbiddenError creates and returns 403 `ApiError`.
func NewForbiddenError(message string, data any) *ApiError {
if message == "" {
message = "You are not allowed to perform this request."
}
return NewApiError(http.StatusForbidden, message, data)
}
// NewUnauthorizedError creates and returns 401 `ApiError`.
func NewUnauthorizedError(message string, data any) *ApiError {
if message == "" {
message = "Missing or invalid authentication token."
}
return NewApiError(http.StatusUnauthorized, message, data)
}
// NewApiError creates and returns new normalized `ApiError` instance.
func NewApiError(status int, message string, data any) *ApiError {
return &ApiError{
rawData: data,
Data: safeErrorsData(data),
Code: status,
Message: strings.TrimSpace(inflector.Sentenize(message)),
}
}
func safeErrorsData(data any) map[string]any {
switch v := data.(type) {
case validation.Errors:
return resolveSafeErrorsData[error](v)
case map[string]validation.Error:
return resolveSafeErrorsData[validation.Error](v)
case map[string]error:
return resolveSafeErrorsData[error](v)
case map[string]any:
return resolveSafeErrorsData[any](v)
default:
return map[string]any{} // not nil to ensure that is json serialized as object
}
}
func resolveSafeErrorsData[T any](data map[string]T) map[string]any {
result := map[string]any{}
for name, err := range data {
if isNestedError(err) {
result[name] = safeErrorsData(err)
continue
}
result[name] = resolveSafeErrorItem(err)
}
return result
}
func isNestedError(err any) bool {
switch err.(type) {
case validation.Errors, map[string]validation.Error, map[string]error, map[string]any:
return true
}
return false
}
// resolveSafeErrorItem extracts from each validation error its
// public safe error code and message.
func resolveSafeErrorItem(err any) map[string]string {
// default public safe error values
code := "validation_invalid_value"
msg := "Invalid value."
// only validation errors are public safe
if obj, ok := err.(validation.Error); ok {
code = obj.Code()
msg = inflector.Sentenize(obj.Error())
}
return map[string]string{
"code": code,
"message": msg,
}
}

42
apis/api_error_aliases.go Normal file
View File

@ -0,0 +1,42 @@
package apis
import "github.com/pocketbase/pocketbase/tools/router"
// ApiError aliases to minimize the breaking changes with earlier versions
// and for consistency with the JSVM binds.
// -------------------------------------------------------------------
// NewApiError is an alias for [router.NewApiError].
func NewApiError(status int, message string, errData any) *router.ApiError {
return router.NewApiError(status, message, errData)
}
// NewBadRequestError is an alias for [router.NewBadRequestError].
func NewBadRequestError(message string, errData any) *router.ApiError {
return router.NewBadRequestError(message, errData)
}
// NewNotFoundError is an alias for [router.NewNotFoundError].
func NewNotFoundError(message string, errData any) *router.ApiError {
return router.NewNotFoundError(message, errData)
}
// NewForbiddenError is an alias for [router.NewForbiddenError].
func NewForbiddenError(message string, errData any) *router.ApiError {
return router.NewForbiddenError(message, errData)
}
// NewUnauthorizedError is an alias for [router.NewUnauthorizedError].
func NewUnauthorizedError(message string, errData any) *router.ApiError {
return router.NewUnauthorizedError(message, errData)
}
// NewTooManyRequestsError is an alias for [router.NewTooManyRequestsError].
func NewTooManyRequestsError(message string, errData any) *router.ApiError {
return router.NewTooManyRequestsError(message, errData)
}
// NewInternalServerError is an alias for [router.NewInternalServerError].
func NewInternalServerError(message string, errData any) *router.ApiError {
return router.NewInternalServerError(message, errData)
}

View File

@ -1,162 +0,0 @@
package apis_test
import (
"encoding/json"
"errors"
"testing"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/apis"
)
func TestNewApiErrorWithRawData(t *testing.T) {
t.Parallel()
e := apis.NewApiError(
300,
"message_test",
"rawData_test",
)
result, _ := json.Marshal(e)
expected := `{"code":300,"message":"Message_test.","data":{}}`
if string(result) != expected {
t.Errorf("Expected %v, got %v", expected, string(result))
}
if e.Error() != "Message_test." {
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
}
if e.RawData() != "rawData_test" {
t.Errorf("Expected rawData %v, got %v", "rawData_test", e.RawData())
}
}
func TestNewApiErrorWithValidationData(t *testing.T) {
t.Parallel()
e := apis.NewApiError(
300,
"message_test",
validation.Errors{
"err1": errors.New("test error"), // should be normalized
"err2": validation.ErrRequired,
"err3": validation.Errors{
"sub1": errors.New("test error"), // should be normalized
"sub2": validation.ErrRequired,
"sub3": validation.Errors{
"sub11": validation.ErrRequired,
},
},
},
)
result, _ := json.Marshal(e)
expected := `{"code":300,"message":"Message_test.","data":{"err1":{"code":"validation_invalid_value","message":"Invalid value."},"err2":{"code":"validation_required","message":"Cannot be blank."},"err3":{"sub1":{"code":"validation_invalid_value","message":"Invalid value."},"sub2":{"code":"validation_required","message":"Cannot be blank."},"sub3":{"sub11":{"code":"validation_required","message":"Cannot be blank."}}}}}`
if string(result) != expected {
t.Errorf("Expected \n%v, \ngot \n%v", expected, string(result))
}
if e.Error() != "Message_test." {
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
}
if e.RawData() == nil {
t.Error("Expected non-nil rawData")
}
}
func TestNewNotFoundError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"code":404,"message":"The requested resource wasn't found.","data":{}}`},
{"demo", "rawData_test", `{"code":404,"message":"Demo.","data":{}}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":404,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
}
for i, scenario := range scenarios {
e := apis.NewNotFoundError(scenario.message, scenario.data)
result, _ := json.Marshal(e)
if string(result) != scenario.expected {
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
}
}
}
func TestNewBadRequestError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"code":400,"message":"Something went wrong while processing your request.","data":{}}`},
{"demo", "rawData_test", `{"code":400,"message":"Demo.","data":{}}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":400,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
}
for i, scenario := range scenarios {
e := apis.NewBadRequestError(scenario.message, scenario.data)
result, _ := json.Marshal(e)
if string(result) != scenario.expected {
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
}
}
}
func TestNewForbiddenError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"code":403,"message":"You are not allowed to perform this request.","data":{}}`},
{"demo", "rawData_test", `{"code":403,"message":"Demo.","data":{}}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":403,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
}
for i, scenario := range scenarios {
e := apis.NewForbiddenError(scenario.message, scenario.data)
result, _ := json.Marshal(e)
if string(result) != scenario.expected {
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
}
}
}
func TestNewUnauthorizedError(t *testing.T) {
t.Parallel()
scenarios := []struct {
message string
data any
expected string
}{
{"", nil, `{"code":401,"message":"Missing or invalid authentication token.","data":{}}`},
{"demo", "rawData_test", `{"code":401,"message":"Demo.","data":{}}`},
{"demo", validation.Errors{"err1": validation.NewError("test_code", "test_message")}, `{"code":401,"message":"Demo.","data":{"err1":{"code":"test_code","message":"Test_message."}}}`},
}
for i, scenario := range scenarios {
e := apis.NewUnauthorizedError(scenario.message, scenario.data)
result, _ := json.Marshal(e)
if string(result) != scenario.expected {
t.Errorf("(%d) Expected \n%v, \ngot \n%v", i, scenario.expected, string(result))
}
}
}

View File

@ -6,42 +6,37 @@ import (
"path/filepath"
"time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/rest"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
// bindBackupApi registers the file api endpoints and the corresponding handlers.
//
// @todo add hooks once the app hooks api restructuring is finalized
func bindBackupApi(app core.App, rg *echo.Group) {
api := backupApi{app: app}
subGroup := rg.Group("/backups", ActivityLogger(app))
subGroup.GET("", api.list, RequireAdminAuth())
subGroup.POST("", api.create, RequireAdminAuth())
subGroup.POST("/upload", api.upload, RequireAdminAuth())
subGroup.GET("/:key", api.download)
subGroup.DELETE("/:key", api.delete, RequireAdminAuth())
subGroup.POST("/:key/restore", api.restore, RequireAdminAuth())
func bindBackupApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
sub := rg.Group("/backups")
sub.GET("", backupsList).Bind(RequireSuperuserAuth())
sub.POST("", backupCreate).Bind(RequireSuperuserAuth())
sub.POST("/upload", backupUpload).Bind(RequireSuperuserAuthOnlyIfAny())
sub.GET("/{key}", backupDownload) // relies on superuser file token
sub.DELETE("/{key}", backupDelete).Bind(RequireSuperuserAuth())
sub.POST("/{key}/restore", backupRestore).Bind(RequireSuperuserAuthOnlyIfAny())
}
type backupApi struct {
app core.App
type backupFileInfo struct {
Modified types.DateTime `json:"modified"`
Key string `json:"key"`
Size int64 `json:"size"`
}
func (api *backupApi) list(c echo.Context) error {
func backupsList(e *core.RequestEvent) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
fsys, err := api.app.NewBackupsFilesystem()
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return NewBadRequestError("Failed to load backups filesystem.", err)
return e.BadRequestError("Failed to load backups filesystem.", err)
}
defer fsys.Close()
@ -49,166 +44,112 @@ func (api *backupApi) list(c echo.Context) error {
backups, err := fsys.List("")
if err != nil {
return NewBadRequestError("Failed to retrieve backup items. Raw error: \n"+err.Error(), nil)
return e.BadRequestError("Failed to retrieve backup items. Raw error: \n"+err.Error(), nil)
}
result := make([]models.BackupFileInfo, len(backups))
result := make([]backupFileInfo, len(backups))
for i, obj := range backups {
modified, _ := types.ParseDateTime(obj.ModTime)
result[i] = models.BackupFileInfo{
result[i] = backupFileInfo{
Key: obj.Key,
Size: obj.Size,
Modified: modified,
}
}
return c.JSON(http.StatusOK, result)
return e.JSON(http.StatusOK, result)
}
func (api *backupApi) create(c echo.Context) error {
if api.app.Store().Has(core.StoreKeyActiveBackup) {
return NewBadRequestError("Try again later - another backup/restore process has already been started", nil)
}
func backupDownload(e *core.RequestEvent) error {
fileToken := e.Request.URL.Query().Get("token")
form := forms.NewBackupCreate(api.app)
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
return form.Submit(func(next forms.InterceptorNextFunc[string]) forms.InterceptorNextFunc[string] {
return func(name string) error {
if err := next(name); err != nil {
return NewBadRequestError("Failed to create backup.", err)
}
// we don't retrieve the generated backup file because it may not be
// available yet due to the eventually consistent nature of some S3 providers
return c.NoContent(http.StatusNoContent)
}
})
}
func (api *backupApi) upload(c echo.Context) error {
files, err := rest.FindUploadedFiles(c.Request(), "file")
if err != nil {
return NewBadRequestError("Missing or invalid uploaded file.", err)
}
form := forms.NewBackupUpload(api.app)
form.File = files[0]
return form.Submit(func(next forms.InterceptorNextFunc[*filesystem.File]) forms.InterceptorNextFunc[*filesystem.File] {
return func(file *filesystem.File) error {
if err := next(file); err != nil {
return NewBadRequestError("Failed to upload backup.", err)
}
// we don't retrieve the generated backup file because it may not be
// available yet due to the eventually consistent nature of some S3 providers
return c.NoContent(http.StatusNoContent)
}
})
}
func (api *backupApi) download(c echo.Context) error {
fileToken := c.QueryParam("token")
_, err := api.app.Dao().FindAdminByToken(
fileToken,
api.app.Settings().AdminFileToken.Secret,
)
if err != nil {
return NewForbiddenError("Insufficient permissions to access the resource.", err)
authRecord, err := e.App.FindAuthRecordByToken(fileToken, core.TokenTypeFile)
if err != nil || !authRecord.IsSuperuser() {
return e.ForbiddenError("Insufficient permissions to access the resource.", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
fsys, err := api.app.NewBackupsFilesystem()
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return NewBadRequestError("Failed to load backups filesystem.", err)
return e.InternalServerError("Failed to load backups filesystem.", err)
}
defer fsys.Close()
fsys.SetContext(ctx)
key := c.PathParam("key")
br, err := fsys.GetFile(key)
if err != nil {
return NewBadRequestError("Failed to retrieve backup item. Raw error: \n"+err.Error(), nil)
}
defer br.Close()
key := e.Request.PathValue("key")
return fsys.Serve(
c.Response(),
c.Request(),
e.Response,
e.Request,
key,
filepath.Base(key), // without the path prefix (if any)
)
}
func (api *backupApi) restore(c echo.Context) error {
if api.app.Store().Has(core.StoreKeyActiveBackup) {
return NewBadRequestError("Try again later - another backup/restore process has already been started.", nil)
func backupDelete(e *core.RequestEvent) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return e.InternalServerError("Failed to load backups filesystem.", err)
}
defer fsys.Close()
fsys.SetContext(ctx)
key := e.Request.PathValue("key")
if key != "" && cast.ToString(e.App.Store().Get(core.StoreKeyActiveBackup)) == key {
return e.BadRequestError("The backup is currently being used and cannot be deleted.", nil)
}
key := c.PathParam("key")
if err := fsys.Delete(key); err != nil {
return e.BadRequestError("Invalid or already deleted backup file. Raw error: \n"+err.Error(), nil)
}
return e.NoContent(http.StatusNoContent)
}
func backupRestore(e *core.RequestEvent) error {
if e.App.Store().Has(core.StoreKeyActiveBackup) {
return e.BadRequestError("Try again later - another backup/restore process has already been started.", nil)
}
key := e.Request.PathValue("key")
existsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
fsys, err := api.app.NewBackupsFilesystem()
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return NewBadRequestError("Failed to load backups filesystem.", err)
return e.InternalServerError("Failed to load backups filesystem.", err)
}
defer fsys.Close()
fsys.SetContext(existsCtx)
if exists, err := fsys.Exists(key); !exists {
return NewBadRequestError("Missing or invalid backup file.", err)
return e.BadRequestError("Missing or invalid backup file.", err)
}
go func() {
// wait max 15 minutes to fetch the backup
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()
// give some optimistic time to write the response
routine.FireAndForget(func() {
// give some optimistic time to write the response before restarting the app
time.Sleep(1 * time.Second)
if err := api.app.RestoreBackup(ctx, key); err != nil {
api.app.Logger().Error("Failed to restore backup", "key", key, "error", err.Error())
}
}()
return c.NoContent(http.StatusNoContent)
}
func (api *backupApi) delete(c echo.Context) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// wait max 10 minutes to fetch the backup
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
fsys, err := api.app.NewBackupsFilesystem()
if err != nil {
return NewBadRequestError("Failed to load backups filesystem.", err)
if err := e.App.RestoreBackup(ctx, key); err != nil {
e.App.Logger().Error("Failed to restore backup", "key", key, "error", err.Error())
}
defer fsys.Close()
})
fsys.SetContext(ctx)
key := c.PathParam("key")
if key != "" && cast.ToString(api.app.Store().Get(core.StoreKeyActiveBackup)) == key {
return NewBadRequestError("The backup is currently being used and cannot be deleted.", nil)
}
if err := fsys.Delete(key); err != nil {
return NewBadRequestError("Invalid or already deleted backup file. Raw error: \n"+err.Error(), nil)
}
return c.NoContent(http.StatusNoContent)
return e.NoContent(http.StatusNoContent)
}

78
apis/backup_create.go Normal file
View File

@ -0,0 +1,78 @@
package apis
import (
"context"
"net/http"
"regexp"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
func backupCreate(e *core.RequestEvent) error {
if e.App.Store().Has(core.StoreKeyActiveBackup) {
return e.BadRequestError("Try again later - another backup/restore process has already been started", nil)
}
form := new(backupCreateForm)
form.app = e.App
err := e.BindBody(form)
if err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
err = form.validate()
if err != nil {
return e.BadRequestError("An error occurred while validating the submitted data.", err)
}
err = e.App.CreateBackup(context.Background(), form.Name)
if err != nil {
return e.BadRequestError("Failed to create backup.", err)
}
// we don't retrieve the generated backup file because it may not be
// available yet due to the eventually consistent nature of some S3 providers
return e.NoContent(http.StatusNoContent)
}
// -------------------------------------------------------------------
var backupNameRegex = regexp.MustCompile(`^[a-z0-9_-]+\.zip$`)
type backupCreateForm struct {
app core.App
Name string `form:"name" json:"name"`
}
func (form *backupCreateForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(
&form.Name,
validation.Length(1, 150),
validation.Match(backupNameRegex),
validation.By(form.checkUniqueName),
),
)
}
func (form *backupCreateForm) checkUniqueName(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
fsys, err := form.app.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
if exists, err := fsys.Exists(v); err != nil || exists {
return validation.NewError("validation_backup_name_exists", "The backup file name is invalid or already exists.")
}
return nil
}

View File

@ -10,7 +10,6 @@ import (
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"gocloud.dev/blob"
@ -23,50 +22,51 @@ func TestBackupsList(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodGet,
Url: "/api/backups",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodGet,
Url: "/api/backups",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (empty list)",
Name: "authorized as superuser (empty list)",
Method: http.MethodGet,
Url: "/api/backups",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`[]`,
},
ExpectedContent: []string{`[]`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin",
Name: "authorized as superuser",
Method: http.MethodGet,
Url: "/api/backups",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
@ -77,6 +77,7 @@ func TestBackupsList(t *testing.T) {
`"test2.zip"`,
`"test3.zip"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}
@ -92,50 +93,53 @@ func TestBackupsCreate(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodPost,
Url: "/api/backups",
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
URL: "/api/backups",
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodPost,
Url: "/api/backups",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (pending backup)",
Name: "authorized as superuser (pending backup)",
Method: http.MethodPost,
Url: "/api/backups",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Store().Set(core.StoreKeyActiveBackup, "")
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (autogenerated name)",
Name: "authorized as superuser (autogenerated name)",
Method: http.MethodPost,
Url: "/api/backups",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)
@ -151,16 +155,20 @@ func TestBackupsCreate(t *testing.T) {
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnBackupCreate": 1,
},
},
{
Name: "authorized as admin (invalid name)",
Name: "authorized as superuser (invalid name)",
Method: http.MethodPost,
Url: "/api/backups",
URL: "/api/backups",
Body: strings.NewReader(`{"name":"!test.zip"}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 400,
@ -168,16 +176,17 @@ func TestBackupsCreate(t *testing.T) {
`"data":{`,
`"name":{"code":"validation_match_invalid"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (valid name)",
Name: "authorized as superuser (valid name)",
Method: http.MethodPost,
Url: "/api/backups",
URL: "/api/backups",
Body: strings.NewReader(`{"name":"test.zip"}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)
@ -193,6 +202,10 @@ func TestBackupsCreate(t *testing.T) {
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnBackupCreate": 1,
},
},
}
@ -201,7 +214,7 @@ func TestBackupsCreate(t *testing.T) {
}
}
func TestBackupsUpload(t *testing.T) {
func TestBackupUpload(t *testing.T) {
t.Parallel()
// create dummy form data bodies
@ -243,55 +256,58 @@ func TestBackupsUpload(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodPost,
Url: "/api/backups/upload",
URL: "/api/backups/upload",
Body: bodies[0].buffer,
RequestHeaders: map[string]string{
Headers: map[string]string{
"Content-Type": bodies[0].contentType,
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodPost,
Url: "/api/backups/upload",
URL: "/api/backups/upload",
Body: bodies[1].buffer,
RequestHeaders: map[string]string{
Headers: map[string]string{
"Content-Type": bodies[1].contentType,
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (missing file)",
Name: "authorized as superuser (missing file)",
Method: http.MethodPost,
Url: "/api/backups/upload",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups/upload",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ensureNoBackups(t, app)
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (existing backup name)",
Name: "authorized as superuser (existing backup name)",
Method: http.MethodPost,
Url: "/api/backups/upload",
URL: "/api/backups/upload",
Body: bodies[3].buffer,
RequestHeaders: map[string]string{
Headers: map[string]string{
"Content-Type": bodies[3].contentType,
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
fsys, err := app.NewBackupsFilesystem()
if err != nil {
t.Fatal(err)
@ -302,7 +318,7 @@ func TestBackupsUpload(t *testing.T) {
t.Fatal(err)
}
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, _ := getBackupFiles(app)
if total := len(files); total != 1 {
t.Fatalf("Expected %d backup file, got %d", 1, total)
@ -310,23 +326,49 @@ func TestBackupsUpload(t *testing.T) {
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"file":{`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (valid file)",
Name: "authorized as superuser (valid file)",
Method: http.MethodPost,
Url: "/api/backups/upload",
URL: "/api/backups/upload",
Body: bodies[4].buffer,
RequestHeaders: map[string]string{
Headers: map[string]string{
"Content-Type": bodies[4].contentType,
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, _ := getBackupFiles(app)
if total := len(files); total != 1 {
t.Fatalf("Expected %d backup file, got %d", 1, total)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "unauthorized with 0 superusers (valid file)",
Method: http.MethodPost,
URL: "/api/backups/upload",
Body: bodies[5].buffer,
Headers: map[string]string{
"Content-Type": bodies[5].contentType,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// delete all superusers
_, err := app.DB().NewQuery("DELETE FROM {{" + core.CollectionNameSuperusers + "}}").Execute()
if err != nil {
t.Fatal(err)
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, _ := getBackupFiles(app)
if total := len(files); total != 1 {
t.Fatalf("Expected %d backup file, got %d", 1, total)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
},
}
@ -342,148 +384,159 @@ func TestBackupsDownload(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodGet,
Url: "/api/backups/test1.zip",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups/test1.zip",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with record auth header",
Method: http.MethodGet,
Url: "/api/backups/test1.zip",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
URL: "/api/backups/test1.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with admin auth header",
Name: "with superuser auth header",
Method: http.MethodGet,
Url: "/api/backups/test1.zip",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups/test1.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with empty or invalid token",
Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups/test1.zip?token=",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid record auth token",
Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid record file token",
Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid admin auth token",
Name: "with valid superuser auth token",
Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with expired admin file token",
Name: "with expired superuser file token",
Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImFkbWluIn0.g7Q_3UX6H--JWJ7yt1Hoe-1ugTX1KpbKzdt0zjGSe-E",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.hTNDzikwJdcoWrLnRnp7xbaifZ2vuYZ0oOYRHtJfnk4",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid admin file token but missing backup name",
Name: "with valid superuser file token but missing backup name",
Method: http.MethodGet,
Url: "/api/backups/missing?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups/missing?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid admin file token",
Name: "with valid superuser file token",
Method: http.MethodGet,
Url: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups/test1.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`storage/`,
`data.db`,
`logs.db`,
"storage/",
"data.db",
"aux.db",
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "with valid admin file token and backup name with escaped char",
Name: "with valid superuser file token and backup name with escaped char",
Method: http.MethodGet,
Url: "/api/backups/%40test4.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups/%40test4.zip?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`storage/`,
`data.db`,
`logs.db`,
"storage/",
"data.db",
"aux.db",
},
ExpectedEvents: map[string]int{"*": 0},
},
}
@ -495,7 +548,7 @@ func TestBackupsDownload(t *testing.T) {
func TestBackupsDelete(t *testing.T) {
t.Parallel()
noTestBackupFilesChanges := func(t *testing.T, app *tests.TestApp) {
noTestBackupFilesChanges := func(t testing.TB, app *tests.TestApp) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)
@ -511,62 +564,65 @@ func TestBackupsDelete(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodDelete,
Url: "/api/backups/test1.zip",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups/test1.zip",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app)
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodDelete,
Url: "/api/backups/test1.zip",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
URL: "/api/backups/test1.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app)
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (missing file)",
Name: "authorized as superuser (missing file)",
Method: http.MethodDelete,
Url: "/api/backups/missing.zip",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups/missing.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app)
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (existing file with matching active backup)",
Name: "authorized as superuser (existing file with matching active backup)",
Method: http.MethodDelete,
Url: "/api/backups/test1.zip",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups/test1.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
@ -574,20 +630,21 @@ func TestBackupsDelete(t *testing.T) {
// mock active backup with the same name to delete
app.Store().Set(core.StoreKeyActiveBackup, "test1.zip")
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
noTestBackupFilesChanges(t, app)
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (existing file and no matching active backup)",
Name: "authorized as superuser (existing file and no matching active backup)",
Method: http.MethodDelete,
Url: "/api/backups/test1.zip",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups/test1.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
@ -595,7 +652,7 @@ func TestBackupsDelete(t *testing.T) {
// mock active backup with different name
app.Store().Set(core.StoreKeyActiveBackup, "new.zip")
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)
@ -614,20 +671,21 @@ func TestBackupsDelete(t *testing.T) {
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (backup with escaped character)",
Name: "authorized as superuser (backup with escaped character)",
Method: http.MethodDelete,
Url: "/api/backups/%40test4.zip",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups/%40test4.zip",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)
@ -646,6 +704,7 @@ func TestBackupsDelete(t *testing.T) {
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
},
}
@ -661,53 +720,56 @@ func TestBackupsRestore(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodPost,
Url: "/api/backups/test1.zip/restore",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
URL: "/api/backups/test1.zip/restore",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodPost,
Url: "/api/backups/test1.zip/restore",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
URL: "/api/backups/test1.zip/restore",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (missing file)",
Name: "authorized as superuser (missing file)",
Method: http.MethodPost,
Url: "/api/backups/missing.zip/restore",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups/missing.zip/restore",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (active backup process)",
Name: "authorized as superuser (active backup process)",
Method: http.MethodPost,
Url: "/api/backups/test1.zip/restore",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/backups/test1.zip/restore",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
@ -716,6 +778,26 @@ func TestBackupsRestore(t *testing.T) {
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "unauthorized with no superusers (checks only access)",
Method: http.MethodPost,
URL: "/api/backups/missing.zip/restore",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// delete all superusers
_, err := app.DB().NewQuery("DELETE FROM {{" + core.CollectionNameSuperusers + "}}").Execute()
if err != nil {
t.Fatal(err)
}
if err := createTestBackups(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
@ -758,7 +840,7 @@ func getBackupFiles(app core.App) ([]*blob.ListObject, error) {
return fsys.List("")
}
func ensureNoBackups(t *testing.T, app *tests.TestApp) {
func ensureNoBackups(t testing.TB, app *tests.TestApp) {
files, err := getBackupFiles(app)
if err != nil {
t.Fatal(err)

72
apis/backup_upload.go Normal file
View File

@ -0,0 +1,72 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/filesystem"
)
func backupUpload(e *core.RequestEvent) error {
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
form := new(backupUploadForm)
form.fsys = fsys
files, _ := FindUploadedFiles(e.Request, "file")
if len(files) > 0 {
form.File = files[0]
}
err = form.validate()
if err != nil {
return e.BadRequestError("An error occurred while validating the submitted data.", err)
}
err = fsys.UploadFile(form.File, form.File.OriginalName)
if err != nil {
return e.BadRequestError("Failed to upload backup.", err)
}
// we don't retrieve the generated backup file because it may not be
// available yet due to the eventually consistent nature of some S3 providers
return e.NoContent(http.StatusNoContent)
}
// -------------------------------------------------------------------
type backupUploadForm struct {
fsys *filesystem.System
File *filesystem.File `json:"file"`
}
func (form *backupUploadForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(
&form.File,
validation.Required,
validation.By(validators.UploadedFileMimeType([]string{"application/zip"})),
validation.By(form.checkUniqueName),
),
)
}
func (form *backupUploadForm) checkUniqueName(value any) error {
v, _ := value.(*filesystem.File)
if v == nil {
return nil // nothing to check
}
// note: we use the original name because that is what we upload
if exists, err := form.fsys.Exists(v.OriginalName); err != nil || exists {
return validation.NewError("validation_backup_name_exists", "Backup file with the specified name already exists.")
}
return nil
}

View File

@ -1,266 +1,202 @@
// Package apis implements the default PocketBase api services and middlewares.
package apis
import (
"database/sql"
"errors"
"fmt"
"io/fs"
"log/slog"
"net/http"
"net/url"
"path/filepath"
"strings"
"time"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/rest"
"github.com/pocketbase/pocketbase/ui"
"github.com/spf13/cast"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/router"
)
const trailedAdminPath = "/_/"
// StaticWildcardParam is the name of Static handler wildcard parameter.
const StaticWildcardParam = "path"
// InitApi creates a configured echo instance with registered
// system and app specific routes and middlewares.
func InitApi(app core.App) (*echo.Echo, error) {
e := echo.New()
e.Debug = false
e.Binder = &rest.MultiBinder{}
e.JSONSerializer = &rest.Serializer{
FieldsParam: fieldsQueryParam,
}
// NewRouter returns a new router instance loaded with the default app middlewares and api routes.
func NewRouter(app core.App) (*router.Router[*core.RequestEvent], error) {
pbRouter := router.NewRouter(func(w http.ResponseWriter, r *http.Request) (*core.RequestEvent, router.EventCleanupFunc) {
event := new(core.RequestEvent)
event.Response = w
event.Request = r
event.App = app
// configure a custom router
e.ResetRouterCreator(func(ec *echo.Echo) echo.Router {
return echo.NewRouter(echo.RouterConfig{
UnescapePathParamValues: true,
AllowOverwritingRoute: true,
})
return event, nil
})
// default middlewares
e.Pre(middleware.RemoveTrailingSlashWithConfig(middleware.RemoveTrailingSlashConfig{
Skipper: func(c echo.Context) bool {
// enable by default only for the API routes
return !strings.HasPrefix(c.Request().URL.Path, "/api/")
},
}))
e.Pre(LoadAuthContext(app))
e.Use(middleware.Recover())
e.Use(middleware.Secure())
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set(ContextExecStartKey, time.Now())
// register default middlewares
pbRouter.Bind(activityLogger())
pbRouter.Bind(loadAuthToken())
pbRouter.Bind(securityHeaders())
pbRouter.Bind(rateLimit())
pbRouter.Bind(BodyLimit(DefaultMaxBodySize))
return next(c)
}
})
apiGroup := pbRouter.Group("/api")
bindSettingsApi(app, apiGroup)
bindCollectionApi(app, apiGroup)
bindRecordCrudApi(app, apiGroup)
bindRecordAuthApi(app, apiGroup)
bindLogsApi(app, apiGroup)
bindBackupApi(app, apiGroup)
bindFileApi(app, apiGroup)
bindBatchApi(app, apiGroup)
bindRealtimeApi(app, apiGroup)
bindHealthApi(app, apiGroup)
// custom error handler
e.HTTPErrorHandler = func(c echo.Context, err error) {
if err == nil {
return // no error
}
var apiErr *ApiError
if errors.As(err, &apiErr) {
// already an api error...
} else if v := new(echo.HTTPError); errors.As(err, &v) {
msg := fmt.Sprintf("%v", v.Message)
apiErr = NewApiError(v.Code, msg, v)
} else {
if errors.Is(err, sql.ErrNoRows) {
apiErr = NewNotFoundError("", err)
} else {
apiErr = NewBadRequestError("", err)
}
}
logRequest(app, c, apiErr)
if c.Response().Committed {
return // already committed
}
event := new(core.ApiErrorEvent)
event.HttpContext = c
event.Error = apiErr
// send error response
hookErr := app.OnBeforeApiError().Trigger(event, func(e *core.ApiErrorEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
// @see https://github.com/labstack/echo/issues/608
if e.HttpContext.Request().Method == http.MethodHead {
return e.HttpContext.NoContent(apiErr.Code)
}
return e.HttpContext.JSON(apiErr.Code, apiErr)
})
if hookErr == nil {
if err := app.OnAfterApiError().Trigger(event); err != nil {
app.Logger().Debug("OnAfterApiError failure", slog.String("error", err.Error()))
}
} else {
app.Logger().Debug("OnBeforeApiError error (truly rare case, eg. client already disconnected)", slog.String("error", hookErr.Error()))
}
}
// admin ui routes
bindStaticAdminUI(app, e)
// default routes
api := e.Group("/api", eagerRequestInfoCache(app))
bindSettingsApi(app, api)
bindAdminApi(app, api)
bindCollectionApi(app, api)
bindRecordCrudApi(app, api)
bindRecordAuthApi(app, api)
bindFileApi(app, api)
bindRealtimeApi(app, api)
bindLogsApi(app, api)
bindHealthApi(app, api)
bindBackupApi(app, api)
// catch all any route
api.Any("/*", func(c echo.Context) error {
return echo.ErrNotFound
}, ActivityLogger(app))
return e, nil
return pbRouter, nil
}
// StaticDirectoryHandler is similar to `echo.StaticDirectoryHandler`
// but without the directory redirect which conflicts with RemoveTrailingSlash middleware.
// WrapStdHandler wraps Go [http.Handler] into a PocketBase handler func.
func WrapStdHandler(h http.Handler) hook.HandlerFunc[*core.RequestEvent] {
return func(e *core.RequestEvent) error {
h.ServeHTTP(e.Response, e.Request)
return nil
}
}
// WrapStdMiddleware wraps Go [func(http.Handler) http.Handle] into a PocketBase middleware func.
func WrapStdMiddleware(m func(http.Handler) http.Handler) hook.HandlerFunc[*core.RequestEvent] {
return func(e *core.RequestEvent) (err error) {
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
e.Response = w
e.Request = r
err = e.Next()
})).ServeHTTP(e.Response, e.Request)
return err
}
}
// MustSubFS returns an [fs.FS] corresponding to the subtree rooted at fsys's dir.
//
// This is similar to [fs.Sub] but panics on failure.
func MustSubFS(fsys fs.FS, dir string) fs.FS {
dir = filepath.ToSlash(filepath.Clean(dir)) // ToSlash in case of Windows path
sub, err := fs.Sub(fsys, dir)
if err != nil {
panic(fmt.Errorf("failed to create sub FS: %w", err))
}
return sub
}
// Static is a handler function to serve static directory content from fsys.
//
// If a file resource is missing and indexFallback is set, the request
// will be forwarded to the base index.html (useful also for SPA).
// will be forwarded to the base index.html (useful for SPA with pretty urls).
//
// @see https://github.com/labstack/echo/issues/2211
func StaticDirectoryHandler(fileSystem fs.FS, indexFallback bool) echo.HandlerFunc {
return func(c echo.Context) error {
p := c.PathParam("*")
// escape url path
tmpPath, err := url.PathUnescape(p)
if err != nil {
return fmt.Errorf("failed to unescape path variable: %w", err)
// NB! Expects the route to have a "{path...}" wildcard parameter.
//
// Special redirects:
// - if "path" is a file that ends in index.html, it is redirected to its non-index.html version (eg. /test/index.html -> /test/)
// - if "path" is a directory that has index.html, the index.html file is rendered,
// otherwise if missing - returns 404 or fallback to the root index.html if indexFallback is set
//
// Example:
//
// fsys := os.DirFS("./pb_public")
// router.GET("/files/{path...}", apis.Static(fsys, false))
func Static(fsys fs.FS, indexFallback bool) hook.HandlerFunc[*core.RequestEvent] {
if fsys == nil {
panic("Static: the provided fs.FS argument is nil")
}
p = tmpPath
// fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid
name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/")))
return func(e *core.RequestEvent) error {
// disable the activity logger to avoid flooding with messages
//
// note: errors are still logged
if e.Get(requestEventKeySkipSuccessActivityLog) == nil {
e.Set(requestEventKeySkipSuccessActivityLog, true)
}
fileErr := c.FileFS(name, fileSystem)
filename := e.Request.PathValue(StaticWildcardParam)
filename = filepath.ToSlash(filepath.Clean(strings.TrimPrefix(filename, "/")))
if fileErr != nil && indexFallback && errors.Is(fileErr, echo.ErrNotFound) {
return c.FileFS("index.html", fileSystem)
// eagerly check for directory traversal
//
// note: this is just out of an abundance of caution because the fs.FS implementation could be non-std,
// but usually shouldn't be necessary since os.DirFS.Open is expected to fail if the filename starts with dots
if len(filename) > 2 && filename[0] == '.' && filename[1] == '.' && (filename[2] == '/' || filename[2] == '\\') {
if indexFallback && filename != router.IndexPage {
return e.FileFS(fsys, router.IndexPage)
}
return router.ErrFileNotFound
}
fi, err := fs.Stat(fsys, filename)
if err != nil {
if indexFallback && filename != router.IndexPage {
return e.FileFS(fsys, router.IndexPage)
}
return router.ErrFileNotFound
}
if fi.IsDir() {
// redirect to a canonical dir url, aka. with trailing slash
if !strings.HasSuffix(e.Request.URL.Path, "/") {
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(e.Request.URL.Path+"/"))
}
} else {
urlPath := e.Request.URL.Path
if strings.HasSuffix(urlPath, "/") {
// redirect to a non-trailing slash file route
urlPath = strings.TrimRight(urlPath, "/")
if len(urlPath) > 0 {
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(urlPath))
}
} else if stripped, ok := strings.CutSuffix(urlPath, router.IndexPage); ok {
// redirect without the index.html
return e.Redirect(http.StatusMovedPermanently, safeRedirectPath(stripped))
}
}
fileErr := e.FileFS(fsys, filename)
if fileErr != nil && indexFallback && filename != router.IndexPage && errors.Is(fileErr, router.ErrFileNotFound) {
return e.FileFS(fsys, router.IndexPage)
}
return fileErr
}
}
// bindStaticAdminUI registers the endpoints that serves the static admin UI.
func bindStaticAdminUI(app core.App, e *echo.Echo) error {
// redirect to trailing slash to ensure that relative urls will still work properly
e.GET(
strings.TrimRight(trailedAdminPath, "/"),
func(c echo.Context) error {
return c.Redirect(http.StatusTemporaryRedirect, strings.TrimLeft(trailedAdminPath, "/"))
},
)
// serves static files from the /ui/dist directory
// (similar to echo.StaticFS but with gzip middleware enabled)
e.GET(
trailedAdminPath+"*",
echo.StaticDirectoryHandler(ui.DistDirFS, false),
installerRedirect(app),
uiCacheControl(),
middleware.Gzip(),
)
return nil
// safeRedirectPath normalizes the path string by replacing all beginning slashes
// (`\\`, `//`, `\/`) with a single forward slash to prevent open redirect attacks
func safeRedirectPath(path string) string {
if len(path) > 1 && (path[0] == '\\' || path[0] == '/') && (path[1] == '\\' || path[1] == '/') {
path = "/" + strings.TrimLeft(path, `/\`)
}
return path
}
func uiCacheControl() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// add default Cache-Control header for all Admin UI resources
// (ignoring the root admin path)
if c.Request().URL.Path != trailedAdminPath {
c.Response().Header().Set("Cache-Control", "max-age=1209600, stale-while-revalidate=86400")
}
return next(c)
}
}
}
const hasAdminsCacheKey = "@hasAdmins"
func updateHasAdminsCache(app core.App) error {
total, err := app.Dao().TotalAdmins()
// FindUploadedFiles extracts all form files of "key" from a http request
// and returns a slice with filesystem.File instances (if any).
func FindUploadedFiles(r *http.Request, key string) ([]*filesystem.File, error) {
if r.MultipartForm == nil {
err := r.ParseMultipartForm(router.DefaultMaxMemory)
if err != nil {
return err
return nil, err
}
}
app.Store().Set(hasAdminsCacheKey, total > 0)
if r.MultipartForm == nil || r.MultipartForm.File == nil || len(r.MultipartForm.File[key]) == 0 {
return nil, http.ErrMissingFile
}
return nil
}
// installerRedirect redirects the user to the installer admin UI page
// when the application needs some preliminary configurations to be done.
func installerRedirect(app core.App) echo.MiddlewareFunc {
// keep hasAdminsCacheKey value up-to-date
app.OnAdminAfterCreateRequest().Add(func(data *core.AdminCreateEvent) error {
return updateHasAdminsCache(app)
})
app.OnAdminAfterDeleteRequest().Add(func(data *core.AdminDeleteEvent) error {
return updateHasAdminsCache(app)
})
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// skip redirect checks for non-root level index.html requests
path := c.Request().URL.Path
if path != trailedAdminPath && path != trailedAdminPath+"index.html" {
return next(c)
}
hasAdmins := cast.ToBool(app.Store().Get(hasAdminsCacheKey))
if !hasAdmins {
// update the cache to make sure that the admin wasn't created by another process
if err := updateHasAdminsCache(app); err != nil {
return err
}
hasAdmins = cast.ToBool(app.Store().Get(hasAdminsCacheKey))
}
_, hasInstallerParam := c.Request().URL.Query()["installer"]
if !hasAdmins && !hasInstallerParam {
// redirect to the installer page
return c.Redirect(http.StatusTemporaryRedirect, "?installer#")
}
if hasAdmins && hasInstallerParam {
// clear the installer param
return c.Redirect(http.StatusTemporaryRedirect, "?")
}
return next(c)
}
}
result := make([]*filesystem.File, 0, len(r.MultipartForm.File[key]))
for _, fh := range r.MultipartForm.File[key] {
file, err := filesystem.NewFileFromMultipart(fh)
if err != nil {
return nil, err
}
result = append(result, file)
}
return result, nil
}

View File

@ -1,422 +1,386 @@
package apis_test
import (
"database/sql"
"errors"
"bytes"
"fmt"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/rest"
"github.com/spf13/cast"
"github.com/pocketbase/pocketbase/tools/router"
)
func Test404(t *testing.T) {
func TestWrapStdHandler(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Method: http.MethodGet,
Url: "/api/missing",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Method: http.MethodPost,
Url: "/api/missing",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Method: http.MethodPatch,
Url: "/api/missing",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Method: http.MethodDelete,
Url: "/api/missing",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Method: http.MethodHead,
Url: "/api/missing",
ExpectedStatus: 404,
},
}
app, _ := tests.NewTestApp()
defer app.Cleanup()
for _, scenario := range scenarios {
scenario.Test(t)
}
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
func TestCustomRoutesAndErrorsHandling(t *testing.T) {
t.Parallel()
e := new(core.RequestEvent)
e.App = app
e.Request = req
e.Response = rec
scenarios := []tests.ApiScenario{
{
Name: "custom route",
Method: http.MethodGet,
Url: "/custom",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/custom",
Handler: func(c echo.Context) error {
return c.String(200, "test123")
},
})
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
{
Name: "custom route with url encoded parameter",
Method: http.MethodGet,
Url: "/a%2Bb%2Bc",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/:param",
Handler: func(c echo.Context) error {
return c.String(200, c.PathParam("param"))
},
})
},
ExpectedStatus: 200,
ExpectedContent: []string{"a+b+c"},
},
{
Name: "route with HTTPError",
Method: http.MethodGet,
Url: "/http-error",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/http-error",
Handler: func(c echo.Context) error {
return echo.ErrBadRequest
},
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`{"code":400,"message":"Bad Request.","data":{}}`},
},
{
Name: "route with api error",
Method: http.MethodGet,
Url: "/api-error",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/api-error",
Handler: func(c echo.Context) error {
return apis.NewApiError(500, "test message", errors.New("internal_test"))
},
})
},
ExpectedStatus: 500,
ExpectedContent: []string{`{"code":500,"message":"Test message.","data":{}}`},
},
{
Name: "route with plain error",
Method: http.MethodGet,
Url: "/plain-error",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/plain-error",
Handler: func(c echo.Context) error {
return errors.New("Test error")
},
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`{"code":400,"message":"Something went wrong while processing your request.","data":{}}`},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRemoveTrailingSlashMiddleware(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "non /api/* route (exact match)",
Method: http.MethodGet,
Url: "/custom",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/custom",
Handler: func(c echo.Context) error {
return c.String(200, "test123")
},
})
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
{
Name: "non /api/* route (with trailing slash)",
Method: http.MethodGet,
Url: "/custom/",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/custom",
Handler: func(c echo.Context) error {
return c.String(200, "test123")
},
})
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "/api/* route (exact match)",
Method: http.MethodGet,
Url: "/api/custom",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/api/custom",
Handler: func(c echo.Context) error {
return c.String(200, "test123")
},
})
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
{
Name: "/api/* route (with trailing slash)",
Method: http.MethodGet,
Url: "/api/custom/",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: http.MethodGet,
Path: "/api/custom",
Handler: func(c echo.Context) error {
return c.String(200, "test123")
},
})
},
ExpectedStatus: 200,
ExpectedContent: []string{"test123"},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestMultiBinder(t *testing.T) {
t.Parallel()
rawJson := `{"name":"test123"}`
formData, mp, err := tests.MockMultipartData(map[string]string{
rest.MultipartJsonKey: rawJson,
})
err := apis.WrapStdHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test"))
}))(e)
if err != nil {
t.Fatal(err)
}
scenarios := []tests.ApiScenario{
{
Name: "non-api group route",
Method: "POST",
Url: "/custom",
Body: strings.NewReader(rawJson),
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: "POST",
Path: "/custom",
Handler: func(c echo.Context) error {
data := &struct {
Name string `json:"name"`
}{}
if err := c.Bind(data); err != nil {
return err
}
// try to read the body again
r := apis.RequestInfo(c)
if v := cast.ToString(r.Data["name"]); v != "test123" {
t.Fatalf("Expected request data with name %q, got, %q", "test123", v)
}
return c.NoContent(200)
},
})
},
ExpectedStatus: 200,
},
{
Name: "api group route",
Method: "GET",
Url: "/api/admins",
Body: strings.NewReader(rawJson),
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// it is not important whether the route handler return an error since
// we just need to ensure that the eagerRequestInfoCache was registered
next(c)
// ensure that the body was read at least once
data := &struct {
Name string `json:"name"`
}{}
c.Bind(data)
// try to read the body again
r := apis.RequestInfo(c)
if v := cast.ToString(r.Data["name"]); v != "test123" {
t.Fatalf("Expected request data with name %q, got, %q", "test123", v)
}
return nil
}
})
},
ExpectedStatus: 200,
},
{
Name: "custom route with @jsonPayload as multipart body",
Method: "POST",
Url: "/custom",
Body: formData,
RequestHeaders: map[string]string{
"Content-Type": mp.FormDataContentType(),
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.AddRoute(echo.Route{
Method: "POST",
Path: "/custom",
Handler: func(c echo.Context) error {
data := &struct {
Name string `json:"name"`
}{}
if err := c.Bind(data); err != nil {
return err
}
// try to read the body again
r := apis.RequestInfo(c)
if v := cast.ToString(r.Data["name"]); v != "test123" {
t.Fatalf("Expected request data with name %q, got, %q", "test123", v)
}
return c.NoContent(200)
},
})
},
ExpectedStatus: 200,
},
}
for _, scenario := range scenarios {
scenario.Test(t)
if body := rec.Body.String(); body != "test" {
t.Fatalf("Expected body %q, got %q", "test", body)
}
}
func TestErrorHandler(t *testing.T) {
func TestWrapStdMiddleware(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "apis.ApiError",
Method: http.MethodGet,
Url: "/test",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.GET("/test", func(c echo.Context) error {
return apis.NewApiError(418, "test", nil)
app, _ := tests.NewTestApp()
defer app.Cleanup()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := new(core.RequestEvent)
e.App = app
e.Request = req
e.Response = rec
err := apis.WrapStdMiddleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test"))
})
},
ExpectedStatus: 418,
ExpectedContent: []string{`"message":"Test."`},
})(e)
if err != nil {
t.Fatal(err)
}
if body := rec.Body.String(); body != "test" {
t.Fatalf("Expected body %q, got %q", "test", body)
}
}
func TestStatic(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
dir := createTestDir(t)
defer os.RemoveAll(dir)
fsys := os.DirFS(filepath.Join(dir, "sub"))
type staticScenario struct {
path string
indexFallback bool
expectedStatus int
expectBody string
expectError bool
}
scenarios := []staticScenario{
{
path: "",
indexFallback: false,
expectedStatus: 200,
expectBody: "sub index.html",
expectError: false,
},
{
Name: "wrapped apis.ApiError",
Method: http.MethodGet,
Url: "/test",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.GET("/test", func(c echo.Context) error {
return fmt.Errorf("example 123: %w", apis.NewApiError(418, "test", nil))
})
},
ExpectedStatus: 418,
ExpectedContent: []string{`"message":"Test."`},
NotExpectedContent: []string{"example", "123"},
path: "missing/a/b/c",
indexFallback: false,
expectedStatus: 404,
expectBody: "",
expectError: true,
},
{
Name: "echo.HTTPError",
Method: http.MethodGet,
Url: "/test",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.GET("/test", func(c echo.Context) error {
return echo.NewHTTPError(418, "test")
})
},
ExpectedStatus: 418,
ExpectedContent: []string{`"message":"Test."`},
path: "missing/a/b/c",
indexFallback: true,
expectedStatus: 200,
expectBody: "sub index.html",
expectError: false,
},
{
Name: "wrapped echo.HTTPError",
Method: http.MethodGet,
Url: "/test",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.GET("/test", func(c echo.Context) error {
return fmt.Errorf("example 123: %w", echo.NewHTTPError(418, "test"))
})
},
ExpectedStatus: 418,
ExpectedContent: []string{`"message":"Test."`},
NotExpectedContent: []string{"example", "123"},
path: "testroot", // parent directory file
indexFallback: false,
expectedStatus: 404,
expectBody: "",
expectError: true,
},
{
Name: "wrapped sql.ErrNoRows",
Method: http.MethodGet,
Url: "/test",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.GET("/test", func(c echo.Context) error {
return fmt.Errorf("example 123: %w", sql.ErrNoRows)
})
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
NotExpectedContent: []string{"example", "123"},
path: "test",
indexFallback: false,
expectedStatus: 200,
expectBody: "sub test",
expectError: false,
},
{
Name: "custom error",
Method: http.MethodGet,
Url: "/test",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
e.GET("/test", func(c echo.Context) error {
return fmt.Errorf("example 123")
})
path: "sub2",
indexFallback: false,
expectedStatus: 301,
expectBody: "",
expectError: false,
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
NotExpectedContent: []string{"example", "123"},
{
path: "sub2/",
indexFallback: false,
expectedStatus: 200,
expectBody: "sub2 index.html",
expectError: false,
},
{
path: "sub2/test",
indexFallback: false,
expectedStatus: 200,
expectBody: "sub2 test",
expectError: false,
},
{
path: "sub2/test/",
indexFallback: false,
expectedStatus: 301,
expectBody: "",
expectError: false,
},
}
for _, scenario := range scenarios {
scenario.Test(t)
// extra directory traversal checks
dtp := []string{
"/../",
"\\../",
"../",
"../../",
"..\\",
"..\\..\\",
"../..\\",
"..\\..//",
`%2e%2e%2f`,
`%2e%2e%2f%2e%2e%2f`,
`%2e%2e/`,
`%2e%2e/%2e%2e/`,
`..%2f`,
`..%2f..%2f`,
`%2e%2e%5c`,
`%2e%2e%5c%2e%2e%5c`,
`%2e%2e\`,
`%2e%2e\%2e%2e\`,
`..%5c`,
`..%5c..%5c`,
`%252e%252e%255c`,
`%252e%252e%255c%252e%252e%255c`,
`..%255c`,
`..%255c..%255c`,
}
for _, p := range dtp {
scenarios = append(scenarios,
staticScenario{
path: p + "testroot",
indexFallback: false,
expectedStatus: 404,
expectBody: "",
expectError: true,
},
staticScenario{
path: p + "testroot",
indexFallback: true,
expectedStatus: 200,
expectBody: "sub index.html",
expectError: false,
},
)
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%v", i, s.path, s.indexFallback), func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/"+s.path, nil)
req.SetPathValue(apis.StaticWildcardParam, s.path)
rec := httptest.NewRecorder()
e := new(core.RequestEvent)
e.App = app
e.Request = req
e.Response = rec
err := apis.Static(fsys, s.indexFallback)(e)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
body := rec.Body.String()
if body != s.expectBody {
t.Fatalf("Expected body %q, got %q", s.expectBody, body)
}
if hasErr {
apiErr := router.ToApiError(err)
if apiErr.Status != s.expectedStatus {
t.Fatalf("Expected status code %d, got %d", s.expectedStatus, apiErr.Status)
}
}
})
}
}
func TestFindUploadedFiles(t *testing.T) {
scenarios := []struct {
filename string
expectedPattern string
}{
{"ab.png", `^ab\w{10}_\w{10}\.png$`},
{"test", `^test_\w{10}\.txt$`},
{"a b c d!@$.j!@$pg", `^a_b_c_d_\w{10}\.jpg$`},
{strings.Repeat("a", 150), `^a{100}_\w{10}\.txt$`},
}
for _, s := range scenarios {
t.Run(s.filename, func(t *testing.T) {
// create multipart form file body
body := new(bytes.Buffer)
mp := multipart.NewWriter(body)
w, err := mp.CreateFormFile("test", s.filename)
if err != nil {
t.Fatal(err)
}
w.Write([]byte("test"))
mp.Close()
// ---
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Add("Content-Type", mp.FormDataContentType())
result, err := apis.FindUploadedFiles(req, "test")
if err != nil {
t.Fatal(err)
}
if len(result) != 1 {
t.Fatalf("Expected 1 file, got %d", len(result))
}
if result[0].Size != 4 {
t.Fatalf("Expected the file size to be 4 bytes, got %d", result[0].Size)
}
pattern, err := regexp.Compile(s.expectedPattern)
if err != nil {
t.Fatalf("Invalid filename pattern %q: %v", s.expectedPattern, err)
}
if !pattern.MatchString(result[0].Name) {
t.Fatalf("Expected filename to match %s, got filename %s", s.expectedPattern, result[0].Name)
}
})
}
}
func TestFindUploadedFilesMissing(t *testing.T) {
body := new(bytes.Buffer)
mp := multipart.NewWriter(body)
mp.Close()
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Add("Content-Type", mp.FormDataContentType())
result, err := apis.FindUploadedFiles(req, "test")
if err == nil {
t.Error("Expected error, got nil")
}
if result != nil {
t.Errorf("Expected result to be nil, got %v", result)
}
}
func TestMustSubFS(t *testing.T) {
t.Parallel()
dir := createTestDir(t)
defer os.RemoveAll(dir)
// invalid path (no beginning and ending slashes)
if !hasPanicked(func() {
apis.MustSubFS(os.DirFS(dir), "/test/")
}) {
t.Fatalf("Expected to panic")
}
// valid path
if hasPanicked(func() {
apis.MustSubFS(os.DirFS(dir), "./////a/b/c") // checks if ToSlash was called
}) {
t.Fatalf("Didn't expect to panic")
}
// check sub content
sub := apis.MustSubFS(os.DirFS(dir), "sub")
_, err := sub.Open("test")
if err != nil {
t.Fatalf("Missing expected file sub/test")
}
}
// -------------------------------------------------------------------
func hasPanicked(f func()) (didPanic bool) {
defer func() {
if r := recover(); r != nil {
didPanic = true
}
}()
f()
return
}
// note: make sure to call os.RemoveAll(dir) after you are done
// working with the created test dir.
func createTestDir(t *testing.T) string {
dir, err := os.MkdirTemp(os.TempDir(), "test_dir")
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "index.html"), []byte("root index.html"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "testroot"), []byte("root test"), 0644); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Join(dir, "sub"), os.ModePerm); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/index.html"), []byte("sub index.html"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/test"), []byte("sub test"), 0644); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Join(dir, "sub", "sub2"), os.ModePerm); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/index.html"), []byte("sub2 index.html"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(dir, "sub/sub2/test"), []byte("sub2 test"), 0644); err != nil {
t.Fatal(err)
}
return dir
}

542
apis/batch.go Normal file
View File

@ -0,0 +1,542 @@
package apis
import (
"bytes"
"encoding/json"
"errors"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"regexp"
"slices"
"strconv"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
func bindBatchApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
sub := rg.Group("/batch")
sub.POST("", batchTransaction).Unbind(DefaultBodyLimitMiddlewareId) // the body limit is inlined
}
type HandleFunc func(e *core.RequestEvent) error
type BatchActionHandlerFunc func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc
// ValidBatchActions defines a map with the supported batch InternalRequest actions.
//
// Note: when adding new routes make sure that their middlewares are inlined!
var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{
// "upsert" handler
regexp.MustCompile(`^PUT /api/collections/(?P<collection>[^\/\?]+)/records(?P<query>\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
var id string
if len(ir.Body) > 0 && ir.Body["id"] != "" {
id = cast.ToString(ir.Body["id"])
}
if id != "" {
_, err := app.FindRecordById(params["collection"], id)
if err == nil {
// update
// ---
params["id"] = id // required for the path value
ir.Method = "PATCH"
ir.URL = "/api/collections/" + params["collection"] + "/records/" + id + params["query"]
return recordUpdate(next)
}
}
// create
// ---
ir.Method = "POST"
ir.URL = "/api/collections/" + params["collection"] + "/records" + params["query"]
return recordCreate(next)
},
regexp.MustCompile(`^POST /api/collections/(?P<collection>[^\/\?]+)/records(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
return recordCreate(next)
},
regexp.MustCompile(`^PATCH /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
return recordUpdate(next)
},
regexp.MustCompile(`^DELETE /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func() error) HandleFunc {
return recordDelete(next)
},
}
type BatchRequestResult struct {
Body any `json:"body"`
Status int `json:"status"`
}
type batchRequestsForm struct {
Requests []*core.InternalRequest `form:"requests" json:"requests"`
max int
}
func (brs batchRequestsForm) validate() error {
return validation.ValidateStruct(&brs,
validation.Field(&brs.Requests, validation.Required, validation.Length(0, brs.max)),
)
}
// NB! When the request is submitted as multipart/form-data,
// the regular fields data is expected to be submitted as serailized
// json under the @jsonPayload field and file keys need to follow the
// pattern "requests.N.fileField" or requests[N].fileField.
func batchTransaction(e *core.RequestEvent) error {
maxRequests := e.App.Settings().Batch.MaxRequests
if !e.App.Settings().Batch.Enabled || maxRequests <= 0 {
return e.ForbiddenError("Batch requests are not allowed.", nil)
}
txTimeout := time.Duration(e.App.Settings().Batch.Timeout) * time.Second
if txTimeout <= 0 {
txTimeout = 3 * time.Second // for now always limit
}
maxBodySize := e.App.Settings().Batch.MaxBodySize
if maxBodySize <= 0 {
maxBodySize = 128 << 20
}
err := applyBodyLimit(e, maxBodySize)
if err != nil {
return err
}
form := &batchRequestsForm{max: maxRequests}
// load base requests data
err = e.BindBody(form)
if err != nil {
return e.BadRequestError("Failed to read the submitted batch data.", err)
}
// load uploaded files into each request item
// note: expects the files to be under "requests.N.fileField" or "requests[N].fileField" format
// (the other regular fields must be put under `@jsonPayload` as serialized json)
if strings.HasPrefix(e.Request.Header.Get("Content-Type"), "multipart/form-data") {
for i, ir := range form.Requests {
iStr := strconv.Itoa(i)
files, err := extractPrefixedFiles(e.Request, "requests."+iStr+".", "requests["+iStr+"].")
if err != nil {
return e.BadRequestError("Failed to read the submitted batch files data.", err)
}
for key, files := range files {
if ir.Body == nil {
ir.Body = map[string]any{}
}
ir.Body[key] = files
}
}
}
// validate batch request form
err = form.validate()
if err != nil {
return e.BadRequestError("Invalid batch request data.", err)
}
event := new(core.BatchRequestEvent)
event.RequestEvent = e
event.Batch = form.Requests
return e.App.OnBatchRequest().Trigger(event, func(e *core.BatchRequestEvent) error {
bp := batchProcessor{
app: e.App,
baseEvent: e.RequestEvent,
infoContext: core.RequestInfoContextBatch,
}
if err := bp.Process(e.Batch, txTimeout); err != nil {
return firstApiError(err, e.BadRequestError("Batch transaction failed.", err))
}
return e.JSON(http.StatusOK, bp.results)
})
}
type batchProcessor struct {
app core.App
baseEvent *core.RequestEvent
infoContext string
results []*BatchRequestResult
failedIndex int
errCh chan error
stopCh chan struct{}
}
func (p *batchProcessor) Process(batch []*core.InternalRequest, timeout time.Duration) error {
p.results = make([]*BatchRequestResult, 0, len(batch))
if p.stopCh != nil {
close(p.stopCh)
}
p.stopCh = make(chan struct{}, 1)
if p.errCh != nil {
close(p.errCh)
}
p.errCh = make(chan error, 1)
return p.app.RunInTransaction(func(txApp core.App) error {
// used to interupts the recursive processing calls in case of a timeout or connection close
defer func() {
p.stopCh <- struct{}{}
}()
go func() {
err := p.process(txApp, batch, 0)
if err != nil {
err = validation.Errors{
"requests": validation.Errors{
strconv.Itoa(p.failedIndex): &BatchResponseError{
code: "batch_request_failed",
message: "Batch request failed.",
err: router.ToApiError(err),
},
},
}
}
// note: to avoid copying and due to the process recursion the final results order is reversed
if err == nil {
slices.Reverse(p.results)
}
p.errCh <- err
}()
select {
case responseErr := <-p.errCh:
return responseErr
case <-time.After(timeout):
// note: we don't return 408 Reques Timeout error because
// some browsers perform automatic retry behind the scenes
// which are hard to debug and unnecessary
return errors.New("batch transaction timeout")
case <-p.baseEvent.Request.Context().Done():
return errors.New("batch request interrupted")
}
})
}
func (p *batchProcessor) process(activeApp core.App, batch []*core.InternalRequest, i int) error {
select {
case <-p.stopCh:
return nil
default:
if len(batch) == 0 {
return nil
}
result, err := processInternalRequest(
activeApp,
p.baseEvent,
batch[0],
p.infoContext,
func() error {
if len(batch) == 1 {
return nil
}
err := p.process(activeApp, batch[1:], i+1)
// update the failed batch index (if not already)
if err != nil && p.failedIndex == 0 {
p.failedIndex = i + 1
}
return err
},
)
if err != nil {
return err
}
p.results = append(p.results, result)
return nil
}
}
func processInternalRequest(
activeApp core.App,
baseEvent *core.RequestEvent,
ir *core.InternalRequest,
infoContext string,
optNext func() error,
) (*BatchRequestResult, error) {
handle, params, ok := prepareInternalAction(activeApp, ir, optNext)
if !ok {
return nil, errors.New("unknown batch request action")
}
// construct a new http.Request
// ---------------------------------------------------------------
buf, mw, err := multipartDataFromInternalRequest(ir)
if err != nil {
return nil, err
}
r, err := http.NewRequest(strings.ToUpper(ir.Method), ir.URL, buf)
if err != nil {
return nil, err
}
// cleanup multipart temp files
defer func() {
if r.MultipartForm != nil {
if err := r.MultipartForm.RemoveAll(); err != nil {
activeApp.Logger().Warn("failed to cleanup temp batch files", "error", err)
}
}
}()
// load batch request path params
// ---
for k, v := range params {
r.SetPathValue(k, v)
}
// clone original request
// ---
r.RequestURI = r.URL.RequestURI()
r.Proto = baseEvent.Request.Proto
r.ProtoMajor = baseEvent.Request.ProtoMajor
r.ProtoMinor = baseEvent.Request.ProtoMinor
r.Host = baseEvent.Request.Host
r.RemoteAddr = baseEvent.Request.RemoteAddr
r.TLS = baseEvent.Request.TLS
if s := baseEvent.Request.TransferEncoding; s != nil {
s2 := make([]string, len(s))
copy(s2, s)
r.TransferEncoding = s2
}
if baseEvent.Request.Trailer != nil {
r.Trailer = baseEvent.Request.Trailer.Clone()
}
if baseEvent.Request.Header != nil {
r.Header = baseEvent.Request.Header.Clone()
}
// apply batch request specific headers
// ---
for k, v := range ir.Headers {
r.Header.Set(k, v)
}
r.Header.Set("Content-Type", mw.FormDataContentType())
// construct a new RequestEvent
// ---------------------------------------------------------------
event := &core.RequestEvent{}
event.App = activeApp
event.Auth = baseEvent.Auth
event.SetAll(baseEvent.GetAll())
// load RequestInfo context
if infoContext == "" {
infoContext = core.RequestInfoContextDefault
}
event.Set(core.RequestEventKeyInfoContext, infoContext)
// assign request
event.Request = r
event.Request.Body = &router.RereadableReadCloser{ReadCloser: r.Body} // enables multiple reads
// assign response
rec := httptest.NewRecorder()
event.Response = &router.ResponseWriter{ResponseWriter: rec} // enables status and write tracking
// execute
// ---------------------------------------------------------------
if err := handle(event); err != nil {
return nil, err
}
result := rec.Result()
defer result.Body.Close()
body, _ := types.ParseJSONRaw(rec.Body.Bytes())
return &BatchRequestResult{
Status: result.StatusCode,
Body: body,
}, nil
}
func multipartDataFromInternalRequest(ir *core.InternalRequest) (*bytes.Buffer, *multipart.Writer, error) {
buf := &bytes.Buffer{}
mw := multipart.NewWriter(buf)
regularFields := map[string]any{}
fileFields := map[string][]*filesystem.File{}
// separate regular fields from files
// ---
for k, rawV := range ir.Body {
switch v := rawV.(type) {
case *filesystem.File:
fileFields[k] = append(fileFields[k], v)
case []*filesystem.File:
fileFields[k] = append(fileFields[k], v...)
default:
regularFields[k] = v
}
}
// submit regularFields as @jsonPayload
// ---
rawBody, err := json.Marshal(regularFields)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
jsonPayload, err := mw.CreateFormField("@jsonPayload")
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
_, err = jsonPayload.Write(rawBody)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
// submit fileFields as multipart files
// ---
for key, files := range fileFields {
for _, file := range files {
part, err := mw.CreateFormFile(key, file.Name)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
fr, err := file.Reader.Open()
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
_, err = io.Copy(part, fr)
if err != nil {
return nil, nil, errors.Join(err, fr.Close(), mw.Close())
}
err = fr.Close()
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
}
}
return buf, mw, mw.Close()
}
func extractPrefixedFiles(request *http.Request, prefixes ...string) (map[string][]*filesystem.File, error) {
if request.MultipartForm == nil {
if err := request.ParseMultipartForm(router.DefaultMaxMemory); err != nil {
return nil, err
}
}
result := make(map[string][]*filesystem.File)
for k, fhs := range request.MultipartForm.File {
for _, p := range prefixes {
if strings.HasPrefix(k, p) {
resultKey := strings.TrimPrefix(k, p)
for _, fh := range fhs {
file, err := filesystem.NewFileFromMultipart(fh)
if err != nil {
return nil, err
}
result[resultKey] = append(result[resultKey], file)
}
}
}
}
return result, nil
}
func prepareInternalAction(activeApp core.App, ir *core.InternalRequest, optNext func() error) (HandleFunc, map[string]string, bool) {
full := strings.ToUpper(ir.Method) + " " + ir.URL
for re, actionFactory := range ValidBatchActions {
params, ok := findNamedMatches(re, full)
if ok {
return actionFactory(activeApp, ir, params, optNext), params, true
}
}
return nil, nil, false
}
func findNamedMatches(re *regexp.Regexp, str string) (map[string]string, bool) {
match := re.FindStringSubmatch(str)
if match == nil {
return nil, false
}
result := map[string]string{}
names := re.SubexpNames()
for i, m := range match {
if names[i] != "" {
result[names[i]] = m
}
}
return result, true
}
// -------------------------------------------------------------------
var (
_ router.SafeErrorItem = (*BatchResponseError)(nil)
_ router.SafeErrorResolver = (*BatchResponseError)(nil)
)
type BatchResponseError struct {
err *router.ApiError
code string
message string
}
func (e *BatchResponseError) Error() string {
return e.message
}
func (e *BatchResponseError) Code() string {
return e.code
}
func (e *BatchResponseError) Resolve(errData map[string]any) any {
errData["response"] = e.err
return errData
}
func (e BatchResponseError) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"message": e.message,
"code": e.code,
"response": e.err,
})
}

691
apis/batch_test.go Normal file
View File

@ -0,0 +1,691 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/router"
)
func TestBatchRequest(t *testing.T) {
t.Parallel()
formData, mp, err := tests.MockMultipartData(
map[string]string{
router.JSONPayloadKey: `{
"requests":[
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch2"}},
{"method":"POST", "url":"/api/collections/demo3/records", "body": {"title": "batch3"}},
{"method":"PATCH", "url":"/api/collections/demo3/records/lcl9d87w22ml6jy", "body": {"files-": "test_FLurQTgrY8.txt"}}
]
}`,
},
"requests.0.files",
"requests.0.files",
"requests.0.files",
"requests[2].files",
)
if err != nil {
t.Fatal(err)
}
scenarios := []tests.ApiScenario{
{
Name: "disabled batch requets",
Method: http.MethodPost,
URL: "/api/batch",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Enabled = false
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "max request limits reached",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"GET", "url":"/test1"},
{"method":"GET", "url":"/test2"},
{"method":"GET", "url":"/test3"}
]
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Enabled = true
app.Settings().Batch.MaxRequests = 2
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{"code":"validation_length_too_long"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "trigger requests validations",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{},
{"method":"GET", "url":"/valid"},
{"method":"invalid", "url":"/valid"},
{"method":"POST", "url":"` + strings.Repeat("a", 2001) + `"}
]
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Enabled = true
app.Settings().Batch.MaxRequests = 100
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{`,
`"0":{"method":{"code":"validation_required"`,
`"2":{"method":{"code":"validation_in_invalid"`,
`"3":{"url":{"code":"validation_length_too_long"`,
},
NotExpectedContent: []string{
`"1":`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "unknown batch request action",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"GET", "url":"/api/health"}
]
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{`,
`0":{"code":"batch_request_failed"`,
`"response":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
},
},
{
Name: "base 2 successful and 1 failed (public collection)",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": ""}}
]
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"response":{`,
`"2":{"code":"batch_request_failed"`,
`"response":{"data":{"title":{"code":"validation_required"`,
},
NotExpectedContent: []string{
`"0":`,
`"1":`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnRecordCreateRequest": 3,
"OnModelCreate": 3,
"OnModelCreateExecute": 2,
"OnModelAfterCreateError": 3,
"OnModelValidate": 3,
"OnRecordCreate": 3,
"OnRecordCreateExecute": 2,
"OnRecordAfterCreateError": 3,
"OnRecordValidate": 3,
"OnRecordEnrich": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
if err != nil {
t.Fatal(err)
}
if len(records) != 0 {
t.Fatalf("Expected no batch records to be persisted, got %d", len(records))
}
},
},
{
Name: "base 4 successful (public collection)",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}},
{"method":"PUT", "url":"/api/collections/demo2/records", "body": {"title": "batch3"}},
{"method":"PUT", "url":"/api/collections/demo2/records?fields=*,id:excerpt(4,true)", "body": {"id":"achvryl401bhse3","title": "batch4"}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch1"`,
`"title":"batch2"`,
`"title":"batch3"`,
`"title":"batch4"`,
`"id":"achv..."`,
`"active":false`,
`"active":true`,
`"status":200`,
`"body":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnModelValidate": 4,
"OnRecordValidate": 4,
"OnRecordEnrich": 4,
"OnRecordCreateRequest": 3,
"OnModelCreate": 3,
"OnModelCreateExecute": 3,
"OnModelAfterCreateSuccess": 3,
"OnRecordCreate": 3,
"OnRecordCreateExecute": 3,
"OnRecordAfterCreateSuccess": 3,
"OnRecordUpdateRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
if err != nil {
t.Fatal(err)
}
if len(records) != 4 {
t.Fatalf("Expected %d batch records to be persisted, got %d", 3, len(records))
}
},
},
{
Name: "mixed create/update/delete (rules failure)",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
]
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"requests":{`,
`"2":{"code":"batch_request_failed"`,
`"response":{`,
},
NotExpectedContent: []string{
// only demo3 requires authentication
`"0":`,
`"1":`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateError": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteError": 1,
"OnModelValidate": 1,
"OnRecordCreateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateError": 1,
"OnRecordDeleteRequest": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteError": 1,
"OnRecordEnrich": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
if err == nil {
t.Fatal("Expected record to not be created")
}
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
if err == nil {
t.Fatal("Expected record to not be updated")
}
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
if err != nil {
t.Fatal("Expected record to not be deleted")
}
},
},
{
Name: "mixed create/update/delete (rules success)",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, clients
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch_create"`,
`"title":"batch_update"`,
`"status":200`,
`"status":204`,
`"body":{`,
`"body":null`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 2,
// ---
"OnRecordCreateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDeleteRequest": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
"OnRecordUpdateRequest": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 2,
"OnRecordEnrich": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
if err == nil {
t.Fatal("Expected record to be deleted")
}
},
},
{
Name: "mixed create/update/delete (superuser auth)",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch_create"}},
{"method":"DELETE", "url":"/api/collections/demo2/records/achvryl401bhse3"},
{"method":"PATCH", "url":"/api/collections/demo3/records/1tmknxy2868d869", "body": {"title": "batch_update"}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch_create"`,
`"title":"batch_update"`,
`"status":200`,
`"status":204`,
`"body":{`,
`"body":null`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 2,
// ---
"OnRecordCreateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDeleteRequest": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
"OnRecordUpdateRequest": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 2,
"OnRecordEnrich": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindFirstRecordByFilter("demo2", `title="batch_create"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindFirstRecordByFilter("demo3", `title="batch_update"`)
if err != nil {
t.Fatal(err)
}
_, err = app.FindRecordById("demo2", "achvryl401bhse3")
if err == nil {
t.Fatal("Expected record to be deleted")
}
},
},
{
Name: "cascade delete/update",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{
"requests": [
{"method":"DELETE", "url":"/api/collections/demo3/records/1tmknxy2868d869"},
{"method":"DELETE", "url":"/api/collections/demo3/records/mk5fmymtx4wsprk"}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"status":204`,
`"body":null`,
},
NotExpectedContent: []string{
`"status":200`,
`"body":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelDelete": 3, // 2 batch + 1 cascade delete
"OnModelDeleteExecute": 3,
"OnModelAfterDeleteSuccess": 3,
"OnModelUpdate": 5, // 5 cascade update
"OnModelUpdateExecute": 5,
"OnModelAfterUpdateSuccess": 5,
// ---
"OnRecordDeleteRequest": 2,
"OnRecordDelete": 3,
"OnRecordDeleteExecute": 3,
"OnRecordAfterDeleteSuccess": 3,
"OnRecordUpdate": 5,
"OnRecordUpdateExecute": 5,
"OnRecordAfterUpdateSuccess": 5,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
ids := []string{
"1tmknxy2868d869",
"mk5fmymtx4wsprk",
"qzaqccwrmva4o1n",
}
for _, id := range ids {
_, err := app.FindRecordById("demo2", id)
if err == nil {
t.Fatalf("Expected record %q to be deleted", id)
}
}
},
},
{
Name: "transaction timeout",
Method: http.MethodPost,
URL: "/api/batch",
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch1"}},
{"method":"POST", "url":"/api/collections/demo2/records", "body": {"title": "batch2"}}
]
}`),
Headers: map[string]string{
// test@example.com, superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.Timeout = 1
app.OnRecordCreateRequest("demo2").BindFunc(func(e *core.RecordRequestEvent) error {
time.Sleep(600 * time.Millisecond) // < 1s so that the first request can succeed
return e.Next()
})
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{}`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
"OnRecordCreateRequest": 2,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateError": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateError": 1,
"OnRecordEnrich": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
records, err := app.FindRecordsByFilter("demo2", `title~"batch"`, "", 0, 0)
if err != nil {
t.Fatal(err)
}
if len(records) != 0 {
t.Fatalf("Expected %d batch records to be persisted, got %d", 0, len(records))
}
},
},
{
Name: "multipart/form-data + file upload",
Method: http.MethodPost,
URL: "/api/batch",
Body: formData,
Headers: map[string]string{
// test@example.com, clients
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
"Content-Type": mp.FormDataContentType(),
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"title":"batch1"`,
`"title":"batch2"`,
`"title":"batch3"`,
`"id":"lcl9d87w22ml6jy"`,
`"files":["300_UhLKX91HVb.png"]`,
`"tmpfile_`,
`"status":200`,
`"body":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 3,
"OnModelCreateExecute": 3,
"OnModelAfterCreateSuccess": 3,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 4,
// ---
"OnRecordCreateRequest": 3,
"OnRecordUpdateRequest": 1,
"OnRecordCreate": 3,
"OnRecordCreateExecute": 3,
"OnRecordAfterCreateSuccess": 3,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 4,
"OnRecordEnrich": 4,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
batch1, err := app.FindFirstRecordByFilter("demo3", `title="batch1"`)
if err != nil {
t.Fatalf("missing batch1: %v", err)
}
batch1Files := batch1.GetStringSlice("files")
if len(batch1Files) != 3 {
t.Fatalf("Expected %d batch1 file(s), got %d", 3, len(batch1Files))
}
batch2, err := app.FindFirstRecordByFilter("demo3", `title="batch2"`)
if err != nil {
t.Fatalf("missing batch2: %v", err)
}
batch2Files := batch2.GetStringSlice("files")
if len(batch2Files) != 0 {
t.Fatalf("Expected %d batch2 file(s), got %d", 0, len(batch2Files))
}
batch3, err := app.FindFirstRecordByFilter("demo3", `title="batch3"`)
if err != nil {
t.Fatalf("missing batch3: %v", err)
}
batch3Files := batch3.GetStringSlice("files")
if len(batch3Files) != 1 {
t.Fatalf("Expected %d batch3 file(s), got %d", 1, len(batch3Files))
}
batch4, err := app.FindRecordById("demo3", "lcl9d87w22ml6jy")
if err != nil {
t.Fatalf("missing batch4: %v", err)
}
batch4Files := batch4.GetStringSlice("files")
if len(batch4Files) != 1 {
t.Fatalf("Expected %d batch4 file(s), got %d", 1, len(batch4Files))
}
},
},
{
Name: "create/update with expand query params",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
]
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"body":{`,
`"id":"qjeql998mtp1azp"`,
`"id":"qzaqccwrmva4o1n"`,
`"id":"i9naidtvr6qsgb4"`,
`"expand":{"rel_one"`,
`"expand":{"rel_many"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnBatchRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 2,
// ---
"OnRecordCreateRequest": 1,
"OnRecordUpdateRequest": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 2,
"OnRecordEnrich": 5,
},
},
{
Name: "check body limit middleware",
Method: http.MethodPost,
URL: "/api/batch",
Headers: map[string]string{
// test@example.com, superusers
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{
"requests": [
{"method":"POST", "url":"/api/collections/demo5/records?expand=rel_one", "body": {"total": 9, "rel_one":"qzaqccwrmva4o1n"}},
{"method":"PATCH", "url":"/api/collections/demo5/records/qjeql998mtp1azp?expand=rel_many", "body": {"total": 10}}
]
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().Batch.MaxBodySize = 10
},
ExpectedStatus: 413,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -1,210 +1,186 @@
package apis
import (
"errors"
"net/http"
"strings"
"github.com/labstack/echo/v5"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/search"
)
// bindCollectionApi registers the collection api endpoints and the corresponding handlers.
func bindCollectionApi(app core.App, rg *echo.Group) {
api := collectionApi{app: app}
subGroup := rg.Group("/collections", ActivityLogger(app), RequireAdminAuth())
subGroup.GET("", api.list)
subGroup.POST("", api.create)
subGroup.GET("/:collection", api.view)
subGroup.PATCH("/:collection", api.update)
subGroup.DELETE("/:collection", api.delete)
subGroup.PUT("/import", api.bulkImport)
func bindCollectionApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
subGroup := rg.Group("/collections").Bind(RequireSuperuserAuth())
subGroup.GET("", collectionsList)
subGroup.POST("", collectionCreate)
subGroup.GET("/{collection}", collectionView)
subGroup.PATCH("/{collection}", collectionUpdate)
subGroup.DELETE("/{collection}", collectionDelete)
subGroup.DELETE("/{collection}/truncate", collectionTruncate)
subGroup.PUT("/import", collectionsImport)
subGroup.GET("/meta/scaffolds", collectionScaffolds)
}
type collectionApi struct {
app core.App
}
func (api *collectionApi) list(c echo.Context) error {
func collectionsList(e *core.RequestEvent) error {
fieldResolver := search.NewSimpleFieldResolver(
"id", "created", "updated", "name", "system", "type",
)
collections := []*models.Collection{}
collections := []*core.Collection{}
result, err := search.NewProvider(fieldResolver).
Query(api.app.Dao().CollectionQuery()).
ParseAndExec(c.QueryParams().Encode(), &collections)
Query(e.App.CollectionQuery()).
ParseAndExec(e.Request.URL.Query().Encode(), &collections)
if err != nil {
return NewBadRequestError("", err)
return e.BadRequestError("", err)
}
event := new(core.CollectionsListEvent)
event.HttpContext = c
event := new(core.CollectionsListRequestEvent)
event.RequestEvent = e
event.Collections = collections
event.Result = result
return api.app.OnCollectionsListRequest().Trigger(event, func(e *core.CollectionsListEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Result)
return event.App.OnCollectionsListRequest().Trigger(event, func(e *core.CollectionsListRequestEvent) error {
return e.JSON(http.StatusOK, e.Result)
})
}
func (api *collectionApi) view(c echo.Context) error {
collection, err := api.app.Dao().FindCollectionByNameOrId(c.PathParam("collection"))
func collectionView(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return NewNotFoundError("", err)
return e.NotFoundError("", err)
}
event := new(core.CollectionViewEvent)
event.HttpContext = c
event := new(core.CollectionRequestEvent)
event.RequestEvent = e
event.Collection = collection
return api.app.OnCollectionViewRequest().Trigger(event, func(e *core.CollectionViewEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Collection)
return e.App.OnCollectionViewRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
return e.JSON(http.StatusOK, e.Collection)
})
}
func (api *collectionApi) create(c echo.Context) error {
collection := &models.Collection{}
form := forms.NewCollectionUpsert(api.app, collection)
// load request
if err := c.Bind(form); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
func collectionCreate(e *core.RequestEvent) error {
// populate the minimal required factory collection data (if any)
factoryExtract := struct {
Type string `form:"type" json:"type"`
Name string `form:"name" json:"name"`
}{}
if err := e.BindBody(&factoryExtract); err != nil {
return e.BadRequestError("Failed to load the collection type data due to invalid formatting.", err)
}
event := new(core.CollectionCreateEvent)
event.HttpContext = c
// create scaffold
collection := core.NewCollection(factoryExtract.Type, factoryExtract.Name)
// merge the scaffold with the submitted request data
if err := e.BindBody(collection); err != nil {
return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.CollectionRequestEvent)
event.RequestEvent = e
event.Collection = collection
// create the collection
return form.Submit(func(next forms.InterceptorNextFunc[*models.Collection]) forms.InterceptorNextFunc[*models.Collection] {
return func(m *models.Collection) error {
event.Collection = m
return api.app.OnCollectionBeforeCreateRequest().Trigger(event, func(e *core.CollectionCreateEvent) error {
if err := next(e.Collection); err != nil {
return NewBadRequestError("Failed to create the collection.", err)
return e.App.OnCollectionCreateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
if err := e.App.Save(e.Collection); err != nil {
// validation failure
var validationErrors validation.Errors
if errors.As(err, &validationErrors) {
return e.BadRequestError("Failed to create collection.", validationErrors)
}
return api.app.OnCollectionAfterCreateRequest().Trigger(event, func(e *core.CollectionCreateEvent) error {
if e.HttpContext.Response().Committed {
return nil
// other generic db error
return e.BadRequestError("Failed to create collection. Raw error: \n"+err.Error(), nil)
}
return e.HttpContext.JSON(http.StatusOK, e.Collection)
})
})
}
return e.JSON(http.StatusOK, e.Collection)
})
}
func (api *collectionApi) update(c echo.Context) error {
collection, err := api.app.Dao().FindCollectionByNameOrId(c.PathParam("collection"))
func collectionUpdate(e *core.RequestEvent) error {
collection, err := e.App.FindCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return NewNotFoundError("", err)
return e.NotFoundError("", err)
}
form := forms.NewCollectionUpsert(api.app, collection)
// load request
if err := c.Bind(form); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
if err := e.BindBody(collection); err != nil {
return e.BadRequestError("Failed to load the submitted data due to invalid formatting.", err)
}
event := new(core.CollectionUpdateEvent)
event.HttpContext = c
event := new(core.CollectionRequestEvent)
event.RequestEvent = e
event.Collection = collection
// update the collection
return form.Submit(func(next forms.InterceptorNextFunc[*models.Collection]) forms.InterceptorNextFunc[*models.Collection] {
return func(m *models.Collection) error {
event.Collection = m
return api.app.OnCollectionBeforeUpdateRequest().Trigger(event, func(e *core.CollectionUpdateEvent) error {
if err := next(e.Collection); err != nil {
return NewBadRequestError("Failed to update the collection.", err)
return event.App.OnCollectionUpdateRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
if err := e.App.Save(e.Collection); err != nil {
// validation failure
var validationErrors validation.Errors
if errors.As(err, &validationErrors) {
return e.BadRequestError("Failed to update collection.", validationErrors)
}
return api.app.OnCollectionAfterUpdateRequest().Trigger(event, func(e *core.CollectionUpdateEvent) error {
if e.HttpContext.Response().Committed {
return nil
// other generic db error
return e.BadRequestError("Failed to update collection. Raw error: \n"+err.Error(), nil)
}
return e.HttpContext.JSON(http.StatusOK, e.Collection)
})
})
}
return e.JSON(http.StatusOK, e.Collection)
})
}
func (api *collectionApi) delete(c echo.Context) error {
collection, err := api.app.Dao().FindCollectionByNameOrId(c.PathParam("collection"))
func collectionDelete(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return NewNotFoundError("", err)
return e.NotFoundError("", err)
}
event := new(core.CollectionDeleteEvent)
event.HttpContext = c
event := new(core.CollectionRequestEvent)
event.RequestEvent = e
event.Collection = collection
return api.app.OnCollectionBeforeDeleteRequest().Trigger(event, func(e *core.CollectionDeleteEvent) error {
if err := api.app.Dao().DeleteCollection(e.Collection); err != nil {
return NewBadRequestError("Failed to delete collection due to existing dependency.", err)
return e.App.OnCollectionDeleteRequest().Trigger(event, func(e *core.CollectionRequestEvent) error {
if err := e.App.Delete(e.Collection); err != nil {
msg := "Failed to delete collection"
// check fo references
refs, _ := e.App.FindCollectionReferences(e.Collection, e.Collection.Id)
if len(refs) > 0 {
names := make([]string, 0, len(refs))
for ref := range refs {
names = append(names, ref.Name)
}
msg += " probably due to existing reference in " + strings.Join(names, ", ")
}
return api.app.OnCollectionAfterDeleteRequest().Trigger(event, func(e *core.CollectionDeleteEvent) error {
if e.HttpContext.Response().Committed {
return nil
return e.BadRequestError(msg, err)
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
return e.NoContent(http.StatusNoContent)
})
}
func (api *collectionApi) bulkImport(c echo.Context) error {
form := forms.NewCollectionsImport(api.app)
// load request data
if err := c.Bind(form); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
func collectionTruncate(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("", err)
}
event := new(core.CollectionsImportEvent)
event.HttpContext = c
event.Collections = form.Collections
// import collections
return form.Submit(func(next forms.InterceptorNextFunc[[]*models.Collection]) forms.InterceptorNextFunc[[]*models.Collection] {
return func(imports []*models.Collection) error {
event.Collections = imports
return api.app.OnCollectionsBeforeImportRequest().Trigger(event, func(e *core.CollectionsImportEvent) error {
if err := next(e.Collections); err != nil {
return NewBadRequestError("Failed to import the submitted collections.", err)
err = e.App.TruncateCollection(collection)
if err != nil {
return e.BadRequestError("Failed to truncate collection (most likely due to required cascade delete record references).", err)
}
return api.app.OnCollectionsAfterImportRequest().Trigger(event, func(e *core.CollectionsImportEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.NoContent(http.StatusNoContent)
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
func collectionScaffolds(e *core.RequestEvent) error {
return e.JSON(http.StatusOK, map[string]*core.Collection{
core.CollectionTypeBase: core.NewBaseCollection(""),
core.CollectionTypeAuth: core.NewAuthCollection(""),
core.CollectionTypeView: core.NewViewCollection(""),
})
}

60
apis/collection_import.go Normal file
View File

@ -0,0 +1,60 @@
package apis
import (
"errors"
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
func collectionsImport(e *core.RequestEvent) error {
form := new(collectionsImportForm)
err := e.BindBody(form)
if err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
err = form.validate()
if err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
event := new(core.CollectionsImportRequestEvent)
event.RequestEvent = e
event.CollectionsData = form.Collections
event.DeleteMissing = form.DeleteMissing
return event.App.OnCollectionsImportRequest().Trigger(event, func(e *core.CollectionsImportRequestEvent) error {
importErr := e.App.ImportCollections(e.CollectionsData, form.DeleteMissing)
if importErr == nil {
return e.NoContent(http.StatusNoContent)
}
// validation failure
var validationErrors validation.Errors
if errors.As(err, &validationErrors) {
return e.BadRequestError("Failed to import collections.", validationErrors)
}
// generic/db failure
return e.BadRequestError("Failed to import collections.", validation.Errors{"collections": validation.NewError(
"validation_collections_import_failure",
"Failed to import the collections configuration. Raw error:\n"+importErr.Error(),
)})
})
}
// -------------------------------------------------------------------
type collectionsImportForm struct {
Collections []map[string]any `form:"collections" json:"collections"`
DeleteMissing bool `form:"deleteMissing" json:"deleteMissing"`
}
func (form *collectionsImportForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Collections, validation.Required),
)
}

View File

@ -0,0 +1,257 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestCollectionsImport(t *testing.T) {
t.Parallel()
totalCollections := 16
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPut,
URL: "/api/collections/import",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as regular user",
Method: http.MethodPut,
URL: "/api/collections/import",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser + empty collections",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{"collections":[]}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"collections":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
expected := totalCollections
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
},
},
{
Name: "authorized as superuser + collections validator failure",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"collections":[
{"name": "import1"},
{
"name": "import2",
"fields": [
{
"id": "koih1lqx",
"name": "expand",
"type": "text"
}
]
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"collections":{"code":"validation_collections_import_failure"`,
`import2`,
`fields`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionsImportRequest": 1,
"OnCollectionCreate": 2,
"OnCollectionCreateExecute": 2,
"OnCollectionAfterCreateError": 2,
"OnModelCreate": 2,
"OnModelCreateExecute": 2,
"OnModelAfterCreateError": 2,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
expected := totalCollections
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
},
},
{
Name: "authorized as superuser + successful collections create",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"collections":[
{
"name": "import1",
"fields": [
{
"id": "koih1lqx",
"name": "test",
"type": "text"
}
]
},
{
"name": "import2",
"fields": [
{
"id": "koih1lqx",
"name": "test",
"type": "text"
}
],
"indexes": [
"create index idx_test on import2 (test)"
]
},
{
"name": "auth_without_fields",
"type": "auth"
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionsImportRequest": 1,
"OnCollectionCreate": 3,
"OnCollectionCreateExecute": 3,
"OnCollectionAfterCreateSuccess": 3,
"OnModelCreate": 3,
"OnModelCreateExecute": 3,
"OnModelAfterCreateSuccess": 3,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
expected := totalCollections + 3
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
indexes, err := app.TableIndexes("import2")
if err != nil || indexes["idx_test"] == "" {
t.Fatalf("Missing index %s (%v)", "idx_test", err)
}
},
},
{
Name: "authorized as superuser + create/update/delete",
Method: http.MethodPut,
URL: "/api/collections/import",
Body: strings.NewReader(`{
"deleteMissing": true,
"collections":[
{"name": "test123"},
{
"id":"wsmn24bux7wo113",
"name":"demo1",
"fields":[
{
"id":"_2hlxbmp",
"name":"title",
"type":"text",
"required":true
}
],
"indexes": []
}
]
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnCollectionsImportRequest": 1,
// ---
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnCollectionCreate": 1,
"OnCollectionCreateExecute": 1,
"OnCollectionAfterCreateSuccess": 1,
// ---
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnCollectionUpdate": 1,
"OnCollectionUpdateExecute": 1,
"OnCollectionAfterUpdateSuccess": 1,
// ---
"OnModelDelete": 14,
"OnModelAfterDeleteSuccess": 14,
"OnModelDeleteExecute": 14,
"OnCollectionDelete": 9,
"OnCollectionDeleteExecute": 9,
"OnCollectionAfterDeleteSuccess": 9,
"OnRecordAfterDeleteSuccess": 5,
"OnRecordDelete": 5,
"OnRecordDeleteExecute": 5,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
collections := []*core.Collection{}
if err := app.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
systemCollections := 0
for _, c := range collections {
if c.System {
systemCollections++
}
}
expected := systemCollections + 2
if len(collections) != expected {
t.Fatalf("Expected %d collections, got %d", expected, len(collections))
}
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

File diff suppressed because it is too large Load Diff

138
apis/dashboard.go Normal file
View File

@ -0,0 +1,138 @@
package apis
import (
"fmt"
"net/http"
"regexp"
"strings"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/router"
)
const installerParam = "pbinstal"
var wildcardPlaceholderRegex = regexp.MustCompile(`/{.+\.\.\.}$`)
func stripWildcard(pattern string) string {
return wildcardPlaceholderRegex.ReplaceAllString(pattern, "")
}
// installerRedirect redirects the user to the installer dashboard UI page
// when the application needs some preliminary configurations to be done.
func installerRedirect(app core.App, cpPath string) hook.HandlerFunc[*core.RequestEvent] {
// note: to avoid locks contention it is not concurrent safe but it
// is expected to be updated only once during initialization
var hasSuperuser bool
// strip named wildcard
cpPath = stripWildcard(cpPath)
updateHasSuperuser := func(app core.App) error {
total, err := app.CountRecords(core.CollectionNameSuperusers)
if err != nil {
return err
}
hasSuperuser = total > 0
return nil
}
// load initial state on app init
app.OnBootstrap().BindFunc(func(e *core.BootstrapEvent) error {
err := e.Next()
if err != nil {
return err
}
err = updateHasSuperuser(e.App)
if err != nil {
return fmt.Errorf("failed to check for existing superuser: %w", err)
}
return nil
})
// update on superuser create
app.OnRecordCreateRequest(core.CollectionNameSuperusers).BindFunc(func(e *core.RecordRequestEvent) error {
err := e.Next()
if err != nil {
return err
}
if !hasSuperuser {
hasSuperuser = true
}
return nil
})
return func(e *core.RequestEvent) error {
if hasSuperuser {
return e.Next()
}
isAPI := strings.HasPrefix(e.Request.URL.Path, "/api/")
isControlPanel := strings.HasPrefix(e.Request.URL.Path, cpPath)
wildcard := e.Request.PathValue(StaticWildcardParam)
// skip redirect checks for API and non-root level dashboard index.html requests (css, images, etc.)
if isAPI || (isControlPanel && wildcard != "" && wildcard != router.IndexPage) {
return e.Next()
}
// check again in case the superuser was created by some other process
if err := updateHasSuperuser(e.App); err != nil {
return err
}
if hasSuperuser {
return e.Next()
}
_, hasInstallerParam := e.Request.URL.Query()[installerParam]
// redirect to the installer page
if !hasInstallerParam {
return e.Redirect(http.StatusTemporaryRedirect, cpPath+"?"+installerParam+"#")
}
return e.Next()
}
}
// dashboardRemoveInstallerParam redirects to a non-installer
// query param in case there is already a superuser created.
//
// Note: intended to be registered only for the dashboard route
// to prevent excessive checks for every other route in installerRedirect.
func dashboardRemoveInstallerParam() hook.HandlerFunc[*core.RequestEvent] {
return func(e *core.RequestEvent) error {
_, hasInstallerParam := e.Request.URL.Query()[installerParam]
if !hasInstallerParam {
return e.Next() // nothing to remove
}
// clear installer param
total, _ := e.App.CountRecords(core.CollectionNameSuperusers)
if total > 0 {
return e.Redirect(http.StatusTemporaryRedirect, "?")
}
return e.Next()
}
}
// dashboardCacheControl adds default Cache-Control header for all
// dashboard UI resources (ignoring the root index.html path)
func dashboardCacheControl() hook.HandlerFunc[*core.RequestEvent] {
return func(e *core.RequestEvent) error {
if e.Request.PathValue(StaticWildcardParam) != "" {
e.Response.Header().Set("Cache-Control", "max-age=1209600, stale-while-revalidate=86400")
}
return e.Next()
}
}

View File

@ -7,18 +7,12 @@ import (
"log/slog"
"net/http"
"runtime"
"strings"
"time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/models/schema"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
"github.com/pocketbase/pocketbase/tools/router"
"golang.org/x/sync/semaphore"
"golang.org/x/sync/singleflight"
)
@ -27,23 +21,19 @@ var imageContentTypes = []string{"image/png", "image/jpg", "image/jpeg", "image/
var defaultThumbSizes = []string{"100x100"}
// bindFileApi registers the file api endpoints and the corresponding handlers.
func bindFileApi(app core.App, rg *echo.Group) {
func bindFileApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
api := fileApi{
app: app,
thumbGenSem: semaphore.NewWeighted(int64(runtime.NumCPU() + 2)), // the value is arbitrary chosen and may change in the future
thumbGenPending: new(singleflight.Group),
thumbGenMaxWait: 60 * time.Second,
}
subGroup := rg.Group("/files", ActivityLogger(app))
subGroup.POST("/token", api.fileToken)
subGroup.HEAD("/:collection/:recordId/:filename", api.download, LoadCollectionContext(api.app))
subGroup.GET("/:collection/:recordId/:filename", api.download, LoadCollectionContext(api.app))
sub := rg.Group("/files")
sub.POST("/token", api.fileToken).Bind(RequireAuth())
sub.GET("/{collection}/{recordId}/{filename}", api.download).Bind(collectionPathRateLimit("", "file"))
}
type fileApi struct {
app core.App
// thumbGenSem is a semaphore to prevent too much concurrent
// requests generating new thumbs at the same time.
thumbGenSem *semaphore.Weighted
@ -57,84 +47,67 @@ type fileApi struct {
thumbGenMaxWait time.Duration
}
func (api *fileApi) fileToken(c echo.Context) error {
event := new(core.FileTokenEvent)
event.HttpContext = c
if admin, _ := c.Get(ContextAdminKey).(*models.Admin); admin != nil {
event.Model = admin
event.Token, _ = tokens.NewAdminFileToken(api.app, admin)
} else if record, _ := c.Get(ContextAuthRecordKey).(*models.Record); record != nil {
event.Model = record
event.Token, _ = tokens.NewRecordFileToken(api.app, record)
func (api *fileApi) fileToken(e *core.RequestEvent) error {
if e.Auth == nil {
return e.UnauthorizedError("Missing auth context.", nil)
}
return api.app.OnFileBeforeTokenRequest().Trigger(event, func(e *core.FileTokenEvent) error {
if e.Model == nil || e.Token == "" {
return NewBadRequestError("Failed to generate file token.", nil)
token, err := e.Auth.NewFileToken()
if err != nil {
return e.InternalServerError("Failed to generate file token", err)
}
return api.app.OnFileAfterTokenRequest().Trigger(event, func(e *core.FileTokenEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
event := new(core.FileTokenRequestEvent)
event.RequestEvent = e
event.Token = token
return e.HttpContext.JSON(http.StatusOK, map[string]string{
return e.App.OnFileTokenRequest().Trigger(event, func(e *core.FileTokenRequestEvent) error {
return e.JSON(http.StatusOK, map[string]string{
"token": e.Token,
})
})
})
}
func (api *fileApi) download(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("", nil)
}
recordId := c.PathParam("recordId")
if recordId == "" {
return NewNotFoundError("", nil)
}
record, err := api.app.Dao().FindRecordById(collection.Id, recordId)
func (api *fileApi) download(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil {
return NewNotFoundError("", err)
return e.NotFoundError("", nil)
}
filename := c.PathParam("filename")
recordId := e.Request.PathValue("recordId")
if recordId == "" {
return e.NotFoundError("", nil)
}
record, err := e.App.FindRecordById(collection, recordId)
if err != nil {
return e.NotFoundError("", err)
}
filename := e.Request.PathValue("filename")
fileField := record.FindFileFieldByFile(filename)
if fileField == nil {
return NewNotFoundError("", nil)
}
options, ok := fileField.Options.(*schema.FileOptions)
if !ok {
return NewBadRequestError("", errors.New("failed to load file options"))
return e.NotFoundError("", nil)
}
// check whether the request is authorized to view the protected file
if options.Protected {
token := c.QueryParam("token")
adminOrAuthRecord, _ := api.findAdminOrAuthRecordByFileToken(token)
// create a copy of the cached request data and adjust it for the current auth model
requestInfo := *RequestInfo(c)
requestInfo.Context = models.RequestInfoContextProtectedFile
requestInfo.Admin = nil
requestInfo.AuthRecord = nil
if adminOrAuthRecord != nil {
if admin, _ := adminOrAuthRecord.(*models.Admin); admin != nil {
requestInfo.Admin = admin
} else if record, _ := adminOrAuthRecord.(*models.Record); record != nil {
requestInfo.AuthRecord = record
}
if fileField.Protected {
originalRequestInfo, err := e.RequestInfo()
if err != nil {
return e.InternalServerError("Failed to load request info", err)
}
if ok, _ := api.app.Dao().CanAccessRecord(record, &requestInfo, record.Collection().ViewRule); !ok {
return NewForbiddenError("Insufficient permissions to access the file resource.", nil)
token := e.Request.URL.Query().Get("token")
authRecord, _ := e.App.FindAuthRecordByToken(token, core.TokenTypeFile)
// create a shallow copy of the cached request data and adjust it to the current auth record (if any)
requestInfo := *originalRequestInfo
requestInfo.Context = core.RequestInfoContextProtectedFile
requestInfo.Auth = authRecord
if ok, _ := e.App.CanAccessRecord(record, &requestInfo, record.Collection().ViewRule); !ok {
return e.NotFoundError("", errors.New("insufficient permissions to access the file resource"))
}
}
@ -142,16 +115,16 @@ func (api *fileApi) download(c echo.Context) error {
// fetch the original view file field related record
if collection.IsView() {
fileRecord, err := api.app.Dao().FindRecordByViewFile(collection.Id, fileField.Name, filename)
fileRecord, err := e.App.FindRecordByViewFile(collection.Id, fileField.Name, filename)
if err != nil {
return NewNotFoundError("", fmt.Errorf("Failed to fetch view file field record: %w", err))
return e.NotFoundError("", fmt.Errorf("failed to fetch view file field record: %w", err))
}
baseFilesPath = fileRecord.BaseFilesPath()
}
fsys, err := api.app.NewFilesystem()
fsys, err := e.App.NewFilesystem()
if err != nil {
return NewBadRequestError("Filesystem initialization failure.", err)
return e.InternalServerError("Filesystem initialization failure.", err)
}
defer fsys.Close()
@ -160,12 +133,12 @@ func (api *fileApi) download(c echo.Context) error {
servedName := filename
// check for valid thumb size param
thumbSize := c.QueryParam("thumb")
if thumbSize != "" && (list.ExistInSlice(thumbSize, defaultThumbSizes) || list.ExistInSlice(thumbSize, options.Thumbs)) {
thumbSize := e.Request.URL.Query().Get("thumb")
if thumbSize != "" && (list.ExistInSlice(thumbSize, defaultThumbSizes) || list.ExistInSlice(thumbSize, fileField.Thumbs)) {
// extract the original file meta attributes and check it existence
oAttrs, oAttrsErr := fsys.Attributes(originalPath)
if oAttrsErr != nil {
return NewNotFoundError("", err)
return e.NotFoundError("", err)
}
// check if it is an image
@ -176,8 +149,8 @@ func (api *fileApi) download(c echo.Context) error {
// create a new thumb if it doesn't exist
if exists, _ := fsys.Exists(servedPath); !exists {
if err := api.createThumb(c, fsys, originalPath, servedPath, thumbSize); err != nil {
api.app.Logger().Warn(
if err := api.createThumb(e, fsys, originalPath, servedPath, thumbSize); err != nil {
e.App.Logger().Warn(
"Fallback to original - failed to create thumb "+servedName,
slog.Any("error", err),
slog.String("original", originalPath),
@ -192,8 +165,8 @@ func (api *fileApi) download(c echo.Context) error {
}
}
event := new(core.FileDownloadEvent)
event.HttpContext = c
event := new(core.FileDownloadRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
event.FileField = fileField
@ -203,61 +176,26 @@ func (api *fileApi) download(c echo.Context) error {
// clickjacking shouldn't be a concern when serving uploaded files,
// so it safe to unset the global X-Frame-Options to allow files embedding
// (note: it is out of the hook to allow users to customize the behavior)
c.Response().Header().Del("X-Frame-Options")
e.Response.Header().Del("X-Frame-Options")
return api.app.OnFileDownloadRequest().Trigger(event, func(e *core.FileDownloadEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
if err := fsys.Serve(e.HttpContext.Response(), e.HttpContext.Request(), e.ServedPath, e.ServedName); err != nil {
return NewNotFoundError("", err)
return e.App.OnFileDownloadRequest().Trigger(event, func(e *core.FileDownloadRequestEvent) error {
if err := fsys.Serve(e.Response, e.Request, e.ServedPath, e.ServedName); err != nil {
return e.NotFoundError("", err)
}
return nil
})
}
func (api *fileApi) findAdminOrAuthRecordByFileToken(fileToken string) (models.Model, error) {
fileToken = strings.TrimSpace(fileToken)
if fileToken == "" {
return nil, errors.New("missing file token")
}
claims, _ := security.ParseUnverifiedJWT(strings.TrimSpace(fileToken))
tokenType := cast.ToString(claims["type"])
switch tokenType {
case tokens.TypeAdmin:
admin, err := api.app.Dao().FindAdminByToken(
fileToken,
api.app.Settings().AdminFileToken.Secret,
)
if err == nil && admin != nil {
return admin, nil
}
case tokens.TypeAuthRecord:
record, err := api.app.Dao().FindAuthRecordByToken(
fileToken,
api.app.Settings().RecordFileToken.Secret,
)
if err == nil && record != nil {
return record, nil
}
}
return nil, errors.New("missing or invalid file token")
}
func (api *fileApi) createThumb(
c echo.Context,
e *core.RequestEvent,
fsys *filesystem.System,
originalPath string,
thumbPath string,
thumbSize string,
) error {
ch := api.thumbGenPending.DoChan(thumbPath, func() (any, error) {
ctx, cancel := context.WithTimeout(c.Request().Context(), api.thumbGenMaxWait)
ctx, cancel := context.WithTimeout(e.Request.Context(), api.thumbGenMaxWait)
defer cancel()
if err := api.thumbGenSem.Acquire(ctx, 1); err != nil {

View File

@ -10,11 +10,8 @@ import (
"sync"
"testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models/schema"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
@ -26,23 +23,54 @@ func TestFileToken(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodPost,
Url: "/api/files/token",
ExpectedStatus: 400,
URL: "/api/files/token",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "regular user",
Method: http.MethodPost,
URL: "/api/files/token",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
},
ExpectedEvents: map[string]int{
"OnFileBeforeTokenRequest": 1,
"*": 0,
"OnFileTokenRequest": 1,
},
},
{
Name: "unauthorized with model and token via hook",
Name: "superuser",
Method: http.MethodPost,
Url: "/api/files/token",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnFileBeforeTokenRequest().Add(func(e *core.FileTokenEvent) error {
record, _ := app.Dao().FindAuthRecordByEmail("users", "test@example.com")
e.Model = record
URL: "/api/files/token",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileTokenRequest": 1,
},
},
{
Name: "hook token overwrite",
Method: http.MethodPost,
URL: "/api/files/token",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnFileTokenRequest().BindFunc(func(e *core.FileTokenRequestEvent) error {
e.Token = "test"
return nil
return e.Next()
})
},
ExpectedStatus: 200,
@ -50,40 +78,8 @@ func TestFileToken(t *testing.T) {
`"token":"test"`,
},
ExpectedEvents: map[string]int{
"OnFileBeforeTokenRequest": 1,
"OnFileAfterTokenRequest": 1,
},
},
{
Name: "auth record",
Method: http.MethodPost,
Url: "/api/files/token",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
},
ExpectedEvents: map[string]int{
"OnFileBeforeTokenRequest": 1,
"OnFileAfterTokenRequest": 1,
},
},
{
Name: "admin",
Method: http.MethodPost,
Url: "/api/files/token",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
},
ExpectedEvents: map[string]int{
"OnFileBeforeTokenRequest": 1,
"OnFileAfterTokenRequest": 1,
"*": 0,
"OnFileTokenRequest": 1,
},
},
}
@ -152,233 +148,271 @@ func TestFileDownload(t *testing.T) {
{
Name: "missing collection",
Method: http.MethodGet,
Url: "/api/files/missing/4q1xlclmfloku33/300_1SEi6Q6U72.png",
URL: "/api/files/missing/4q1xlclmfloku33/300_1SEi6Q6U72.png",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing record",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/missing/300_1SEi6Q6U72.png",
URL: "/api/files/_pb_users_auth_/missing/300_1SEi6Q6U72.png",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing file",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/missing.png",
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/missing.png",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "existing image",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
ExpectedStatus: 200,
ExpectedContent: []string{string(testImg)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - missing thumb (should fallback to the original)",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=999x999",
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=999x999",
ExpectedStatus: 200,
ExpectedContent: []string{string(testImg)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (crop center)",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50",
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbCropCenter)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (crop top)",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50t",
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50t",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbCropTop)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (crop bottom)",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50b",
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50b",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbCropBottom)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (fit)",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50f",
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x50f",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbFit)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (zero width)",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=0x50",
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=0x50",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbZeroWidth)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing image - existing thumb (zero height)",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x0",
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png?thumb=70x0",
ExpectedStatus: 200,
ExpectedContent: []string{string(testThumbZeroHeight)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "existing non image file - thumb parameter should be ignored",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100",
URL: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100",
ExpectedStatus: 200,
ExpectedContent: []string{string(testFile)},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
// protected file access checks
{
Name: "protected file - expired token",
Name: "protected file - superuser with expired file token",
Method: http.MethodGet,
Url: "/api/files/_pb_users_auth_/oap640cot4yru2s/test_kfd2wYLxkz.txt?thumb=100x100",
ExpectedStatus: 200,
ExpectedContent: []string{string(testFile)},
ExpectedEvents: map[string]int{
"OnFileDownloadRequest": 1,
},
},
{
Name: "protected file - admin with expired file token",
Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImFkbWluIn0.g7Q_3UX6H--JWJ7yt1Hoe-1ugTX1KpbKzdt0zjGSe-E",
ExpectedStatus: 403,
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.hTNDzikwJdcoWrLnRnp7xbaifZ2vuYZ0oOYRHtJfnk4",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "protected file - admin with valid file token",
Name: "protected file - superuser with valid file token",
Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MTg5MzQ1MjQ2MSwidHlwZSI6ImFkbWluIn0.LyAMpSfaHVsuUqIlqqEbhDQSdFzoPz_EIDcb2VJMBsU",
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJjXzMzMjM4NjYzMzkifQ.C8m3aRZNOxUDhMiuZuDTRIIjRl7wsOyzoxs8EjvKNgY",
ExpectedStatus: 200,
ExpectedContent: []string{"PNG"},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "protected file - guest without view access",
Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
ExpectedStatus: 403,
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "protected file - guest with view access",
Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
dao := daos.New(app.Dao().DB())
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// mock public view access
c, err := dao.FindCollectionByNameOrId("demo1")
c, err := app.FindCachedCollectionByNameOrId("demo1")
if err != nil {
t.Fatalf("Failed to fetch mock collection: %v", err)
}
c.ViewRule = types.Pointer("")
if err := dao.SaveCollection(c); err != nil {
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
t.Fatalf("Failed to update mock collection: %v", err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{"PNG"},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "protected file - auth record without view access",
Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
dao := daos.New(app.Dao().DB())
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// mock restricted user view access
c, err := dao.FindCollectionByNameOrId("demo1")
c, err := app.FindCachedCollectionByNameOrId("demo1")
if err != nil {
t.Fatalf("Failed to fetch mock collection: %v", err)
}
c.ViewRule = types.Pointer("@request.auth.verified = true")
if err := dao.SaveCollection(c); err != nil {
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
t.Fatalf("Failed to update mock collection: %v", err)
}
},
ExpectedStatus: 403,
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "protected file - auth record with view access",
Method: http.MethodGet,
Url: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk",
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
dao := daos.New(app.Dao().DB())
URL: "/api/files/demo1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// mock user view access
c, err := dao.FindCollectionByNameOrId("demo1")
c, err := app.FindCachedCollectionByNameOrId("demo1")
if err != nil {
t.Fatalf("Failed to fetch mock collection: %v", err)
}
c.ViewRule = types.Pointer("@request.auth.verified = false")
if err := dao.SaveCollection(c); err != nil {
if err := app.UnsafeWithoutHooks().Save(c); err != nil {
t.Fatalf("Failed to update mock collection: %v", err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{"PNG"},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
{
Name: "protected file in view (view's View API rule failure)",
Method: http.MethodGet,
Url: "/api/files/view1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk",
ExpectedStatus: 403,
URL: "/api/files/view1/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "protected file in view (view's View API rule success)",
Method: http.MethodGet,
Url: "/api/files/view1/84nmscqy84lsi1t/test_d61b33QdDU.txt?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTg5MzQ1MjQ2MSwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwidHlwZSI6ImF1dGhSZWNvcmQifQ.0d_0EO6kfn9ijZIQWAqgRi8Bo1z7MKcg1LQpXhQsEPk",
URL: "/api/files/view1/84nmscqy84lsi1t/test_d61b33QdDU.txt?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6ImZpbGUiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8ifQ.nSTLuCPcGpWn2K2l-BFkC3Vlzc-ZTDPByYq8dN1oPSo",
ExpectedStatus: 200,
ExpectedContent: []string{"test"},
ExpectedEvents: map[string]int{
"*": 0,
"OnFileDownloadRequest": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:file",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:file"},
{MaxRequests: 0, Label: "users:file"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:file",
Method: http.MethodGet,
URL: "/api/files/_pb_users_auth_/4q1xlclmfloku33/300_1SEi6Q6U72.png",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:file"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
@ -410,30 +444,23 @@ func TestConcurrentThumbsGeneration(t *testing.T) {
defer fsys.Close()
// create a dummy file field collection
demo1, err := app.Dao().FindCollectionByNameOrId("demo1")
demo1, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
fileField := demo1.Schema.GetFieldByName("file_one")
fileField.Options = &schema.FileOptions{
Protected: false,
MaxSelect: 1,
MaxSize: 999999,
fileField := demo1.Fields.GetByName("file_one").(*core.FileField)
fileField.Protected = false
fileField.MaxSelect = 1
fileField.MaxSize = 999999
// new thumbs
Thumbs: []string{"111x111", "111x222", "111x333"},
}
demo1.Schema.AddField(fileField)
if err := app.Dao().SaveCollection(demo1); err != nil {
fileField.Thumbs = []string{"111x111", "111x222", "111x333"}
demo1.Fields.Add(fileField)
if err = app.Save(demo1); err != nil {
t.Fatal(err)
}
fileKey := "wsmn24bux7wo113/al1h9ijdeojtsjy/300_Jsjq7RdBgA.png"
e, err := apis.InitApi(app)
if err != nil {
t.Fatal(err)
}
urls := []string{
"/api/files/" + fileKey + "?thumb=111x111",
"/api/files/" + fileKey + "?thumb=111x111", // should still result in single thumb
@ -446,7 +473,6 @@ func TestConcurrentThumbsGeneration(t *testing.T) {
wg.Add(len(urls))
for _, url := range urls {
url := url
go func() {
defer wg.Done()
@ -454,7 +480,11 @@ func TestConcurrentThumbsGeneration(t *testing.T) {
req := httptest.NewRequest("GET", url, nil)
e.ServeHTTP(recorder, req)
pbRouter, _ := apis.NewRouter(app)
mux, _ := pbRouter.BuildMux()
if mux != nil {
mux.ServeHTTP(recorder, req)
}
}()
}

View File

@ -2,42 +2,52 @@ package apis
import (
"net/http"
"slices"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/router"
)
// bindHealthApi registers the health api endpoint.
func bindHealthApi(app core.App, rg *echo.Group) {
api := healthApi{app: app}
func bindHealthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
subGroup := rg.Group("/health")
subGroup.HEAD("", api.healthCheck)
subGroup.GET("", api.healthCheck)
}
type healthApi struct {
app core.App
}
type healthCheckResponse struct {
Message string `json:"message"`
Code int `json:"code"`
Data struct {
CanBackup bool `json:"canBackup"`
} `json:"data"`
subGroup.GET("", healthCheck)
}
// healthCheck returns a 200 OK response if the server is healthy.
func (api *healthApi) healthCheck(c echo.Context) error {
if c.Request().Method == http.MethodHead {
return c.NoContent(http.StatusOK)
func healthCheck(e *core.RequestEvent) error {
resp := struct {
Message string `json:"message"`
Code int `json:"code"`
Data map[string]any `json:"data"`
}{
Code: http.StatusOK,
Message: "API is healthy.",
}
resp := new(healthCheckResponse)
resp.Code = http.StatusOK
resp.Message = "API is healthy."
resp.Data.CanBackup = !api.app.Store().Has(core.StoreKeyActiveBackup)
if e.HasSuperuserAuth() {
resp.Data = make(map[string]any, 3)
resp.Data["canBackup"] = !e.App.Store().Has(core.StoreKeyActiveBackup)
resp.Data["realIP"] = e.RealIP()
return c.JSON(http.StatusOK, resp)
// loosely check if behind a reverse proxy
// (usually used in the dashboard to remind superusers in case deployed behind reverse-proxy)
possibleProxyHeader := ""
headersToCheck := append(
slices.Clone(e.App.Settings().TrustedProxy.Headers),
// common proxy headers
"CF-Connecting-IP", "Fly-Client-IP", "X‑Forwarded-For",
)
for _, header := range headersToCheck {
if e.Request.Header.Get(header) != "" {
possibleProxyHeader = header
break
}
}
resp.Data["possibleProxyHeader"] = possibleProxyHeader
} else {
resp.Data = map[string]any{} // ensure that it is returned as object
}
return e.JSON(http.StatusOK, resp)
}

View File

@ -12,21 +12,56 @@ func TestHealthAPI(t *testing.T) {
scenarios := []tests.ApiScenario{
{
Name: "HEAD health status",
Method: http.MethodHead,
Url: "/api/health",
Name: "GET health status (guest)",
Method: http.MethodGet, // automatically matches also HEAD as a side-effect of the Go std mux
URL: "/api/health",
ExpectedStatus: 200,
ExpectedContent: []string{
`"code":200`,
`"data":{}`,
},
NotExpectedContent: []string{
"canBackup",
"realIP",
"possibleProxyHeader",
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "GET health status",
Name: "GET health status (regular user)",
Method: http.MethodGet,
Url: "/api/health",
URL: "/api/health",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"code":200`,
`"data":{}`,
},
NotExpectedContent: []string{
"canBackup",
"realIP",
"possibleProxyHeader",
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "GET health status (superuser)",
Method: http.MethodGet,
URL: "/api/health",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"code":200`,
`"data":{`,
`"canBackup":true`,
`"realIP"`,
`"possibleProxyHeader"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}

View File

@ -3,79 +3,71 @@ package apis
import (
"net/http"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/search"
)
// bindLogsApi registers the request logs api endpoints.
func bindLogsApi(app core.App, rg *echo.Group) {
api := logsApi{app: app}
subGroup := rg.Group("/logs", RequireAdminAuth())
subGroup.GET("", api.list)
subGroup.GET("/stats", api.stats)
subGroup.GET("/:id", api.view)
}
type logsApi struct {
app core.App
func bindLogsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
sub := rg.Group("/logs").Bind(RequireSuperuserAuth(), SkipSuccessActivityLog())
sub.GET("", logsList)
sub.GET("/stats", logsStats)
sub.GET("/{id}", logsView)
}
var logFilterFields = []string{
"rowid", "id", "created", "updated",
"level", "message", "data",
"id", "created", "level", "message", "data",
`^data\.[\w\.\:]*\w+$`,
}
func (api *logsApi) list(c echo.Context) error {
func logsList(e *core.RequestEvent) error {
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
result, err := search.NewProvider(fieldResolver).
Query(api.app.LogsDao().LogQuery()).
ParseAndExec(c.QueryParams().Encode(), &[]*models.Log{})
Query(e.App.AuxModelQuery(&core.Log{})).
ParseAndExec(e.Request.URL.Query().Encode(), &[]*core.Log{})
if err != nil {
return NewBadRequestError("", err)
return e.BadRequestError("", err)
}
return c.JSON(http.StatusOK, result)
return e.JSON(http.StatusOK, result)
}
func (api *logsApi) stats(c echo.Context) error {
func logsStats(e *core.RequestEvent) error {
fieldResolver := search.NewSimpleFieldResolver(logFilterFields...)
filter := c.QueryParam(search.FilterQueryParam)
filter := e.Request.URL.Query().Get(search.FilterQueryParam)
var expr dbx.Expression
if filter != "" {
var err error
expr, err = search.FilterData(filter).BuildExpr(fieldResolver)
if err != nil {
return NewBadRequestError("Invalid filter format.", err)
return e.BadRequestError("Invalid filter format.", err)
}
}
stats, err := api.app.LogsDao().LogsStats(expr)
stats, err := e.App.LogsStats(expr)
if err != nil {
return NewBadRequestError("Failed to generate logs stats.", err)
return e.BadRequestError("Failed to generate logs stats.", err)
}
return c.JSON(http.StatusOK, stats)
return e.JSON(http.StatusOK, stats)
}
func (api *logsApi) view(c echo.Context) error {
id := c.PathParam("id")
func logsView(e *core.RequestEvent) error {
id := e.Request.PathValue("id")
if id == "" {
return NewNotFoundError("", nil)
return e.NotFoundError("", nil)
}
log, err := api.app.LogsDao().FindLogById(id)
log, err := e.App.FindLogById(id)
if err != nil || log == nil {
return NewNotFoundError("", err)
return e.NotFoundError("", err)
}
return c.JSON(http.StatusOK, log)
return e.JSON(http.StatusOK, log)
}

View File

@ -4,7 +4,7 @@ import (
"net/http"
"testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
@ -15,29 +15,31 @@ func TestLogsList(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodGet,
Url: "/api/logs",
URL: "/api/logs",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodGet,
Url: "/api/logs",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
URL: "/api/logs",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin",
Name: "authorized as superuser",
Method: http.MethodGet,
Url: "/api/logs",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/logs",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
if err := tests.MockLogsData(app); err != nil {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
@ -50,16 +52,17 @@ func TestLogsList(t *testing.T) {
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin + filter",
Name: "authorized as superuser + filter",
Method: http.MethodGet,
Url: "/api/logs?filter=data.status>200",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/logs?filter=data.status>200",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
if err := tests.MockLogsData(app); err != nil {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
@ -71,6 +74,7 @@ func TestLogsList(t *testing.T) {
`"items":[{`,
`"id":"f2133873-44fb-9f38-bf82-c918f53b310d"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}
@ -86,44 +90,47 @@ func TestLogView(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodGet,
Url: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodGet,
Url: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (nonexisting request log)",
Name: "authorized as superuser (nonexisting request log)",
Method: http.MethodGet,
Url: "/api/logs/missing1-9f38-44fb-bf82-c8f53b310d91",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/logs/missing1-9f38-44fb-bf82-c8f53b310d91",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
if err := tests.MockLogsData(app); err != nil {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (existing request log)",
Name: "authorized as superuser (existing request log)",
Method: http.MethodGet,
Url: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/logs/873f2133-9f38-44fb-bf82-c8f53b310d91",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
if err := tests.MockLogsData(app); err != nil {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
@ -131,6 +138,7 @@ func TestLogView(t *testing.T) {
ExpectedContent: []string{
`"id":"873f2133-9f38-44fb-bf82-c8f53b310d91"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}
@ -146,52 +154,54 @@ func TestLogsStats(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodGet,
Url: "/api/logs/stats",
URL: "/api/logs/stats",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodGet,
Url: "/api/logs/stats",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
URL: "/api/logs/stats",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin",
Name: "authorized as superuser",
Method: http.MethodGet,
Url: "/api/logs/stats",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/logs/stats",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
if err := tests.MockLogsData(app); err != nil {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`[{"total":1,"date":"2022-05-01 10:00:00.000Z"},{"total":1,"date":"2022-05-02 10:00:00.000Z"}]`,
`[{"date":"2022-05-01 10:00:00.000Z","total":1},{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
},
},
{
Name: "authorized as admin + filter",
Name: "authorized as superuser + filter",
Method: http.MethodGet,
Url: "/api/logs/stats?filter=data.status>200",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/logs/stats?filter=data.status>200",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
if err := tests.MockLogsData(app); err != nil {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubLogsData(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`[{"total":1,"date":"2022-05-02 10:00:00.000Z"}]`,
`[{"date":"2022-05-02 10:00:00.000Z","total":1}]`,
},
},
}

View File

@ -3,303 +3,321 @@ package apis
import (
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
"slices"
"strings"
"time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
// Common request context keys used by the middlewares and api handlers.
// Common request event store keys used by the middlewares and api handlers.
const (
ContextAdminKey string = "admin"
ContextAuthRecordKey string = "authRecord"
ContextCollectionKey string = "collection"
ContextExecStartKey string = "execStart"
RequestEventKeyLogMeta = "pbLogMeta" // extra data to store with the request activity log
requestEventKeyExecStart = "__execStart" // the value must be time.Time
requestEventKeySkipSuccessActivityLog = "__skipSuccessActivityLogger" // the value must be bool
)
const (
DefaultWWWRedirectMiddlewarePriority = -99999
DefaultWWWRedirectMiddlewareId = "pbWWWRedirect"
DefaultActivityLoggerMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 30
DefaultActivityLoggerMiddlewareId = "pbActivityLogger"
DefaultSkipSuccessActivityLogMiddlewareId = "pbSkipSuccessActivityLog"
DefaultEnableAuthIdActivityLog = "pbEnableAuthIdActivityLog"
DefaultLoadAuthTokenMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 20
DefaultLoadAuthTokenMiddlewareId = "pbLoadAuthToken"
DefaultSecurityHeadersMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 10
DefaultSecurityHeadersMiddlewareId = "pbSecurityHeaders"
DefaultRequireGuestOnlyMiddlewareId = "pbRequireGuestOnly"
DefaultRequireAuthMiddlewareId = "pbRequireAuth"
DefaultRequireSuperuserAuthMiddlewareId = "pbRequireSuperuserAuth"
DefaultRequireSuperuserAuthOnlyIfAnyMiddlewareId = "pbRequireSuperuserAuthOnlyIfAny"
DefaultRequireSuperuserOrOwnerAuthMiddlewareId = "pbRequireSuperuserOrOwnerAuth"
DefaultRequireSameCollectionContextAuthMiddlewareId = "pbRequireSameCollectionContextAuth"
)
// RequireGuestOnly middleware requires a request to NOT have a valid
// Authorization header.
//
// This middleware is the opposite of [apis.RequireAdminOrRecordAuth()].
func RequireGuestOnly() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
err := NewBadRequestError("The request can be accessed only by guests.", nil)
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if record != nil {
return err
// This middleware is the opposite of [apis.RequireAuth()].
func RequireGuestOnly() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireGuestOnlyMiddlewareId,
Func: func(e *core.RequestEvent) error {
if e.Auth != nil {
return router.NewBadRequestError("The request can be accessed only by guests.", nil)
}
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
if admin != nil {
return err
}
return next(c)
}
return e.Next()
},
}
}
// RequireRecordAuth middleware requires a request to have
// a valid record auth Authorization header.
// RequireAuth middleware requires a request to have a valid record Authorization header.
//
// The auth record could be from any collection.
//
// You can further filter the allowed record auth collections by
// specifying their names.
// You can further filter the allowed record auth collections by specifying their names.
//
// Example:
//
// apis.RequireRecordAuth()
//
// Or:
//
// apis.RequireRecordAuth("users", "supervisors")
//
// To restrict the auth record only to the loaded context collection,
// use [apis.RequireSameContextRecordAuth()] instead.
func RequireRecordAuth(optCollectionNames ...string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if record == nil {
return NewUnauthorizedError("The request requires valid record authorization token to be set.", nil)
// apis.RequireAuth() // any auth collection
// apis.RequireAuth("_superusers", "users") // only the listed auth collections
func RequireAuth(optCollectionNames ...string) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireAuthMiddlewareId,
Func: requireAuth(optCollectionNames...),
}
}
func requireAuth(optCollectionNames ...string) hook.HandlerFunc[*core.RequestEvent] {
return func(e *core.RequestEvent) error {
if e.Auth == nil {
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
}
// check record collection name
if len(optCollectionNames) > 0 && !list.ExistInSlice(record.Collection().Name, optCollectionNames) {
return NewForbiddenError("The authorized record model is not allowed to perform this action.", nil)
if len(optCollectionNames) > 0 && !slices.Contains(optCollectionNames, e.Auth.Collection().Name) {
return e.ForbiddenError("The authorized record is not allowed to perform this action.", nil)
}
return next(c)
}
return e.Next()
}
}
// RequireSameContextRecordAuth middleware requires a request to have
// a valid record Authorization header.
//
// The auth record must be from the same collection already loaded in the context.
func RequireSameContextRecordAuth() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if record == nil {
return NewUnauthorizedError("The request requires valid record authorization token to be set.", nil)
}
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil || record.Collection().Id != collection.Id {
return NewForbiddenError(fmt.Sprintf("The request requires auth record from %s collection.", record.Collection().Name), nil)
}
return next(c)
}
// RequireSuperuserAuth middleware requires a request to have
// a valid superuser Authorization header.
func RequireSuperuserAuth() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireSuperuserAuthMiddlewareId,
Func: requireAuth(core.CollectionNameSuperusers),
}
}
// RequireAdminAuth middleware requires a request to have
// a valid admin Authorization header.
func RequireAdminAuth() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
if admin == nil {
return NewUnauthorizedError("The request requires valid admin authorization token to be set.", nil)
// RequireSuperuserAuthOnlyIfAny middleware requires a request to have
// a valid superuser Authorization header ONLY if the application has
// at least 1 existing superuser.
func RequireSuperuserAuthOnlyIfAny() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireSuperuserAuthOnlyIfAnyMiddlewareId,
Func: func(e *core.RequestEvent) error {
if e.HasSuperuserAuth() {
return e.Next()
}
return next(c)
}
}
}
// RequireAdminAuthOnlyIfAny middleware requires a request to have
// a valid admin Authorization header ONLY if the application has
// at least 1 existing Admin model.
func RequireAdminAuthOnlyIfAny(app core.App) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
if admin != nil {
return next(c)
}
totalAdmins, err := app.Dao().TotalAdmins()
totalSuperusers, err := e.App.CountRecords(core.CollectionNameSuperusers)
if err != nil {
return NewBadRequestError("Failed to fetch admins info.", err)
return e.InternalServerError("Failed to fetch superusers info.", err)
}
if totalAdmins == 0 {
return next(c)
if totalSuperusers == 0 {
return e.Next()
}
return NewUnauthorizedError("The request requires valid admin authorization token to be set.", nil)
}
return requireAuth(core.CollectionNameSuperusers)(e)
},
}
}
// RequireAdminOrRecordAuth middleware requires a request to have
// a valid admin or record Authorization header set.
// RequireSuperuserOrOwnerAuth middleware requires a request to have
// a valid superuser or regular record owner Authorization header set.
//
// You can further filter the allowed auth record collections by providing their names.
//
// This middleware is the opposite of [apis.RequireGuestOnly()].
func RequireAdminOrRecordAuth(optCollectionNames ...string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if admin == nil && record == nil {
return NewUnauthorizedError("The request requires admin or record authorization token to be set.", nil)
}
if record != nil && len(optCollectionNames) > 0 && !list.ExistInSlice(record.Collection().Name, optCollectionNames) {
return NewForbiddenError("The authorized record model is not allowed to perform this action.", nil)
}
return next(c)
}
}
}
// RequireAdminOrOwnerAuth middleware requires a request to have
// a valid admin or auth record owner Authorization header set.
//
// This middleware is similar to [apis.RequireAdminOrRecordAuth()] but
// This middleware is similar to [apis.RequireAuth()] but
// for the auth record token expects to have the same id as the path
// parameter ownerIdParam (default to "id" if empty).
func RequireAdminOrOwnerAuth(ownerIdParam string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
if admin != nil {
return next(c)
// parameter ownerIdPathParam (default to "id" if empty).
func RequireSuperuserOrOwnerAuth(ownerIdPathParam string) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireSuperuserOrOwnerAuthMiddlewareId,
Func: func(e *core.RequestEvent) error {
if e.Auth == nil {
return e.UnauthorizedError("The request requires superuser or record authorization token.", nil)
}
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if record == nil {
return NewUnauthorizedError("The request requires admin or record authorization token to be set.", nil)
if e.Auth.IsSuperuser() {
return e.Next()
}
if ownerIdParam == "" {
ownerIdParam = "id"
if ownerIdPathParam == "" {
ownerIdPathParam = "id"
}
ownerId := c.PathParam(ownerIdParam)
ownerId := e.Request.PathValue(ownerIdPathParam)
// note: it is "safe" to compare only the record id since the auth
// record ids are treated as unique across all auth collections
if record.Id != ownerId {
return NewForbiddenError("You are not allowed to perform this request.", nil)
// note: it is considered "safe" to compare only the record id
// since the auth record ids are treated as unique across all auth collections
if e.Auth.Id != ownerId {
return e.ForbiddenError("You are not allowed to perform this request.", nil)
}
return next(c)
}
return e.Next()
},
}
}
// LoadAuthContext middleware reads the Authorization request header
// and loads the token related record or admin instance into the
// request's context.
//
// This middleware is expected to be already registered by default for all routes.
func LoadAuthContext(app core.App) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
token := c.Request().Header.Get("Authorization")
if token == "" {
return next(c)
// RequireSameCollectionContextAuth middleware requires a request to have
// a valid record Authorization header and the auth record's collection to
// match the one from the route path parameter (default to "collection" if collectionParam is empty).
func RequireSameCollectionContextAuth(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRequireSameCollectionContextAuthMiddlewareId,
Func: func(e *core.RequestEvent) error {
if e.Auth == nil {
return e.UnauthorizedError("The request requires valid record authorization token.", nil)
}
// the schema is not required and it is only for
if collectionPathParam == "" {
collectionPathParam = "collection"
}
collection, _ := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
if collection == nil || e.Auth.Collection().Id != collection.Id {
return e.ForbiddenError(fmt.Sprintf("The request requires auth record from %s collection.", e.Auth.Collection().Name), nil)
}
return e.Next()
},
}
}
// loadAuthToken attempts to load the auth context based on the "Authorization: TOKEN" header value.
//
// This middleware does nothing in case of missing, invalid or expired token.
//
// This middleware is registered by default for all routes.
//
// Note: We don't throw an error on invalid or expired token to allow
// users to extend with their own custom handling in external middleware(s).
func loadAuthToken() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultLoadAuthTokenMiddlewareId,
Priority: DefaultLoadAuthTokenMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
token := getAuthTokenFromRequest(e)
if token == "" {
return e.Next()
}
record, err := e.App.FindAuthRecordByToken(token, core.TokenTypeAuth)
if err != nil {
e.App.Logger().Debug("loadAuthToken failure", "error", err)
} else if record != nil {
e.Auth = record
}
return e.Next()
},
}
}
func getAuthTokenFromRequest(e *core.RequestEvent) string {
token := e.Request.Header.Get("Authorization")
if token != "" {
// the schema prefix is not required and it is only for
// compatibility with the defaults of some HTTP clients
token = strings.TrimPrefix(token, "Bearer ")
claims, _ := security.ParseUnverifiedJWT(token)
tokenType := cast.ToString(claims["type"])
switch tokenType {
case tokens.TypeAdmin:
admin, err := app.Dao().FindAdminByToken(
token,
app.Settings().AdminAuthToken.Secret,
)
if err == nil && admin != nil {
c.Set(ContextAdminKey, admin)
}
case tokens.TypeAuthRecord:
record, err := app.Dao().FindAuthRecordByToken(
token,
app.Settings().RecordAuthToken.Secret,
)
if err == nil && record != nil {
c.Set(ContextAuthRecordKey, record)
}
}
return next(c)
}
}
return token
}
// LoadCollectionContext middleware finds the collection with related
// path identifier and loads it into the request context.
// wwwRedirect performs www->non-www redirect(s) if the request host
// matches with one of the values in redirectHosts.
//
// Set optCollectionTypes to further filter the found collection by its type.
func LoadCollectionContext(app core.App, optCollectionTypes ...string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if param := c.PathParam("collection"); param != "" {
collection, err := core.FindCachedCollectionByNameOrId(app, param)
if err != nil || collection == nil {
return NewNotFoundError("", err)
// This middleware is registered by default on Serve for all routes.
func wwwRedirect(redirectHosts []string) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultWWWRedirectMiddlewareId,
Priority: DefaultWWWRedirectMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
host := e.Request.Host
if strings.HasPrefix(host, "www.") && list.ExistInSlice(host, redirectHosts) {
return e.Redirect(
http.StatusTemporaryRedirect,
(e.Request.URL.Scheme + "://" + host[4:] + e.Request.RequestURI),
)
}
if len(optCollectionTypes) > 0 && !list.ExistInSlice(collection.Type, optCollectionTypes) {
return NewBadRequestError("Unsupported collection type.", nil)
}
c.Set(ContextCollectionKey, collection)
}
return next(c)
}
return e.Next()
},
}
}
// ActivityLogger middleware takes care to save the request information
// securityHeaders middleware adds common security headers to the response.
//
// This middleware is registered by default for all routes.
func securityHeaders() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultSecurityHeadersMiddlewareId,
Priority: DefaultSecurityHeadersMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
e.Response.Header().Set("X-XSS-Protection", "1; mode=block")
e.Response.Header().Set("X-Content-Type-Options", "nosniff")
e.Response.Header().Set("X-Frame-Options", "SAMEORIGIN")
// @todo consider a default HSTS?
// (see also https://webkit.org/blog/8146/protecting-against-hsts-abuse/)
return e.Next()
},
}
}
// SkipSuccessActivityLog is a helper middleware that instructs the global
// activity logger to log only requests that have failed/returned an error.
func SkipSuccessActivityLog() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultSkipSuccessActivityLogMiddlewareId,
Func: func(e *core.RequestEvent) error {
e.Set(requestEventKeySkipSuccessActivityLog, true)
return e.Next()
},
}
}
// activityLogger middleware takes care to save the request information
// into the logs database.
//
// This middleware is registered by default for all routes.
//
// The middleware does nothing if the app logs retention period is zero
// (aka. app.Settings().Logs.MaxDays = 0).
func ActivityLogger(app core.App) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if err := next(c); err != nil {
//
// Users can attach the [apis.SkipSuccessActivityLog()] middleware if
// you want to log only the failed requests.
func activityLogger() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultActivityLoggerMiddlewareId,
Priority: DefaultActivityLoggerMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
e.Set(requestEventKeyExecStart, time.Now())
err := e.Next()
logRequest(e, err)
return err
}
logRequest(app, c, nil)
return nil
}
},
}
}
func logRequest(app core.App, c echo.Context, err *ApiError) {
func logRequest(event *core.RequestEvent, err error) {
// no logs retention
if app.Settings().Logs.MaxDays == 0 {
if event.App.Settings().Logs.MaxDays == 0 {
return
}
// the non-error route has explicitly disabled the activity logger
if err == nil && event.Get(requestEventKeySkipSuccessActivityLog) != nil {
return
}
@ -307,32 +325,31 @@ func logRequest(app core.App, c echo.Context, err *ApiError) {
attrs = append(attrs, slog.String("type", "request"))
started := cast.ToTime(c.Get(ContextExecStartKey))
started := cast.ToTime(event.Get(requestEventKeyExecStart))
if !started.IsZero() {
attrs = append(attrs, slog.Float64("execTime", float64(time.Since(started))/float64(time.Millisecond)))
}
httpRequest := c.Request()
httpResponse := c.Response()
method := strings.ToUpper(httpRequest.Method)
status := httpResponse.Status
requestUri := httpRequest.URL.RequestURI()
if meta := event.Get(RequestEventKeyLogMeta); meta != nil {
attrs = append(attrs, slog.Any("meta", meta))
}
status := event.Status()
method := cutStr(strings.ToUpper(event.Request.Method), 50)
requestUri := cutStr(event.Request.URL.RequestURI(), 3000)
// parse the request error
if err != nil {
status = err.Code
if apiErr, ok := err.(*router.ApiError); ok {
status = apiErr.Status
attrs = append(
attrs,
slog.String("error", err.Message),
slog.Any("details", err.RawData()),
slog.String("error", apiErr.Message),
slog.Any("details", apiErr.RawData()),
)
} else {
attrs = append(attrs, slog.String("error", err.Error()))
}
requestAuth := models.RequestAuthGuest
if c.Get(ContextAuthRecordKey) != nil {
requestAuth = models.RequestAuthRecord
} else if c.Get(ContextAdminKey) != nil {
requestAuth = models.RequestAuthAdmin
}
attrs = append(
@ -340,17 +357,33 @@ func logRequest(app core.App, c echo.Context, err *ApiError) {
slog.String("url", requestUri),
slog.String("method", method),
slog.Int("status", status),
slog.String("auth", requestAuth),
slog.String("referer", httpRequest.Referer()),
slog.String("userAgent", httpRequest.UserAgent()),
slog.String("referer", cutStr(event.Request.Referer(), 2000)),
slog.String("userAgent", cutStr(event.Request.UserAgent(), 2000)),
)
if app.Settings().Logs.LogIp {
ip, _, _ := net.SplitHostPort(httpRequest.RemoteAddr)
if event.Auth != nil {
attrs = append(attrs, slog.String("auth", event.Auth.Collection().Name))
if event.App.Settings().Logs.LogAuthId {
attrs = append(attrs, slog.String("authId", event.Auth.Id))
}
} else {
attrs = append(attrs, slog.String("auth", ""))
}
if event.App.Settings().Logs.LogIP {
var userIP string
if len(event.App.Settings().TrustedProxy.Headers) > 0 {
userIP = event.RealIP()
} else {
// fallback to the legacy behavior (it is "safe" since it is only for log purposes)
userIP = cutStr(event.UnsafeRealIP(), 50)
}
attrs = append(
attrs,
slog.String("userIp", realUserIp(httpRequest, ip)),
slog.String("remoteIp", ip),
slog.String("userIP", userIP),
slog.String("remoteIP", event.RemoteIP()),
)
}
@ -358,64 +391,23 @@ func logRequest(app core.App, c echo.Context, err *ApiError) {
routine.FireAndForget(func() {
message := method + " "
if escaped, err := url.PathUnescape(requestUri); err == nil {
if escaped, unescapeErr := url.PathUnescape(requestUri); unescapeErr == nil {
message += escaped
} else {
message += requestUri
}
if err != nil {
app.Logger().Error(message, attrs...)
event.App.Logger().Error(message, attrs...)
} else {
app.Logger().Info(message, attrs...)
event.App.Logger().Info(message, attrs...)
}
})
}
// Returns the "real" user IP from common proxy headers (or fallbackIp if none is found).
//
// The returned IP value shouldn't be trusted if not behind a trusted reverse proxy!
func realUserIp(r *http.Request, fallbackIp string) string {
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("Fly-Client-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ipsList := r.Header.Get("X-Forwarded-For"); ipsList != "" {
// extract the first non-empty leftmost-ish ip
ips := strings.Split(ipsList, ",")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if ip != "" {
return ip
}
}
}
return fallbackIp
}
// @todo consider removing as this may no longer be needed due to the custom rest.MultiBinder.
//
// eagerRequestInfoCache ensures that the request data is cached in the request
// context to allow reading for example the json request body data more than once.
func eagerRequestInfoCache(app core.App) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
switch c.Request().Method {
// currently we are eagerly caching only the requests with body
case "POST", "PUT", "PATCH", "DELETE":
RequestInfo(c)
}
return next(c)
}
func cutStr(str string, max int) string {
if len(str) > max {
return str[:max] + "..."
}
return str
}

View File

@ -0,0 +1,123 @@
package apis
import (
"io"
"net/http"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/router"
)
var ErrRequestEntityTooLarge = router.NewApiError(http.StatusRequestEntityTooLarge, "Request entity too large", nil)
const DefaultMaxBodySize int64 = 32 << 20
const (
DefaultBodyLimitMiddlewareId = "pbBodyLimit"
DefaultBodyLimitMiddlewarePriority = DefaultRateLimitMiddlewarePriority + 10
)
// BodyLimit returns a middleware function that changes the default request body size limit.
//
// Note that in order to have effect this middleware should be registered
// before other middlewares that reads the request body.
//
// If limitBytes <= 0, no limit is applied.
//
// Otherwise, if the request body size exceeds the configured limitBytes,
// it sends 413 error response.
func BodyLimit(limitBytes int64) *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultBodyLimitMiddlewareId,
Priority: DefaultBodyLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
err := applyBodyLimit(e, limitBytes)
if err != nil {
return err
}
return e.Next()
},
}
}
func dynamicCollectionBodyLimit(collectionPathParam string) *hook.Handler[*core.RequestEvent] {
if collectionPathParam == "" {
collectionPathParam = "collection"
}
return &hook.Handler[*core.RequestEvent]{
Id: DefaultBodyLimitMiddlewareId,
Priority: DefaultBodyLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
if err != nil {
return e.NotFoundError("Missing or invalid collection context.", err)
}
limitBytes := DefaultMaxBodySize
if !collection.IsView() {
for _, f := range collection.Fields {
if calc, ok := f.(core.MaxBodySizeCalculator); ok {
limitBytes += calc.CalculateMaxBodySize()
}
}
}
err = applyBodyLimit(e, limitBytes)
if err != nil {
return err
}
return e.Next()
},
}
}
func applyBodyLimit(e *core.RequestEvent, limitBytes int64) error {
// no limit
if limitBytes <= 0 {
return nil
}
// optimistically check the submitted request content length
if e.Request.ContentLength > limitBytes {
return ErrRequestEntityTooLarge
}
// replace the request body
//
// note: we don't use sync.Pool since the size of the elements could vary too much
// and it might not be efficient (see https://github.com/golang/go/issues/23199)
e.Request.Body = &limitedReader{ReadCloser: e.Request.Body, limit: limitBytes}
return nil
}
type limitedReader struct {
io.ReadCloser
limit int64
totalRead int64
}
func (r *limitedReader) Read(b []byte) (int, error) {
n, err := r.ReadCloser.Read(b)
if err != nil {
return n, err
}
r.totalRead += int64(n)
if r.totalRead > r.limit {
return n, ErrRequestEntityTooLarge
}
return n, nil
}
func (r *limitedReader) Reread() {
rr, ok := r.ReadCloser.(router.Rereader)
if ok {
rr.Reread()
}
}

View File

@ -0,0 +1,60 @@
package apis_test
import (
"bytes"
"fmt"
"net/http/httptest"
"testing"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestBodyLimitMiddleware(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
pbRouter, err := apis.NewRouter(app)
if err != nil {
t.Fatal(err)
}
pbRouter.POST("/a", func(e *core.RequestEvent) error {
return e.String(200, "a")
}) // default global BodyLimit check
pbRouter.POST("/b", func(e *core.RequestEvent) error {
return e.String(200, "b")
}).Bind(apis.BodyLimit(20))
mux, err := pbRouter.BuildMux()
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
url string
size int64
expectedStatus int
}{
{"/a", 21, 200},
{"/a", apis.DefaultMaxBodySize + 1, 413},
{"/b", 20, 200},
{"/b", 21, 413},
}
for _, s := range scenarios {
t.Run(fmt.Sprintf("%s_%d", s.url, s.size), func(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest("POST", s.url, bytes.NewReader(make([]byte, s.size)))
mux.ServeHTTP(rec, req)
result := rec.Result()
defer result.Body.Close()
if result.StatusCode != s.expectedStatus {
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
}
})
}
}

307
apis/middlewares_cors.go Normal file
View File

@ -0,0 +1,307 @@
package apis
// -------------------------------------------------------------------
// This middleware is ported from echo/middleware to minimize the breaking
// changes and differences in the API behavior from earlier PocketBase versions
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/cors.go).
//
// I doubt that this would matter for most cases, but the only major difference
// is that for non-supported routes this middleware doesn't return 405 and fallbacks
// to the default catch-all PocketBase route (aka. returns 404) to avoid
// the extra overhead of further hijacking and wrapping the Go default mux
// (https://github.com/golang/go/issues/65648#issuecomment-1955328807).
// -------------------------------------------------------------------
import (
"net/http"
"regexp"
"strconv"
"strings"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
)
const (
DefaultCorsMiddlewareId = "pbCors"
DefaultCorsMiddlewarePriority = DefaultActivityLoggerMiddlewarePriority - 1 // before the activity logger and rate limit so that OPTIONS preflight requests are not counted
)
// CORSConfig defines the config for CORS middleware.
type CORSConfig struct {
// AllowOrigins determines the value of the Access-Control-Allow-Origin
// response header. This header defines a list of origins that may access the
// resource. The wildcard characters '*' and '?' are supported and are
// converted to regex fragments '.*' and '.' accordingly.
//
// Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// Optional. Default value []string{"*"}.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
//
// Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// Optional.
AllowOriginFunc func(origin string) (bool, error)
// AllowMethods determines the value of the Access-Control-Allow-Methods
// response header. This header specified the list of methods allowed when
// accessing the resource. This is used in response to a preflight request.
//
// Optional. Default value DefaultCORSConfig.AllowMethods.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
AllowMethods []string
// AllowHeaders determines the value of the Access-Control-Allow-Headers
// response header. This header is used in response to a preflight request to
// indicate which HTTP headers can be used when making the actual request.
//
// Optional. Default value []string{}.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
AllowHeaders []string
// AllowCredentials determines the value of the
// Access-Control-Allow-Credentials response header. This header indicates
// whether or not the response to the request can be exposed when the
// credentials mode (Request.credentials) is true. When used as part of a
// response to a preflight request, this indicates whether or not the actual
// request can be made using credentials. See also
// [MDN: Access-Control-Allow-Credentials].
//
// Optional. Default value false, in which case the header is not set.
//
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
// See "Exploiting CORS misconfigurations for Bitcoins and bounties",
// https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
AllowCredentials bool
// UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials
// flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header.
//
// This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties)
// attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject.
//
// Optional. Default value is false.
UnsafeWildcardOriginWithAllowCredentials bool
// ExposeHeaders determines the value of Access-Control-Expose-Headers, which
// defines a list of headers that clients are allowed to access.
//
// Optional. Default value []string{}, in which case the header is not set.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
ExposeHeaders []string
// MaxAge determines the value of the Access-Control-Max-Age response header.
// This header indicates how long (in seconds) the results of a preflight
// request can be cached.
// The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response.
//
// Optional. Default value 0 - meaning header is not sent.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
MaxAge int
}
// DefaultCORSConfig is the default CORS middleware config.
var DefaultCORSConfig = CORSConfig{
AllowOrigins: []string{"*"},
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}
// CORSWithConfig returns a CORS middleware with config.
func CORSWithConfig(config CORSConfig) hook.HandlerFunc[*core.RequestEvent] {
// Defaults
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}
if len(config.AllowMethods) == 0 {
config.AllowMethods = DefaultCORSConfig.AllowMethods
}
allowOriginPatterns := []string{}
for _, origin := range config.AllowOrigins {
pattern := regexp.QuoteMeta(origin)
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
pattern = strings.ReplaceAll(pattern, "\\?", ".")
pattern = "^" + pattern + "$"
allowOriginPatterns = append(allowOriginPatterns, pattern)
}
allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",")
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
maxAge := "0"
if config.MaxAge > 0 {
maxAge = strconv.Itoa(config.MaxAge)
}
return func(e *core.RequestEvent) error {
req := e.Request
res := e.Response
origin := req.Header.Get("Origin")
allowOrigin := ""
res.Header().Add("Vary", "Origin")
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
// For simplicity we just consider method type and later `Origin` header.
preflight := req.Method == http.MethodOptions
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
if origin == "" {
if !preflight {
return e.Next()
}
return e.NoContent(http.StatusNoContent)
}
if config.AllowOriginFunc != nil {
allowed, err := config.AllowOriginFunc(origin)
if err != nil {
return err
}
if allowed {
allowOrigin = origin
}
} else {
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials {
allowOrigin = origin
break
}
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
checkPatterns := false
if allowOrigin == "" {
// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
if len(origin) <= (253+3+5) && strings.Contains(origin, "://") {
checkPatterns = true
}
}
if checkPatterns {
for _, re := range allowOriginPatterns {
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
}
}
}
}
// Origin not allowed
if allowOrigin == "" {
if !preflight {
return e.Next()
}
return e.NoContent(http.StatusNoContent)
}
res.Header().Set("Access-Control-Allow-Origin", allowOrigin)
if config.AllowCredentials {
res.Header().Set("Access-Control-Allow-Credentials", "true")
}
// Simple request
if !preflight {
if exposeHeaders != "" {
res.Header().Set("Access-Control-Expose-Headers", exposeHeaders)
}
return e.Next()
}
// Preflight request
res.Header().Add("Vary", "Access-Control-Request-Method")
res.Header().Add("Vary", "Access-Control-Request-Headers")
res.Header().Set("Access-Control-Allow-Methods", allowMethods)
if allowHeaders != "" {
res.Header().Set("Access-Control-Allow-Headers", allowHeaders)
} else {
h := req.Header.Get("Access-Control-Request-Headers")
if h != "" {
res.Header().Set("Access-Control-Allow-Headers", h)
}
}
if config.MaxAge != 0 {
res.Header().Set("Access-Control-Max-Age", maxAge)
}
return e.NoContent(http.StatusNoContent)
}
}
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// matchSubdomain compares authority with wildcard
func matchSubdomain(domain, pattern string) bool {
if !matchScheme(domain, pattern) {
return false
}
didx := strings.Index(domain, "://")
pidx := strings.Index(pattern, "://")
if didx == -1 || pidx == -1 {
return false
}
domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain
if len(domAuth) > 253 {
return false
}
patAuth := pattern[pidx+3:]
domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".")
for i := len(domComp)/2 - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i]
}
for i := len(patComp)/2 - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i]
}
for i, v := range domComp {
if len(patComp) <= i {
return false
}
p := patComp[i]
if p == "*" {
return true
}
if p != v {
return false
}
}
return false
}

237
apis/middlewares_gzip.go Normal file
View File

@ -0,0 +1,237 @@
package apis
// -------------------------------------------------------------------
// This middleware is ported from echo/middleware to minimize the breaking
// changes and differences in the API behavior from earlier PocketBase versions
// (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/compress.go).
// -------------------------------------------------------------------
import (
"bufio"
"bytes"
"compress/gzip"
"errors"
"io"
"net"
"net/http"
"strings"
"sync"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/router"
)
const (
gzipScheme = "gzip"
)
// GzipConfig defines the config for Gzip middleware.
type GzipConfig struct {
// Gzip compression level.
// Optional. Default value -1.
Level int
// Length threshold before gzip compression is applied.
// Optional. Default value 0.
//
// Most of the time you will not need to change the default. Compressing
// a short response might increase the transmitted data because of the
// gzip format overhead. Compressing the response will also consume CPU
// and time on the server and the client (for decompressing). Depending on
// your use case such a threshold might be useful.
//
// See also:
// https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
MinLength int
}
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
func Gzip() hook.HandlerFunc[*core.RequestEvent] {
return GzipWithConfig(GzipConfig{})
}
// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
func GzipWithConfig(config GzipConfig) hook.HandlerFunc[*core.RequestEvent] {
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
panic(errors.New("invalid gzip level"))
}
if config.Level == 0 {
config.Level = -1
}
if config.MinLength < 0 {
config.MinLength = 0
}
pool := sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(io.Discard, config.Level)
if err != nil {
return err
}
return w
},
}
bpool := sync.Pool{
New: func() interface{} {
b := &bytes.Buffer{}
return b
},
}
return func(e *core.RequestEvent) error {
e.Response.Header().Add("Vary", "Accept-Encoding")
if strings.Contains(e.Request.Header.Get("Accept-Encoding"), gzipScheme) {
w, ok := pool.Get().(*gzip.Writer)
if !ok {
return e.InternalServerError("", errors.New("failed to get gzip.Writer"))
}
rw := e.Response
w.Reset(rw)
buf := bpool.Get().(*bytes.Buffer)
buf.Reset()
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf}
defer func() {
// There are different reasons for cases when we have not yet written response to the client and now need to do so.
// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
if !grw.wroteBody {
if rw.Header().Get("Content-Encoding") == gzipScheme {
rw.Header().Del("Content-Encoding")
}
if grw.wroteHeader {
rw.WriteHeader(grw.code)
}
// We have to reset response to it's pristine state when
// nothing is written to body or error is returned.
// See issue echo#424, echo#407.
e.Response = rw
w.Reset(io.Discard)
} else if !grw.minLengthExceeded {
// Write uncompressed response
e.Response = rw
if grw.wroteHeader {
rw.WriteHeader(grw.code)
}
grw.buffer.WriteTo(rw)
w.Reset(io.Discard)
}
w.Close()
bpool.Put(buf)
pool.Put(w)
}()
e.Response = grw
}
return e.Next()
}
}
type gzipResponseWriter struct {
http.ResponseWriter
io.Writer
buffer *bytes.Buffer
minLength int
code int
wroteHeader bool
wroteBody bool
minLengthExceeded bool
}
func (w *gzipResponseWriter) WriteHeader(code int) {
w.Header().Del("Content-Length") // Issue echo#444
w.wroteHeader = true
// Delay writing of the header until we know if we'll actually compress the response
w.code = code
}
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
if w.Header().Get("Content-Type") == "" {
w.Header().Set("Content-Type", http.DetectContentType(b))
}
w.wroteBody = true
if !w.minLengthExceeded {
n, err := w.buffer.Write(b)
if w.buffer.Len() >= w.minLength {
w.minLengthExceeded = true
// The minimum length is exceeded, add Content-Encoding header and write the header
w.Header().Set("Content-Encoding", gzipScheme)
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
return w.Writer.Write(w.buffer.Bytes())
}
return n, err
}
return w.Writer.Write(b)
}
func (w *gzipResponseWriter) Flush() {
if !w.minLengthExceeded {
// Enforce compression because we will not know how much more data will come
w.minLengthExceeded = true
w.Header().Set("Content-Encoding", gzipScheme)
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
_, _ = w.Writer.Write(w.buffer.Bytes())
}
_ = w.Writer.(*gzip.Writer).Flush()
_ = http.NewResponseController(w.ResponseWriter).Flush()
}
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(w.ResponseWriter).Hijack()
}
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
rw := w.ResponseWriter
for {
switch p := rw.(type) {
case http.Pusher:
return p.Push(target, opts)
case router.RWUnwrapper:
rw = p.Unwrap()
default:
return http.ErrNotSupported
}
}
}
func (w *gzipResponseWriter) ReadFrom(r io.Reader) (n int64, err error) {
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
rw := w.ResponseWriter
for {
switch rf := rw.(type) {
case io.ReaderFrom:
return rf.ReadFrom(r)
case router.RWUnwrapper:
rw = rf.Unwrap()
default:
return io.Copy(w.ResponseWriter, r)
}
}
}
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

View File

@ -0,0 +1,298 @@
package apis
import (
"sync"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/store"
)
const (
DefaultRateLimitMiddlewareId = "pbRateLimit"
DefaultRateLimitMiddlewarePriority = -1000
)
const (
rateLimitersStoreKey = "__pbRateLimiters__"
rateLimitersCronKey = "__pbRateLimitersCleanup__"
rateLimitersSettingsHookId = "__pbRateLimitersSettingsHook__"
)
// rateLimit defines the global rate limit middleware.
//
// This middleware is registered by default for all routes.
func rateLimit() *hook.Handler[*core.RequestEvent] {
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRateLimitMiddlewareId,
Priority: DefaultRateLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
if skipRateLimit(e) {
return e.Next()
}
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(defaultRateLimitLabels(e))
if ok {
err := checkRateLimit(e, e.Request.Pattern, rule)
if err != nil {
return err
}
}
return e.Next()
},
}
}
// collectionPathRateLimit defines a rate limit middleware for the internal collection handlers.
func collectionPathRateLimit(collectionPathParam string, baseTags ...string) *hook.Handler[*core.RequestEvent] {
if collectionPathParam == "" {
collectionPathParam = "collection"
}
return &hook.Handler[*core.RequestEvent]{
Id: DefaultRateLimitMiddlewareId,
Priority: DefaultRateLimitMiddlewarePriority,
Func: func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
if err != nil {
return e.NotFoundError("Missing or invalid collection context.", err)
}
if err := checkCollectionRateLimit(e, collection, baseTags...); err != nil {
return err
}
return e.Next()
},
}
}
// checkCollectionRateLimit checks whether the current request satisfy the
// rate limit configuration for the specific collection.
//
// Each baseTags entry will be prefixed with the collection name and its wildcard variant.
func checkCollectionRateLimit(e *core.RequestEvent, collection *core.Collection, baseTags ...string) error {
if skipRateLimit(e) {
return nil
}
labels := make([]string, 0, 2+len(baseTags)*2)
rtId := collection.Id + e.Request.Pattern
// add first the primary labels (aka. ["collectionName:action1", "collectionName:action2"])
for _, baseTag := range baseTags {
rtId += baseTag
labels = append(labels, collection.Name+":"+baseTag)
}
// add the wildcard labels (aka. [..., "*:action1","*:action2", "*"])
for _, baseTag := range baseTags {
labels = append(labels, "*:"+baseTag)
}
labels = append(labels, defaultRateLimitLabels(e)...)
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(labels)
if ok {
return checkRateLimit(e, rtId, rule)
}
return nil
}
// -------------------------------------------------------------------
// @todo consider exporting as RateLimit helper?
func checkRateLimit(e *core.RequestEvent, rtId string, rule core.RateLimitRule) error {
rateLimiters := e.App.Store().GetOrSet(rateLimitersStoreKey, func() any {
return initRateLimitersStore(e.App)
}).(*store.Store[*rateLimiter])
if rateLimiters == nil {
e.App.Logger().Warn("Failed to retrieve app rate limiters store")
return nil
}
rt := rateLimiters.GetOrSet(rtId, func() *rateLimiter {
return newRateLimiter(rule.MaxRequests, rule.Duration, rule.Duration+1800)
})
if rt == nil {
e.App.Logger().Warn("Failed to retrieve app rate limiter", "id", rtId)
return nil
}
key := e.RealIP()
if key == "" {
e.App.Logger().Warn("Empty rate limit client key")
return nil
}
if !rt.isAllowed(key) {
return e.TooManyRequestsError("", nil)
}
return nil
}
func skipRateLimit(e *core.RequestEvent) bool {
return !e.App.Settings().RateLimits.Enabled || e.HasSuperuserAuth()
}
func defaultRateLimitLabels(e *core.RequestEvent) []string {
return []string{e.Request.Method + " " + e.Request.URL.Path, e.Request.URL.Path}
}
func destroyRateLimitersStore(app core.App) {
app.OnSettingsReload().Unbind(rateLimitersSettingsHookId)
app.Cron().Remove(rateLimitersCronKey)
app.Store().Remove(rateLimitersStoreKey)
}
func initRateLimitersStore(app core.App) *store.Store[*rateLimiter] {
app.Cron().Add(rateLimitersCronKey, "2 * * * *", func() { // offset a little since too many cleanup tasks execute at 00
limitersStore, ok := app.Store().Get(rateLimitersStoreKey).(*store.Store[*rateLimiter])
if !ok {
return
}
limiters := limitersStore.GetAll()
for _, limiter := range limiters {
limiter.clean()
}
})
app.OnSettingsReload().Bind(&hook.Handler[*core.SettingsReloadEvent]{
Id: rateLimitersSettingsHookId,
Func: func(e *core.SettingsReloadEvent) error {
err := e.Next()
if err != nil {
return err
}
// reset
destroyRateLimitersStore(e.App)
return nil
},
})
return store.New[*rateLimiter](nil)
}
func newRateLimiter(maxAllowed int, intervalInSec int64, minDeleteIntervalInSec int64) *rateLimiter {
return &rateLimiter{
maxAllowed: maxAllowed,
interval: intervalInSec,
minDeleteInterval: minDeleteIntervalInSec,
clients: map[string]*fixedWindow{},
}
}
type rateLimiter struct {
clients map[string]*fixedWindow
maxAllowed int
interval int64
minDeleteInterval int64
totalDeleted int64
sync.RWMutex
}
func (rt *rateLimiter) isAllowed(key string) bool {
// lock only reads to minimize locks contention
rt.RLock()
client, ok := rt.clients[key]
rt.RUnlock()
if !ok {
rt.Lock()
// check again in case the client was added by another request
client, ok = rt.clients[key]
if !ok {
client = newFixedWindow(rt.maxAllowed, rt.interval)
rt.clients[key] = client
}
rt.Unlock()
}
return client.consume()
}
func (rt *rateLimiter) clean() {
rt.Lock()
defer rt.Unlock()
nowUnix := time.Now().Unix()
for k, client := range rt.clients {
if client.hasExpired(nowUnix, rt.minDeleteInterval) {
delete(rt.clients, k)
rt.totalDeleted++
}
}
// "shrink" the map if too may items were deleted
//
// @todo remove after https://github.com/golang/go/issues/20135
if rt.totalDeleted >= 300 {
shrunk := make(map[string]*fixedWindow, len(rt.clients))
for k, v := range rt.clients {
shrunk[k] = v
}
rt.clients = shrunk
rt.totalDeleted = 0
}
}
func newFixedWindow(maxAllowed int, intervalInSec int64) *fixedWindow {
return &fixedWindow{
maxAllowed: maxAllowed,
interval: intervalInSec,
}
}
type fixedWindow struct {
// use plain Mutex instead of RWMutex since the operations are expected
// to be mostly writes (e.g. consume()) and it should perform better
sync.Mutex
maxAllowed int // the max allowed tokens per interval
available int // the total available tokens
interval int64 // in seconds
lastConsume int64 // the time of the last consume
}
// hasExpired checks whether it has been at least minElapsed seconds since the lastConsume time.
// (usually used to perform periodic cleanup of staled instances).
func (l *fixedWindow) hasExpired(relativeNow int64, minElapsed int64) bool {
l.Lock()
defer l.Unlock()
return relativeNow-l.lastConsume > minElapsed
}
// consume decrease the current window allowance with 1 (if not exhausted already).
//
// It returns false if the allowance has been already exhausted and the user
// has to wait until it resets back to its maxAllowed value.
func (l *fixedWindow) consume() bool {
l.Lock()
defer l.Unlock()
nowUnix := time.Now().Unix()
// reset consumed counter
if nowUnix-l.lastConsume >= l.interval {
l.available = l.maxAllowed
}
if l.available > 0 {
l.available--
l.lastConsume = nowUnix
return true
}
return false
}

View File

@ -0,0 +1,103 @@
package apis_test
import (
"net/http/httptest"
"testing"
"time"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestDefaultRateLimitMiddleware(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{
Label: "/rate/",
MaxRequests: 2,
Duration: 1,
},
{
Label: "/rate/b",
MaxRequests: 3,
Duration: 1,
},
{
Label: "POST /rate/b",
MaxRequests: 1,
Duration: 1,
},
}
pbRouter, err := apis.NewRouter(app)
if err != nil {
t.Fatal(err)
}
pbRouter.GET("/norate", func(e *core.RequestEvent) error {
return e.String(200, "norate")
}).BindFunc(func(e *core.RequestEvent) error {
return e.Next()
})
pbRouter.GET("/rate/a", func(e *core.RequestEvent) error {
return e.String(200, "a")
})
pbRouter.GET("/rate/b", func(e *core.RequestEvent) error {
return e.String(200, "b")
})
mux, err := pbRouter.BuildMux()
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
url string
wait float64
expectedStatus int
}{
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/norate", 0, 200},
{"/rate/a", 0, 200},
{"/rate/a", 0, 200},
{"/rate/a", 0, 429},
{"/rate/a", 0, 429},
{"/rate/a", 1.1, 200},
{"/rate/a", 0, 200},
{"/rate/a", 0, 429},
{"/rate/b", 0, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 429},
{"/rate/b", 1.1, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 200},
{"/rate/b", 0, 429},
}
for _, s := range scenarios {
t.Run(s.url, func(t *testing.T) {
if s.wait > 0 {
time.Sleep(time.Duration(s.wait) * time.Second)
}
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", s.url, nil)
mux.ServeHTTP(rec, req)
result := rec.Result()
if result.StatusCode != s.expectedStatus {
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@ -9,198 +9,196 @@ import (
"strings"
"time"
"github.com/labstack/echo/v5"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/resolvers"
"github.com/pocketbase/pocketbase/tools/rest"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/picker"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/subscriptions"
"github.com/spf13/cast"
"golang.org/x/sync/errgroup"
)
// note: the chunk size is arbitrary chosen and may change in the future
const clientsChunkSize = 150
// RealtimeClientAuthKey is the name of the realtime client store key that holds its auth state.
const RealtimeClientAuthKey = "auth"
// bindRealtimeApi registers the realtime api endpoints.
func bindRealtimeApi(app core.App, rg *echo.Group) {
api := realtimeApi{app: app}
func bindRealtimeApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
sub := rg.Group("/realtime")
sub.GET("", realtimeConnect).Bind(SkipSuccessActivityLog())
sub.POST("", realtimeSetSubscriptions)
subGroup := rg.Group("/realtime")
subGroup.GET("", api.connect)
subGroup.POST("", api.setSubscriptions, ActivityLogger(app))
api.bindEvents()
bindRealtimeEvents(app)
}
type realtimeApi struct {
app core.App
}
func realtimeConnect(e *core.RequestEvent) error {
// disable global write deadline for the SSE connection
rc := http.NewResponseController(e.Response)
writeDeadlineErr := rc.SetWriteDeadline(time.Time{})
if writeDeadlineErr != nil {
if !errors.Is(writeDeadlineErr, http.ErrNotSupported) {
return e.InternalServerError("Failed to initialize SSE connection.", writeDeadlineErr)
}
func (api *realtimeApi) connect(c echo.Context) error {
cancelCtx, cancelRequest := context.WithCancel(c.Request().Context())
// only log since there are valid cases where it may not be implement (e.g. httptest.ResponseRecorder)
e.App.Logger().Warn("SetWriteDeadline is not supported, fallback to the default server WriteTimeout")
}
// create cancellable request
cancelCtx, cancelRequest := context.WithCancel(e.Request.Context())
defer cancelRequest()
c.SetRequest(c.Request().Clone(cancelCtx))
e.Request = e.Request.Clone(cancelCtx)
// register new subscription client
client := subscriptions.NewDefaultClient()
api.app.SubscriptionsBroker().Register(client)
defer func() {
disconnectEvent := &core.RealtimeDisconnectEvent{
HttpContext: c,
Client: client,
}
if err := api.app.OnRealtimeDisconnectRequest().Trigger(disconnectEvent); err != nil {
api.app.Logger().Debug(
"OnRealtimeDisconnectRequest error",
slog.String("clientId", client.Id()),
slog.String("error", err.Error()),
)
}
api.app.SubscriptionsBroker().Unregister(client.Id())
}()
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-store")
e.Response.Header().Set("Content-Type", "text/event-stream")
e.Response.Header().Set("Cache-Control", "no-store")
// https://github.com/pocketbase/pocketbase/discussions/480#discussioncomment-3657640
// https://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_buffering
c.Response().Header().Set("X-Accel-Buffering", "no")
e.Response.Header().Set("X-Accel-Buffering", "no")
connectEvent := &core.RealtimeConnectEvent{
HttpContext: c,
Client: client,
IdleTimeout: 5 * time.Minute,
}
connectEvent := new(core.RealtimeConnectRequestEvent)
connectEvent.RequestEvent = e
connectEvent.Client = subscriptions.NewDefaultClient()
connectEvent.IdleTimeout = 5 * time.Minute
if err := api.app.OnRealtimeConnectRequest().Trigger(connectEvent); err != nil {
return err
}
return e.App.OnRealtimeConnectRequest().Trigger(connectEvent, func(ce *core.RealtimeConnectRequestEvent) error {
// register new subscription client
ce.App.SubscriptionsBroker().Register(ce.Client)
defer func() {
e.App.SubscriptionsBroker().Unregister(ce.Client.Id())
}()
api.app.Logger().Debug("Realtime connection established.", slog.String("clientId", client.Id()))
ce.App.Logger().Debug("Realtime connection established.", slog.String("clientId", ce.Client.Id()))
// signalize established connection (aka. fire "connect" message)
connectMsgEvent := &core.RealtimeMessageEvent{
HttpContext: c,
Client: client,
Message: &subscriptions.Message{
connectMsgEvent := new(core.RealtimeMessageEvent)
connectMsgEvent.RequestEvent = ce.RequestEvent
connectMsgEvent.Client = ce.Client
connectMsgEvent.Message = &subscriptions.Message{
Name: "PB_CONNECT",
Data: []byte(`{"clientId":"` + client.Id() + `"}`),
},
Data: []byte(`{"clientId":"` + ce.Client.Id() + `"}`),
}
connectMsgErr := api.app.OnRealtimeBeforeMessageSend().Trigger(connectMsgEvent, func(e *core.RealtimeMessageEvent) error {
w := e.HttpContext.Response()
w.Write([]byte("id:" + client.Id() + "\n"))
w.Write([]byte("event:" + e.Message.Name + "\n"))
w.Write([]byte("data:"))
w.Write(e.Message.Data)
w.Write([]byte("\n\n"))
w.Flush()
return api.app.OnRealtimeAfterMessageSend().Trigger(e)
connectMsgErr := ce.App.OnRealtimeMessageSend().Trigger(connectMsgEvent, func(me *core.RealtimeMessageEvent) error {
me.Response.Write([]byte("id:" + me.Client.Id() + "\n"))
me.Response.Write([]byte("event:" + me.Message.Name + "\n"))
me.Response.Write([]byte("data:"))
me.Response.Write(me.Message.Data)
me.Response.Write([]byte("\n\n"))
return me.Flush()
})
if connectMsgErr != nil {
api.app.Logger().Debug(
ce.App.Logger().Debug(
"Realtime connection closed (failed to deliver PB_CONNECT)",
slog.String("clientId", client.Id()),
slog.String("clientId", ce.Client.Id()),
slog.String("error", connectMsgErr.Error()),
)
return nil
}
// start an idle timer to keep track of inactive/forgotten connections
idleTimeout := connectEvent.IdleTimeout
idleTimer := time.NewTimer(idleTimeout)
idleTimer := time.NewTimer(ce.IdleTimeout)
defer idleTimer.Stop()
for {
select {
case <-idleTimer.C:
cancelRequest()
case msg, ok := <-client.Channel():
case msg, ok := <-ce.Client.Channel():
if !ok {
// channel is closed
api.app.Logger().Debug(
ce.App.Logger().Debug(
"Realtime connection closed (closed channel)",
slog.String("clientId", client.Id()),
slog.String("clientId", ce.Client.Id()),
)
return nil
}
msgEvent := &core.RealtimeMessageEvent{
HttpContext: c,
Client: client,
Message: &msg,
}
msgErr := api.app.OnRealtimeBeforeMessageSend().Trigger(msgEvent, func(e *core.RealtimeMessageEvent) error {
w := e.HttpContext.Response()
w.Write([]byte("id:" + e.Client.Id() + "\n"))
w.Write([]byte("event:" + e.Message.Name + "\n"))
w.Write([]byte("data:"))
w.Write(e.Message.Data)
w.Write([]byte("\n\n"))
w.Flush()
return api.app.OnRealtimeAfterMessageSend().Trigger(msgEvent)
msgEvent := new(core.RealtimeMessageEvent)
msgEvent.RequestEvent = ce.RequestEvent
msgEvent.Client = ce.Client
msgEvent.Message = &msg
msgErr := ce.App.OnRealtimeMessageSend().Trigger(msgEvent, func(me *core.RealtimeMessageEvent) error {
me.Response.Write([]byte("id:" + me.Client.Id() + "\n"))
me.Response.Write([]byte("event:" + me.Message.Name + "\n"))
me.Response.Write([]byte("data:"))
me.Response.Write(me.Message.Data)
me.Response.Write([]byte("\n\n"))
return me.Flush()
})
if msgErr != nil {
api.app.Logger().Debug(
ce.App.Logger().Debug(
"Realtime connection closed (failed to deliver message)",
slog.String("clientId", client.Id()),
slog.String("clientId", ce.Client.Id()),
slog.String("error", msgErr.Error()),
)
return nil
}
idleTimer.Stop()
idleTimer.Reset(idleTimeout)
case <-c.Request().Context().Done():
idleTimer.Reset(ce.IdleTimeout)
case <-ce.Request.Context().Done():
// connection is closed
api.app.Logger().Debug(
ce.App.Logger().Debug(
"Realtime connection closed (cancelled request)",
slog.String("clientId", client.Id()),
slog.String("clientId", ce.Client.Id()),
)
return nil
}
}
})
}
type realtimeSubscribeForm struct {
ClientId string `form:"clientId" json:"clientId"`
Subscriptions []string `form:"subscriptions" json:"subscriptions"`
}
func (form *realtimeSubscribeForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.ClientId, validation.Required, validation.Length(1, 255)),
)
}
// note: in case of reconnect, clients will have to resubmit all subscriptions again
func (api *realtimeApi) setSubscriptions(c echo.Context) error {
form := forms.NewRealtimeSubscribe()
func realtimeSetSubscriptions(e *core.RequestEvent) error {
form := new(realtimeSubscribeForm)
// read request data
if err := c.Bind(form); err != nil {
return NewBadRequestError("", err)
err := e.BindBody(form)
if err != nil {
return e.BadRequestError("", err)
}
// validate request data
if err := form.Validate(); err != nil {
return NewBadRequestError("", err)
err = form.validate()
if err != nil {
return e.BadRequestError("", err)
}
// find subscription client
client, err := api.app.SubscriptionsBroker().ClientById(form.ClientId)
client, err := e.App.SubscriptionsBroker().ClientById(form.ClientId)
if err != nil {
return NewNotFoundError("Missing or invalid client id.", err)
return e.NotFoundError("Missing or invalid client id.", err)
}
// check if the previous request was authorized
oldAuthId := extractAuthIdFromGetter(client)
newAuthId := extractAuthIdFromGetter(c)
newAuthId := extractAuthIdFromGetter(e)
if oldAuthId != "" && oldAuthId != newAuthId {
return NewForbiddenError("The current and the previous request authorization don't match.", nil)
return e.ForbiddenError("The current and the previous request authorization don't match.", nil)
}
event := &core.RealtimeSubscribeEvent{
HttpContext: c,
Client: client,
Subscriptions: form.Subscriptions,
}
event := new(core.RealtimeSubscribeRequestEvent)
event.RequestEvent = e
event.Client = client
event.Subscriptions = form.Subscriptions
return api.app.OnRealtimeBeforeSubscribeRequest().Trigger(event, func(e *core.RealtimeSubscribeEvent) error {
return e.App.OnRealtimeSubscribeRequest().Trigger(event, func(e *core.RealtimeSubscribeRequestEvent) error {
// update auth state
e.Client.Set(ContextAdminKey, e.HttpContext.Get(ContextAdminKey))
e.Client.Set(ContextAuthRecordKey, e.HttpContext.Get(ContextAuthRecordKey))
e.Client.Set(RealtimeClientAuthKey, e.Auth)
// unsubscribe from any previous existing subscriptions
e.Client.Unsubscribe()
@ -208,81 +206,113 @@ func (api *realtimeApi) setSubscriptions(c echo.Context) error {
// subscribe to the new subscriptions
e.Client.Subscribe(e.Subscriptions...)
api.app.Logger().Debug(
e.App.Logger().Debug(
"Realtime subscriptions updated.",
slog.String("clientId", e.Client.Id()),
slog.Any("subscriptions", e.Subscriptions),
)
return api.app.OnRealtimeAfterSubscribeRequest().Trigger(event, func(e *core.RealtimeSubscribeEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
return e.NoContent(http.StatusNoContent)
})
}
// updateClientsAuthModel updates the existing clients auth model with the new one (matched by ID).
func (api *realtimeApi) updateClientsAuthModel(contextKey string, newModel models.Model) error {
for _, client := range api.app.SubscriptionsBroker().Clients() {
clientModel, _ := client.Get(contextKey).(models.Model)
if clientModel != nil &&
clientModel.TableName() == newModel.TableName() &&
clientModel.GetId() == newModel.GetId() {
client.Set(contextKey, newModel)
// updateClientsAuth updates the existing clients auth record with the new one (matched by ID).
func realtimeUpdateClientsAuth(app core.App, newAuthRecord *core.Record) error {
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
group := new(errgroup.Group)
for _, chunk := range chunks {
group.Go(func() error {
for _, client := range chunk {
clientAuth, _ := client.Get(RealtimeClientAuthKey).(*core.Record)
if clientAuth != nil &&
clientAuth.Id == newAuthRecord.Id &&
clientAuth.Collection().Name == newAuthRecord.Collection().Name {
client.Set(RealtimeClientAuthKey, newAuthRecord)
}
}
return nil
})
}
return group.Wait()
}
// unregisterClientsByAuthModel unregister all clients that has the provided auth model.
func (api *realtimeApi) unregisterClientsByAuthModel(contextKey string, model models.Model) error {
for _, client := range api.app.SubscriptionsBroker().Clients() {
clientModel, _ := client.Get(contextKey).(models.Model)
if clientModel != nil &&
clientModel.TableName() == model.TableName() &&
clientModel.GetId() == model.GetId() {
api.app.SubscriptionsBroker().Unregister(client.Id())
func realtimeUnregisterClientsByAuth(app core.App, authModel core.Model) error {
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
group := new(errgroup.Group)
for _, chunk := range chunks {
group.Go(func() error {
for _, client := range chunk {
clientAuth, _ := client.Get(RealtimeClientAuthKey).(*core.Record)
if clientAuth != nil &&
clientAuth.Id == authModel.PK() &&
clientAuth.Collection().Name == authModel.TableName() {
app.SubscriptionsBroker().Unregister(client.Id())
}
}
return nil
})
}
return group.Wait()
}
func (api *realtimeApi) bindEvents() {
// update the clients that has admin or auth record association
api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error {
if record := api.resolveRecord(e.Model); record != nil && record.Collection().IsAuth() {
return api.updateClientsAuthModel(ContextAuthRecordKey, record)
func bindRealtimeEvents(app core.App) {
// update the clients that has auth record association
app.OnModelAfterUpdateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
authRecord := realtimeResolveRecord(e.App, e.Model, core.CollectionTypeAuth)
if authRecord != nil {
if err := realtimeUpdateClientsAuth(e.App, authRecord); err != nil {
app.Logger().Warn(
"Failed to update client(s) associated to the updated auth record",
slog.Any("id", authRecord.Id),
slog.String("collectionName", authRecord.Collection().Name),
slog.String("error", err.Error()),
)
}
}
if admin, ok := e.Model.(*models.Admin); ok && admin != nil {
return api.updateClientsAuthModel(ContextAdminKey, admin)
}
return nil
return e.Next()
},
Priority: -99,
})
// remove the client(s) associated to the deleted admin or auth record
api.app.OnModelAfterDelete().PreAdd(func(e *core.ModelEvent) error {
if collection := api.resolveRecordCollection(e.Model); collection != nil && collection.IsAuth() {
return api.unregisterClientsByAuthModel(ContextAuthRecordKey, e.Model)
// remove the client(s) associated to the deleted auth model
// (note: works also with custom model for backward compatibility)
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
collection := realtimeResolveRecordCollection(e.App, e.Model)
if collection != nil && collection.IsAuth() {
if err := realtimeUnregisterClientsByAuth(e.App, e.Model); err != nil {
app.Logger().Warn(
"Failed to remove client(s) associated to the deleted auth model",
slog.Any("id", e.Model.PK()),
slog.String("collectionName", e.Model.TableName()),
slog.String("error", err.Error()),
)
}
}
if admin, ok := e.Model.(*models.Admin); ok && admin != nil {
return api.unregisterClientsByAuthModel(ContextAdminKey, admin)
}
return nil
return e.Next()
},
Priority: -99,
})
api.app.OnModelAfterCreate().PreAdd(func(e *core.ModelEvent) error {
if record := api.resolveRecord(e.Model); record != nil {
if err := api.broadcastRecord("create", record, false); err != nil {
api.app.Logger().Debug(
app.OnModelAfterCreateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
err := realtimeBroadcastRecord(e.App, "create", record, false)
if err != nil {
app.Logger().Debug(
"Failed to broadcast record create",
slog.String("id", record.Id),
slog.String("collectionName", record.Collection().Name),
@ -290,13 +320,19 @@ func (api *realtimeApi) bindEvents() {
)
}
}
return nil
return e.Next()
},
Priority: -99,
})
api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error {
if record := api.resolveRecord(e.Model); record != nil {
if err := api.broadcastRecord("update", record, false); err != nil {
api.app.Logger().Debug(
app.OnModelAfterUpdateSuccess().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
err := realtimeBroadcastRecord(e.App, "update", record, false)
if err != nil {
app.Logger().Debug(
"Failed to broadcast record update",
slog.String("id", record.Id),
slog.String("collectionName", record.Collection().Name),
@ -304,13 +340,20 @@ func (api *realtimeApi) bindEvents() {
)
}
}
return nil
return e.Next()
},
Priority: -99,
})
api.app.OnModelBeforeDelete().Add(func(e *core.ModelEvent) error {
if record := api.resolveRecord(e.Model); record != nil {
if err := api.broadcastRecord("delete", record, true); err != nil {
api.app.Logger().Debug(
// delete: dry cache
app.OnModelDelete().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
err := realtimeBroadcastRecord(e.App, "delete", record, true)
if err != nil {
app.Logger().Debug(
"Failed to dry cache record delete",
slog.String("id", record.Id),
slog.String("collectionName", record.Collection().Name),
@ -318,13 +361,20 @@ func (api *realtimeApi) bindEvents() {
)
}
}
return nil
return e.Next()
},
Priority: 99, // execute as later as possible
})
api.app.OnModelAfterDelete().Add(func(e *core.ModelEvent) error {
if record := api.resolveRecord(e.Model); record != nil {
if err := api.broadcastDryCachedRecord("delete", record); err != nil {
api.app.Logger().Debug(
// delete: broadcast
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*core.ModelEvent]{
Func: func(e *core.ModelEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
err := realtimeBroadcastDryCachedRecord(e.App, "delete", record)
if err != nil {
app.Logger().Debug(
"Failed to broadcast record delete",
slog.String("id", record.Id),
slog.String("collectionName", record.Collection().Name),
@ -332,31 +382,71 @@ func (api *realtimeApi) bindEvents() {
)
}
}
return nil
return e.Next()
},
Priority: -99,
})
// delete: failure
app.OnModelAfterDeleteError().Bind(&hook.Handler[*core.ModelErrorEvent]{
Func: func(e *core.ModelErrorEvent) error {
record := realtimeResolveRecord(e.App, e.Model, "")
if record != nil {
err := realtimeUnsetDryCachedRecord(e.App, "delete", record)
if err != nil {
app.Logger().Debug(
"Failed to cleanup after broadcast record delete failure",
slog.String("id", record.Id),
slog.String("collectionName", record.Collection().Name),
slog.String("error", err.Error()),
)
}
}
return e.Next()
},
Priority: -99,
})
}
// resolveRecord converts *if possible* the provided model interface to a Record.
// This is usually helpful if the provided model is a custom Record model struct.
func (api *realtimeApi) resolveRecord(model models.Model) (record *models.Record) {
record, _ = model.(*models.Record)
func realtimeResolveRecord(app core.App, model core.Model, optCollectionType string) *core.Record {
record, _ := model.(*core.Record)
if record != nil {
if optCollectionType == "" || record.Collection().Type == optCollectionType {
return record
}
return nil
}
// check if it is custom Record model struct (ignore "private" tables)
if record == nil && !strings.HasPrefix(model.TableName(), "_") {
record, _ = api.app.Dao().FindRecordById(model.TableName(), model.GetId())
tblName := model.TableName()
// skip Log model checks
if tblName == core.LogsTableName {
return nil
}
// check if it is custom Record model struct
collection, _ := app.FindCachedCollectionByNameOrId(tblName)
if collection != nil && (optCollectionType == "" || collection.Type == optCollectionType) {
if id, ok := model.PK().(string); ok {
record, _ = app.FindRecordById(collection, id)
}
}
return record
}
// resolveRecordCollection extracts *if possible* the Collection model from the provided model interface.
// realtimeResolveRecordCollection extracts *if possible* the Collection model from the provided model interface.
// This is usually helpful if the provided model is a custom Record model struct.
func (api *realtimeApi) resolveRecordCollection(model models.Model) (collection *models.Collection) {
if record, ok := model.(*models.Record); ok {
func realtimeResolveRecordCollection(app core.App, model core.Model) (collection *core.Collection) {
if record, ok := model.(*core.Record); ok {
collection = record.Collection()
} else if !strings.HasPrefix(model.TableName(), "_") {
} else {
// check if it is custom Record model struct (ignore "private" tables)
collection, _ = api.app.Dao().FindCollectionByNameOrId(model.TableName())
collection, _ = app.FindCachedCollectionByNameOrId(model.TableName())
}
return collection
@ -364,18 +454,18 @@ func (api *realtimeApi) resolveRecordCollection(model models.Model) (collection
// recordData represents the broadcasted record subscrition message data.
type recordData struct {
Record any `json:"record"` /* map or models.Record */
Record any `json:"record"` /* map or core.Record */
Action string `json:"action"`
}
func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dryCache bool) error {
func realtimeBroadcastRecord(app core.App, action string, record *core.Record, dryCache bool) error {
collection := record.Collection()
if collection == nil {
return errors.New("[broadcastRecord] Record collection not set")
}
clients := api.app.SubscriptionsBroker().Clients()
if len(clients) == 0 {
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
if len(chunks) == 0 {
return nil // no subscribers
}
@ -384,6 +474,7 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr
(collection.Id + "/" + record.Id + "?"): collection.ViewRule,
(collection.Name + "/*?"): collection.ListRule,
(collection.Id + "/*?"): collection.ListRule,
// @deprecated: the same as the wildcard topic but kept for backward compatibility
(collection.Name + "?"): collection.ListRule,
(collection.Id + "?"): collection.ListRule,
@ -391,9 +482,11 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr
dryCacheKey := action + "/" + record.Id
for _, client := range clients {
client := client
group := new(errgroup.Group)
for _, chunk := range chunks {
group.Go(func() error {
for _, client := range chunk {
// note: not executed concurrently to avoid races and to ensure
// that the access checks are applied for the current record db state
for prefix, rule := range subscriptionRuleMap {
@ -405,27 +498,29 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr
for sub, options := range subs {
// create a clean record copy without expand and unknown fields
// because we don't know yet which exact fields the client subscription has permissions to access
cleanRecord := record.CleanCopy()
cleanRecord := record.Fresh()
// mock request data
requestInfo := &models.RequestInfo{
Context: models.RequestInfoContextRealtime,
requestInfo := &core.RequestInfo{
Context: core.RequestInfoContextRealtime,
Method: "GET",
Query: options.Query,
Headers: options.Headers,
}
requestInfo.Admin, _ = client.Get(ContextAdminKey).(*models.Admin)
requestInfo.AuthRecord, _ = client.Get(ContextAuthRecordKey).(*models.Record)
requestInfo.Auth, _ = client.Get(RealtimeClientAuthKey).(*core.Record)
if !api.canAccessRecord(cleanRecord, requestInfo, rule) {
if !realtimeCanAccessRecord(app, cleanRecord, requestInfo, rule) {
continue
}
rawExpand := cast.ToString(options.Query[expandQueryParam])
// trigger the enrich hooks
enrichErr := triggerRecordEnrichHooks(app, requestInfo, []*core.Record{cleanRecord}, func() error {
// apply expand
rawExpand := options.Query[expandQueryParam]
if rawExpand != "" {
expandErrs := api.app.Dao().ExpandRecord(cleanRecord, strings.Split(rawExpand, ","), expandFetch(api.app.Dao(), requestInfo))
expandErrs := app.ExpandRecord(cleanRecord, strings.Split(rawExpand, ","), expandFetch(app, requestInfo))
if len(expandErrs) > 0 {
api.app.Logger().Debug(
app.Logger().Debug(
"[broadcastRecord] expand errors",
slog.String("id", cleanRecord.Id),
slog.String("collectionName", cleanRecord.Collection().Name),
@ -437,14 +532,26 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr
}
// ignore the auth record email visibility checks
// for auth owner, admin or manager
// for auth owner, superuser or manager
if collection.IsAuth() {
authId := extractAuthIdFromGetter(client)
if authId == cleanRecord.Id {
if api.canAccessRecord(cleanRecord, requestInfo, collection.AuthOptions().ManageRule) {
if authId == cleanRecord.Id ||
realtimeCanAccessRecord(app, cleanRecord, requestInfo, collection.ManageRule) {
cleanRecord.IgnoreEmailVisibility(true)
}
}
return nil
})
if enrichErr != nil {
app.Logger().Debug(
"[broadcastRecord] record enrich error",
slog.String("id", cleanRecord.Id),
slog.String("collectionName", cleanRecord.Collection().Name),
slog.String("sub", sub),
slog.Any("error", enrichErr),
)
continue
}
data := &recordData{
@ -453,13 +560,13 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr
}
// check fields
rawFields := cast.ToString(options.Query[fieldsQueryParam])
rawFields := options.Query[fieldsQueryParam]
if rawFields != "" {
decoded, err := rest.PickFields(cleanRecord, rawFields)
decoded, err := picker.Pick(cleanRecord, rawFields)
if err == nil {
data.Record = decoded
} else {
api.app.Logger().Debug(
app.Logger().Debug(
"[broadcastRecord] pick fields error",
slog.String("id", cleanRecord.Id),
slog.String("collectionName", cleanRecord.Collection().Name),
@ -472,7 +579,7 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr
dataBytes, err := json.Marshal(data)
if err != nil {
api.app.Logger().Debug(
app.Logger().Debug(
"[broadcastRecord] data marshal error",
slog.String("id", cleanRecord.Id),
slog.String("collectionName", cleanRecord.Collection().Name),
@ -504,15 +611,26 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr
}
return nil
})
}
return group.Wait()
}
// broadcastDryCachedRecord broadcasts all cached record related messages.
func (api *realtimeApi) broadcastDryCachedRecord(action string, record *models.Record) error {
// realtimeBroadcastDryCachedRecord broadcasts all cached record related messages.
func realtimeBroadcastDryCachedRecord(app core.App, action string, record *core.Record) error {
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
if len(chunks) == 0 {
return nil // no subscribers
}
key := action + "/" + record.Id
clients := api.app.SubscriptionsBroker().Clients()
group := new(errgroup.Group)
for _, client := range clients {
for _, chunk := range chunks {
group.Go(func() error {
for _, client := range chunk {
messages, ok := client.Get(key).([]subscriptions.Message)
if !ok {
continue
@ -530,6 +648,36 @@ func (api *realtimeApi) broadcastDryCachedRecord(action string, record *models.R
}
return nil
})
}
return group.Wait()
}
// realtimeUnsetDryCachedRecord removes the dry cached record related messages.
func realtimeUnsetDryCachedRecord(app core.App, action string, record *core.Record) error {
chunks := app.SubscriptionsBroker().ChunkedClients(clientsChunkSize)
if len(chunks) == 0 {
return nil // no subscribers
}
key := action + "/" + record.Id
group := new(errgroup.Group)
for _, chunk := range chunks {
group.Go(func() error {
for _, client := range chunk {
if client.Get(key) != nil {
client.Unset(key)
}
}
return nil
})
}
return group.Wait()
}
type getter interface {
@ -537,28 +685,24 @@ type getter interface {
}
func extractAuthIdFromGetter(val getter) string {
record, _ := val.Get(ContextAuthRecordKey).(*models.Record)
record, _ := val.Get(RealtimeClientAuthKey).(*core.Record)
if record != nil {
return record.Id
}
admin, _ := val.Get(ContextAdminKey).(*models.Admin)
if admin != nil {
return admin.Id
}
return ""
}
// canAccessRecord checks if the subscription client has access to the specified record model.
func (api *realtimeApi) canAccessRecord(
record *models.Record,
requestInfo *models.RequestInfo,
// realtimeCanAccessRecord checks if the subscription client has access to the specified record model.
func realtimeCanAccessRecord(
app core.App,
record *core.Record,
requestInfo *core.RequestInfo,
accessRule *string,
) bool {
// check the access rule
// ---
if ok, _ := api.app.Dao().CanAccessRecord(record, requestInfo, accessRule); !ok {
if ok, _ := app.CanAccessRecord(record, requestInfo, accessRule); !ok {
return false
}
@ -569,25 +713,27 @@ func (api *realtimeApi) canAccessRecord(
return true // no further checks needed
}
if err := checkForAdminOnlyRuleFields(requestInfo); err != nil {
err := checkForSuperuserOnlyRuleFields(requestInfo)
if err != nil {
return false
}
ruleFunc := func(q *dbx.SelectQuery) error {
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), record.Collection(), requestInfo, false)
var exists bool
q := app.DB().Select("(1)").
From(record.Collection().Name).
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
resolver := core.NewRecordFieldResolver(app, record.Collection(), requestInfo, false)
expr, err := search.FilterData(filter).BuildExpr(resolver)
if err != nil {
return err
return false
}
q.AndWhere(expr)
q.AndWhere(expr)
resolver.UpdateQuery(q)
return nil
}
err = q.Limit(1).Row(&exists)
_, err := api.app.Dao().FindRecordById(record.Collection().Id, record.Id, ruleFunc)
return err == nil
return err == nil && exists
}

View File

@ -1,20 +1,17 @@
package apis_test
import (
"context"
"errors"
"net/http"
"strings"
"testing"
"time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/subscriptions"
)
@ -22,7 +19,7 @@ func TestRealtimeConnect(t *testing.T) {
scenarios := []tests.ApiScenario{
{
Method: http.MethodGet,
Url: "/api/realtime",
URL: "/api/realtime",
Timeout: 100 * time.Millisecond,
ExpectedStatus: 200,
ExpectedContent: []string{
@ -31,12 +28,11 @@ func TestRealtimeConnect(t *testing.T) {
`data:{"clientId":`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeConnectRequest": 1,
"OnRealtimeBeforeMessageSend": 1,
"OnRealtimeAfterMessageSend": 1,
"OnRealtimeDisconnectRequest": 1,
"OnRealtimeMessageSend": 1,
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(app.SubscriptionsBroker().Clients()) != 0 {
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
}
@ -45,23 +41,23 @@ func TestRealtimeConnect(t *testing.T) {
{
Name: "PB_CONNECT interrupt",
Method: http.MethodGet,
Url: "/api/realtime",
URL: "/api/realtime",
Timeout: 100 * time.Millisecond,
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeConnectRequest": 1,
"OnRealtimeBeforeMessageSend": 1,
"OnRealtimeDisconnectRequest": 1,
"OnRealtimeMessageSend": 1,
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnRealtimeBeforeMessageSend().Add(func(e *core.RealtimeMessageEvent) error {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
if e.Message.Name == "PB_CONNECT" {
return errors.New("PB_CONNECT error")
}
return nil
return e.Next()
})
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(app.SubscriptionsBroker().Clients()) != 0 {
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
}
@ -70,20 +66,20 @@ func TestRealtimeConnect(t *testing.T) {
{
Name: "Skipping/ignoring messages",
Method: http.MethodGet,
Url: "/api/realtime",
URL: "/api/realtime",
Timeout: 100 * time.Millisecond,
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRealtimeConnectRequest": 1,
"OnRealtimeBeforeMessageSend": 1,
"OnRealtimeDisconnectRequest": 1,
"OnRealtimeMessageSend": 1,
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnRealtimeBeforeMessageSend().Add(func(e *core.RealtimeMessageEvent) error {
return hook.StopPropagation
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRealtimeMessageSend().BindFunc(func(e *core.RealtimeMessageEvent) error {
return nil
})
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(app.SubscriptionsBroker().Clients()) != 0 {
t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients()))
}
@ -101,34 +97,34 @@ func TestRealtimeSubscribe(t *testing.T) {
resetClient := func() {
client.Unsubscribe()
client.Set(apis.ContextAdminKey, nil)
client.Set(apis.ContextAuthRecordKey, nil)
client.Set(apis.RealtimeClientAuthKey, nil)
}
scenarios := []tests.ApiScenario{
{
Name: "missing client",
Method: http.MethodPost,
Url: "/api/realtime",
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"missing","subscriptions":["test1", "test2"]}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "existing client - empty subscriptions",
Method: http.MethodPost,
Url: "/api/realtime",
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":[]}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"OnRealtimeBeforeSubscribeRequest": 1,
"OnRealtimeAfterSubscribeRequest": 1,
"*": 0,
"OnRealtimeSubscribeRequest": 1,
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
client.Subscribe("test0")
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if len(client.Subscriptions()) != 0 {
t.Errorf("Expected no subscriptions, got %v", client.Subscriptions())
}
@ -138,18 +134,18 @@ func TestRealtimeSubscribe(t *testing.T) {
{
Name: "existing client - 2 new subscriptions",
Method: http.MethodPost,
Url: "/api/realtime",
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"OnRealtimeBeforeSubscribeRequest": 1,
"OnRealtimeAfterSubscribeRequest": 1,
"*": 0,
"OnRealtimeSubscribeRequest": 1,
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
client.Subscribe("test0")
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
expectedSubs := []string{"test1", "test2"}
if len(expectedSubs) != len(client.Subscriptions()) {
t.Errorf("Expected subscriptions %v, got %v", expectedSubs, client.Subscriptions())
@ -164,49 +160,49 @@ func TestRealtimeSubscribe(t *testing.T) {
},
},
{
Name: "existing client - authorized admin",
Name: "existing client - authorized superuser",
Method: http.MethodPost,
Url: "/api/realtime",
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"OnRealtimeBeforeSubscribeRequest": 1,
"OnRealtimeAfterSubscribeRequest": 1,
"*": 0,
"OnRealtimeSubscribeRequest": 1,
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
admin, _ := client.Get(apis.ContextAdminKey).(*models.Admin)
if admin == nil {
t.Errorf("Expected admin auth model, got nil")
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil || !authRecord.IsSuperuser() {
t.Errorf("Expected superuser auth record, got %v", authRecord)
}
resetClient()
},
},
{
Name: "existing client - authorized record",
Name: "existing client - authorized regular record",
Method: http.MethodPost,
Url: "/api/realtime",
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"OnRealtimeBeforeSubscribeRequest": 1,
"OnRealtimeAfterSubscribeRequest": 1,
"*": 0,
"OnRealtimeSubscribeRequest": 1,
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record)
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil {
t.Errorf("Expected auth record model, got nil")
t.Errorf("Expected regular user auth record, got %v", authRecord)
}
resetClient()
},
@ -214,22 +210,50 @@ func TestRealtimeSubscribe(t *testing.T) {
{
Name: "existing client - mismatched auth",
Method: http.MethodPost,
Url: "/api/realtime",
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
initialAuth := &models.Record{}
initialAuth.RefreshId()
client.Set(apis.ContextAuthRecordKey, initialAuth)
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
if err != nil {
t.Fatal(err)
}
client.Set(apis.RealtimeClientAuthKey, user)
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record)
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil {
t.Errorf("Expected auth record model, got nil")
}
resetClient()
},
},
{
Name: "existing client - unauthorized client",
Method: http.MethodPost,
URL: "/api/realtime",
Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test2@example.com")
if err != nil {
t.Fatal(err)
}
client.Set(apis.RealtimeClientAuthKey, user)
app.SubscriptionsBroker().Register(client)
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
authRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if authRecord == nil {
t.Errorf("Expected auth record model, got nil")
}
@ -247,24 +271,29 @@ func TestRealtimeAuthRecordDeleteEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
apis.InitApi(testApp)
// init realtime handlers
apis.NewRouter(testApp)
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAuthRecordKey, authRecord)
client.Set(apis.RealtimeClientAuthKey, authRecord)
testApp.SubscriptionsBroker().Register(client)
// mock delete event
e := new(core.ModelEvent)
e.Dao = testApp.Dao()
e.App = testApp
e.Type = core.ModelEventTypeDelete
e.Context = context.Background()
e.Model = authRecord
testApp.OnModelAfterDelete().Trigger(e)
if len(testApp.SubscriptionsBroker().Clients()) != 0 {
t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients()))
testApp.OnModelAfterDeleteSuccess().Trigger(e)
if total := len(testApp.SubscriptionsBroker().Clients()); total != 0 {
t.Fatalf("Expected no subscription clients, found %d", total)
}
}
@ -272,111 +301,58 @@ func TestRealtimeAuthRecordUpdateEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
apis.InitApi(testApp)
// init realtime handlers
apis.NewRouter(testApp)
authRecord1, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
authRecord1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAuthRecordKey, authRecord1)
client.Set(apis.RealtimeClientAuthKey, authRecord1)
testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord and change its email
authRecord2, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
authRecord2, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
authRecord2.SetEmail("new@example.com")
// mock update event
e := new(core.ModelEvent)
e.Dao = testApp.Dao()
e.App = testApp
e.Type = core.ModelEventTypeUpdate
e.Context = context.Background()
e.Model = authRecord2
testApp.OnModelAfterUpdate().Trigger(e)
clientAuthRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record)
testApp.OnModelAfterUpdateSuccess().Trigger(e)
clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if clientAuthRecord.Email() != authRecord2.Email() {
t.Fatalf("Expected authRecord with email %q, got %q", authRecord2.Email(), clientAuthRecord.Email())
}
}
func TestRealtimeAdminDeleteEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
apis.InitApi(testApp)
admin, err := testApp.Dao().FindAdminByEmail("test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAdminKey, admin)
testApp.SubscriptionsBroker().Register(client)
e := new(core.ModelEvent)
e.Dao = testApp.Dao()
e.Model = admin
testApp.OnModelAfterDelete().Trigger(e)
if len(testApp.SubscriptionsBroker().Clients()) != 0 {
t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients()))
}
}
func TestRealtimeAdminUpdateEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
apis.InitApi(testApp)
admin1, err := testApp.Dao().FindAdminByEmail("test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAdminKey, admin1)
testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord and change its email
admin2, err := testApp.Dao().FindAdminByEmail("test@example.com")
if err != nil {
t.Fatal(err)
}
admin2.Email = "new@example.com"
e := new(core.ModelEvent)
e.Dao = testApp.Dao()
e.Model = admin2
testApp.OnModelAfterUpdate().Trigger(e)
clientAdmin, _ := client.Get(apis.ContextAdminKey).(*models.Admin)
if clientAdmin.Email != admin2.Email {
t.Fatalf("Expected authRecord with email %q, got %q", admin2.Email, clientAdmin.Email)
}
}
// Custom auth record model struct
// -------------------------------------------------------------------
var _ models.Model = (*CustomUser)(nil)
var _ core.Model = (*CustomUser)(nil)
type CustomUser struct {
models.BaseModel
core.BaseModel
Email string `db:"email" json:"email"`
}
func (m *CustomUser) TableName() string {
return "users" // the name of your collection
return "users"
}
func findCustomUserByEmail(dao *daos.Dao, email string) (*CustomUser, error) {
func findCustomUserByEmail(app core.App, email string) (*CustomUser, error) {
model := &CustomUser{}
err := dao.ModelQuery(model).
err := app.ModelQuery(model).
AndWhere(dbx.HashExp{"email": email}).
Limit(1).
One(model)
@ -392,30 +368,31 @@ func TestRealtimeCustomAuthModelDeleteEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
apis.InitApi(testApp)
// init realtime handlers
apis.NewRouter(testApp)
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAuthRecordKey, authRecord)
client.Set(apis.RealtimeClientAuthKey, authRecord)
testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord as CustomUser
customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com")
customUser, err := findCustomUserByEmail(testApp, "test@example.com")
if err != nil {
t.Fatal(err)
}
// delete the custom user (should unset the client auth record)
if err := testApp.Dao().Delete(customUser); err != nil {
if err := testApp.Delete(customUser); err != nil {
t.Fatal(err)
}
if len(testApp.SubscriptionsBroker().Clients()) != 0 {
t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients()))
if total := len(testApp.SubscriptionsBroker().Clients()); total != 0 {
t.Fatalf("Expected no subscription clients, found %d", total)
}
}
@ -423,30 +400,31 @@ func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
apis.InitApi(testApp)
// init realtime handlers
apis.NewRouter(testApp)
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
authRecord, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAuthRecordKey, authRecord)
client.Set(apis.RealtimeClientAuthKey, authRecord)
testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord as CustomUser
customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com")
customUser, err := findCustomUserByEmail(testApp, "test@example.com")
if err != nil {
t.Fatal(err)
}
// change its email
customUser.Email = "new@example.com"
if err := testApp.Dao().Save(customUser); err != nil {
if err := testApp.Save(customUser); err != nil {
t.Fatal(err)
}
clientAuthRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record)
clientAuthRecord, _ := client.Get(apis.RealtimeClientAuthKey).(*core.Record)
if clientAuthRecord.Email() != customUser.Email {
t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email())
}

View File

@ -1,765 +1,75 @@
package apis
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"sort"
"time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/models/schema"
"github.com/pocketbase/pocketbase/resolvers"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/subscriptions"
"github.com/pocketbase/pocketbase/tools/types"
"golang.org/x/oauth2"
"github.com/pocketbase/pocketbase/tools/router"
)
// bindRecordAuthApi registers the auth record api endpoints and
// the corresponding handlers.
func bindRecordAuthApi(app core.App, rg *echo.Group) {
api := recordAuthApi{app: app}
func bindRecordAuthApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
// global oauth2 subscription redirect handler
rg.GET("/oauth2-redirect", api.oauth2SubscriptionRedirect)
rg.POST("/oauth2-redirect", api.oauth2SubscriptionRedirect) // needed in case of response_mode=form_post
rg.GET("/oauth2-redirect", oauth2SubscriptionRedirect)
// add again as POST in case of response_mode=form_post
rg.POST("/oauth2-redirect", oauth2SubscriptionRedirect)
// common collection record related routes
subGroup := rg.Group(
"/collections/:collection",
ActivityLogger(app),
LoadCollectionContext(app, models.CollectionTypeAuth),
sub := rg.Group("/collections/{collection}")
sub.GET("/auth-methods", recordAuthMethods).Bind(
collectionPathRateLimit("", "listAuthMethods"),
)
subGroup.GET("/auth-methods", api.authMethods)
subGroup.POST("/auth-refresh", api.authRefresh, RequireSameContextRecordAuth())
subGroup.POST("/auth-with-oauth2", api.authWithOAuth2)
subGroup.POST("/auth-with-password", api.authWithPassword)
subGroup.POST("/request-password-reset", api.requestPasswordReset)
subGroup.POST("/confirm-password-reset", api.confirmPasswordReset)
subGroup.POST("/request-verification", api.requestVerification)
subGroup.POST("/confirm-verification", api.confirmVerification)
subGroup.POST("/request-email-change", api.requestEmailChange, RequireSameContextRecordAuth())
subGroup.POST("/confirm-email-change", api.confirmEmailChange)
subGroup.GET("/records/:id/external-auths", api.listExternalAuths, RequireAdminOrOwnerAuth("id"))
subGroup.DELETE("/records/:id/external-auths/:provider", api.unlinkExternalAuth, RequireAdminOrOwnerAuth("id"))
}
type recordAuthApi struct {
app core.App
}
func (api *recordAuthApi) authRefresh(c echo.Context) error {
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if record == nil {
return NewNotFoundError("Missing auth record context.", nil)
}
event := new(core.RecordAuthRefreshEvent)
event.HttpContext = c
event.Collection = record.Collection()
event.Record = record
return api.app.OnRecordBeforeAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshEvent) error {
return api.app.OnRecordAfterAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshEvent) error {
return RecordAuthResponse(api.app, e.HttpContext, e.Record, nil)
})
})
}
type providerInfo struct {
Name string `json:"name"`
DisplayName string `json:"displayName"`
State string `json:"state"`
AuthUrl string `json:"authUrl"`
// technically could be omitted if the provider doesn't support PKCE,
// but to avoid breaking existing typed clients we'll return them as empty string
CodeVerifier string `json:"codeVerifier"`
CodeChallenge string `json:"codeChallenge"`
CodeChallengeMethod string `json:"codeChallengeMethod"`
}
func (api *recordAuthApi) authMethods(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
authOptions := collection.AuthOptions()
result := struct {
AuthProviders []providerInfo `json:"authProviders"`
UsernamePassword bool `json:"usernamePassword"`
EmailPassword bool `json:"emailPassword"`
OnlyVerified bool `json:"onlyVerified"`
}{
UsernamePassword: authOptions.AllowUsernameAuth,
EmailPassword: authOptions.AllowEmailAuth,
OnlyVerified: authOptions.OnlyVerified,
AuthProviders: []providerInfo{},
}
if !authOptions.AllowOAuth2Auth {
return c.JSON(http.StatusOK, result)
}
nameConfigMap := api.app.Settings().NamedAuthProviderConfigs()
for name, config := range nameConfigMap {
if !config.Enabled {
continue
}
provider, err := auth.NewProviderByName(name)
if err != nil {
api.app.Logger().Debug("Missing or invalid provider name", slog.String("name", name))
continue // skip provider
}
if err := config.SetupProvider(provider); err != nil {
api.app.Logger().Debug(
"Failed to setup provider",
slog.String("name", name),
slog.String("error", err.Error()),
sub.POST("/auth-refresh", recordAuthRefresh).Bind(
collectionPathRateLimit("", "authRefresh"),
RequireSameCollectionContextAuth(""),
)
continue // skip provider
}
info := providerInfo{
Name: name,
DisplayName: provider.DisplayName(),
State: security.RandomString(30),
}
if info.DisplayName == "" {
info.DisplayName = name
}
urlOpts := []oauth2.AuthCodeOption{}
// custom providers url options
switch name {
case auth.NameApple:
// see https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms#3332113
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("response_mode", "form_post"))
}
if provider.PKCE() {
info.CodeVerifier = security.RandomString(43)
info.CodeChallenge = security.S256Challenge(info.CodeVerifier)
info.CodeChallengeMethod = "S256"
urlOpts = append(urlOpts,
oauth2.SetAuthURLParam("code_challenge", info.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", info.CodeChallengeMethod),
sub.POST("/auth-with-password", recordAuthWithPassword).Bind(
collectionPathRateLimit("", "authWithPassword", "auth"),
)
}
info.AuthUrl = provider.BuildAuthUrl(
info.State,
urlOpts...,
) + "&redirect_uri=" // empty redirect_uri so that users can append their redirect url
result.AuthProviders = append(result.AuthProviders, info)
}
// sort providers
sort.SliceStable(result.AuthProviders, func(i, j int) bool {
return result.AuthProviders[i].Name < result.AuthProviders[j].Name
})
return c.JSON(http.StatusOK, result)
}
func (api *recordAuthApi) authWithOAuth2(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
if !collection.AuthOptions().AllowOAuth2Auth {
return NewBadRequestError("The collection is not configured to allow OAuth2 authentication.", nil)
}
var fallbackAuthRecord *models.Record
loggedAuthRecord, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if loggedAuthRecord != nil && loggedAuthRecord.Collection().Id == collection.Id {
fallbackAuthRecord = loggedAuthRecord
}
form := forms.NewRecordOAuth2Login(api.app, collection, fallbackAuthRecord)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.RecordAuthWithOAuth2Event)
event.HttpContext = c
event.Collection = collection
event.ProviderName = form.Provider
form.SetBeforeNewRecordCreateFunc(func(createForm *forms.RecordUpsert, authRecord *models.Record, authUser *auth.AuthUser) error {
return createForm.DrySubmit(func(txDao *daos.Dao) error {
event.IsNewRecord = true
// clone the current request data and assign the form create data as its body data
requestInfo := *RequestInfo(c)
requestInfo.Context = models.RequestInfoContextOAuth2
requestInfo.Data = form.CreateData
createRuleFunc := func(q *dbx.SelectQuery) error {
admin, _ := c.Get(ContextAdminKey).(*models.Admin)
if admin != nil {
return nil // either admin or the rule is empty
}
if collection.CreateRule == nil {
return errors.New("Only admins can create new accounts with OAuth2")
}
if *collection.CreateRule != "" {
resolver := resolvers.NewRecordFieldResolver(txDao, collection, &requestInfo, true)
expr, err := search.FilterData(*collection.CreateRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
}
return nil
}
if _, err := txDao.FindRecordById(collection.Id, createForm.Id, createRuleFunc); err != nil {
return fmt.Errorf("Failed create rule constraint: %w", err)
}
return nil
})
})
_, _, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*forms.RecordOAuth2LoginData]) forms.InterceptorNextFunc[*forms.RecordOAuth2LoginData] {
return func(data *forms.RecordOAuth2LoginData) error {
event.Record = data.Record
event.OAuth2User = data.OAuth2User
event.ProviderClient = data.ProviderClient
event.IsNewRecord = data.Record == nil
return api.app.OnRecordBeforeAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2Event) error {
data.Record = e.Record
data.OAuth2User = e.OAuth2User
if err := next(data); err != nil {
return NewBadRequestError("Failed to authenticate.", err)
}
e.Record = data.Record
e.OAuth2User = data.OAuth2User
meta := struct {
*auth.AuthUser
IsNew bool `json:"isNew"`
}{
AuthUser: e.OAuth2User,
IsNew: event.IsNewRecord,
}
return api.app.OnRecordAfterAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2Event) error {
// clear the lastLoginAlertSentAt field so that we can enforce password auth notifications
if !e.Record.LastLoginAlertSentAt().IsZero() {
e.Record.Set(schema.FieldNameLastLoginAlertSentAt, "")
if err := api.app.Dao().SaveRecord(e.Record); err != nil {
api.app.Logger().Warn("Failed to reset lastLoginAlertSentAt", "error", err, "recordId", e.Record.Id)
}
}
return RecordAuthResponse(api.app, e.HttpContext, e.Record, meta)
})
})
}
})
return submitErr
}
func (api *recordAuthApi) authWithPassword(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
form := forms.NewRecordPasswordLogin(api.app, collection)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.RecordAuthWithPasswordEvent)
event.HttpContext = c
event.Collection = collection
event.Password = form.Password
event.Identity = form.Identity
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to authenticate.", err)
}
// @todo remove after the refactoring
if collection.AuthOptions().AllowOAuth2Auth && e.Record.Email() != "" {
externalAuths, err := api.app.Dao().FindAllExternalAuthsByRecord(e.Record)
if err != nil {
return NewBadRequestError("Failed to authenticate.", err)
}
if len(externalAuths) > 0 {
lastLoginAlert := e.Record.LastLoginAlertSentAt().Time()
// send an email alert if the password auth is after OAuth2 auth (lastLoginAlert will be empty)
// or if it has been ~7 days since the last alert
if lastLoginAlert.IsZero() || time.Now().UTC().Sub(lastLoginAlert).Hours() > 168 {
providerNames := make([]string, len(externalAuths))
for i, ea := range externalAuths {
var name string
if provider, err := auth.NewProviderByName(ea.Provider); err == nil {
name = provider.DisplayName()
}
if name == "" {
name = ea.Provider
}
providerNames[i] = name
}
if err := mails.SendRecordPasswordLoginAlert(api.app, e.Record, providerNames...); err != nil {
return NewBadRequestError("Failed to authenticate.", err)
}
e.Record.SetLastLoginAlertSentAt(types.NowDateTime())
if err := api.app.Dao().SaveRecord(e.Record); err != nil {
api.app.Logger().Warn("Failed to update lastLoginAlertSentAt", "error", err, "recordId", e.Record.Id)
}
}
}
}
return api.app.OnRecordAfterAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordEvent) error {
return RecordAuthResponse(api.app, e.HttpContext, e.Record, nil)
})
})
}
})
return submitErr
}
func (api *recordAuthApi) requestPasswordReset(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
authOptions := collection.AuthOptions()
if !authOptions.AllowUsernameAuth && !authOptions.AllowEmailAuth {
return NewBadRequestError("The collection is not configured to allow password authentication.", nil)
}
form := forms.NewRecordPasswordResetRequest(api.app, collection)
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
if err := form.Validate(); err != nil {
return NewBadRequestError("An error occurred while validating the form.", err)
}
event := new(core.RecordRequestPasswordResetEvent)
event.HttpContext = c
event.Collection = collection
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetEvent) error {
// run in background because we don't need to show the result to the client
routine.FireAndForget(func() {
if err := next(e.Record); err != nil {
api.app.Logger().Debug(
"Failed to send password reset email",
slog.String("error", err.Error()),
sub.POST("/auth-with-oauth2", recordAuthWithOAuth2).Bind(
collectionPathRateLimit("", "authWithOAuth2", "auth"),
)
}
})
return api.app.OnRecordAfterRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
// eagerly write 204 response and skip submit errors
// as a measure against emails enumeration
if !c.Response().Committed {
c.NoContent(http.StatusNoContent)
}
return submitErr
}
func (api *recordAuthApi) confirmPasswordReset(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
form := forms.NewRecordPasswordResetConfirm(api.app, collection)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.RecordConfirmPasswordResetEvent)
event.HttpContext = c
event.Collection = collection
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to set new password.", err)
}
return api.app.OnRecordAfterConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
return submitErr
}
func (api *recordAuthApi) requestVerification(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
form := forms.NewRecordVerificationRequest(api.app, collection)
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
if err := form.Validate(); err != nil {
return NewBadRequestError("An error occurred while validating the form.", err)
}
event := new(core.RecordRequestVerificationEvent)
event.HttpContext = c
event.Collection = collection
submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationEvent) error {
// run in background because we don't need to show the result to the client
routine.FireAndForget(func() {
if err := next(e.Record); err != nil {
api.app.Logger().Debug(
"Failed to send verification email",
slog.String("error", err.Error()),
sub.POST("/request-otp", recordRequestOTP).Bind(
collectionPathRateLimit("", "requestOTP"),
)
sub.POST("/auth-with-otp", recordAuthWithOTP).Bind(
collectionPathRateLimit("", "authWithOTP", "auth"),
)
}
})
return api.app.OnRecordAfterRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
sub.POST("/request-password-reset", recordRequestPasswordReset).Bind(
collectionPathRateLimit("", "requestPasswordReset"),
)
sub.POST("/confirm-password-reset", recordConfirmPasswordReset).Bind(
collectionPathRateLimit("", "confirmPasswordReset"),
)
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
sub.POST("/request-verification", recordRequestVerification).Bind(
collectionPathRateLimit("", "requestVerification"),
)
sub.POST("/confirm-verification", recordConfirmVerification).Bind(
collectionPathRateLimit("", "confirmVerification"),
)
// eagerly write 204 response and skip submit errors
// as a measure against users enumeration
if !c.Response().Committed {
c.NoContent(http.StatusNoContent)
}
sub.POST("/request-email-change", recordRequestEmailChange).Bind(
collectionPathRateLimit("", "requestEmailChange"),
RequireSameCollectionContextAuth(""),
)
sub.POST("/confirm-email-change", recordConfirmEmailChange).Bind(
collectionPathRateLimit("", "confirmEmailChange"),
)
return submitErr
sub.POST("/impersonate/{id}", recordAuthImpersonate).Bind(RequireSuperuserAuth())
}
func (api *recordAuthApi) confirmVerification(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
func findAuthCollection(e *core.RequestEvent) (*core.Collection, error) {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || !collection.IsAuth() {
return nil, e.NotFoundError("Missing or invalid auth collection context.", err)
}
form := forms.NewRecordVerificationConfirm(api.app, collection)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.RecordConfirmVerificationEvent)
event.HttpContext = c
event.Collection = collection
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("An error occurred while submitting the form.", err)
}
return api.app.OnRecordAfterConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
return submitErr
}
func (api *recordAuthApi) requestEmailChange(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
record, _ := c.Get(ContextAuthRecordKey).(*models.Record)
if record == nil {
return NewUnauthorizedError("The request requires valid auth record.", nil)
}
form := forms.NewRecordEmailChangeRequest(api.app, record)
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
event := new(core.RecordRequestEmailChangeEvent)
event.HttpContext = c
event.Collection = collection
event.Record = record
return form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
return api.app.OnRecordBeforeRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to request email change.", err)
}
return api.app.OnRecordAfterRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
}
func (api *recordAuthApi) confirmEmailChange(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
form := forms.NewRecordEmailChangeConfirm(api.app, collection)
if readErr := c.Bind(form); readErr != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", readErr)
}
event := new(core.RecordConfirmEmailChangeEvent)
event.HttpContext = c
event.Collection = collection
_, submitErr := form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(record *models.Record) error {
event.Record = record
return api.app.OnRecordBeforeConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to confirm email change.", err)
}
return api.app.OnRecordAfterConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
})
return submitErr
}
func (api *recordAuthApi) listExternalAuths(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
id := c.PathParam("id")
if id == "" {
return NewNotFoundError("", nil)
}
record, err := api.app.Dao().FindRecordById(collection.Id, id)
if err != nil || record == nil {
return NewNotFoundError("", err)
}
externalAuths, err := api.app.Dao().FindAllExternalAuthsByRecord(record)
if err != nil {
return NewBadRequestError("Failed to fetch the external auths for the specified auth record.", err)
}
event := new(core.RecordListExternalAuthsEvent)
event.HttpContext = c
event.Collection = collection
event.Record = record
event.ExternalAuths = externalAuths
return api.app.OnRecordListExternalAuthsRequest().Trigger(event, func(e *core.RecordListExternalAuthsEvent) error {
return e.HttpContext.JSON(http.StatusOK, e.ExternalAuths)
})
}
func (api *recordAuthApi) unlinkExternalAuth(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("Missing collection context.", nil)
}
id := c.PathParam("id")
provider := c.PathParam("provider")
if id == "" || provider == "" {
return NewNotFoundError("", nil)
}
record, err := api.app.Dao().FindRecordById(collection.Id, id)
if err != nil || record == nil {
return NewNotFoundError("", err)
}
externalAuth, err := api.app.Dao().FindExternalAuthByRecordAndProvider(record, provider)
if err != nil {
return NewNotFoundError("Missing external auth provider relation.", err)
}
event := new(core.RecordUnlinkExternalAuthEvent)
event.HttpContext = c
event.Collection = collection
event.Record = record
event.ExternalAuth = externalAuth
return api.app.OnRecordBeforeUnlinkExternalAuthRequest().Trigger(event, func(e *core.RecordUnlinkExternalAuthEvent) error {
if err := api.app.Dao().DeleteExternalAuth(externalAuth); err != nil {
return NewBadRequestError("Cannot unlink the external auth provider.", err)
}
return api.app.OnRecordAfterUnlinkExternalAuthRequest().Trigger(event, func(e *core.RecordUnlinkExternalAuthEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
// -------------------------------------------------------------------
const (
oauth2SubscriptionTopic string = "@oauth2"
oauth2RedirectFailurePath string = "../_/#/auth/oauth2-redirect-failure"
oauth2RedirectSuccessPath string = "../_/#/auth/oauth2-redirect-success"
)
type oauth2RedirectData struct {
State string `form:"state" query:"state" json:"state"`
Code string `form:"code" query:"code" json:"code"`
Error string `form:"error" query:"error" json:"error,omitempty"`
}
func (api *recordAuthApi) oauth2SubscriptionRedirect(c echo.Context) error {
redirectStatusCode := http.StatusTemporaryRedirect
if c.Request().Method != http.MethodGet {
redirectStatusCode = http.StatusSeeOther
}
data := oauth2RedirectData{}
if err := c.Bind(&data); err != nil {
api.app.Logger().Debug("Failed to read OAuth2 redirect data", "error", err)
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
if data.State == "" {
api.app.Logger().Debug("Missing OAuth2 state parameter")
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
client, err := api.app.SubscriptionsBroker().ClientById(data.State)
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscriptionTopic) {
api.app.Logger().Debug("Missing or invalid OAuth2 subscription client", "error", err, "clientId", data.State)
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
defer client.Unsubscribe(oauth2SubscriptionTopic)
encodedData, err := json.Marshal(data)
if err != nil {
api.app.Logger().Debug("Failed to marshalize OAuth2 redirect data", "error", err)
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
msg := subscriptions.Message{
Name: oauth2SubscriptionTopic,
Data: encodedData,
}
client.Send(msg)
if data.Error != "" || data.Code == "" {
api.app.Logger().Debug("Failed OAuth2 redirect due to an error or missing code parameter", "error", data.Error, "clientId", data.State)
return c.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
return c.Redirect(redirectStatusCode, oauth2RedirectSuccessPath)
return collection, nil
}

View File

@ -0,0 +1,121 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/security"
)
func recordConfirmEmailChange(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers can change their emails directly.", nil)
}
form := newEmailChangeConfirmForm(e.App, collection)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
authRecord, newEmail, err := form.parseToken()
if err != nil {
return firstApiError(err, e.BadRequestError("Invalid or expired token.", err))
}
event := new(core.RecordConfirmEmailChangeRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = authRecord
event.NewEmail = newEmail
return e.App.OnRecordConfirmEmailChangeRequest().Trigger(event, func(e *core.RecordConfirmEmailChangeRequestEvent) error {
authRecord.Set(core.FieldNameEmail, e.NewEmail)
authRecord.Set(core.FieldNameVerified, true)
authRecord.RefreshTokenKey() // invalidate old tokens
if err := e.App.Save(e.Record); err != nil {
return firstApiError(err, e.BadRequestError("Failed to confirm email change.", err))
}
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
func newEmailChangeConfirmForm(app core.App, collection *core.Collection) *EmailChangeConfirmForm {
return &EmailChangeConfirmForm{
app: app,
collection: collection,
}
}
type EmailChangeConfirmForm struct {
app core.App
collection *core.Collection
Token string `form:"token" json:"token"`
Password string `form:"password" json:"password"`
}
func (form *EmailChangeConfirmForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
validation.Field(&form.Password, validation.Required, validation.Length(1, 100), validation.By(form.checkPassword)),
)
}
func (form *EmailChangeConfirmForm) checkToken(value any) error {
_, _, err := form.parseToken()
return err
}
func (form *EmailChangeConfirmForm) checkPassword(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
authRecord, _, _ := form.parseToken()
if authRecord == nil || !authRecord.ValidatePassword(v) {
return validation.NewError("validation_invalid_password", "Missing or invalid auth record password.")
}
return nil
}
func (form *EmailChangeConfirmForm) parseToken() (*core.Record, string, error) {
// check token payload
claims, _ := security.ParseUnverifiedJWT(form.Token)
newEmail, _ := claims[core.TokenClaimNewEmail].(string)
if newEmail == "" {
return nil, "", validation.NewError("validation_invalid_token_payload", "Invalid token payload - newEmail must be set.")
}
// ensure that there aren't other users with the new email
_, err := form.app.FindAuthRecordByEmail(form.collection, newEmail)
if err == nil {
return nil, "", validation.NewError("validation_existing_token_email", "The new email address is already registered: "+newEmail)
}
// verify that the token is not expired and its signature is valid
authRecord, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeEmailChange)
if err != nil {
return nil, "", validation.NewError("validation_invalid_token", "Invalid or expired token.")
}
if authRecord.Collection().Id != form.collection.Id {
return nil, "", validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
}
return authRecord, newEmail, nil
}

View File

@ -0,0 +1,205 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordConfirmEmailChange(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/confirm-email-change",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":`,
`"token":{"code":"validation_required"`,
`"password":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{"token`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired token and correct password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoxNjQwOTkxNjYxfQ.dff842MO0mgRTHY8dktp0dqG9-7LGQOgRuiAbQpYBls",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{`,
`"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-email change token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{`,
`"code":"validation_invalid_token_payload"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token and incorrect password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567891"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"password":{`,
`"code":"validation_invalid_password"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token and correct password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmEmailChangeRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByEmail("users", "change@example.com")
if err != nil {
t.Fatalf("Expected to find user with email %q, got error: %v", "change@example.com", err)
}
},
},
{
Name: "valid token in different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_token_collection_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "OnRecordAfterConfirmEmailChangeRequest error response",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmEmailChangeRequest().BindFunc(func(e *core.RecordConfirmEmailChangeRequestEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmEmailChangeRequest": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:confirmEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:confirmEmailChange"},
{MaxRequests: 0, Label: "users:confirmEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:confirmEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-email-change",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsInR5cGUiOiJlbWFpbENoYW5nZSIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSIsIm5ld0VtYWlsIjoiY2hhbmdlQGV4YW1wbGUuY29tIiwiZXhwIjoyNTI0NjA0NDYxfQ.Y7mVlaEPhJiNPoIvIqbIosZU4c4lEhwysOrRR8c95iU",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:confirmEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,90 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
)
func recordRequestEmailChange(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers can change their emails directly.", nil)
}
record := e.Auth
if record == nil {
return e.UnauthorizedError("The request requires valid auth record.", nil)
}
form := newEmailChangeRequestForm(e.App, record)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
event := new(core.RecordRequestEmailChangeRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
event.NewEmail = form.NewEmail
return e.App.OnRecordRequestEmailChangeRequest().Trigger(event, func(e *core.RecordRequestEmailChangeRequestEvent) error {
if err := mails.SendRecordChangeEmail(e.App, e.Record, e.NewEmail); err != nil {
return firstApiError(err, e.BadRequestError("Failed to request email change.", err))
}
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
func newEmailChangeRequestForm(app core.App, record *core.Record) *emailChangeRequestForm {
return &emailChangeRequestForm{
app: app,
record: record,
}
}
type emailChangeRequestForm struct {
app core.App
record *core.Record
NewEmail string `form:"newEmail" json:"newEmail"`
}
func (form *emailChangeRequestForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.NewEmail,
validation.Required,
validation.Length(1, 255),
is.EmailFormat,
validation.NotIn(form.record.Email()),
validation.By(form.checkUniqueEmail),
),
)
}
func (form *emailChangeRequestForm) checkUniqueEmail(value any) error {
v, _ := value.(string)
if v == "" {
return nil
}
found, _ := form.app.FindAuthRecordByEmail(form.record.Collection(), v)
if found != nil && found.Id != form.record.Id {
return validation.NewError("validation_invalid_new_email", "Invalid new email address.")
}
return nil
}

View File

@ -0,0 +1,168 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordRequestEmailChange(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "record authentication but from different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superuser authentication",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":`,
`"newEmail":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid data (existing email)",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"test2@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":`,
`"newEmail":{"code":"validation_invalid_new_email"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid data (new email)",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestEmailChangeRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordEmailChangeSend": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-email-change") {
t.Fatalf("Expected email change email, got\n%v", app.TestMailer.LastMessage().HTML)
}
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestEmailChange"},
{MaxRequests: 0, Label: "users:requestEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestEmailChange",
Method: http.MethodPost,
URL: "/api/collections/users/request-email-change",
Body: strings.NewReader(`{"newEmail":"change@example.com"}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestEmailChange"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,54 @@
package apis
import (
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
// note: for now allow superusers but it may change in the future to allow access
// also to users with "Manage API" rule access depending on the use cases that will arise
func recordAuthImpersonate(e *core.RequestEvent) error {
if !e.HasSuperuserAuth() {
return e.ForbiddenError("", nil)
}
collection, err := findAuthCollection(e)
if err != nil {
return err
}
record, err := e.App.FindRecordById(collection, e.Request.PathValue("id"))
if err != nil {
return e.NotFoundError("", err)
}
form := &impersonateForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
token, err := record.NewStaticAuthToken(time.Duration(form.Duration) * time.Second)
if err != nil {
e.InternalServerError("Failed to generate static auth token", err)
}
return recordAuthResponse(e, record, token, "", nil)
}
// -------------------------------------------------------------------
type impersonateForm struct {
// Duration is the optional custom token duration in seconds.
Duration int64 `form:"duration" json:"duration"`
}
func (form *impersonateForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Duration, validation.Min(0)),
)
}

View File

@ -0,0 +1,109 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordAuthImpersonate(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as different user",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as the same user",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
`"id":"4q1xlclmfloku33"`,
`"record":{`,
},
NotExpectedContent: []string{
// hidden fields should remain hidden even though we are authenticated as superuser
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "authorized as superuser with custom invalid duration",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{"duration":-1}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"duration":{`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as superuser with custom valid duration",
Method: http.MethodPost,
URL: "/api/collections/users/impersonate/4q1xlclmfloku33",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: strings.NewReader(`{"duration":100}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
`"id":"4q1xlclmfloku33"`,
`"record":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

170
apis/record_auth_methods.go Normal file
View File

@ -0,0 +1,170 @@
package apis
import (
"log/slog"
"net/http"
"slices"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/security"
"golang.org/x/oauth2"
)
type otpResponse struct {
Enabled bool `json:"enabled"`
Duration int64 `json:"duration"` // in seconds
}
type mfaResponse struct {
Enabled bool `json:"enabled"`
Duration int64 `json:"duration"` // in seconds
}
type passwordResponse struct {
IdentityFields []string `json:"identityFields"`
Enabled bool `json:"enabled"`
}
type oauth2Response struct {
Providers []providerInfo `json:"providers"`
Enabled bool `json:"enabled"`
}
type providerInfo struct {
Name string `json:"name"`
DisplayName string `json:"displayName"`
State string `json:"state"`
AuthURL string `json:"authURL"`
// @todo
// deprecated: use AuthURL instead
// AuthUrl will be removed after dropping v0.22 support
AuthUrl string `json:"authUrl"`
// technically could be omitted if the provider doesn't support PKCE,
// but to avoid breaking existing typed clients we'll return them as empty string
CodeVerifier string `json:"codeVerifier"`
CodeChallenge string `json:"codeChallenge"`
CodeChallengeMethod string `json:"codeChallengeMethod"`
}
type authMethodsResponse struct {
Password passwordResponse `json:"password"`
OAuth2 oauth2Response `json:"oauth2"`
MFA mfaResponse `json:"mfa"`
OTP otpResponse `json:"otp"`
// legacy fields
// @todo remove after dropping v0.22 support
AuthProviders []providerInfo `json:"authProviders"`
UsernamePassword bool `json:"usernamePassword"`
EmailPassword bool `json:"emailPassword"`
}
func (amr *authMethodsResponse) fillLegacyFields() {
amr.EmailPassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "email")
amr.UsernamePassword = amr.Password.Enabled && slices.Contains(amr.Password.IdentityFields, "username")
if amr.OAuth2.Enabled {
amr.AuthProviders = amr.OAuth2.Providers
}
}
func recordAuthMethods(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
result := authMethodsResponse{
Password: passwordResponse{
IdentityFields: make([]string, 0, len(collection.PasswordAuth.IdentityFields)),
},
OAuth2: oauth2Response{
Providers: make([]providerInfo, 0, len(collection.OAuth2.Providers)),
},
OTP: otpResponse{
Enabled: collection.OTP.Enabled,
},
MFA: mfaResponse{
Enabled: collection.MFA.Enabled,
},
}
if collection.PasswordAuth.Enabled {
result.Password.Enabled = true
result.Password.IdentityFields = collection.PasswordAuth.IdentityFields
}
if collection.OTP.Enabled {
result.OTP.Duration = collection.OTP.Duration
}
if collection.MFA.Enabled {
result.MFA.Duration = collection.MFA.Duration
}
if !collection.OAuth2.Enabled {
result.fillLegacyFields()
return e.JSON(http.StatusOK, result)
}
result.OAuth2.Enabled = true
for _, config := range collection.OAuth2.Providers {
provider, err := config.InitProvider()
if err != nil {
e.App.Logger().Debug(
"Failed to setup OAuth2 provider",
slog.String("name", config.Name),
slog.String("error", err.Error()),
)
continue // skip provider
}
info := providerInfo{
Name: config.Name,
DisplayName: provider.DisplayName(),
State: security.RandomString(30),
}
if info.DisplayName == "" {
info.DisplayName = config.Name
}
urlOpts := []oauth2.AuthCodeOption{}
// custom providers url options
switch config.Name {
case auth.NameApple:
// see https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms#3332113
urlOpts = append(urlOpts, oauth2.SetAuthURLParam("response_mode", "form_post"))
}
if provider.PKCE() {
info.CodeVerifier = security.RandomString(43)
info.CodeChallenge = security.S256Challenge(info.CodeVerifier)
info.CodeChallengeMethod = "S256"
urlOpts = append(urlOpts,
oauth2.SetAuthURLParam("code_challenge", info.CodeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", info.CodeChallengeMethod),
)
}
info.AuthURL = provider.BuildAuthURL(
info.State,
urlOpts...,
) + "&redirect_uri=" // empty redirect_uri so that users can append their redirect url
info.AuthUrl = info.AuthURL
result.OAuth2.Providers = append(result.OAuth2.Providers, info)
}
result.fillLegacyFields()
return e.JSON(http.StatusOK, result)
}

View File

@ -0,0 +1,106 @@
package apis_test
import (
"net/http"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordAuthMethodsList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "missing collection",
Method: http.MethodGet,
URL: "/api/collections/missing/auth-methods",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non auth collection",
Method: http.MethodGet,
URL: "/api/collections/demo1/auth-methods",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with none auth methods allowed",
Method: http.MethodGet,
URL: "/api/collections/nologin/auth-methods",
ExpectedStatus: 200,
ExpectedContent: []string{
`"password":{"identityFields":[],"enabled":false}`,
`"oauth2":{"providers":[],"enabled":false}`,
`"mfa":{"enabled":false,"duration":0}`,
`"otp":{"enabled":false,"duration":0}`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with all auth methods allowed",
Method: http.MethodGet,
URL: "/api/collections/users/auth-methods",
ExpectedStatus: 200,
ExpectedContent: []string{
`"password":{"identityFields":["email","username"],"enabled":true}`,
`"mfa":{"enabled":true,"duration":1800}`,
`"otp":{"enabled":true,"duration":300}`,
`"oauth2":{`,
`"providers":[{`,
`"name":"google"`,
`"name":"gitlab"`,
`"state":`,
`"displayName":`,
`"codeVerifier":`,
`"codeChallenge":`,
`"codeChallengeMethod":`,
`"authURL":`,
`redirect_uri="`, // ensures that the redirect_uri is the last url param
},
ExpectedEvents: map[string]int{"*": 0},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - nologin:listAuthMethods",
Method: http.MethodGet,
URL: "/api/collections/nologin/auth-methods",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:listAuthMethods"},
{MaxRequests: 0, Label: "nologin:listAuthMethods"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:listAuthMethods",
Method: http.MethodGet,
URL: "/api/collections/nologin/auth-methods",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:listAuthMethods"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,118 @@
package apis
import (
"errors"
"fmt"
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/security"
)
func recordRequestOTP(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.OTP.Enabled {
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
}
form := &createOTPForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
if err != nil {
// eagerly write a dummy 200 response as a very rudimentary user emails enumeration protection
e.JSON(http.StatusOK, map[string]string{
"otpId": core.GenerateDefaultRandomId(),
})
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
}
event := new(core.RecordCreateOTPRequestEvent)
event.RequestEvent = e
event.Password = security.RandomStringWithAlphabet(collection.OTP.Length, "1234567890")
event.Collection = collection
event.Record = record
return e.App.OnRecordRequestOTPRequest().Trigger(event, func(e *core.RecordCreateOTPRequestEvent) error {
var otp *core.OTP
// limit the new OTP creations for a single user
if !e.App.IsDev() {
otps, err := e.App.FindAllOTPsByRecord(e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to fetch previous record OTPs.", err))
}
totalRecent := 0
for _, existingOTP := range otps {
if !existingOTP.HasExpired(collection.OTP.DurationTime()) {
totalRecent++
}
// use the last issued one
if totalRecent > 9 {
otp = otps[0] // otps are DESC sorted
e.App.Logger().Warn(
"Too many OTP requests - reusing the last issued",
"email", form.Email,
"recordId", e.Record.Id,
"otpId", existingOTP.Id,
)
break
}
}
}
if otp == nil {
// create new OTP
// ---
otp = core.NewOTP(e.App)
otp.SetCollectionRef(e.Record.Collection().Id)
otp.SetRecordRef(e.Record.Id)
otp.SetPassword(e.Password)
err = e.App.Save(otp)
if err != nil {
return err
}
// send OTP email
// (in the background as a very basic timing attacks and emails enumeration protection)
// ---
app := e.App
routine.FireAndForget(func() {
err = mails.SendRecordOTP(app, e.Record, otp.Id, e.Password)
if err != nil {
app.Logger().Error("Failed to send OTP email", "error", errors.Join(err, e.App.Delete(otp)))
}
})
}
return e.JSON(http.StatusOK, map[string]string{
"otpId": otp.Id,
})
})
}
// -------------------------------------------------------------------
type createOTPForm struct {
Email string `form:"email" json:"email"`
}
func (form createOTPForm) validate() error {
return validation.ValidateStruct(&form,
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
)
}

View File

@ -0,0 +1,231 @@
package apis_test
import (
"net/http"
"strconv"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestRecordRequestOTP(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with disabled otp",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
usersCol, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
usersCol.OTP.Enabled = false
if err := app.Save(usersCol); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty body",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid body",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid request data",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"invalid"}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"email":{"code":"validation_is_email`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 200,
ExpectedContent: []string{
`"otpId":"`, // some fake random generated string
},
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "existing auth record (with < 9 non-expired)",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// insert 8 non-expired and 2 expired
for i := 0; i < 10; i++ {
otp := core.NewOTP(app)
otp.Id = "otp_" + strconv.Itoa(i)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if i >= 8 {
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
otp.SetRaw("created", expiredDate)
otp.SetRaw("updated", expiredDate)
}
if err := app.SaveNoValidate(otp); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"otpId":"`,
},
NotExpectedContent: []string{
`"otpId":"otp_`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestOTPRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordOTPSend": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("Expected 1 email, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "existing auth record (with > 9 non-expired)",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// insert 10 non-expired
for i := 0; i < 10; i++ {
otp := core.NewOTP(app)
otp.Id = "otp_" + strconv.Itoa(i)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.SaveNoValidate(otp); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"otpId":"otp_9"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestOTPRequest": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected 0 sent emails, got %d", app.TestMailer.TotalSend())
}
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestOTP",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestOTP"},
{MaxRequests: 0, Label: "users:requestOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestOTP",
Method: http.MethodPost,
URL: "/api/collections/users/request-otp",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,102 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/core/validators"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
func recordConfirmPasswordReset(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
form := new(recordConfirmPasswordResetForm)
form.app = e.App
form.collection = collection
if err = e.BindBody(form); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
authRecord, err := e.App.FindAuthRecordByToken(form.Token, core.TokenTypePasswordReset)
if err != nil {
return firstApiError(err, e.BadRequestError("Invalid or expired password reset token.", err))
}
event := new(core.RecordConfirmPasswordResetRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = authRecord
return e.App.OnRecordConfirmPasswordResetRequest().Trigger(event, func(e *core.RecordConfirmPasswordResetRequestEvent) error {
authRecord.SetPassword(form.Password)
if !authRecord.Verified() {
payload, err := security.ParseUnverifiedJWT(form.Token)
if err == nil && authRecord.Email() == cast.ToString(payload[core.TokenClaimEmail]) {
// mark as verified if the email hasn't changed
authRecord.SetVerified(true)
}
}
err = form.app.Save(authRecord)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to set new password.", err))
}
form.app.Store().Remove(getPasswordResetResendKey(authRecord))
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
type recordConfirmPasswordResetForm struct {
app core.App
collection *core.Collection
Token string `form:"token" json:"token"`
Password string `form:"password" json:"password"`
PasswordConfirm string `form:"passwordConfirm" json:"passwordConfirm"`
}
func (form *recordConfirmPasswordResetForm) validate() error {
min := 1
passField, ok := form.collection.Fields.GetByName(core.FieldNamePassword).(*core.PasswordField)
if ok && passField != nil && passField.Min > 0 {
min = passField.Min
}
return validation.ValidateStruct(form,
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
validation.Field(&form.Password, validation.Required, validation.Length(min, 255)), // the FieldPassword validator will check further the specicic length constraints
validation.Field(&form.PasswordConfirm, validation.Required, validation.By(validators.Equal(form.Password))),
)
}
func (form *recordConfirmPasswordResetForm) checkToken(value any) error {
v, _ := value.(string)
if v == "" {
return nil
}
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypePasswordReset)
if err != nil || record == nil {
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
}
if record.Collection().Id != form.collection.Id {
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
}
return nil
}

View File

@ -0,0 +1,345 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordConfirmPasswordReset(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"password":{"code":"validation_required"`,
`"passwordConfirm":{"code":"validation_required"`,
`"token":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data format",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{"password`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired token and invalid password",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.5Tm6_6amQqOlX3urAnXlEdmxwG5qQJfiTg6U0hHR1hk",
"password":"1234567",
"passwordConfirm":"7654321"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
`"password":{"code":"validation_length_out_of_range"`,
`"passwordConfirm":{"code":"validation_values_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-password reset token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/confirm-password-reset?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/confirm-password-reset?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token and data (unverified user)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if user.Verified() {
t.Fatal("Expected the user to be unverified")
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByToken(
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
core.TokenTypePasswordReset,
)
if err == nil {
t.Fatal("Expected the password reset token to be invalidated")
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if !user.Verified() {
t.Fatal("Expected the user to be marked as verified")
}
if !user.ValidatePassword("1234567!") {
t.Fatal("Password wasn't changed")
}
},
},
{
Name: "valid token and data (unverified user with different email from the one in the token)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if user.Verified() {
t.Fatal("Expected the user to be unverified")
}
// manually change the email to check whether the verified state will be updated
user.SetEmail("test_update@example.com")
if err := app.Save(user); err != nil {
t.Fatalf("Failed to update user test email: %v", err)
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByToken(
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
core.TokenTypePasswordReset,
)
if err == nil {
t.Fatalf("Expected the password reset token to be invalidated")
}
user, err := app.FindAuthRecordByEmail("users", "test_update@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if user.Verified() {
t.Fatal("Expected the user to remain unverified")
}
if !user.ValidatePassword("1234567!") {
t.Fatal("Password wasn't changed")
}
},
},
{
Name: "valid token and data (verified user)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
// ensure that the user is already verified
user.SetVerified(true)
if err := app.Save(user); err != nil {
t.Fatalf("Failed to update user verified state")
}
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
_, err := app.FindAuthRecordByToken(
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
core.TokenTypePasswordReset,
)
if err == nil {
t.Fatal("Expected the password reset token to be invalidated")
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatalf("Failed to fetch confirm password user: %v", err)
}
if !user.Verified() {
t.Fatal("Expected the user to remain verified")
}
if !user.ValidatePassword("1234567!") {
t.Fatal("Password wasn't changed")
}
},
},
{
Name: "OnRecordAfterConfirmPasswordResetRequest error response",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmPasswordResetRequest().BindFunc(func(e *core.RecordConfirmPasswordResetRequestEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmPasswordResetRequest": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:confirmPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:confirmPasswordReset"},
{MaxRequests: 0, Label: "users:confirmPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:confirmPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-password-reset",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY",
"password":"1234567!",
"passwordConfirm":"1234567!"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:confirmPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,86 @@
package apis
import (
"errors"
"fmt"
"net/http"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/tools/routine"
)
func recordRequestPasswordReset(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.PasswordAuth.Enabled {
return e.BadRequestError("The collection is not configured to allow password authentication.", nil)
}
form := new(recordRequestPasswordResetForm)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
if err != nil {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
}
resendKey := getPasswordResetResendKey(record)
if e.App.Store().Has(resendKey) {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return errors.New("try again later - you've already requested a password reset email")
}
event := new(core.RecordRequestPasswordResetRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return e.App.OnRecordRequestPasswordResetRequest().Trigger(event, func(e *core.RecordRequestPasswordResetRequestEvent) error {
// run in background because we don't need to show the result to the client
app := e.App
routine.FireAndForget(func() {
if err := mails.SendRecordPasswordReset(app, e.Record); err != nil {
app.Logger().Error("Failed to send password reset email", "error", err)
return
}
app.Store().Set(resendKey, struct{}{})
time.AfterFunc(2*time.Minute, func() {
app.Store().Remove(resendKey)
})
})
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
type recordRequestPasswordResetForm struct {
Email string `form:"email" json:"email"`
}
func (form *recordRequestPasswordResetForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
)
}
func getPasswordResetResendKey(record *core.Record) string {
return "@limitPasswordResetEmail_" + record.Collection().Id + record.Id
}

View File

@ -0,0 +1,145 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordRequestPasswordReset(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "existing auth record in a collection with disabled password login",
Method: http.MethodPost,
URL: "/api/collections/nologin/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "existing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestPasswordResetRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordPasswordResetSend": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-password-reset") {
t.Fatalf("Expected password reset email, got\n%v", app.TestMailer.LastMessage().HTML)
}
},
},
{
Name: "existing auth record (after already sent)",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// simulate recent verification sent
authRecord, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
resendKey := "@limitPasswordResetEmail_" + authRecord.Collection().Id + authRecord.Id
app.Store().Set(resendKey, struct{}{})
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestPasswordReset"},
{MaxRequests: 0, Label: "users:requestPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestPasswordReset",
Method: http.MethodPost,
URL: "/api/collections/users/request-password-reset",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestPasswordReset"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,29 @@
package apis
import (
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
func recordAuthRefresh(e *core.RequestEvent) error {
record := e.Auth
if record == nil {
return e.NotFoundError("Missing auth record context.", nil)
}
currentToken := getAuthTokenFromRequest(e)
claims, _ := security.ParseUnverifiedJWT(currentToken)
if v, ok := claims[core.TokenClaimRefreshable]; !ok || !cast.ToBool(v) {
return e.ForbiddenError("The current auth token is not refreshable.", nil)
}
event := new(core.RecordAuthRefreshRequestEvent)
event.RequestEvent = e
event.Collection = record.Collection()
event.Record = record
return e.App.OnRecordAuthRefreshRequest().Trigger(event, func(e *core.RecordAuthRefreshRequestEvent) error {
return RecordAuthResponse(e.RequestEvent, e.Record, "", nil)
})
}

View File

@ -0,0 +1,196 @@
package apis_test
import (
"errors"
"net/http"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordAuthRefresh(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superuser trying to refresh the auth of another auth collection",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth record + not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth record + different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-refresh?expand=rel,missing",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth record + same auth collection as the token",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh?expand=rel,missing",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":`,
`"record":`,
`"id":"4q1xlclmfloku33"`,
`"emailVisibility":false`,
`"email":"test@example.com"`, // the owner can always view their email address
`"expand":`,
`"rel":`,
`"id":"llvuca81nly1qls"`,
},
NotExpectedContent: []string{
`"missing":`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 2,
},
},
{
Name: "auth record + same auth collection as the token but static/unrefreshable",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6ZmFsc2V9.4IsO6YMsR19crhwl_YWzvRH8pfq2Ri4Gv2dzGyneLak",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "unverified auth record in onlyVerified collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im8xeTBkZDBzcGQ3ODZtZCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.Zi0yXE-CNmnbTdVaQEzYZVuECqRdn3LgEM6pmB3XWBE",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
},
},
{
Name: "verified auth record in onlyVerified collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":`,
`"record":`,
`"id":"gk390qegs4y47wn"`,
`"verified":true`,
`"email":"test@example.com"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "OnRecordAfterAuthRefreshRequest error response",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh?expand=rel,missing",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordAuthRefreshRequest().BindFunc(func(e *core.RecordAuthRefreshRequestEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthRefreshRequest": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:authRefresh",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authRefresh"},
{MaxRequests: 0, Label: "users:authRefresh"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:authRefresh",
Method: http.MethodPost,
URL: "/api/collections/users/auth-refresh",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:authRefresh"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,102 @@
package apis
import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/spf13/cast"
)
func recordConfirmVerification(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers are verified by default.", nil)
}
form := new(recordConfirmVerificationForm)
form.app = e.App
form.collection = collection
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := form.app.FindAuthRecordByToken(form.Token, core.TokenTypeVerification)
if err != nil {
return e.BadRequestError("Invalid or expired verification token.", err)
}
wasVerified := record.Verified()
event := new(core.RecordConfirmVerificationRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return e.App.OnRecordConfirmVerificationRequest().Trigger(event, func(e *core.RecordConfirmVerificationRequestEvent) error {
if wasVerified {
return e.NoContent(http.StatusNoContent)
}
e.Record.SetVerified(true)
if err := e.App.Save(e.Record); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while saving the verified state.", err))
}
e.App.Store().Remove(getVerificationResendKey(e.Record))
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
type recordConfirmVerificationForm struct {
app core.App
collection *core.Collection
Token string `form:"token" json:"token"`
}
func (form *recordConfirmVerificationForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Token, validation.Required, validation.By(form.checkToken)),
)
}
func (form *recordConfirmVerificationForm) checkToken(value any) error {
v, _ := value.(string)
if v == "" {
return nil // nothing to check
}
claims, _ := security.ParseUnverifiedJWT(v)
email := cast.ToString(claims["email"])
if email == "" {
return validation.NewError("validation_invalid_token_claims", "Missing email token claim.")
}
record, err := form.app.FindAuthRecordByToken(v, core.TokenTypeVerification)
if err != nil || record == nil {
return validation.NewError("validation_invalid_token", "Invalid or expired token.")
}
if record.Collection().Id != form.collection.Id {
return validation.NewError("validation_token_collection_mismatch", "The provided token is for different auth collection.")
}
if record.Email() != email {
return validation.NewError("validation_token_email_mismatch", "The record email doesn't match with the requested token claims.")
}
return nil
}

View File

@ -0,0 +1,210 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordConfirmVerification(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data format",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{"password`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MTY0MDk5MTY2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.qqelNNL2Udl6K_TJ282sNHYCpASgA6SIuSVKGfBHMZU"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-verification token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InBhc3N3b3JkUmVzZXQiLCJjb2xsZWN0aW9uSWQiOiJfcGJfdXNlcnNfYXV0aF8iLCJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.xR-xq1oHDy0D8Q4NDOAEyYKGHWd_swzoiSoL8FLFBHY"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"token":{"code":"validation_invalid_token"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/confirm-verification?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "different auth collection",
Method: http.MethodPost,
URL: "/api/collections/clients/confirm-verification?expand=rel,missing",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{"token":{"code":"validation_token_collection_mismatch"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid token",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
"OnModelUpdate": 1,
"OnModelValidate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordValidate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
},
},
{
Name: "valid token (already verified)",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdDJAZXhhbXBsZS5jb20ifQ.QQmM3odNFVk6u4J4-5H8IBM3dfk9YCD7mPW-8PhBAI8"
}`),
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
},
},
{
Name: "valid verification token from a collection without allowed login",
Method: http.MethodPost,
URL: "/api/collections/nologin/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
}`),
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
"OnModelUpdate": 1,
"OnModelValidate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnRecordUpdate": 1,
"OnRecordValidate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
},
},
{
Name: "OnRecordAfterConfirmVerificationRequest error response",
Method: http.MethodPost,
URL: "/api/collections/users/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6Il9wYl91c2Vyc19hdXRoXyIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.SetHpu2H-x-q4TIUz-xiQjwi7MNwLCLvSs4O0hUSp0E"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordConfirmVerificationRequest().BindFunc(func(e *core.RecordConfirmVerificationRequestEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordConfirmVerificationRequest": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - nologin:confirmVerification",
Method: http.MethodPost,
URL: "/api/collections/nologin/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:confirmVerification"},
{MaxRequests: 0, Label: "nologin:confirmVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:confirmVerification",
Method: http.MethodPost,
URL: "/api/collections/nologin/confirm-verification",
Body: strings.NewReader(`{
"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6ImRjNDlrNmpnZWpuNDBoMyIsImV4cCI6MjUyNDYwNDQ2MSwidHlwZSI6InZlcmlmaWNhdGlvbiIsImNvbGxlY3Rpb25JZCI6ImtwdjcwOXNrMmxxYnFrOCIsImVtYWlsIjoidGVzdEBleGFtcGxlLmNvbSJ9.5GmuZr4vmwk3Cb_3ZZWNxwbE75KZC-j71xxIPR9AsVw"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:confirmVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,89 @@
package apis
import (
"errors"
"fmt"
"net/http"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/tools/routine"
)
func recordRequestVerification(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if collection.Name == core.CollectionNameSuperusers {
return e.BadRequestError("All superusers are verified by default.", nil)
}
form := new(recordRequestVerificationForm)
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
record, err := e.App.FindAuthRecordByEmail(collection, form.Email)
if err != nil {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return fmt.Errorf("failed to fetch %s record with email %s: %w", collection.Name, form.Email, err)
}
resendKey := getVerificationResendKey(record)
if !record.Verified() && e.App.Store().Has(resendKey) {
// eagerly write 204 response as a very basic measure against emails enumeration
e.NoContent(http.StatusNoContent)
return errors.New("try again later - you've already requested a verification email")
}
event := new(core.RecordRequestVerificationRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return e.App.OnRecordRequestVerificationRequest().Trigger(event, func(e *core.RecordRequestVerificationRequestEvent) error {
if e.Record.Verified() {
return e.NoContent(http.StatusNoContent)
}
// run in background because we don't need to show the result to the client
app := e.App
routine.FireAndForget(func() {
if err := mails.SendRecordVerification(app, e.Record); err != nil {
app.Logger().Error("Failed to send verification email", "error", err)
}
app.Store().Set(resendKey, struct{}{})
time.AfterFunc(2*time.Minute, func() {
app.Store().Remove(resendKey)
})
})
return e.NoContent(http.StatusNoContent)
})
}
// -------------------------------------------------------------------
type recordRequestVerificationForm struct {
Email string `form:"email" json:"email"`
}
func (form *recordRequestVerificationForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Email, validation.Required, validation.Length(1, 255), is.EmailFormat),
)
}
func getVerificationResendKey(record *core.Record) string {
return "@limitVerificationEmail_" + record.Collection().Id + record.Id
}

View File

@ -0,0 +1,162 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordRequestVerification(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/request-verification",
Body: strings.NewReader(``),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty data",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{"email":{"code":"validation_required","message":"Cannot be blank."}}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid data",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"missing@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "already verified auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test2@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestVerificationRequest": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
{
Name: "existing auth record",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordRequestVerificationRequest": 1,
"OnMailerSend": 1,
"OnMailerRecordVerificationSend": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if !strings.Contains(app.TestMailer.LastMessage().HTML, "/auth/confirm-verification") {
t.Fatalf("Expected verification email, got\n%v", app.TestMailer.LastMessage().HTML)
}
},
},
{
Name: "existing auth record (after already sent)",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
Delay: 100 * time.Millisecond,
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
// terminated before firing the event
// "OnRecordRequestVerificationRequest": 1,
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// simulate recent verification sent
authRecord, err := app.FindFirstRecordByData("users", "email", "test@example.com")
if err != nil {
t.Fatal(err)
}
resendKey := "@limitVerificationEmail_" + authRecord.Collection().Id + authRecord.Id
app.Store().Set(resendKey, struct{}{})
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected zero emails, got %d", app.TestMailer.TotalSend())
}
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:requestVerification",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:requestVerification"},
{MaxRequests: 0, Label: "users:requestVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:requestVerification",
Method: http.MethodPost,
URL: "/api/collections/users/request-verification",
Body: strings.NewReader(`{"email":"test@example.com"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:requestVerification"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,355 @@
package apis
import (
"context"
"encoding/json"
"errors"
"fmt"
"maps"
"net/http"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/dbutils"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/security"
"golang.org/x/oauth2"
)
func recordAuthWithOAuth2(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.OAuth2.Enabled {
return e.ForbiddenError("The collection is not configured to allow OAuth2 authentication.", nil)
}
var fallbackAuthRecord *core.Record
if e.Auth != nil && e.Auth.Collection().Id == collection.Id {
fallbackAuthRecord = e.Auth
}
form := new(recordOAuth2LoginForm)
form.collection = collection
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if form.RedirectUrl != "" && form.RedirectURL == "" {
e.App.Logger().Warn("[recordAuthWithOAuth2] redirectUrl body param is deprecated and will be removed in the future. Please replace it with redirectURL.")
form.RedirectURL = form.RedirectUrl
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
// exchange token for OAuth2 user info and locate existing ExternalAuth rel
// ---------------------------------------------------------------
// load provider configuration
providerConfig, ok := collection.OAuth2.GetProviderConfig(form.Provider)
if !ok {
return e.InternalServerError("Missing or invalid provider config.", nil)
}
provider, err := providerConfig.InitProvider()
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to init provider "+form.Provider, err))
}
ctx, cancel := context.WithTimeout(e.Request.Context(), 30*time.Second)
defer cancel()
provider.SetContext(ctx)
provider.SetRedirectURL(form.RedirectURL)
var opts []oauth2.AuthCodeOption
if provider.PKCE() {
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", form.CodeVerifier))
}
// fetch token
token, err := provider.FetchToken(form.Code, opts...)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 token.", err))
}
// fetch external auth user
authUser, err := provider.FetchAuthUser(token)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to fetch OAuth2 user.", err))
}
var authRecord *core.Record
// check for existing relation with the auth record
externalAuthRel, err := e.App.FindFirstExternalAuthByExpr(dbx.HashExp{
"collectionRef": form.collection.Id,
"provider": form.Provider,
"providerId": authUser.Id,
})
switch {
case err == nil && externalAuthRel != nil:
authRecord, err = e.App.FindRecordById(form.collection, externalAuthRel.RecordRef())
if err != nil {
return err
}
case fallbackAuthRecord != nil && fallbackAuthRecord.Collection().Id == form.collection.Id:
// fallback to the logged auth record (if any)
authRecord = fallbackAuthRecord
case authUser.Email != "":
// look for an existing auth record by the external auth record's email
authRecord, _ = e.App.FindAuthRecordByEmail(form.collection.Id, authUser.Email)
}
// ---------------------------------------------------------------
event := new(core.RecordAuthWithOAuth2RequestEvent)
event.RequestEvent = e
event.Collection = collection
event.ProviderName = form.Provider
event.ProviderClient = provider
event.OAuth2User = authUser
event.CreateData = form.CreateData
event.Record = authRecord
event.IsNewRecord = authRecord == nil
return e.App.OnRecordAuthWithOAuth2Request().Trigger(event, func(e *core.RecordAuthWithOAuth2RequestEvent) error {
if err := oauth2Submit(e, externalAuthRel); err != nil {
return firstApiError(err, e.BadRequestError("Failed to authenticate.", err))
}
meta := struct {
*auth.AuthUser
IsNew bool `json:"isNew"`
}{
AuthUser: e.OAuth2User,
IsNew: e.IsNewRecord,
}
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOAuth2, meta)
})
}
// -------------------------------------------------------------------
type recordOAuth2LoginForm struct {
collection *core.Collection
// Additional data that will be used for creating a new auth record
// if an existing OAuth2 account doesn't exist.
CreateData map[string]any `form:"createData" json:"createData"`
// The name of the OAuth2 client provider (eg. "google")
Provider string `form:"provider" json:"provider"`
// The authorization code returned from the initial request.
Code string `form:"code" json:"code"`
// The optional PKCE code verifier as part of the code_challenge sent with the initial request.
CodeVerifier string `form:"codeVerifier" json:"codeVerifier"`
// The redirect url sent with the initial request.
RedirectURL string `form:"redirectURL" json:"redirectURL"`
// @todo
// deprecated: use RedirectURL instead
// RedirectUrl will be removed after dropping v0.22 support
RedirectUrl string `form:"redirectUrl" json:"redirectUrl"`
}
func (form *recordOAuth2LoginForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Provider, validation.Required, validation.By(form.checkProviderName)),
validation.Field(&form.Code, validation.Required),
validation.Field(&form.RedirectURL, validation.Required),
)
}
func (form *recordOAuth2LoginForm) checkProviderName(value any) error {
name, _ := value.(string)
_, ok := form.collection.OAuth2.GetProviderConfig(name)
if !ok {
return validation.NewError("validation_invalid_provider", fmt.Sprintf("Provider with name %q is missing or is not enabled.", name)).
SetParams(map[string]any{"name": name})
}
return nil
}
func oldCanAssignUsername(txApp core.App, collection *core.Collection, username string) bool {
// ensure that username is unique
checkUnique := dbutils.HasSingleColumnUniqueIndex(collection.OAuth2.MappedFields.Username, collection.Indexes)
if checkUnique {
if _, err := txApp.FindFirstRecordByData(collection, collection.OAuth2.MappedFields.Username, username); err == nil {
return false // already exist
}
}
// ensure that the value matches the pattern of the username field (if text)
txtField, _ := collection.Fields.GetByName(collection.OAuth2.MappedFields.Username).(*core.TextField)
return txtField != nil && txtField.ValidatePlainValue(username) == nil
}
func oauth2Submit(e *core.RecordAuthWithOAuth2RequestEvent, optExternalAuth *core.ExternalAuth) error {
return e.App.RunInTransaction(func(txApp core.App) error {
if e.Record == nil {
// extra check to prevent creating a superuser record via
// OAuth2 in case the method is used by another action
if e.Collection.Name == core.CollectionNameSuperusers {
return errors.New("superusers are not allowed to sign-up with OAuth2")
}
payload := maps.Clone(e.CreateData)
if payload == nil {
payload = map[string]any{}
}
payload[core.FieldNameEmail] = e.OAuth2User.Email
// set a random password if none is set
if v, _ := payload[core.FieldNamePassword].(string); v == "" {
payload[core.FieldNamePassword] = security.RandomString(30)
payload[core.FieldNamePassword+"Confirm"] = payload[core.FieldNamePassword]
}
// map known fields (unless the field was explicitly submitted as part of CreateData)
if _, ok := payload[e.Collection.OAuth2.MappedFields.Id]; !ok && e.Collection.OAuth2.MappedFields.Id != "" {
payload[e.Collection.OAuth2.MappedFields.Id] = e.OAuth2User.Id
}
if _, ok := payload[e.Collection.OAuth2.MappedFields.Name]; !ok && e.Collection.OAuth2.MappedFields.Name != "" {
payload[e.Collection.OAuth2.MappedFields.Name] = e.OAuth2User.Name
}
if _, ok := payload[e.Collection.OAuth2.MappedFields.Username]; !ok &&
// no explicit username payload value and existing OAuth2 mapping
e.Collection.OAuth2.MappedFields.Username != "" &&
// extra checks for backward compatibility with earlier versions
oldCanAssignUsername(txApp, e.Collection, e.OAuth2User.Username) {
payload[e.Collection.OAuth2.MappedFields.Username] = e.OAuth2User.Username
}
if _, ok := payload[e.Collection.OAuth2.MappedFields.AvatarURL]; !ok && e.Collection.OAuth2.MappedFields.AvatarURL != "" {
mappedField := e.Collection.Fields.GetByName(e.Collection.OAuth2.MappedFields.AvatarURL)
if mappedField != nil && mappedField.Type() == core.FieldTypeFile {
// download the avatar if the mapped field is a file
avatarFile, err := func() (*filesystem.File, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
return filesystem.NewFileFromURL(ctx, e.OAuth2User.AvatarURL)
}()
if err != nil {
return err
}
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = avatarFile
} else {
// otherwise - assign the url string
payload[e.Collection.OAuth2.MappedFields.AvatarURL] = e.OAuth2User.AvatarURL
}
}
createdRecord, err := sendOAuth2RecordCreateRequest(txApp, e, payload)
if err != nil {
return err
}
e.Record = createdRecord
if e.Record.Email() == e.OAuth2User.Email && !e.Record.Verified() {
// mark as verified as long as it matches the OAuth2 data (even if the email is empty)
e.Record.SetVerified(true)
if err := txApp.Save(e.Record); err != nil {
return err
}
}
} else {
var needUpdate bool
isLoggedAuthRecord := e.Auth != nil &&
e.Auth.Id == e.Record.Id &&
e.Auth.Collection().Id == e.Record.Collection().Id
// set random password for users with unverified email
// (this is in case a malicious actor has registered previously with the user email)
if !isLoggedAuthRecord && e.Record.Email() != "" && !e.Record.Verified() {
e.Record.SetPassword(security.RandomString(30))
needUpdate = true
}
// update the existing auth record empty email if the data.OAuth2User has one
// (this is in case previously the auth record was created
// with an OAuth2 provider that didn't return an email address)
if e.Record.Email() == "" && e.OAuth2User.Email != "" {
e.Record.SetEmail(e.OAuth2User.Email)
needUpdate = true
}
// update the existing auth record verified state
// (only if the auth record doesn't have an email or the auth record email match with the one in data.OAuth2User)
if !e.Record.Verified() && (e.Record.Email() == "" || e.Record.Email() == e.OAuth2User.Email) {
e.Record.SetVerified(true)
needUpdate = true
}
if needUpdate {
if err := txApp.Save(e.Record); err != nil {
return err
}
}
}
// create ExternalAuth relation if missing
if optExternalAuth == nil {
optExternalAuth = core.NewExternalAuth(txApp)
optExternalAuth.SetCollectionRef(e.Record.Collection().Id)
optExternalAuth.SetRecordRef(e.Record.Id)
optExternalAuth.SetProvider(e.ProviderName)
optExternalAuth.SetProviderId(e.OAuth2User.Id)
if err := txApp.Save(optExternalAuth); err != nil {
return fmt.Errorf("failed to save linked rel: %w", err)
}
}
return nil
})
}
func sendOAuth2RecordCreateRequest(txApp core.App, e *core.RecordAuthWithOAuth2RequestEvent, payload map[string]any) (*core.Record, error) {
ir := &core.InternalRequest{
Method: http.MethodPost,
URL: "/api/collections/" + e.Collection.Name + "/records",
Body: payload,
}
response, err := processInternalRequest(txApp, e.RequestEvent, ir, core.RequestInfoContextOAuth2, nil)
if err != nil {
return nil, err
}
if response.Status != http.StatusOK {
return nil, errors.New("failed to create OAuth2 auth record")
}
recordResponse := struct {
Id string `json:"id"`
}{}
raw, err := json.Marshal(response.Body)
if err != nil {
return nil, err
}
if err = json.Unmarshal(raw, &recordResponse); err != nil {
return nil, err
}
return txApp.FindRecordById(e.Collection, recordResponse.Id)
}

View File

@ -0,0 +1,74 @@
package apis
import (
"encoding/json"
"net/http"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/subscriptions"
)
const (
oauth2SubscriptionTopic string = "@oauth2"
oauth2RedirectFailurePath string = "../_/#/auth/oauth2-redirect-failure"
oauth2RedirectSuccessPath string = "../_/#/auth/oauth2-redirect-success"
)
type oauth2RedirectData struct {
State string `form:"state" json:"state"`
Code string `form:"code" json:"code"`
Error string `form:"error" json:"error,omitempty"`
}
func oauth2SubscriptionRedirect(e *core.RequestEvent) error {
redirectStatusCode := http.StatusTemporaryRedirect
if e.Request.Method != http.MethodGet {
redirectStatusCode = http.StatusSeeOther
}
data := oauth2RedirectData{}
if e.Request.Method == http.MethodPost {
if err := e.BindBody(&data); err != nil {
e.App.Logger().Debug("Failed to read OAuth2 redirect data", "error", err)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
} else {
query := e.Request.URL.Query()
data.State = query.Get("state")
data.Code = query.Get("code")
data.Error = query.Get("error")
}
if data.State == "" {
e.App.Logger().Debug("Missing OAuth2 state parameter")
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
client, err := e.App.SubscriptionsBroker().ClientById(data.State)
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscriptionTopic) {
e.App.Logger().Debug("Missing or invalid OAuth2 subscription client", "error", err, "clientId", data.State)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
defer client.Unsubscribe(oauth2SubscriptionTopic)
encodedData, err := json.Marshal(data)
if err != nil {
e.App.Logger().Debug("Failed to marshalize OAuth2 redirect data", "error", err)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
msg := subscriptions.Message{
Name: oauth2SubscriptionTopic,
Data: encodedData,
}
client.Send(msg)
if data.Error != "" || data.Code == "" {
e.App.Logger().Debug("Failed OAuth2 redirect due to an error or missing code parameter", "error", data.Error, "clientId", data.State)
return e.Redirect(redirectStatusCode, oauth2RedirectFailurePath)
}
return e.Redirect(redirectStatusCode, oauth2RedirectSuccessPath)
}

View File

@ -0,0 +1,252 @@
package apis_test
import (
"context"
"net/http"
"strings"
"testing"
"time"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/subscriptions"
)
func TestRecordAuthWithOAuth2Redirect(t *testing.T) {
t.Parallel()
clientStubs := make([]map[string]subscriptions.Client, 0, 10)
for i := 0; i < 10; i++ {
c1 := subscriptions.NewDefaultClient()
c2 := subscriptions.NewDefaultClient()
c2.Subscribe("@oauth2")
c3 := subscriptions.NewDefaultClient()
c3.Subscribe("test1", "@oauth2")
c4 := subscriptions.NewDefaultClient()
c4.Subscribe("test1", "test2")
c5 := subscriptions.NewDefaultClient()
c5.Subscribe("@oauth2")
c5.Discard()
clientStubs = append(clientStubs, map[string]subscriptions.Client{
"c1": c1,
"c2": c2,
"c3": c3,
"c4": c4,
"c5": c5,
})
}
checkFailureRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
loc := res.Header.Get("Location")
if !strings.Contains(loc, "/oauth2-redirect-failure") {
t.Fatalf("Expected failure redirect, got %q", loc)
}
}
checkSuccessRedirect := func(t testing.TB, app *tests.TestApp, res *http.Response) {
loc := res.Header.Get("Location")
if !strings.Contains(loc, "/oauth2-redirect-success") {
t.Fatalf("Expected success redirect, got %q", loc)
}
}
checkClientMessages := func(t testing.TB, clientId string, msg subscriptions.Message, expectedMessages map[string][]string) {
if len(expectedMessages[clientId]) == 0 {
t.Fatalf("Unexpected client %q message, got %s:\n%s", clientId, msg.Name, msg.Data)
}
if msg.Name != "@oauth2" {
t.Fatalf("Expected @oauth2 msg.Name, got %q", msg.Name)
}
for _, txt := range expectedMessages[clientId] {
if !strings.Contains(string(msg.Data), txt) {
t.Fatalf("Failed to find %q in \n%s", txt, msg.Data)
}
}
}
beforeTestFunc := func(
clients map[string]subscriptions.Client,
expectedMessages map[string][]string,
) func(testing.TB, *tests.TestApp, *core.ServeEvent) {
return func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
for _, client := range clients {
app.SubscriptionsBroker().Register(client)
}
ctx, cancelFunc := context.WithTimeout(context.Background(), 100*time.Millisecond)
// add to the app store so that it can be cancelled manually after test completion
app.Store().Set("cancelFunc", cancelFunc)
go func() {
defer cancelFunc()
for {
select {
case msg := <-clients["c1"].Channel():
checkClientMessages(t, "c1", msg, expectedMessages)
case msg := <-clients["c2"].Channel():
checkClientMessages(t, "c2", msg, expectedMessages)
case msg := <-clients["c3"].Channel():
checkClientMessages(t, "c3", msg, expectedMessages)
case msg := <-clients["c4"].Channel():
checkClientMessages(t, "c4", msg, expectedMessages)
case msg := <-clients["c5"].Channel():
checkClientMessages(t, "c5", msg, expectedMessages)
case <-ctx.Done():
for _, c := range clients {
close(c.Channel())
}
return
}
}
}()
}
}
scenarios := []tests.ApiScenario{
{
Name: "no state query param",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123",
BeforeTestFunc: beforeTestFunc(clientStubs[0], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "invalid or missing client",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=missing",
BeforeTestFunc: beforeTestFunc(clientStubs[1], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "no code query param",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?state=" + clientStubs[2]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[2], map[string][]string{
"c3": {`"state":"` + clientStubs[2]["c3"].Id(), `"code":""`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
if clientStubs[2]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
{
Name: "error query param",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?error=example&code=123&state=" + clientStubs[3]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[3], map[string][]string{
"c3": {`"state":"` + clientStubs[3]["c3"].Id(), `"code":"123"`, `"error":"example"`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
if clientStubs[3]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
{
Name: "discarded client with @oauth2 subscription",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c5"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[4], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "client without @oauth2 subscription",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c4"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[5], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "client with @oauth2 subscription",
Method: http.MethodGet,
URL: "/api/oauth2-redirect?code=123&state=" + clientStubs[6]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[6], map[string][]string{
"c3": {`"state":"` + clientStubs[6]["c3"].Id(), `"code":"123"`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkSuccessRedirect(t, app, res)
if clientStubs[6]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
{
Name: "(POST) client with @oauth2 subscription",
Method: http.MethodPost,
URL: "/api/oauth2-redirect",
Body: strings.NewReader("code=123&state=" + clientStubs[7]["c3"].Id()),
Headers: map[string]string{
"content-type": "application/x-www-form-urlencoded",
},
BeforeTestFunc: beforeTestFunc(clientStubs[7], map[string][]string{
"c3": {`"state":"` + clientStubs[7]["c3"].Id(), `"code":"123"`},
}),
ExpectedStatus: http.StatusSeeOther,
ExpectedEvents: map[string]int{"*": 0},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkSuccessRedirect(t, app, res)
if clientStubs[7]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,99 @@
package apis
import (
"errors"
"fmt"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
)
func recordAuthWithOTP(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.OTP.Enabled {
return e.ForbiddenError("The collection is not configured to allow OTP authentication.", nil)
}
form := &authWithOTPForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
event := new(core.RecordAuthWithOTPRequestEvent)
event.RequestEvent = e
event.Collection = collection
// extra validations
// (note: returns a generic 400 as a very basic OTPs enumeration protection)
// ---
event.OTP, err = e.App.FindOTPById(form.OTPId)
if err != nil {
return e.BadRequestError("Invalid or expired OTP", err)
}
if event.OTP.CollectionRef() != collection.Id {
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is for a different collection"))
}
if event.OTP.HasExpired(collection.OTP.DurationTime()) {
return e.BadRequestError("Invalid or expired OTP", errors.New("the OTP is expired"))
}
event.Record, err = e.App.FindRecordById(event.OTP.CollectionRef(), event.OTP.RecordRef())
if err != nil {
return e.BadRequestError("Invalid or expired OTP", fmt.Errorf("missing auth record: %w", err))
}
// since otps are usually simple digit numbers we enforce an extra rate limit rule to prevent enumerations
err = checkRateLimit(e, "@pb_otp_"+event.OTP.Id+event.Record.Id, core.RateLimitRule{MaxRequests: 4, Duration: 180})
if err != nil {
return e.TooManyRequestsError("Too many attempts, please try again later with a new OTP.", nil)
}
if !event.OTP.ValidatePassword(form.Password) {
return e.BadRequestError("Invalid or expired OTP", errors.New("incorrect password"))
}
// ---
return e.App.OnRecordAuthWithOTPRequest().Trigger(event, func(e *core.RecordAuthWithOTPRequestEvent) error {
err = RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodOTP, nil)
if err != nil {
return err
}
// try to delete the used otp
if e.OTP != nil {
err = e.App.Delete(e.OTP)
if err != nil {
e.App.Logger().Error("Failed to delete used OTP", "error", err, "otpId", e.OTP.Id)
}
}
// note: we don't update the user verified state the same way as in the password reset confirmation
// at the moment because it is not clear whether the otp confirmation came from the user email
// (e.g. it could be from an sms or some other channel)
return nil
})
}
// -------------------------------------------------------------------
type authWithOTPForm struct {
OTPId string `form:"otpId" json:"otpId"`
Password string `form:"password" json:"password"`
}
func (form *authWithOTPForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.OTPId, validation.Required, validation.Length(1, 255)),
validation.Field(&form.Password, validation.Required, validation.Length(1, 71)),
)
}

View File

@ -0,0 +1,438 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestRecordAuthWithOTP(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "not an auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/auth-with-otp",
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "auth collection with disabled otp",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{"otpId":"test","password":"123456"}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
usersCol, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
usersCol.OTP.Enabled = false
if err := app.Save(usersCol); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid body",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{"email`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty body",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(``),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"otpId":{"code":"validation_required"`,
`"password":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid request data",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 256) + `",
"password":"` + strings.Repeat("a", 72) + `"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"otpId":{"code":"validation_length_out_of_range"`,
`"password":{"code":"validation_length_out_of_range"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "missing otp",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"missing",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "otp for different collection",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
client, err := app.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(client.Collection().Id)
otp.SetRecordRef(client.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "otp with wrong password",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("1234567890")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "expired otp with valid password",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
expiredDate := types.NowDateTime().AddDate(-3, 0, 0)
otp.SetRaw("created", expiredDate)
otp.SetRaw("updated", expiredDate)
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "valid otp with valid password",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 401,
ExpectedContent: []string{`"mfaId":"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithOTPRequest": 1,
"OnRecordAuthRequest": 1,
// ---
"OnModelValidate": 1,
"OnModelCreate": 1, // mfa record
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1, // otp delete
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
// ---
"OnRecordValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
{
Name: "valid otp with valid password (disabled MFA)",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + strings.Repeat("a", 15) + `",
"password":"123456"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
user.Collection().MFA.Enabled = false
if err := app.Save(user.Collection()); err != nil {
t.Fatal(err)
}
otp := core.NewOTP(app)
otp.Id = strings.Repeat("a", 15)
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"token":"`,
`"record":{`,
`"email":"test@example.com"`,
},
NotExpectedContent: []string{
`"meta":`,
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithOTPRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// ---
"OnModelValidate": 1,
"OnModelCreate": 1, // authOrigin
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelDelete": 1, // otp delete
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
// ---
"OnRecordValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:authWithOTP",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithOTP"},
{MaxRequests: 100, Label: "users:auth"},
{MaxRequests: 0, Label: "users:authWithOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:authWithOTP",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:auth"},
{MaxRequests: 0, Label: "*:authWithOTP"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - users:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithOTP"},
{MaxRequests: 0, Label: "users:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordAuthWithOTPManualRateLimiterCheck(t *testing.T) {
t.Parallel()
var storeCache map[string]any
otpAId := strings.Repeat("a", 15)
otpBId := strings.Repeat("b", 15)
scenarios := []struct {
otpId string
password string
expectedStatus int
}{
{otpAId, "12345", 400},
{otpAId, "12345", 400},
{otpAId, "12345", 400},
{otpAId, "12345", 400},
{otpAId, "123456", 429},
{otpBId, "12345", 400},
{otpBId, "123456", 200},
}
for _, s := range scenarios {
(&tests.ApiScenario{
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-otp",
Body: strings.NewReader(`{
"otpId":"` + s.otpId + `",
"password":"` + s.password + `"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
for k, v := range storeCache {
app.Store().Set(k, v)
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
user.Collection().MFA.Enabled = false
if err := app.Save(user.Collection()); err != nil {
t.Fatal(err)
}
for _, id := range []string{otpAId, otpBId} {
otp := core.NewOTP(app)
otp.Id = id
otp.SetCollectionRef(user.Collection().Id)
otp.SetRecordRef(user.Id)
otp.SetPassword("123456")
if err := app.Save(otp); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: s.expectedStatus,
ExpectedContent: []string{`"`}, // it doesn't matter anything non-empty
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
storeCache = app.Store().GetAll()
},
}).Test(t)
}
}

View File

@ -0,0 +1,97 @@
package apis
import (
"database/sql"
"errors"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/list"
)
func recordAuthWithPassword(e *core.RequestEvent) error {
collection, err := findAuthCollection(e)
if err != nil {
return err
}
if !collection.PasswordAuth.Enabled {
return e.ForbiddenError("The collection is not configured to allow password authentication.", nil)
}
form := &authWithPasswordForm{}
if err = e.BindBody(form); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
if err = form.validate(collection); err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
var foundRecord *core.Record
var foundErr error
if form.IdentityField != "" {
foundRecord, foundErr = e.App.FindFirstRecordByData(collection.Id, form.IdentityField, form.Identity)
} else {
// prioritize email lookup
isEmail := is.EmailFormat.Validate(form.Identity) == nil
if isEmail && list.ExistInSlice(core.FieldNameEmail, collection.PasswordAuth.IdentityFields) {
foundRecord, foundErr = e.App.FindAuthRecordByEmail(collection.Id, form.Identity)
}
// search by the other identity fields
if !isEmail || foundErr != nil {
for _, name := range collection.PasswordAuth.IdentityFields {
if !isEmail && name == core.FieldNameEmail {
continue // no need to search by the email field if it is not an email
}
foundRecord, foundErr = e.App.FindFirstRecordByData(collection.Id, name, form.Identity)
if foundErr == nil {
break
}
}
}
}
// ignore not found errors to allow custom record find implementations
if foundErr != nil && !errors.Is(foundErr, sql.ErrNoRows) {
return e.InternalServerError("", foundErr)
}
event := new(core.RecordAuthWithPasswordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = foundRecord
event.Identity = form.Identity
event.Password = form.Password
event.IdentityField = form.IdentityField
return e.App.OnRecordAuthWithPasswordRequest().Trigger(event, func(e *core.RecordAuthWithPasswordRequestEvent) error {
if e.Record == nil || !e.Record.ValidatePassword(e.Password) {
return e.BadRequestError("Failed to authenticate.", errors.New("invalid login credentials"))
}
return RecordAuthResponse(e.RequestEvent, e.Record, core.MFAMethodPassword, nil)
})
}
// -------------------------------------------------------------------
type authWithPasswordForm struct {
Identity string `form:"identity" json:"identity"`
Password string `form:"password" json:"password"`
// IdentityField specifies the field to use to search for the identity
// (leave it empty for "auto" detection).
IdentityField string `form:"identityField" json:"identityField"`
}
func (form *authWithPasswordForm) validate(collection *core.Collection) error {
return validation.ValidateStruct(form,
validation.Field(&form.Identity, validation.Required, validation.Length(1, 255)),
validation.Field(&form.Password, validation.Required, validation.Length(1, 255)),
validation.Field(&form.IdentityField, validation.In(list.ToInterfaceSlice(collection.PasswordAuth.IdentityFields)...)),
)
}

View File

@ -0,0 +1,514 @@
package apis_test
import (
"errors"
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordAuthWithPassword(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "disabled password auth",
Method: http.MethodPost,
URL: "/api/collections/nologin/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-auth collection",
Method: http.MethodPost,
URL: "/api/collections/demo1/auth-with-password",
Body: strings.NewReader(`{"identity":"test@example.com","password":"1234567890"}`),
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "invalid body format",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{"identity`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "empty body params",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{"identity":"","password":""}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"identity":{`,
`"password":{`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "OnRecordAuthWithPasswordRequest error response",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.OnRecordAuthWithPasswordRequest().BindFunc(func(e *core.RecordAuthWithPasswordRequestEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "valid identity field and invalid password",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"invalid"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{}`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "valid identity field (email) and valid password",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "valid identity field (username) and valid password",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"clients57772",
"password":"1234567890"
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"username":"clients57772"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "valid identity field and valid password with mismatched explicit identityField",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identityField": "username",
"identity":"test@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "valid identity field and valid password with matched explicit identityField",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identityField": "username",
"identity":"clients57772",
"password":"1234567890"
}`),
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"username":"clients57772"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "valid identity (unverified) and valid password in onlyVerified collection",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test2@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
},
},
{
Name: "already authenticated record",
Method: http.MethodPost,
URL: "/api/collections/clients/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"gk390qegs4y47wn"`,
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 1,
"OnMailerRecordAuthAlertSend": 1,
},
},
{
Name: "with mfa first auth check",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
ExpectedStatus: 401,
ExpectedContent: []string{
`"mfaId":"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
// mfa create
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
mfas, err := app.FindAllMFAsByRecord(user)
if err != nil {
t.Fatal(err)
}
if v := len(mfas); v != 1 {
t.Fatalf("Expected 1 mfa record to be created, got %d", v)
}
},
},
{
Name: "with mfa second auth check",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
Body: strings.NewReader(`{
"mfaId": "` + strings.Repeat("a", 15) + `",
"identity":"test@example.com",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
// insert a dummy mfa record
mfa := core.NewMFA(app)
mfa.Id = strings.Repeat("a", 15)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("test")
if err := app.Save(mfa); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 0, // disabled auth email alerts
"OnMailerRecordAuthAlertSend": 0,
// mfa delete
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
{
Name: "with enabled mfa but unsatisfied mfa rule (aka. skip the mfa check)",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
Body: strings.NewReader(`{
"identity":"test@example.com",
"password":"1234567890"
}`),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
users, err := app.FindCollectionByNameOrId("users")
if err != nil {
t.Fatal(err)
}
users.MFA.Enabled = true
users.MFA.Rule = "1=2"
if err := app.Save(users); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"email":"test@example.com"`,
`"token":`,
},
NotExpectedContent: []string{
// hidden fields
`"tokenKey"`,
`"password"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordAuthWithPasswordRequest": 1,
"OnRecordAuthRequest": 1,
"OnRecordEnrich": 1,
// authOrigin track
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
"OnMailerSend": 0, // disabled auth email alerts
"OnMailerRecordAuthAlertSend": 0,
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
mfas, err := app.FindAllMFAsByRecord(user)
if err != nil {
t.Fatal(err)
}
if v := len(mfas); v != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", v)
}
},
},
// rate limit checks
// -----------------------------------------------------------
{
Name: "RateLimit rule - users:authWithPassword",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithPassword"},
{MaxRequests: 100, Label: "users:auth"},
{MaxRequests: 0, Label: "users:authWithPassword"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:authWithPassword",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:auth"},
{MaxRequests: 0, Label: "*:authWithPassword"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - users:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 100, Label: "*:authWithPassword"},
{MaxRequests: 0, Label: "users:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "RateLimit rule - *:auth",
Method: http.MethodPost,
URL: "/api/collections/users/auth-with-password",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{MaxRequests: 100, Label: "abc"},
{MaxRequests: 0, Label: "*:auth"},
}
},
ExpectedStatus: 429,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -1,121 +1,123 @@
package apis
import (
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/resolvers"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/search"
)
// bindRecordCrudApi registers the record crud api endpoints and
// the corresponding handlers.
func bindRecordCrudApi(app core.App, rg *echo.Group) {
api := recordApi{app: app}
subGroup := rg.Group(
"/collections/:collection",
ActivityLogger(app),
)
subGroup.GET("/records", api.list, LoadCollectionContext(app))
subGroup.GET("/records/:id", api.view, LoadCollectionContext(app))
subGroup.POST("/records", api.create, LoadCollectionContext(app, models.CollectionTypeBase, models.CollectionTypeAuth))
subGroup.PATCH("/records/:id", api.update, LoadCollectionContext(app, models.CollectionTypeBase, models.CollectionTypeAuth))
subGroup.DELETE("/records/:id", api.delete, LoadCollectionContext(app, models.CollectionTypeBase, models.CollectionTypeAuth))
//
// note: the rate limiter is "inlined" because some of the crud actions are also used in the batch APIs
func bindRecordCrudApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
subGroup := rg.Group("/collections/{collection}/records").Unbind(DefaultRateLimitMiddlewareId)
subGroup.GET("", recordsList)
subGroup.GET("/{id}", recordView)
subGroup.POST("", recordCreate(nil)).Bind(dynamicCollectionBodyLimit(""))
subGroup.PATCH("/{id}", recordUpdate(nil)).Bind(dynamicCollectionBodyLimit(""))
subGroup.DELETE("/{id}", recordDelete(nil))
}
type recordApi struct {
app core.App
}
func (api *recordApi) list(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("", "Missing collection context.")
func recordsList(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
requestInfo := RequestInfo(c)
// forbid users and guests to query special filter/sort fields
if err := checkForAdminOnlyRuleFields(requestInfo); err != nil {
err = checkCollectionRateLimit(e, collection, "list")
if err != nil {
return err
}
if requestInfo.Admin == nil && collection.ListRule == nil {
// only admins can access if the rule is nil
return NewForbiddenError("Only admins can perform this action.", nil)
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
fieldsResolver := resolvers.NewRecordFieldResolver(
api.app.Dao(),
if collection.ListRule == nil && !requestInfo.HasSuperuserAuth() {
return e.ForbiddenError("Only superusers can perform this action.", nil)
}
// forbid users and guests to query special filter/sort fields
err = checkForSuperuserOnlyRuleFields(requestInfo)
if err != nil {
return err
}
fieldsResolver := core.NewRecordFieldResolver(
e.App,
collection,
requestInfo,
// hidden fields are searchable only by admins
requestInfo.Admin != nil,
// hidden fields are searchable only by superusers
requestInfo.HasSuperuserAuth(),
)
searchProvider := search.NewProvider(fieldsResolver).
Query(api.app.Dao().RecordQuery(collection))
Query(e.App.RecordQuery(collection))
if requestInfo.Admin == nil && collection.ListRule != nil {
if !requestInfo.HasSuperuserAuth() && collection.ListRule != nil {
searchProvider.AddFilter(search.FilterData(*collection.ListRule))
}
records := []*models.Record{}
records := []*core.Record{}
result, err := searchProvider.ParseAndExec(c.QueryParams().Encode(), &records)
result, err := searchProvider.ParseAndExec(e.Request.URL.Query().Encode(), &records)
if err != nil {
return NewBadRequestError("", err)
return firstApiError(err, e.BadRequestError("", err))
}
event := new(core.RecordsListEvent)
event.HttpContext = c
event := new(core.RecordsListRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Records = records
event.Result = result
return api.app.OnRecordsListRequest().Trigger(event, func(e *core.RecordsListEvent) error {
if e.HttpContext.Response().Committed {
return nil
return e.App.OnRecordsListRequest().Trigger(event, func(e *core.RecordsListRequestEvent) error {
if err := EnrichRecords(e.RequestEvent, e.Records); err != nil {
return firstApiError(err, e.InternalServerError("Failed to enrich records", err))
}
if err := EnrichRecords(e.HttpContext, api.app.Dao(), e.Records); err != nil {
api.app.Logger().Debug("Failed to enrich list records", slog.String("error", err.Error()))
}
return e.HttpContext.JSON(http.StatusOK, e.Result)
return e.JSON(http.StatusOK, e.Result)
})
}
func (api *recordApi) view(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("", "Missing collection context.")
func recordView(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
recordId := c.PathParam("id")
err = checkCollectionRateLimit(e, collection, "view")
if err != nil {
return err
}
recordId := e.Request.PathValue("id")
if recordId == "" {
return NewNotFoundError("", nil)
return e.NotFoundError("", nil)
}
requestInfo := RequestInfo(c)
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
if requestInfo.Admin == nil && collection.ViewRule == nil {
// only admins can access if the rule is nil
return NewForbiddenError("Only admins can perform this action.", nil)
if collection.ViewRule == nil && !requestInfo.HasSuperuserAuth() {
return e.ForbiddenError("Only superusers can perform this action.", nil)
}
ruleFunc := func(q *dbx.SelectQuery) error {
if requestInfo.Admin == nil && collection.ViewRule != nil && *collection.ViewRule != "" {
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true)
if !requestInfo.HasSuperuserAuth() && collection.ViewRule != nil && *collection.ViewRule != "" {
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
expr, err := search.FilterData(*collection.ViewRule).BuildExpr(resolver)
if err != nil {
return err
@ -126,176 +128,229 @@ func (api *recordApi) view(c echo.Context) error {
return nil
}
record, fetchErr := api.app.Dao().FindRecordById(collection.Id, recordId, ruleFunc)
record, fetchErr := e.App.FindRecordById(collection, recordId, ruleFunc)
if fetchErr != nil || record == nil {
return NewNotFoundError("", fetchErr)
return firstApiError(err, e.NotFoundError("", fetchErr))
}
event := new(core.RecordViewEvent)
event.HttpContext = c
event := new(core.RecordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return api.app.OnRecordViewRequest().Trigger(event, func(e *core.RecordViewEvent) error {
if e.HttpContext.Response().Committed {
return nil
return e.App.OnRecordViewRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
if err := EnrichRecord(e.RequestEvent, e.Record); err != nil {
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
}
if err := EnrichRecord(e.HttpContext, api.app.Dao(), e.Record); err != nil {
api.app.Logger().Debug(
"Failed to enrich view record",
slog.String("id", e.Record.Id),
slog.String("collectionName", e.Record.Collection().Name),
slog.String("error", err.Error()),
)
}
return e.HttpContext.JSON(http.StatusOK, e.Record)
return e.JSON(http.StatusOK, e.Record)
})
}
func (api *recordApi) create(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("", "Missing collection context.")
func recordCreate(optFinalizer func() error) func(e *core.RequestEvent) error {
return func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
requestInfo := RequestInfo(c)
if requestInfo.Admin == nil && collection.CreateRule == nil {
// only admins can access if the rule is nil
return NewForbiddenError("Only admins can perform this action.", nil)
if collection.IsView() {
return e.BadRequestError("Unsupported collection type.", nil)
}
hasFullManageAccess := requestInfo.Admin != nil
err = checkCollectionRateLimit(e, collection, "create")
if err != nil {
return err
}
// temporary save the record and check it against the create rule
if requestInfo.Admin == nil && collection.CreateRule != nil {
testRecord := models.NewRecord(collection)
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
hasSuperuserAuth := requestInfo.HasSuperuserAuth()
canSkipRuleCheck := hasSuperuserAuth
// special case for the first superuser creation
// ---
if !canSkipRuleCheck && collection.Name == core.CollectionNameSuperusers {
total, totalErr := e.App.CountRecords(core.CollectionNameSuperusers)
canSkipRuleCheck = totalErr == nil && total == 0
}
// ---
if !canSkipRuleCheck && collection.CreateRule == nil {
return e.ForbiddenError("Only superusers can perform this action.", nil)
}
record := core.NewRecord(collection)
data, err := recordDataFromRequest(e, record)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
}
// replace modifiers fields so that the resolved value is always
// available when accessing requestInfo.Data using just the field name
if requestInfo.HasModifierDataKeys() {
requestInfo.Data = testRecord.ReplaceModifers(requestInfo.Data)
}
// available when accessing requestInfo.Body
requestInfo.Body = data
testForm := forms.NewRecordUpsert(api.app, testRecord)
testForm.SetFullManageAccess(true)
if err := testForm.LoadRequest(c.Request(), ""); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
form := forms.NewRecordUpsert(e.App, record)
if hasSuperuserAuth {
form.GrantSuperuserAccess()
}
form.Load(data)
// force unset the verified state to prevent ManageRule misuse
if !hasFullManageAccess {
testForm.Verified = false
var isOptFinalizerCalled bool
event := new(core.RecordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
hookErr := e.App.OnRecordCreateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
form.SetApp(e.App)
form.SetRecord(e.Record)
// temporary save the record and check it against the create and manage rules
if !canSkipRuleCheck && e.Collection.CreateRule != nil {
// temporary grant manager access level
form.GrantManagerAccess()
// manually unset the verified field to prevent manage API rule misuse in case the rule relies on it
initialVerified := e.Record.Verified()
if initialVerified {
e.Record.SetVerified(false)
}
createRuleFunc := func(q *dbx.SelectQuery) error {
if *collection.CreateRule == "" {
if *e.Collection.CreateRule == "" {
return nil // no create rule to resolve
}
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true)
expr, err := search.FilterData(*collection.CreateRule).BuildExpr(resolver)
resolver := core.NewRecordFieldResolver(e.App, e.Collection, requestInfo, true)
expr, err := search.FilterData(*e.Collection.CreateRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
return nil
}
testErr := testForm.DrySubmit(func(txDao *daos.Dao) error {
foundRecord, err := txDao.FindRecordById(collection.Id, testRecord.Id, createRuleFunc)
testErr := form.DrySubmit(func(txApp core.App, drySavedRecord *core.Record) error {
foundRecord, err := txApp.FindRecordById(drySavedRecord.Collection(), drySavedRecord.Id, createRuleFunc)
if err != nil {
return fmt.Errorf("DrySubmit create rule failure: %w", err)
}
hasFullManageAccess = hasAuthManageAccess(txDao, foundRecord, requestInfo)
// reset the form access level in case it satisfies the Manage API rule
if !hasAuthManageAccess(txApp, requestInfo, foundRecord) {
form.ResetAccess()
}
return nil
})
if testErr != nil {
return NewBadRequestError("Failed to create record.", testErr)
return e.BadRequestError("Failed to create record.", testErr)
}
// restore initial verified state (it will be further validated on submit)
if initialVerified != e.Record.Verified() {
e.Record.SetVerified(initialVerified)
}
}
record := models.NewRecord(collection)
form := forms.NewRecordUpsert(api.app, record)
form.SetFullManageAccess(hasFullManageAccess)
// load request
if err := form.LoadRequest(c.Request(), ""); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
err := form.Submit()
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to create record.", err))
}
event := new(core.RecordCreateEvent)
event.HttpContext = c
event.Collection = collection
event.Record = record
event.UploadedFiles = form.FilesToUpload()
// create the record
return form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(m *models.Record) error {
event.Record = m
return api.app.OnRecordBeforeCreateRequest().Trigger(event, func(e *core.RecordCreateEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to create record.", err)
err = EnrichRecord(e.RequestEvent, e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
}
if err := EnrichRecord(e.HttpContext, api.app.Dao(), e.Record); err != nil {
api.app.Logger().Debug(
"Failed to enrich create record",
slog.String("id", e.Record.Id),
slog.String("collectionName", e.Record.Collection().Name),
slog.String("error", err.Error()),
)
err = e.JSON(http.StatusOK, e.Record)
if err != nil {
return err
}
if optFinalizer != nil {
isOptFinalizerCalled = true
err = optFinalizer()
if err != nil {
return firstApiError(err, e.InternalServerError("", err))
}
}
return nil
})
if hookErr != nil {
return hookErr
}
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
if !isOptFinalizerCalled && optFinalizer != nil {
if err := optFinalizer(); err != nil {
return firstApiError(err, e.InternalServerError("", err))
}
}
return api.app.OnRecordAfterCreateRequest().Trigger(event, func(e *core.RecordCreateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Record)
})
})
}
})
}
func (api *recordApi) update(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("", "Missing collection context.")
func recordUpdate(optFinalizer func() error) func(e *core.RequestEvent) error {
return func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
recordId := c.PathParam("id")
if collection.IsView() {
return e.BadRequestError("Unsupported collection type.", nil)
}
err = checkCollectionRateLimit(e, collection, "update")
if err != nil {
return err
}
recordId := e.Request.PathValue("id")
if recordId == "" {
return NewNotFoundError("", nil)
return e.NotFoundError("", nil)
}
requestInfo := RequestInfo(c)
if requestInfo.Admin == nil && collection.UpdateRule == nil {
// only admins can access if the rule is nil
return NewForbiddenError("Only admins can perform this action.", nil)
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
// eager fetch the record so that the modifier field values are replaced
// and available when accessing requestInfo.Data using just the field name
if requestInfo.HasModifierDataKeys() {
record, err := api.app.Dao().FindRecordById(collection.Id, recordId)
if err != nil || record == nil {
return NewNotFoundError("", err)
hasSuperuserAuth := requestInfo.HasSuperuserAuth()
if !hasSuperuserAuth && collection.UpdateRule == nil {
return firstApiError(err, e.ForbiddenError("Only superusers can perform this action.", nil))
}
requestInfo.Data = record.ReplaceModifers(requestInfo.Data)
// eager fetch the record so that the modifiers field values can be resolved
record, err := e.App.FindRecordById(collection, recordId)
if err != nil {
return firstApiError(err, e.NotFoundError("", err))
}
data, err := recordDataFromRequest(e, record)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to read the submitted data.", err))
}
// replace modifiers fields so that the resolved value is always
// available when accessing requestInfo.Body
requestInfo.Body = data
ruleFunc := func(q *dbx.SelectQuery) error {
if requestInfo.Admin == nil && collection.UpdateRule != nil && *collection.UpdateRule != "" {
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true)
if !hasSuperuserAuth && collection.UpdateRule != nil && *collection.UpdateRule != "" {
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
expr, err := search.FilterData(*collection.UpdateRule).BuildExpr(resolver)
if err != nil {
return err
@ -306,78 +361,105 @@ func (api *recordApi) update(c echo.Context) error {
return nil
}
// fetch record
record, fetchErr := api.app.Dao().FindRecordById(collection.Id, recordId, ruleFunc)
if fetchErr != nil || record == nil {
return NewNotFoundError("", fetchErr)
// refetch with access checks
record, err = e.App.FindRecordById(collection, recordId, ruleFunc)
if err != nil {
return firstApiError(err, e.NotFoundError("", err))
}
form := forms.NewRecordUpsert(api.app, record)
form.SetFullManageAccess(requestInfo.Admin != nil || hasAuthManageAccess(api.app.Dao(), record, requestInfo))
// load request
if err := form.LoadRequest(c.Request(), ""); err != nil {
return NewBadRequestError("Failed to load the submitted data due to invalid formatting.", err)
form := forms.NewRecordUpsert(e.App, record)
if hasSuperuserAuth {
form.GrantSuperuserAccess()
}
form.Load(data)
event := new(core.RecordUpdateEvent)
event.HttpContext = c
var isOptFinalizerCalled bool
event := new(core.RecordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
event.UploadedFiles = form.FilesToUpload()
// update the record
return form.Submit(func(next forms.InterceptorNextFunc[*models.Record]) forms.InterceptorNextFunc[*models.Record] {
return func(m *models.Record) error {
event.Record = m
return api.app.OnRecordBeforeUpdateRequest().Trigger(event, func(e *core.RecordUpdateEvent) error {
if err := next(e.Record); err != nil {
return NewBadRequestError("Failed to update record.", err)
hookErr := e.App.OnRecordUpdateRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
form.SetApp(e.App)
form.SetRecord(e.Record)
if !form.HasManageAccess() && hasAuthManageAccess(e.App, requestInfo, e.Record) {
form.GrantManagerAccess()
}
if err := EnrichRecord(e.HttpContext, api.app.Dao(), e.Record); err != nil {
api.app.Logger().Debug(
"Failed to enrich update record",
slog.String("id", e.Record.Id),
slog.String("collectionName", e.Record.Collection().Name),
slog.String("error", err.Error()),
)
err := form.Submit()
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to update record.", err))
}
err = EnrichRecord(e.RequestEvent, e.Record)
if err != nil {
return firstApiError(err, e.InternalServerError("Failed to enrich record", err))
}
err = e.JSON(http.StatusOK, e.Record)
if err != nil {
return err
}
if optFinalizer != nil {
isOptFinalizerCalled = true
err = optFinalizer()
if err != nil {
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
}
}
return nil
})
if hookErr != nil {
return hookErr
}
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
if !isOptFinalizerCalled && optFinalizer != nil {
if err := optFinalizer(); err != nil {
return firstApiError(err, e.InternalServerError("", fmt.Errorf("update optFinalizer error: %w", err)))
}
}
return api.app.OnRecordAfterUpdateRequest().Trigger(event, func(e *core.RecordUpdateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.Record)
})
})
}
})
}
func (api *recordApi) delete(c echo.Context) error {
collection, _ := c.Get(ContextCollectionKey).(*models.Collection)
if collection == nil {
return NewNotFoundError("", "Missing collection context.")
func recordDelete(optFinalizer func() error) func(e *core.RequestEvent) error {
return func(e *core.RequestEvent) error {
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue("collection"))
if err != nil || collection == nil {
return e.NotFoundError("Missing collection context.", err)
}
recordId := c.PathParam("id")
if collection.IsView() {
return e.BadRequestError("Unsupported collection type.", nil)
}
err = checkCollectionRateLimit(e, collection, "delete")
if err != nil {
return err
}
recordId := e.Request.PathValue("id")
if recordId == "" {
return NewNotFoundError("", nil)
return e.NotFoundError("", nil)
}
requestInfo := RequestInfo(c)
requestInfo, err := e.RequestInfo()
if err != nil {
return firstApiError(err, e.BadRequestError("", err))
}
if requestInfo.Admin == nil && collection.DeleteRule == nil {
// only admins can access if the rule is nil
return NewForbiddenError("Only admins can perform this action.", nil)
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule == nil {
return e.ForbiddenError("Only superusers can perform this action.", nil)
}
ruleFunc := func(q *dbx.SelectQuery) error {
if requestInfo.Admin == nil && collection.DeleteRule != nil && *collection.DeleteRule != "" {
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), collection, requestInfo, true)
if !requestInfo.HasSuperuserAuth() && collection.DeleteRule != nil && *collection.DeleteRule != "" {
resolver := core.NewRecordFieldResolver(e.App, collection, requestInfo, true)
expr, err := search.FilterData(*collection.DeleteRule).BuildExpr(resolver)
if err != nil {
return err
@ -388,28 +470,130 @@ func (api *recordApi) delete(c echo.Context) error {
return nil
}
record, fetchErr := api.app.Dao().FindRecordById(collection.Id, recordId, ruleFunc)
if fetchErr != nil || record == nil {
return NewNotFoundError("", fetchErr)
record, err := e.App.FindRecordById(collection, recordId, ruleFunc)
if err != nil || record == nil {
return e.NotFoundError("", err)
}
event := new(core.RecordDeleteEvent)
event.HttpContext = c
var isOptFinalizerCalled bool
event := new(core.RecordRequestEvent)
event.RequestEvent = e
event.Collection = collection
event.Record = record
return api.app.OnRecordBeforeDeleteRequest().Trigger(event, func(e *core.RecordDeleteEvent) error {
// delete the record
if err := api.app.Dao().DeleteRecord(e.Record); err != nil {
return NewBadRequestError("Failed to delete record. Make sure that the record is not part of a required relation reference.", err)
hookErr := e.App.OnRecordDeleteRequest().Trigger(event, func(e *core.RecordRequestEvent) error {
if err := e.App.Delete(e.Record); err != nil {
return firstApiError(err, e.BadRequestError("Failed to delete record. Make sure that the record is not part of a required relation reference.", err))
}
err = e.NoContent(http.StatusNoContent)
if err != nil {
return err
}
if optFinalizer != nil {
isOptFinalizerCalled = true
err = optFinalizer()
if err != nil {
return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
}
}
return nil
})
if hookErr != nil {
return hookErr
}
// e.g. in case the regular hook chain was stopped and the finalizer cannot be executed as part of the last e.Next() task
if !isOptFinalizerCalled && optFinalizer != nil {
if err := optFinalizer(); err != nil {
return firstApiError(err, e.InternalServerError("", fmt.Errorf("delete optFinalizer error: %w", err)))
}
}
return api.app.OnRecordAfterDeleteRequest().Trigger(event, func(e *core.RecordDeleteEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.NoContent(http.StatusNoContent)
})
})
}
// -------------------------------------------------------------------
func recordDataFromRequest(e *core.RequestEvent, record *core.Record) (map[string]any, error) {
info, err := e.RequestInfo()
if err != nil {
return nil, err
}
// resolve regular fields
result := record.ReplaceModifiers(info.Body)
// resolve uploaded files
uploadedFiles, err := extractUploadedFiles(e.Request, record.Collection(), "")
if err != nil {
return nil, err
}
if len(uploadedFiles) > 0 {
for k, v := range uploadedFiles {
result[k] = v
}
result = record.ReplaceModifiers(result)
}
isAuth := record.Collection().IsAuth()
// unset hidden fields for non-superusers
if !info.HasSuperuserAuth() {
for _, f := range record.Collection().Fields {
if f.GetHidden() {
// exception for the auth collection "password" field
if isAuth && f.GetName() == core.FieldNamePassword {
continue
}
delete(result, f.GetName())
}
}
}
return result, nil
}
func extractUploadedFiles(request *http.Request, collection *core.Collection, prefix string) (map[string][]*filesystem.File, error) {
contentType := request.Header.Get("content-type")
if !strings.HasPrefix(contentType, "multipart/form-data") {
return nil, nil // not multipart/form-data request
}
result := map[string][]*filesystem.File{}
for _, field := range collection.Fields {
if field.Type() != core.FieldTypeFile {
continue
}
baseKey := field.GetName()
keys := []string{
baseKey,
// prepend and append modifiers
"+" + baseKey,
baseKey + "+",
}
for _, k := range keys {
if prefix != "" {
k = prefix + "." + k
}
files, err := FindUploadedFiles(request, k)
if err != nil && !errors.Is(err, http.ErrMissingFile) {
return nil, err
}
if len(files) > 0 {
result[k] = files
}
}
}
return result, nil
}

View File

@ -0,0 +1,314 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudAuthOriginList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with authOrigins",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"9r2j0m74260ur8i"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without authOrigins",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"9r2j0m74260ur8i"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"fingerprint": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"fingerprint":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudAuthOriginUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"fingerprint":"abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameAuthOrigins + "/records/9r2j0m74260ur8i",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"id":"9r2j0m74260ur8i"`,
`"fingerprint":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,316 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudExternalAuthList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with externalAuths",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"f1z5b3843pzc964"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without externalAuths",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// users, test2@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6Im9hcDY0MGNvdDR5cnUycyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.GfJo6EHIobgas_AXt-M-tj5IoQendPnrkMSe9ExuSEY",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"dlmflokuq1xl342"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"provider": "github",
"providerId": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"recordRef":"4q1xlclmfloku33"`,
`"providerId":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudExternalAuthUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"providerId": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameExternalAuths + "/records/dlmflokuq1xl342",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"id":"dlmflokuq1xl342"`,
`"providerId":"abc"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,388 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudMFAList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with mfas",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"user1_0"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without mfas",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFAView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"user1_0"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFADelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFACreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"method": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"recordRef":"4q1xlclmfloku33"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudMFAUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"method":"abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameMFAs + "/records/user1_0",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubMFARecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"id":"user1_0"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,388 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudOTPList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
{
Name: "regular auth with otps",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":1`,
`"totalPages":1`,
`"id":"user1_0"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 1,
},
},
{
Name: "regular auth without otps",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalItems":0`,
`"totalPages":0`,
`"items":[]`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 200,
ExpectedContent: []string{`"id":"user1_0"`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// clients, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6ImdrMzkwcWVnczR5NDd3biIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoidjg1MXE0cjc5MHJoa25sIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.0ONnm_BsvPRZyDNT31GN1CKUB6uQRxvVvQ-Wc9AZfG0",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelDeleteExecute": 1,
"OnModelAfterDeleteSuccess": 1,
"OnRecordDelete": 1,
"OnRecordDeleteExecute": 1,
"OnRecordAfterDeleteSuccess": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"recordRef": "4q1xlclmfloku33",
"collectionRef": "_pb_users_auth_",
"password": "abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"recordRef":"4q1xlclmfloku33"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudOTPUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"password":"abc"
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "owner regular auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameOTPs + "/records/user1_0",
Headers: map[string]string{
// superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
if err := tests.StubOTPRecords(app); err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"id":"user1_0"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

View File

@ -0,0 +1,371 @@
package apis_test
import (
"net/http"
"strings"
"testing"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestRecordCrudSuperuserList(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"page":1`,
`"perPage":30`,
`"totalPages":1`,
`"totalItems":4`,
`"items":[{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordsListRequest": 1,
"OnRecordEnrich": 4,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserView(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodGet,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"id":"sywbhecnh46rhm0"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordViewRequest": 1,
"OnRecordEnrich": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserDelete(t *testing.T) {
t.Parallel()
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sbmbsdb40jyxf7h",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 204,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 4, // + 3 AuthOrigins
"OnModelDeleteExecute": 4,
"OnModelAfterDeleteSuccess": 4,
"OnRecordDelete": 4,
"OnRecordDeleteExecute": 4,
"OnRecordAfterDeleteSuccess": 4,
},
},
{
Name: "delete the last superuser",
Method: http.MethodDelete,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// delete all other superusers
superusers, err := app.FindAllRecords(core.CollectionNameSuperusers, dbx.Not(dbx.HashExp{"id": "sywbhecnh46rhm0"}))
if err != nil {
t.Fatal(err)
}
for _, superuser := range superusers {
if err = app.Delete(superuser); err != nil {
t.Fatal(err)
}
}
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordDeleteRequest": 1,
"OnModelDelete": 1,
"OnModelAfterDeleteError": 1,
"OnRecordDelete": 1,
"OnRecordAfterDeleteError": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserCreate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"email": "test_new@example.com",
"password": "1234567890",
"passwordConfirm": "1234567890",
"verified": false
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "guest creating first superuser",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Body: body(),
BeforeTestFunc: func(t testing.TB, app *tests.TestApp, e *core.ServeEvent) {
// delete all superusers
_, err := app.DB().NewQuery("DELETE FROM {{" + core.CollectionNameSuperusers + "}}").Execute()
if err != nil {
t.Fatal(err)
}
},
ExpectedContent: []string{
`"collectionName":"_superusers"`,
`"verified":true`,
},
NotExpectedContent: []string{
// because the action has no auth the email field shouldn't be returned if emailVisibility is not set
`"email"`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
{
Name: "superusers auth",
Method: http.MethodPost,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"collectionName":"_superusers"`,
`"email":"test_new@example.com"`,
`"verified":true`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordCreateRequest": 1,
"OnRecordEnrich": 1,
"OnModelCreate": 1,
"OnModelCreateExecute": 1,
"OnModelAfterCreateSuccess": 1,
"OnModelValidate": 1,
"OnRecordCreate": 1,
"OnRecordCreateExecute": 1,
"OnRecordAfterCreateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}
func TestRecordCrudSuperuserUpdate(t *testing.T) {
t.Parallel()
body := func() *strings.Reader {
return strings.NewReader(`{
"email": "test_new@example.com",
"verified": true
}`)
}
scenarios := []tests.ApiScenario{
{
Name: "guest",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "non-superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// users, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
Body: body(),
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "superusers auth",
Method: http.MethodPatch,
URL: "/api/collections/" + core.CollectionNameSuperusers + "/records/sywbhecnh46rhm0",
Headers: map[string]string{
// _superusers, test@example.com
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
Body: body(),
ExpectedContent: []string{
`"collectionName":"_superusers"`,
`"id":"sywbhecnh46rhm0"`,
`"email":"test_new@example.com"`,
`"verified":true`,
},
ExpectedStatus: 200,
ExpectedEvents: map[string]int{
"*": 0,
"OnRecordUpdateRequest": 1,
"OnRecordEnrich": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnRecordUpdate": 1,
"OnRecordUpdateExecute": 1,
"OnRecordAfterUpdateSuccess": 1,
"OnRecordValidate": 1,
},
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,121 +1,111 @@
package apis
import (
"database/sql"
"errors"
"fmt"
"log"
"log/slog"
"net/http"
"strings"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/resolvers"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/rest"
"github.com/pocketbase/pocketbase/mails"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/security"
)
const ContextRequestInfoKey = "requestInfo"
const (
expandQueryParam = "expand"
fieldsQueryParam = "fields"
)
const expandQueryParam = "expand"
const fieldsQueryParam = "fields"
// Deprecated: Use RequestInfo instead.
func RequestData(c echo.Context) *models.RequestInfo {
log.Println("RequestData(c) is deprecated and will be removed in the future! You can replace it with RequestInfo(c).")
return RequestInfo(c)
}
// RequestInfo exports cached common request data fields
// (query, body, logged auth state, etc.) from the provided context.
func RequestInfo(c echo.Context) *models.RequestInfo {
// return cached to avoid copying the body multiple times
if v := c.Get(ContextRequestInfoKey); v != nil {
if data, ok := v.(*models.RequestInfo); ok {
// refresh auth state
data.AuthRecord, _ = c.Get(ContextAuthRecordKey).(*models.Record)
data.Admin, _ = c.Get(ContextAdminKey).(*models.Admin)
return data
}
}
result := &models.RequestInfo{
Context: models.RequestInfoContextDefault,
Method: c.Request().Method,
Query: map[string]any{},
Data: map[string]any{},
Headers: map[string]any{},
}
// extract the first value of all headers and normalizes the keys
// ("X-Token" is converted to "x_token")
for k, v := range c.Request().Header {
if len(v) > 0 {
result.Headers[inflector.Snakecase(k)] = v[0]
}
}
result.AuthRecord, _ = c.Get(ContextAuthRecordKey).(*models.Record)
result.Admin, _ = c.Get(ContextAdminKey).(*models.Admin)
echo.BindQueryParams(c, &result.Query)
rest.BindBody(c, &result.Data)
c.Set(ContextRequestInfoKey, result)
return result
}
// RecordAuthResponse writes standardised json record auth response
// RecordAuthResponse writes standardized json record auth response
// into the specified request context.
func RecordAuthResponse(
app core.App,
c echo.Context,
authRecord *models.Record,
meta any,
finalizers ...func(token string) error,
) error {
if !authRecord.Verified() && authRecord.Collection().AuthOptions().OnlyVerified {
return NewForbiddenError("Please verify your account first.", nil)
}
token, tokenErr := tokens.NewRecordAuthToken(app, authRecord)
//
// The authMethod argument specify the name of the current authentication method (eg. password, oauth2, etc.)
// that it is used primarily as an auth identifier during MFA and for login alerts.
//
// Set authMethod to empty string if you want to ignore the MFA checks and the login alerts
// (can be also adjusted additionally via the OnRecordAuthRequest hook).
func RecordAuthResponse(e *core.RequestEvent, authRecord *core.Record, authMethod string, meta any) error {
token, tokenErr := authRecord.NewAuthToken()
if tokenErr != nil {
return NewBadRequestError("Failed to create auth token.", tokenErr)
return e.InternalServerError("Failed to create auth token.", tokenErr)
}
event := new(core.RecordAuthEvent)
event.HttpContext = c
return recordAuthResponse(e, authRecord, token, authMethod, meta)
}
func recordAuthResponse(e *core.RequestEvent, authRecord *core.Record, token string, authMethod string, meta any) error {
originalRequestInfo, err := e.RequestInfo()
if err != nil {
return err
}
ok, err := e.App.CanAccessRecord(authRecord, originalRequestInfo, authRecord.Collection().AuthRule)
if !ok {
return firstApiError(err, e.ForbiddenError("The request doesn't satisfy the collection requirements to authenticate.", err))
}
event := new(core.RecordAuthRequestEvent)
event.RequestEvent = e
event.Collection = authRecord.Collection()
event.Record = authRecord
event.Token = token
event.Meta = meta
event.AuthMethod = authMethod
return app.OnRecordAuthRequest().Trigger(event, func(e *core.RecordAuthEvent) error {
if e.HttpContext.Response().Committed {
return e.App.OnRecordAuthRequest().Trigger(event, func(e *core.RecordAuthRequestEvent) error {
if e.Written() {
return nil
}
// allow always returning the email address of the authenticated account
// MFA
// ---
mfaId, err := checkMFA(e.RequestEvent, e.Record, e.AuthMethod)
if err != nil {
return err
}
// require additional authentication
if mfaId != "" {
return e.JSON(http.StatusUnauthorized, map[string]string{
"mfaId": mfaId,
})
}
// ---
// create a shallow copy of the cached request data and adjust it to the current auth record
requestInfo := *originalRequestInfo
requestInfo.Auth = e.Record
err = triggerRecordEnrichHooks(e.App, &requestInfo, []*core.Record{e.Record}, func() error {
if e.Record.IsSuperuser() {
e.Record.Unhide(e.Record.Collection().Fields.FieldNames()...)
}
// allow always returning the email address of the authenticated model
e.Record.IgnoreEmailVisibility(true)
// expand record relations
expands := strings.Split(c.QueryParam(expandQueryParam), ",")
expands := strings.Split(e.Request.URL.Query().Get(expandQueryParam), ",")
if len(expands) > 0 {
// create a copy of the cached request data and adjust it to the current auth record
requestInfo := *RequestInfo(e.HttpContext)
requestInfo.Admin = nil
requestInfo.AuthRecord = e.Record
failed := app.Dao().ExpandRecord(
e.Record,
expands,
expandFetch(app.Dao(), &requestInfo),
)
failed := e.App.ExpandRecord(e.Record, expands, expandFetch(e.App, &requestInfo))
if len(failed) > 0 {
app.Logger().Debug("[RecordAuthResponse] Failed to expand relations", slog.Any("errors", failed))
e.App.Logger().Warn("[recordAuthResponse] Failed to expand relations", "error", failed)
}
}
return nil
})
if err != nil {
return err
}
if e.AuthMethod != "" && authRecord.Collection().AuthAlert.Enabled {
if err = authAlert(e.RequestEvent, e.Record); err != nil {
e.App.Logger().Warn("[recordAuthResponse] Failed to send login alert", "error", err)
}
}
@ -128,68 +118,254 @@ func RecordAuthResponse(
result["meta"] = e.Meta
}
for _, f := range finalizers {
if err := f(e.Token); err != nil {
return err
}
return e.JSON(http.StatusOK, result)
})
}
// wantsMFA checks whether to enable MFA for the specified auth record based on its MFA rule.
func wantsMFA(e *core.RequestEvent, record *core.Record) (bool, error) {
rule := record.Collection().MFA.Rule
if rule == "" {
return true, nil
}
return e.HttpContext.JSON(http.StatusOK, result)
})
requestInfo, err := e.RequestInfo()
if err != nil {
return false, err
}
var exists bool
query := e.App.RecordQuery(record.Collection()).
Select("(1)").
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
// parse and apply the access rule filter
resolver := core.NewRecordFieldResolver(e.App, record.Collection(), requestInfo, true)
expr, err := search.FilterData(rule).BuildExpr(resolver)
if err != nil {
return false, err
}
resolver.UpdateQuery(query)
err = query.AndWhere(expr).Limit(1).Row(&exists)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, err
}
return exists, nil
}
// checkMFA handles any MFA auth checks that needs to be performed for the specified request event.
// Returns the mfaId that needs to be written as response to the user.
//
// (note: all auth methods are treated as equal and there is no requirement for "pairing").
func checkMFA(e *core.RequestEvent, authRecord *core.Record, currentAuthMethod string) (string, error) {
if !authRecord.Collection().MFA.Enabled || currentAuthMethod == "" {
return "", nil
}
ok, err := wantsMFA(e, authRecord)
if !ok {
if err != nil {
return "", e.BadRequestError("Failed to authenticate.", fmt.Errorf("MFA rule failure: %w", err))
}
return "", nil // no mfa needed for this auth record
}
// read the mfaId either from the qyery params or request body
mfaId := e.Request.URL.Query().Get("mfaId")
if mfaId == "" {
// check the body
data := struct {
MfaId string `form:"mfaId" json:"mfaId" xml:"mfaId"`
}{}
if err := e.BindBody(&data); err != nil {
return "", firstApiError(err, e.BadRequestError("Failed to read MFA Id", err))
}
mfaId = data.MfaId
}
// first-time auth
// ---
if mfaId == "" {
mfa := core.NewMFA(e.App)
mfa.SetCollectionRef(authRecord.Collection().Id)
mfa.SetRecordRef(authRecord.Id)
mfa.SetMethod(currentAuthMethod)
if err := e.App.Save(mfa); err != nil {
return "", firstApiError(err, e.InternalServerError("Failed to create MFA record", err))
}
return mfa.Id, nil
}
// second-time auth
// ---
mfa, err := e.App.FindMFAById(mfaId)
deleteMFA := func() {
// try to delete the expired mfa
if mfa != nil {
if deleteErr := e.App.Delete(mfa); deleteErr != nil {
e.App.Logger().Warn("Failed to delete expired MFA record", "error", deleteErr, "mfaId", mfa.Id)
}
}
}
if err != nil || mfa.HasExpired(authRecord.Collection().MFA.DurationTime()) {
deleteMFA()
return "", firstApiError(err, e.BadRequestError("Invalid or expired MFA session.", err))
}
if mfa.RecordRef() != authRecord.Id || mfa.CollectionRef() != authRecord.Collection().Id {
return "", e.BadRequestError("Invalid MFA session.", nil)
}
if mfa.Method() == currentAuthMethod {
return "", e.BadRequestError("A different authentication method is required.", nil)
}
deleteMFA()
return "", nil
}
// EnrichRecord parses the request context and enrich the provided record:
// - expands relations (if defaultExpands and/or ?expand query param is set)
// - ensures that the emails of the auth record and its expanded auth relations
// are visible only for the current logged admin, record owner or record with manage access
func EnrichRecord(c echo.Context, dao *daos.Dao, record *models.Record, defaultExpands ...string) error {
return EnrichRecords(c, dao, []*models.Record{record}, defaultExpands...)
// are visible only for the current logged superuser, record owner or record with manage access
func EnrichRecord(e *core.RequestEvent, record *core.Record, defaultExpands ...string) error {
return EnrichRecords(e, []*core.Record{record}, defaultExpands...)
}
// EnrichRecords parses the request context and enriches the provided records:
// - expands relations (if defaultExpands and/or ?expand query param is set)
// - ensures that the emails of the auth records and their expanded auth relations
// are visible only for the current logged admin, record owner or record with manage access
func EnrichRecords(c echo.Context, dao *daos.Dao, records []*models.Record, defaultExpands ...string) error {
requestInfo := RequestInfo(c)
if err := autoIgnoreAuthRecordsEmailVisibility(dao, records, requestInfo); err != nil {
return fmt.Errorf("failed to resolve email visibility: %w", err)
// are visible only for the current logged superuser, record owner or record with manage access
//
// Note: Expects all records to be from the same collection!
func EnrichRecords(e *core.RequestEvent, records []*core.Record, defaultExpands ...string) error {
if len(records) == 0 {
return nil
}
info, err := e.RequestInfo()
if err != nil {
return err
}
return triggerRecordEnrichHooks(e.App, info, records, func() error {
expands := defaultExpands
if param := c.QueryParam(expandQueryParam); param != "" {
if param := e.Request.URL.Query().Get(expandQueryParam); param != "" {
expands = append(expands, strings.Split(param, ",")...)
}
if len(expands) == 0 {
return nil // nothing to expand
err := defaultEnrichRecords(e.App, info, records, expands...)
if err != nil {
// only log as it is not critical
e.App.Logger().Warn("failed to apply default enriching", "error", err)
}
errs := dao.ExpandRecords(records, expands, expandFetch(dao, requestInfo))
if len(errs) > 0 {
return fmt.Errorf("failed to expand: %v", errs)
return nil
})
}
var iterate func(record *core.Record) error
type iterator[T any] struct {
items []T
index int
}
func (ri *iterator[T]) next() T {
var item T
if ri.index < len(ri.items) {
item = ri.items[ri.index]
ri.index++
}
return item
}
func triggerRecordEnrichHooks(app core.App, requestInfo *core.RequestInfo, records []*core.Record, finalizer func() error) error {
it := iterator[*core.Record]{items: records}
enrichHook := app.OnRecordEnrich()
event := new(core.RecordEnrichEvent)
event.App = app
event.RequestInfo = requestInfo
iterate = func(record *core.Record) error {
if record == nil {
return nil
}
event.Record = record
return enrichHook.Trigger(event, func(ee *core.RecordEnrichEvent) error {
next := it.next()
if next == nil {
if finalizer != nil {
return finalizer()
}
return nil
}
event.App = ee.App // in case it was replaced with a transaction
event.Record = next
err := iterate(next)
event.App = app
event.Record = record
return err
})
}
return iterate(it.next())
}
func defaultEnrichRecords(app core.App, requestInfo *core.RequestInfo, records []*core.Record, expands ...string) error {
err := autoResolveRecordsFlags(app, records, requestInfo)
if err != nil {
return fmt.Errorf("failed to resolve records flags: %w", err)
}
if len(expands) > 0 {
expandErrs := app.ExpandRecords(records, expands, expandFetch(app, requestInfo))
if len(expandErrs) > 0 {
errsSlice := make([]error, 0, len(expandErrs))
for key, err := range expandErrs {
errsSlice = append(errsSlice, fmt.Errorf("failed to expand %q: %w", key, err))
}
return fmt.Errorf("failed to expand records: %w", errors.Join(errsSlice...))
}
}
return nil
}
// expandFetch is the records fetch function that is used to expand related records.
func expandFetch(
dao *daos.Dao,
requestInfo *models.RequestInfo,
) daos.ExpandFetchFunc {
return func(relCollection *models.Collection, relIds []string) ([]*models.Record, error) {
records, err := dao.FindRecordsByIds(relCollection.Id, relIds, func(q *dbx.SelectQuery) error {
if requestInfo.Admin != nil {
return nil // admins can access everything
func expandFetch(app core.App, originalRequestInfo *core.RequestInfo) core.ExpandFetchFunc {
requestInfoClone := *originalRequestInfo
requestInfoPtr := &requestInfoClone
requestInfoPtr.Context = core.RequestInfoContextExpand
return func(relCollection *core.Collection, relIds []string) ([]*core.Record, error) {
records, findErr := app.FindRecordsByIds(relCollection.Id, relIds, func(q *dbx.SelectQuery) error {
if requestInfoPtr.Auth != nil && requestInfoPtr.Auth.IsSuperuser() {
return nil // superusers can access everything
}
if relCollection.ViewRule == nil {
return fmt.Errorf("only admins can view collection %q records", relCollection.Name)
return fmt.Errorf("only superusers can view collection %q records", relCollection.Name)
}
if *relCollection.ViewRule != "" {
resolver := resolvers.NewRecordFieldResolver(dao, relCollection, requestInfo, true)
resolver := core.NewRecordFieldResolver(app, relCollection, requestInfoPtr, true)
expr, err := search.FilterData(*(relCollection.ViewRule)).BuildExpr(resolver)
if err != nil {
return err
@ -200,50 +376,66 @@ func expandFetch(
return nil
})
if err == nil && len(records) > 0 {
autoIgnoreAuthRecordsEmailVisibility(dao, records, requestInfo)
if findErr != nil {
return nil, findErr
}
return records, err
enrichErr := triggerRecordEnrichHooks(app, requestInfoPtr, records, func() error {
if err := autoResolveRecordsFlags(app, records, requestInfoPtr); err != nil {
// non-critical error
app.Logger().Warn("Failed to apply autoResolveRecordsFlags for the expanded records", "error", err)
}
return nil
})
if enrichErr != nil {
return nil, enrichErr
}
return records, nil
}
}
// autoIgnoreAuthRecordsEmailVisibility ignores the email visibility check for
// the provided record if the current auth model is admin, owner or a "manager".
// autoResolveRecordsFlags resolves various visibility flags of the provided records.
//
// Note: Expects all records to be from the same auth collection!
func autoIgnoreAuthRecordsEmailVisibility(
dao *daos.Dao,
records []*models.Record,
requestInfo *models.RequestInfo,
) error {
if len(records) == 0 || !records[0].Collection().IsAuth() {
return nil // nothing to check
// Currently it enables:
// - export of hidden fields if the current auth model is a superuser
// - email export ignoring the emailVisibity checks if the current auth model is superuser, owner or a "manager".
//
// Note: Expects all records to be from the same collection!
func autoResolveRecordsFlags(app core.App, records []*core.Record, requestInfo *core.RequestInfo) error {
if len(records) == 0 {
return nil // nothing to resolve
}
if requestInfo.Admin != nil {
if requestInfo.HasSuperuserAuth() {
hiddenFields := records[0].Collection().Fields.FieldNames()
for _, rec := range records {
rec.Unhide(hiddenFields...)
rec.IgnoreEmailVisibility(true)
}
return nil
}
// additional emailVisibility checks
// ---------------------------------------------------------------
if !records[0].Collection().IsAuth() {
return nil // not auth collection records
}
collection := records[0].Collection()
mappedRecords := make(map[string]*models.Record, len(records))
mappedRecords := make(map[string]*core.Record, len(records))
recordIds := make([]any, len(records))
for i, rec := range records {
mappedRecords[rec.Id] = rec
recordIds[i] = rec.Id
}
if requestInfo != nil && requestInfo.AuthRecord != nil && mappedRecords[requestInfo.AuthRecord.Id] != nil {
mappedRecords[requestInfo.AuthRecord.Id].IgnoreEmailVisibility(true)
if requestInfo.Auth != nil && mappedRecords[requestInfo.Auth.Id] != nil {
mappedRecords[requestInfo.Auth.Id].IgnoreEmailVisibility(true)
}
authOptions := collection.AuthOptions()
if authOptions.ManageRule == nil || *authOptions.ManageRule == "" {
if collection.ManageRule == nil || *collection.ManageRule == "" {
return nil // no manage rule to check
}
@ -251,12 +443,12 @@ func autoIgnoreAuthRecordsEmailVisibility(
// ---
managedIds := []string{}
query := dao.RecordQuery(collection).
Select(dao.DB().QuoteSimpleColumnName(collection.Name) + ".id").
AndWhere(dbx.In(dao.DB().QuoteSimpleColumnName(collection.Name)+".id", recordIds...))
query := app.RecordQuery(collection).
Select(app.DB().QuoteSimpleColumnName(collection.Name) + ".id").
AndWhere(dbx.In(app.DB().QuoteSimpleColumnName(collection.Name)+".id", recordIds...))
resolver := resolvers.NewRecordFieldResolver(dao, collection, requestInfo, true)
expr, err := search.FilterData(*authOptions.ManageRule).BuildExpr(resolver)
resolver := core.NewRecordFieldResolver(app, collection, requestInfo, true)
expr, err := search.FilterData(*collection.ManageRule).BuildExpr(resolver)
if err != nil {
return err
}
@ -278,30 +470,26 @@ func autoIgnoreAuthRecordsEmailVisibility(
return nil
}
// hasAuthManageAccess checks whether the client is allowed to have full
// hasAuthManageAccess checks whether the client is allowed to have
// [forms.RecordUpsert] auth management permissions
// (aka. allowing to change system auth fields without oldPassword).
func hasAuthManageAccess(
dao *daos.Dao,
record *models.Record,
requestInfo *models.RequestInfo,
) bool {
// (e.g. allowing to change system auth fields without oldPassword).
func hasAuthManageAccess(app core.App, requestInfo *core.RequestInfo, record *core.Record) bool {
if !record.Collection().IsAuth() {
return false
}
manageRule := record.Collection().AuthOptions().ManageRule
manageRule := record.Collection().ManageRule
if manageRule == nil || *manageRule == "" {
return false // only for admins (manageRule can't be empty)
return false // only for superusers (manageRule can't be empty)
}
if requestInfo == nil || requestInfo.AuthRecord == nil {
if requestInfo == nil || requestInfo.Auth == nil {
return false // no auth record
}
ruleFunc := func(q *dbx.SelectQuery) error {
resolver := resolvers.NewRecordFieldResolver(dao, record.Collection(), requestInfo, true)
resolver := core.NewRecordFieldResolver(app, record.Collection(), requestInfo, true)
expr, err := search.FilterData(*manageRule).BuildExpr(resolver)
if err != nil {
return err
@ -311,35 +499,118 @@ func hasAuthManageAccess(
return nil
}
_, findErr := dao.FindRecordById(record.Collection().Id, record.Id, ruleFunc)
_, findErr := app.FindRecordById(record.Collection().Id, record.Id, ruleFunc)
return findErr == nil
}
var ruleQueryParams = []string{search.FilterQueryParam, search.SortQueryParam}
var adminOnlyRuleFields = []string{"@collection.", "@request."}
var superuserOnlyRuleFields = []string{"@collection.", "@request."}
// @todo consider moving the rules check to the RecordFieldResolver.
//
// checkForAdminOnlyRuleFields loosely checks and returns an error if
// the provided RequestInfo contains rule fields that only the admin can use.
func checkForAdminOnlyRuleFields(requestInfo *models.RequestInfo) error {
if requestInfo.Admin != nil || len(requestInfo.Query) == 0 {
return nil // admin or nothing to check
// checkForSuperuserOnlyRuleFields loosely checks and returns an error if
// the provided RequestInfo contains rule fields that only the superuser can use.
func checkForSuperuserOnlyRuleFields(requestInfo *core.RequestInfo) error {
if len(requestInfo.Query) == 0 || requestInfo.HasSuperuserAuth() {
return nil // superuser or nothing to check
}
for _, param := range ruleQueryParams {
v, _ := requestInfo.Query[param].(string)
v := requestInfo.Query[param]
if v == "" {
continue
}
for _, field := range adminOnlyRuleFields {
for _, field := range superuserOnlyRuleFields {
if strings.Contains(v, field) {
return NewForbiddenError("Only admins can filter by "+field, nil)
return router.NewForbiddenError("Only superusers can filter by "+field, nil)
}
}
}
return nil
}
// firstApiError returns the first ApiError from the errors list
// (this is used usually to prevent unnecessary wraping and to allow bubling ApiError from nested hooks)
//
// If no ApiError is found, returns a default "Internal server" error.
func firstApiError(errs ...error) *router.ApiError {
var apiErr *router.ApiError
var ok bool
for _, err := range errs {
if err == nil {
continue
}
// quick assert to avoid the reflection checks
apiErr, ok = err.(*router.ApiError)
if ok {
return apiErr
}
// nested/wrapped errors
if errors.As(err, &apiErr) {
return apiErr
}
}
return router.NewInternalServerError("", errors.Join(errs...))
}
// -------------------------------------------------------------------
const maxAuthOrigins = 5
func authAlert(e *core.RequestEvent, authRecord *core.Record) error {
// generating fingerprint
// ---
userAgent := e.Request.UserAgent()
if len(userAgent) > 300 {
userAgent = userAgent[:300]
}
fingerprint := security.MD5(e.RealIP() + userAgent)
// ---
origins, err := e.App.FindAllAuthOriginsByRecord(authRecord)
if err != nil {
return err
}
isFirstLogin := len(origins) == 0
var currentOrigin *core.AuthOrigin
for _, origin := range origins {
if origin.Fingerprint() == fingerprint {
currentOrigin = origin
break
}
}
if currentOrigin == nil {
currentOrigin = core.NewAuthOrigin(e.App)
currentOrigin.SetCollectionRef(authRecord.Collection().Id)
currentOrigin.SetRecordRef(authRecord.Id)
currentOrigin.SetFingerprint(fingerprint)
}
// send email alert for the new origin auth (skip first login)
if !isFirstLogin && currentOrigin.IsNew() && authRecord.Email() != "" {
if err := mails.SendRecordAuthAlert(e.App, authRecord); err != nil {
return err
}
}
// try to keep only up to maxAuthOrigins
// (pop the last used ones; it is not executed in a transaction to avoid unnecessary locks)
if currentOrigin.IsNew() && len(origins) >= maxAuthOrigins {
for i := len(origins) - 1; i >= maxAuthOrigins-1; i-- {
if err := e.App.Delete(origins[i]); err != nil {
// treat as non-critical error, just log for now
e.App.Logger().Warn("Failed to delete old AuthOrigin record", "error", err, "authOriginId", origins[i].Id)
}
}
}
// create/update the origin fingerprint
return e.App.Save(currentOrigin)
}

View File

@ -6,231 +6,742 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestRequestInfo(t *testing.T) {
t.Parallel()
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/?test=123", strings.NewReader(`{"test":456}`))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
req.Header.Set("X-Token-Test", "123")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
dummyRecord := &models.Record{}
dummyRecord.Id = "id1"
c.Set(apis.ContextAuthRecordKey, dummyRecord)
dummyAdmin := &models.Admin{}
dummyAdmin.Id = "id2"
c.Set(apis.ContextAdminKey, dummyAdmin)
result := apis.RequestInfo(c)
if result == nil {
t.Fatal("Expected *models.RequestInfo instance, got nil")
}
if result.Method != http.MethodPost {
t.Fatalf("Expected Method %v, got %v", http.MethodPost, result.Method)
}
rawHeaders, _ := json.Marshal(result.Headers)
expectedHeaders := `{"content_type":"application/json","x_token_test":"123"}`
if v := string(rawHeaders); v != expectedHeaders {
t.Fatalf("Expected Query %v, got %v", expectedHeaders, v)
}
rawQuery, _ := json.Marshal(result.Query)
expectedQuery := `{"test":"123"}`
if v := string(rawQuery); v != expectedQuery {
t.Fatalf("Expected Query %v, got %v", expectedQuery, v)
}
rawData, _ := json.Marshal(result.Data)
expectedData := `{"test":456}`
if v := string(rawData); v != expectedData {
t.Fatalf("Expected Data %v, got %v", expectedData, v)
}
if result.AuthRecord == nil || result.AuthRecord.Id != dummyRecord.Id {
t.Fatalf("Expected AuthRecord %v, got %v", dummyRecord, result.AuthRecord)
}
if result.Admin == nil || result.Admin.Id != dummyAdmin.Id {
t.Fatalf("Expected Admin %v, got %v", dummyAdmin, result.Admin)
}
}
func TestRecordAuthResponse(t *testing.T) {
func TestEnrichRecords(t *testing.T) {
t.Parallel()
// mock test data
// ---
app, _ := tests.NewTestApp()
defer app.Cleanup()
dummyAdmin := &models.Admin{}
dummyAdmin.Id = "id1"
nonAuthRecord, err := app.Dao().FindRecordById("demo1", "al1h9ijdeojtsjy")
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
authRecord, err := app.Dao().FindRecordById("users", "4q1xlclmfloku33")
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test@example.com")
if err != nil {
t.Fatal(err)
}
unverifiedAuthRecord, err := app.Dao().FindRecordById("clients", "o1y0dd0spd786md")
usersRecords, err := app.FindRecordsByIds("users", []string{"4q1xlclmfloku33", "bgs820n361vj1qd"})
if err != nil {
t.Fatal(err)
}
nologinRecords, err := app.FindRecordsByIds("nologin", []string{"dc49k6jgejn40h3", "oos036e9xvqeexy"})
if err != nil {
t.Fatal(err)
}
demo1Records, err := app.FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"})
if err != nil {
t.Fatal(err)
}
demo5Records, err := app.FindRecordsByIds("demo5", []string{"la4y2w4o98acwuj", "qjeql998mtp1azp"})
if err != nil {
t.Fatal(err)
}
// temp update the view rule to ensure that request context is set to "expand"
demo4, err := app.FindCollectionByNameOrId("demo4")
if err != nil {
t.Fatal(err)
}
demo4.ViewRule = types.Pointer("@request.context = 'expand'")
if err := app.Save(demo4); err != nil {
t.Fatal(err)
}
// ---
scenarios := []struct {
name string
auth *core.Record
records []*core.Record
queryExpand string
defaultExpands []string
expected []string
notExpected []string
}{
// email visibility checks
{
name: "[emailVisibility] guest",
auth: nil,
records: usersRecords,
queryExpand: "",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"test3@example.com"`, // emailVisibility=true
},
notExpected: []string{
`"test@example.com"`,
},
},
{
name: "[emailVisibility] owner",
auth: user,
records: usersRecords,
queryExpand: "",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"test3@example.com"`, // emailVisibility=true
`"test@example.com"`, // owner
},
},
{
name: "[emailVisibility] manager",
auth: user,
records: nologinRecords,
queryExpand: "",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"test3@example.com"`,
`"test@example.com"`,
},
},
{
name: "[emailVisibility] superuser",
auth: superuser,
records: nologinRecords,
queryExpand: "",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"test3@example.com"`,
`"test@example.com"`,
},
},
{
name: "[emailVisibility + expand] recursive auth rule checks (regular user)",
auth: user,
records: demo1Records,
queryExpand: "",
defaultExpands: []string{"rel_many"},
expected: []string{
`"customField":"123"`,
`"expand":{"rel_many"`,
`"expand":{}`,
`"test@example.com"`,
},
notExpected: []string{
`"id":"bgs820n361vj1qd"`,
`"id":"oap640cot4yru2s"`,
},
},
{
name: "[emailVisibility + expand] recursive auth rule checks (superuser)",
auth: superuser,
records: demo1Records,
queryExpand: "",
defaultExpands: []string{"rel_many"},
expected: []string{
`"customField":"123"`,
`"test@example.com"`,
`"expand":{"rel_many"`,
`"id":"bgs820n361vj1qd"`,
`"id":"oap640cot4yru2s"`,
},
notExpected: []string{
`"expand":{}`,
},
},
// expand checks
{
name: "[expand] guest (query)",
auth: nil,
records: usersRecords,
queryExpand: "rel",
defaultExpands: nil,
expected: []string{
`"customField":"123"`,
`"expand":{"rel"`,
`"id":"llvuca81nly1qls"`,
`"id":"0yxhwia2amd8gec"`,
},
notExpected: []string{
`"expand":{}`,
},
},
{
name: "[expand] guest (default expands)",
auth: nil,
records: usersRecords,
queryExpand: "",
defaultExpands: []string{"rel"},
expected: []string{
`"customField":"123"`,
`"expand":{"rel"`,
`"id":"llvuca81nly1qls"`,
`"id":"0yxhwia2amd8gec"`,
},
},
{
name: "[expand] @request.context=expand check",
auth: nil,
records: demo5Records,
queryExpand: "rel_one",
defaultExpands: []string{"rel_many"},
expected: []string{
`"customField":"123"`,
`"expand":{}`,
`"expand":{"`,
`"rel_many":[{`,
`"rel_one":{`,
`"id":"i9naidtvr6qsgb4"`,
`"id":"qzaqccwrmva4o1n"`,
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
app.OnRecordEnrich().BindFunc(func(e *core.RecordEnrichEvent) error {
e.Record.WithCustomData(true)
e.Record.Set("customField", "123")
return e.Next()
})
req := httptest.NewRequest(http.MethodGet, "/?expand="+s.queryExpand, nil)
rec := httptest.NewRecorder()
requestEvent := new(core.RequestEvent)
requestEvent.App = app
requestEvent.Request = req
requestEvent.Response = rec
requestEvent.Auth = s.auth
err := apis.EnrichRecords(requestEvent, s.records, s.defaultExpands...)
if err != nil {
t.Fatal(err)
}
raw, err := json.Marshal(s.records)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
for _, str := range s.expected {
if !strings.Contains(rawStr, str) {
t.Fatalf("Expected\n%q\nin\n%v", str, rawStr)
}
}
for _, str := range s.notExpected {
if strings.Contains(rawStr, str) {
t.Fatalf("Didn't expected\n%q\nin\n%v", str, rawStr)
}
}
})
}
}
func TestRecordAuthResponseAuthRuleCheck(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
event := new(core.RequestEvent)
event.App = app
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
event.Response = httptest.NewRecorder()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
record *models.Record
meta any
rule *string
expectError bool
expectedContent []string
notExpectedContent []string
expectedEvents map[string]int
}{
{
name: "non auth record",
record: nonAuthRecord,
expectError: true,
"admin only rule",
nil,
true,
},
{
name: "valid auth record but with unverified email in onlyVerified collection",
record: unverifiedAuthRecord,
expectError: true,
"empty rule",
types.Pointer(""),
false,
},
{
name: "valid auth record - without meta",
record: authRecord,
expectError: false,
expectedContent: []string{
`"token":"`,
`"record":{`,
`"id":"`,
`"expand":{"rel":{`,
},
notExpectedContent: []string{
`"meta":`,
},
expectedEvents: map[string]int{
"OnRecordAuthRequest": 1,
},
"false rule",
types.Pointer("1=2"),
true,
},
{
name: "valid auth record - with meta",
record: authRecord,
meta: map[string]any{"meta_test": 123},
expectError: false,
expectedContent: []string{
`"token":"`,
`"record":{`,
`"id":"`,
`"expand":{"rel":{`,
`"meta":{"meta_test":123`,
},
expectedEvents: map[string]int{
"OnRecordAuthRequest": 1,
},
"true rule",
types.Pointer("1=1"),
false,
},
}
for _, s := range scenarios {
app.ResetEventCalls()
t.Run(s.name, func(t *testing.T) {
user.Collection().AuthRule = s.rule
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/?expand=rel", nil)
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(apis.ContextAdminKey, dummyAdmin)
err := apis.RecordAuthResponse(event, user, "", nil)
responseErr := apis.RecordAuthResponse(app, c, s.record, s.meta)
hasErr := responseErr != nil
if hasErr != s.expectError {
t.Fatalf("[%s] Expected hasErr to be %v, got %v (%v)", s.name, s.expectError, hasErr, responseErr)
hasErr := err != nil
if s.expectError != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if len(app.EventCalls) != len(s.expectedEvents) {
t.Fatalf("[%s] Expected events \n%v, \ngot \n%v", s.name, s.expectedEvents, app.EventCalls)
}
for k, v := range s.expectedEvents {
if app.EventCalls[k] != v {
t.Fatalf("[%s] Expected event %s to be called %d times, got %d", s.name, k, v, app.EventCalls[k])
}
// in all cases login alert shouldn't be send because of the empty auth method
if app.TestMailer.TotalSend() != 0 {
t.Fatalf("Expected no emails send, got %d:\n%v", app.TestMailer.TotalSend(), app.TestMailer.LastMessage().HTML)
}
if hasErr {
continue
if !hasErr {
return
}
response := rec.Body.String()
apiErr, ok := err.(*router.ApiError)
for _, v := range s.expectedContent {
if !strings.Contains(response, v) {
t.Fatalf("[%s] Missing %v in response \n%v", s.name, v, response)
}
if !ok || apiErr == nil {
t.Fatalf("Expected ApiError, got %v", apiErr)
}
for _, v := range s.notExpectedContent {
if strings.Contains(response, v) {
t.Fatalf("[%s] Unexpected %v in response \n%v", s.name, v, response)
}
if apiErr.Status != http.StatusForbidden {
t.Fatalf("Expected ApiError.Status %d, got %d", http.StatusForbidden, apiErr.Status)
}
})
}
}
func TestEnrichRecords(t *testing.T) {
t.Parallel()
func TestRecordAuthResponseAuthAlertCheck(t *testing.T) {
const testFingerprint = "d0f88d6c87767262ba8e93d6acccd784"
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/?expand=rel_many", nil)
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
dummyAdmin := &models.Admin{}
dummyAdmin.Id = "test_id"
c.Set(apis.ContextAdminKey, dummyAdmin)
scenarios := []struct {
name string
devices []string // mock existing device fingerprints
expectDevices []string
enabled bool
expectEmail bool
}{
{
name: "first login",
devices: nil,
expectDevices: []string{testFingerprint},
enabled: true,
expectEmail: false,
},
{
name: "existing device",
devices: []string{"1", testFingerprint},
expectDevices: []string{"1", testFingerprint},
enabled: true,
expectEmail: false,
},
{
name: "new device (< 5)",
devices: []string{"1", "2"},
expectDevices: []string{"1", "2", testFingerprint},
enabled: true,
expectEmail: true,
},
{
name: "new device (>= 5)",
devices: []string{"1", "2", "3", "4", "5"},
expectDevices: []string{"2", "3", "4", "5", testFingerprint},
enabled: true,
expectEmail: true,
},
{
name: "with disabled auth alert collection flag",
devices: []string{"1", "2"},
expectDevices: []string{"1", "2"},
enabled: false,
expectEmail: false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
records, err := app.Dao().FindRecordsByIds("demo1", []string{"al1h9ijdeojtsjy", "84nmscqy84lsi1t"})
event := new(core.RequestEvent)
event.App = app
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
event.Response = httptest.NewRecorder()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
apis.EnrichRecords(c, app.Dao(), records, "rel_one")
user.Collection().MFA.Enabled = false
user.Collection().AuthRule = types.Pointer("")
user.Collection().AuthAlert.Enabled = s.enabled
for _, record := range records {
expand := record.Expand()
if len(expand) == 0 {
t.Fatalf("Expected non-empty expand, got nil for record %v", record)
// ensure that there are no other auth origins
err = app.DeleteAllAuthOriginsByRecord(user)
if err != nil {
t.Fatal(err)
}
if len(record.GetStringSlice("rel_one")) != 0 {
if _, ok := expand["rel_one"]; !ok {
t.Fatalf("Expected rel_one to be expanded for record %v, got \n%v", record, expand)
// insert the mock devices
for _, fingerprint := range s.devices {
d := core.NewAuthOrigin(app)
d.SetCollectionRef(user.Collection().Id)
d.SetRecordRef(user.Id)
d.SetFingerprint(fingerprint)
if err = app.Save(d); err != nil {
t.Fatal(err)
}
}
if len(record.GetStringSlice("rel_many")) != 0 {
if _, ok := expand["rel_many"]; !ok {
t.Fatalf("Expected rel_many to be expanded for record %v, got \n%v", record, expand)
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Failed to resolve auth response: %v", err)
}
var expectTotalSend int
if s.expectEmail {
expectTotalSend = 1
}
if total := app.TestMailer.TotalSend(); total != expectTotalSend {
t.Fatalf("Expected %d sent emails, got %d", expectTotalSend, total)
}
devices, err := app.FindAllAuthOriginsByRecord(user)
if err != nil {
t.Fatalf("Failed to retrieve auth origins: %v", err)
}
if len(devices) != len(s.expectDevices) {
t.Fatalf("Expected %d devices, got %d", len(s.expectDevices), len(devices))
}
for _, fingerprint := range s.expectDevices {
var exists bool
fingerprints := make([]string, 0, len(devices))
for _, d := range devices {
if d.Fingerprint() == fingerprint {
exists = true
break
}
fingerprints = append(fingerprints, d.Fingerprint())
}
if !exists {
t.Fatalf("Missing device with fingerprint %q:\n%v", fingerprint, fingerprints)
}
}
})
}
}
func TestRecordAuthResponseMFACheck(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
user2, err := app.FindAuthRecordByEmail("users", "test2@example.com")
if err != nil {
t.Fatal(err)
}
rec := httptest.NewRecorder()
event := new(core.RequestEvent)
event.App = app
event.Request = httptest.NewRequest(http.MethodGet, "/", nil)
event.Response = rec
resetMFAs := func(authRecord *core.Record) {
// ensure that mfa is enabled
user.Collection().MFA.Enabled = true
user.Collection().MFA.Duration = 5
user.Collection().MFA.Rule = ""
mfas, err := app.FindAllMFAsByRecord(authRecord)
if err != nil {
t.Fatalf("Failed to retrieve mfas: %v", err)
}
for _, mfa := range mfas {
if err := app.Delete(mfa); err != nil {
t.Fatalf("Failed to delete mfa %q: %v", mfa.Id, err)
}
}
// reset response
rec = httptest.NewRecorder()
event.Response = rec
}
totalMFAs := func(authRecord *core.Record) int {
mfas, err := app.FindAllMFAsByRecord(authRecord)
if err != nil {
t.Fatalf("Failed to retrieve mfas: %v", err)
}
return len(mfas)
}
t.Run("no collection MFA enabled", func(t *testing.T) {
resetMFAs(user)
user.Collection().MFA.Enabled = false
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if strings.Contains(body, "mfaId") {
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
}
if !strings.Contains(body, "token") {
t.Fatalf("Expected auth token in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", total)
}
})
t.Run("no explicit auth method", func(t *testing.T) {
resetMFAs(user)
err = apis.RecordAuthResponse(event, user, "", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if strings.Contains(body, "mfaId") {
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
}
if !strings.Contains(body, "token") {
t.Fatalf("Expected auth token in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", total)
}
})
t.Run("no mfa wanted (mfa rule check failure)", func(t *testing.T) {
resetMFAs(user)
user.Collection().MFA.Rule = "1=2"
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if strings.Contains(body, "mfaId") {
t.Fatalf("Expected no mfaId in the response body, got\n%v", body)
}
if !strings.Contains(body, "token") {
t.Fatalf("Expected auth token in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no mfa records to be created, got %d", total)
}
})
t.Run("mfa wanted (mfa rule check success)", func(t *testing.T) {
resetMFAs(user)
user.Collection().MFA.Rule = "1=1"
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if !strings.Contains(body, "mfaId") {
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 1 {
t.Fatalf("Expected a single mfa record to be created, got %d", total)
}
})
t.Run("mfa first-time", func(t *testing.T) {
resetMFAs(user)
err = apis.RecordAuthResponse(event, user, "example", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
body := rec.Body.String()
if !strings.Contains(body, "mfaId") {
t.Fatalf("Expected the created mfaId to be returned in the response body, got\n%v", body)
}
if total := totalMFAs(user); total != 1 {
t.Fatalf("Expected a single mfa record to be created, got %d", total)
}
})
t.Run("mfa second-time with the same auth method", func(t *testing.T) {
resetMFAs(user)
// create a dummy mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if total := totalMFAs(user); total != 1 {
t.Fatalf("Expected only 1 mfa record (the existing one), got %d", total)
}
})
t.Run("mfa second-time with the different auth method (query param)", func(t *testing.T) {
resetMFAs(user)
// create a dummy mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example1")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
}
})
t.Run("mfa second-time with the different auth method (body param)", func(t *testing.T) {
resetMFAs(user)
// create a dummy mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example1")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"mfaId":"`+mfa.Id+`"}`))
event.Request.Header.Add("content-type", "application/json")
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err != nil {
t.Fatalf("Expected nil, got error: %v", err)
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected the dummy mfa record to be deleted, found %d", total)
}
})
t.Run("missing mfa", func(t *testing.T) {
resetMFAs(user)
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId=missing", nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected 0 mfa records, got %d", total)
}
})
t.Run("expired mfa", func(t *testing.T) {
resetMFAs(user)
// create a dummy expired mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user.Collection().Id)
mfa.SetRecordRef(user.Id)
mfa.SetMethod("example1")
mfa.SetRaw("created", types.NowDateTime().Add(-1*time.Hour))
mfa.SetRaw("updated", types.NowDateTime().Add(-1*time.Hour))
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if totalMFAs(user) != 0 {
t.Fatal("Expected the expired mfa record to be deleted")
}
})
t.Run("mfa for different auth record", func(t *testing.T) {
resetMFAs(user)
// create a dummy expired mfa record
mfa := core.NewMFA(app)
mfa.SetCollectionRef(user2.Collection().Id)
mfa.SetRecordRef(user2.Id)
mfa.SetMethod("example1")
if err = app.Save(mfa); err != nil {
t.Fatal(err)
}
event.Request = httptest.NewRequest(http.MethodGet, "/?mfaId="+mfa.Id, nil)
err = apis.RecordAuthResponse(event, user, "example2", nil)
if err == nil {
t.Fatal("Expected error, got nil")
}
if total := totalMFAs(user); total != 0 {
t.Fatalf("Expected no user mfas, got %d", total)
}
if total := totalMFAs(user2); total != 1 {
t.Fatalf("Expected only 1 user2 mfa, got %d", total)
}
})
}

View File

@ -3,6 +3,7 @@ package apis
import (
"context"
"crypto/tls"
"errors"
"log"
"net"
"net/http"
@ -12,14 +13,10 @@ import (
"time"
"github.com/fatih/color"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/migrations"
"github.com/pocketbase/pocketbase/migrations/logs"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/migrate"
"github.com/pocketbase/pocketbase/ui"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
)
@ -29,10 +26,16 @@ type ServeConfig struct {
// ShowStartBanner indicates whether to show or hide the server start console message.
ShowStartBanner bool
// HttpAddr is the TCP address to listen for the HTTP server (eg. `127.0.0.1:80`).
// DashboardPath specifies the route path to the superusers dashboard interface
// (default to "/_/{path...}").
//
// Note: Must include the "{path...}" wildcard parameter.
DashboardPath string
// HttpAddr is the TCP address to listen for the HTTP server (eg. "127.0.0.1:80").
HttpAddr string
// HttpsAddr is the TCP address to listen for the HTTPS server (eg. `127.0.0.1:443`).
// HttpsAddr is the TCP address to listen for the HTTPS server (eg. "127.0.0.1:443").
HttpsAddr string
// Optional domains list to use when issuing the TLS certificate.
@ -58,36 +61,43 @@ type ServeConfig struct {
// HttpAddr: "127.0.0.1:8080",
// ShowStartBanner: false,
// })
func Serve(app core.App, config ServeConfig) (*http.Server, error) {
func Serve(app core.App, config ServeConfig) error {
if len(config.AllowedOrigins) == 0 {
config.AllowedOrigins = []string{"*"}
}
if config.DashboardPath == "" {
config.DashboardPath = "/_/{path...}"
} else if !strings.HasSuffix(config.DashboardPath, "{path...}") {
return errors.New("invalid dashboard path - missing {path...} wildcard")
}
// ensure that the latest migrations are applied before starting the server
if err := runMigrations(app); err != nil {
return nil, err
}
// reload app settings in case a new default value was set with a migration
// (or if this is the first time the init migration was executed)
if err := app.RefreshSettings(); err != nil {
color.Yellow("=====================================")
color.Yellow("WARNING: Settings load error! \n%v", err)
color.Yellow("Fallback to the application defaults.")
color.Yellow("=====================================")
}
router, err := InitApi(app)
err := app.RunAllMigrations()
if err != nil {
return nil, err
return err
}
// configure cors
router.Use(middleware.CORSWithConfig(middleware.CORSConfig{
Skipper: middleware.DefaultSkipper,
pbRouter, err := NewRouter(app)
if err != nil {
return err
}
pbRouter.Bind(&hook.Handler[*core.RequestEvent]{
Id: DefaultCorsMiddlewareId,
Func: CORSWithConfig(CORSConfig{
AllowOrigins: config.AllowedOrigins,
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}))
}),
Priority: DefaultCorsMiddlewarePriority,
})
pbRouter.BindFunc(installerRedirect(app, config.DashboardPath))
pbRouter.GET(config.DashboardPath, Static(ui.DistDirFS, false)).
BindFunc(dashboardRemoveInstallerParam()).
BindFunc(dashboardCacheControl()).
BindFunc(Gzip())
// start http server
// ---
@ -118,25 +128,12 @@ func Serve(app core.App, config ServeConfig) (*http.Server, error) {
// implicit www->non-www redirect(s)
if len(wwwRedirects) > 0 {
router.Pre(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
host := c.Request().Host
if strings.HasPrefix(host, "www.") && list.ExistInSlice(host, wwwRedirects) {
return c.Redirect(
http.StatusTemporaryRedirect,
(c.Scheme() + "://" + host[4:] + c.Request().RequestURI),
)
}
return next(c)
}
})
pbRouter.Bind(wwwRedirect(wwwRedirects))
}
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(filepath.Join(app.DataDir(), ".autocert_cache")),
Cache: autocert.DirCache(filepath.Join(app.DataDir(), core.LocalAutocertCacheDirName)),
HostPolicy: autocert.HostWhitelist(hostNames...),
}
@ -151,24 +148,96 @@ func Serve(app core.App, config ServeConfig) (*http.Server, error) {
GetCertificate: certManager.GetCertificate,
NextProtos: []string{acme.ALPNProto},
},
ReadTimeout: 10 * time.Minute,
// higher defaults to accommodate large file uploads/downloads
WriteTimeout: 3 * time.Minute,
ReadTimeout: 3 * time.Minute,
ReadHeaderTimeout: 30 * time.Second,
// WriteTimeout: 60 * time.Second, // breaks sse!
Handler: router,
Addr: mainAddr,
BaseContext: func(l net.Listener) context.Context {
return baseCtx
},
ErrorLog: log.New(&serverErrorLogWriter{app: app}, "", 0),
}
serveEvent := &core.ServeEvent{
App: app,
Router: router,
Server: server,
CertManager: certManager,
serveEvent := new(core.ServeEvent)
serveEvent.App = app
serveEvent.Router = pbRouter
serveEvent.Server = server
serveEvent.CertManager = certManager
var listener net.Listener
// graceful shutdown
// ---------------------------------------------------------------
// WaitGroup to block until server.ShutDown() returns because Serve and similar methods exit immediately.
// Note that the WaitGroup would do nothing if the app.OnTerminate() hook isn't triggered.
var wg sync.WaitGroup
// try to gracefully shutdown the server on app termination
app.OnTerminate().Bind(&hook.Handler[*core.TerminateEvent]{
Id: "pbGracefulShutdown",
Func: func(te *core.TerminateEvent) error {
cancelBaseCtx()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
wg.Add(1)
_ = server.Shutdown(ctx)
if te.IsRestart {
// wait for execve and other handlers up to 3 seconds before exit
time.AfterFunc(3*time.Second, func() {
wg.Done()
})
} else {
wg.Done()
}
if err := app.OnBeforeServe().Trigger(serveEvent); err != nil {
return nil, err
return te.Next()
},
Priority: -9999,
})
// wait for the graceful shutdown to complete before exit
defer func() {
wg.Wait()
if listener != nil {
_ = listener.Close()
}
}()
// ---------------------------------------------------------------
// trigger the OnServe hook and start the tcp listener
serveHookErr := app.OnServe().Trigger(serveEvent, func(e *core.ServeEvent) error {
handler, err := e.Router.BuildMux()
if err != nil {
return err
}
e.Server.Handler = handler
addr := e.Server.Addr
// fallback similar to the std Server.ListenAndServe/ListenAndServeTLS
if addr == "" {
if config.HttpsAddr != "" {
addr = ":https"
} else {
addr = ":http"
}
}
var lnErr error
listener, lnErr = net.Listen("tcp", addr)
return lnErr
})
if serveHookErr != nil {
return serveHookErr
}
if config.ShowStartBanner {
@ -198,80 +267,32 @@ func Serve(app core.App, config ServeConfig) (*http.Server, error) {
regular.Printf("└─ Admin UI: %s\n", color.CyanString("%s://%s/_/", schema, addr))
}
// WaitGroup to block until server.ShutDown() returns because Serve and similar methods exit immediately.
// Note that the WaitGroup would not do anything if the app.OnTerminate() hook isn't triggered.
var wg sync.WaitGroup
// try to gracefully shutdown the server on app termination
app.OnTerminate().Add(func(e *core.TerminateEvent) error {
cancelBaseCtx()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
wg.Add(1)
server.Shutdown(ctx)
if e.IsRestart {
// wait for execve and other handlers up to 5 seconds before exit
time.AfterFunc(5*time.Second, func() {
wg.Done()
})
} else {
wg.Done()
}
return nil
})
// wait for the graceful shutdown to complete before exit
defer wg.Wait()
// ---
// @todo consider removing the server return value because it is
// not really useful when combined with the blocking serve calls
// ---
// start HTTPS server
var serveErr error
if config.HttpsAddr != "" {
// if httpAddr is set, start an HTTP server to redirect the traffic to the HTTPS version
if config.HttpAddr != "" {
// start an additional HTTP server for redirecting the traffic to the HTTPS version
go http.ListenAndServe(config.HttpAddr, certManager.HTTPHandler(nil))
}
return server, server.ListenAndServeTLS("", "")
}
// start HTTPS server
serveErr = server.ServeTLS(listener, "", "")
} else {
// OR start HTTP server
return server, server.ListenAndServe()
}
type migrationsConnection struct {
DB *dbx.DB
MigrationsList migrate.MigrationsList
}
func runMigrations(app core.App) error {
connections := []migrationsConnection{
{
DB: app.DB(),
MigrationsList: migrations.AppMigrations,
},
{
DB: app.LogsDB(),
MigrationsList: logs.LogsMigrations,
},
}
for _, c := range connections {
runner, err := migrate.NewRunner(c.DB, c.MigrationsList)
if err != nil {
return err
}
if _, err := runner.Up(); err != nil {
return err
serveErr = server.Serve(listener)
}
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
return serveErr
}
return nil
}
type serverErrorLogWriter struct {
app core.App
}
func (s *serverErrorLogWriter) Write(p []byte) (int, error) {
s.app.Logger().Debug(strings.TrimSpace(string(p)))
return len(p), nil
}

View File

@ -4,136 +4,121 @@ import (
"net/http"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models/settings"
"github.com/pocketbase/pocketbase/tools/router"
)
// bindSettingsApi registers the settings api endpoints.
func bindSettingsApi(app core.App, rg *echo.Group) {
api := settingsApi{app: app}
subGroup := rg.Group("/settings", ActivityLogger(app), RequireAdminAuth())
subGroup.GET("", api.list)
subGroup.PATCH("", api.set)
subGroup.POST("/test/s3", api.testS3)
subGroup.POST("/test/email", api.testEmail)
subGroup.POST("/apple/generate-client-secret", api.generateAppleClientSecret)
func bindSettingsApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
subGroup := rg.Group("/settings").Bind(RequireSuperuserAuth())
subGroup.GET("", settingsList)
subGroup.PATCH("", settingsSet)
subGroup.POST("/test/s3", settingsTestS3)
subGroup.POST("/test/email", settingsTestEmail)
subGroup.POST("/apple/generate-client-secret", settingsGenerateAppleClientSecret)
}
type settingsApi struct {
app core.App
}
func (api *settingsApi) list(c echo.Context) error {
settings, err := api.app.Settings().RedactClone()
func settingsList(e *core.RequestEvent) error {
clone, err := e.App.Settings().Clone()
if err != nil {
return NewBadRequestError("", err)
return e.InternalServerError("", err)
}
event := new(core.SettingsListEvent)
event.HttpContext = c
event.RedactedSettings = settings
event := new(core.SettingsListRequestEvent)
event.RequestEvent = e
event.Settings = clone
return api.app.OnSettingsListRequest().Trigger(event, func(e *core.SettingsListEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
return e.HttpContext.JSON(http.StatusOK, e.RedactedSettings)
return e.App.OnSettingsListRequest().Trigger(event, func(e *core.SettingsListRequestEvent) error {
return e.JSON(http.StatusOK, e.Settings)
})
}
func (api *settingsApi) set(c echo.Context) error {
form := forms.NewSettingsUpsert(api.app)
func settingsSet(e *core.RequestEvent) error {
event := new(core.SettingsUpdateRequestEvent)
event.RequestEvent = e
if clone, err := e.App.Settings().Clone(); err == nil {
event.OldSettings = clone
} else {
return e.BadRequestError("", err)
}
if clone, err := e.App.Settings().Clone(); err == nil {
event.NewSettings = clone
} else {
return e.BadRequestError("", err)
}
if err := e.BindBody(&event.NewSettings); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
return e.App.OnSettingsUpdateRequest().Trigger(event, func(e *core.SettingsUpdateRequestEvent) error {
err := e.App.Save(e.NewSettings)
if err != nil {
return e.BadRequestError("An error occurred while saving the new settings.", err)
}
appSettings, err := e.App.Settings().Clone()
if err != nil {
return e.InternalServerError("Failed to clone app settings.", err)
}
return e.JSON(http.StatusOK, appSettings)
})
}
func settingsTestS3(e *core.RequestEvent) error {
form := forms.NewTestS3Filesystem(e.App)
// load request
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
}
event := new(core.SettingsUpdateEvent)
event.HttpContext = c
event.OldSettings = api.app.Settings()
// update the settings
return form.Submit(func(next forms.InterceptorNextFunc[*settings.Settings]) forms.InterceptorNextFunc[*settings.Settings] {
return func(s *settings.Settings) error {
event.NewSettings = s
return api.app.OnSettingsBeforeUpdateRequest().Trigger(event, func(e *core.SettingsUpdateEvent) error {
if err := next(e.NewSettings); err != nil {
return NewBadRequestError("An error occurred while submitting the form.", err)
}
return api.app.OnSettingsAfterUpdateRequest().Trigger(event, func(e *core.SettingsUpdateEvent) error {
if e.HttpContext.Response().Committed {
return nil
}
redactedSettings, err := api.app.Settings().RedactClone()
if err != nil {
return NewBadRequestError("", err)
}
return e.HttpContext.JSON(http.StatusOK, redactedSettings)
})
})
}
})
}
func (api *settingsApi) testS3(c echo.Context) error {
form := forms.NewTestS3Filesystem(api.app)
// load request
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
if err := e.BindBody(form); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
// send
if err := form.Submit(); err != nil {
// form error
if fErr, ok := err.(validation.Errors); ok {
return NewBadRequestError("Failed to test the S3 filesystem.", fErr)
return e.BadRequestError("Failed to test the S3 filesystem.", fErr)
}
// mailer error
return NewBadRequestError("Failed to test the S3 filesystem. Raw error: \n"+err.Error(), nil)
return e.BadRequestError("Failed to test the S3 filesystem. Raw error: \n"+err.Error(), nil)
}
return c.NoContent(http.StatusNoContent)
return e.NoContent(http.StatusNoContent)
}
func (api *settingsApi) testEmail(c echo.Context) error {
form := forms.NewTestEmailSend(api.app)
func settingsTestEmail(e *core.RequestEvent) error {
form := forms.NewTestEmailSend(e.App)
// load request
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
if err := e.BindBody(form); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
// send
if err := form.Submit(); err != nil {
// form error
if fErr, ok := err.(validation.Errors); ok {
return NewBadRequestError("Failed to send the test email.", fErr)
return e.BadRequestError("Failed to send the test email.", fErr)
}
// mailer error
return NewBadRequestError("Failed to send the test email. Raw error: \n"+err.Error(), nil)
return e.BadRequestError("Failed to send the test email. Raw error: \n"+err.Error(), nil)
}
return c.NoContent(http.StatusNoContent)
return e.NoContent(http.StatusNoContent)
}
func (api *settingsApi) generateAppleClientSecret(c echo.Context) error {
form := forms.NewAppleClientSecretCreate(api.app)
func settingsGenerateAppleClientSecret(e *core.RequestEvent) error {
form := forms.NewAppleClientSecretCreate(e.App)
// load request
if err := c.Bind(form); err != nil {
return NewBadRequestError("An error occurred while loading the submitted data.", err)
if err := e.BindBody(form); err != nil {
return e.BadRequestError("An error occurred while loading the submitted data.", err)
}
// generate
@ -141,14 +126,14 @@ func (api *settingsApi) generateAppleClientSecret(c echo.Context) error {
if err != nil {
// form error
if fErr, ok := err.(validation.Errors); ok {
return NewBadRequestError("Invalid client secret data.", fErr)
return e.BadRequestError("Invalid client secret data.", fErr)
}
// secret generation error
return NewBadRequestError("Failed to generate client secret. Raw error: \n"+err.Error(), nil)
return e.BadRequestError("Failed to generate client secret. Raw error: \n"+err.Error(), nil)
}
return c.JSON(http.StatusOK, map[string]any{
return e.JSON(http.StatusOK, map[string]string{
"secret": secret,
})
}

View File

@ -6,14 +6,11 @@ import (
"crypto/rand"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net/http"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
@ -24,26 +21,28 @@ func TestSettingsList(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodGet,
Url: "/api/settings",
URL: "/api/settings",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodGet,
Url: "/api/settings",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
URL: "/api/settings",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin",
Name: "authorized as superuser",
Method: http.MethodGet,
Url: "/api/settings",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/settings",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
@ -52,44 +51,10 @@ func TestSettingsList(t *testing.T) {
`"smtp":{`,
`"s3":{`,
`"backups":{`,
`"adminAuthToken":{`,
`"adminPasswordResetToken":{`,
`"adminFileToken":{`,
`"recordAuthToken":{`,
`"recordPasswordResetToken":{`,
`"recordEmailChangeToken":{`,
`"recordVerificationToken":{`,
`"recordFileToken":{`,
`"emailAuth":{`,
`"googleAuth":{`,
`"facebookAuth":{`,
`"githubAuth":{`,
`"gitlabAuth":{`,
`"twitterAuth":{`,
`"discordAuth":{`,
`"microsoftAuth":{`,
`"spotifyAuth":{`,
`"kakaoAuth":{`,
`"twitchAuth":{`,
`"stravaAuth":{`,
`"giteeAuth":{`,
`"livechatAuth":{`,
`"giteaAuth":{`,
`"oidcAuth":{`,
`"oidc2Auth":{`,
`"oidc3Auth":{`,
`"appleAuth":{`,
`"instagramAuth":{`,
`"vkAuth":{`,
`"yandexAuth":{`,
`"patreonAuth":{`,
`"mailcowAuth":{`,
`"bitbucketAuth":{`,
`"planningcenterAuth":{`,
`"secret":"******"`,
`"clientSecret":"******"`,
`"batch":{`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnSettingsListRequest": 1,
},
},
@ -103,35 +68,41 @@ func TestSettingsList(t *testing.T) {
func TestSettingsSet(t *testing.T) {
t.Parallel()
validData := `{"meta":{"appName":"update_test"}}`
validData := `{
"meta":{"appName":"update_test"},
"s3":{"secret": "s3_secret"},
"backups":{"s3":{"secret":"backups_s3_secret"}}
}`
scenarios := []tests.ApiScenario{
{
Name: "unauthorized",
Method: http.MethodPatch,
Url: "/api/settings",
URL: "/api/settings",
Body: strings.NewReader(validData),
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodPatch,
Url: "/api/settings",
URL: "/api/settings",
Body: strings.NewReader(validData),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin submitting empty data",
Name: "authorized as superuser submitting empty data",
Method: http.MethodPatch,
Url: "/api/settings",
URL: "/api/settings",
Body: strings.NewReader(``),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
@ -140,71 +111,46 @@ func TestSettingsSet(t *testing.T) {
`"smtp":{`,
`"s3":{`,
`"backups":{`,
`"adminAuthToken":{`,
`"adminPasswordResetToken":{`,
`"adminFileToken":{`,
`"recordAuthToken":{`,
`"recordPasswordResetToken":{`,
`"recordEmailChangeToken":{`,
`"recordVerificationToken":{`,
`"recordFileToken":{`,
`"emailAuth":{`,
`"googleAuth":{`,
`"facebookAuth":{`,
`"githubAuth":{`,
`"gitlabAuth":{`,
`"discordAuth":{`,
`"microsoftAuth":{`,
`"spotifyAuth":{`,
`"kakaoAuth":{`,
`"twitchAuth":{`,
`"stravaAuth":{`,
`"giteeAuth":{`,
`"livechatAuth":{`,
`"giteaAuth":{`,
`"oidcAuth":{`,
`"oidc2Auth":{`,
`"oidc3Auth":{`,
`"appleAuth":{`,
`"instagramAuth":{`,
`"vkAuth":{`,
`"yandexAuth":{`,
`"patreonAuth":{`,
`"mailcowAuth":{`,
`"bitbucketAuth":{`,
`"planningcenterAuth":{`,
`"secret":"******"`,
`"clientSecret":"******"`,
`"appName":"acme_test"`,
`"batch":{`,
},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnSettingsBeforeUpdateRequest": 1,
"OnSettingsAfterUpdateRequest": 1,
"*": 0,
"OnSettingsUpdateRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnSettingsReload": 1,
},
},
{
Name: "authorized as admin submitting invalid data",
Name: "authorized as superuser submitting invalid data",
Method: http.MethodPatch,
Url: "/api/settings",
URL: "/api/settings",
Body: strings.NewReader(`{"meta":{"appName":""}}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"meta":{"appName":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{
"*": 0,
"OnModelUpdate": 1,
"OnModelAfterUpdateError": 1,
"OnModelValidate": 1,
"OnSettingsUpdateRequest": 1,
},
},
{
Name: "authorized as admin submitting valid data",
Name: "authorized as superuser submitting valid data",
Method: http.MethodPatch,
Url: "/api/settings",
URL: "/api/settings",
Body: strings.NewReader(validData),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
@ -213,71 +159,21 @@ func TestSettingsSet(t *testing.T) {
`"smtp":{`,
`"s3":{`,
`"backups":{`,
`"adminAuthToken":{`,
`"adminPasswordResetToken":{`,
`"adminFileToken":{`,
`"recordAuthToken":{`,
`"recordPasswordResetToken":{`,
`"recordEmailChangeToken":{`,
`"recordVerificationToken":{`,
`"recordFileToken":{`,
`"emailAuth":{`,
`"googleAuth":{`,
`"facebookAuth":{`,
`"githubAuth":{`,
`"gitlabAuth":{`,
`"twitterAuth":{`,
`"discordAuth":{`,
`"microsoftAuth":{`,
`"spotifyAuth":{`,
`"kakaoAuth":{`,
`"twitchAuth":{`,
`"stravaAuth":{`,
`"giteeAuth":{`,
`"livechatAuth":{`,
`"giteaAuth":{`,
`"oidcAuth":{`,
`"oidc2Auth":{`,
`"oidc3Auth":{`,
`"appleAuth":{`,
`"instagramAuth":{`,
`"vkAuth":{`,
`"yandexAuth":{`,
`"patreonAuth":{`,
`"mailcowAuth":{`,
`"bitbucketAuth":{`,
`"planningcenterAuth":{`,
`"secret":"******"`,
`"clientSecret":"******"`,
`"batch":{`,
`"appName":"update_test"`,
},
NotExpectedContent: []string{
"secret",
"password",
},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnSettingsBeforeUpdateRequest": 1,
"OnSettingsAfterUpdateRequest": 1,
},
},
{
Name: "OnSettingsAfterUpdateRequest error response",
Method: http.MethodPatch,
Url: "/api/settings",
Body: strings.NewReader(validData),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
},
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.OnSettingsAfterUpdateRequest().Add(func(e *core.SettingsUpdateEvent) error {
return errors.New("error")
})
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{
"OnModelBeforeUpdate": 1,
"OnModelAfterUpdate": 1,
"OnSettingsBeforeUpdateRequest": 1,
"OnSettingsAfterUpdateRequest": 1,
"*": 0,
"OnSettingsUpdateRequest": 1,
"OnModelUpdate": 1,
"OnModelUpdateExecute": 1,
"OnModelAfterUpdateSuccess": 1,
"OnModelValidate": 1,
"OnSettingsReload": 1,
},
},
}
@ -294,59 +190,64 @@ func TestSettingsTestS3(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodPost,
Url: "/api/settings/test/s3",
URL: "/api/settings/test/s3",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodPost,
Url: "/api/settings/test/s3",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
URL: "/api/settings/test/s3",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (missing body + no s3)",
Name: "authorized as superuser (missing body + no s3)",
Method: http.MethodPost,
Url: "/api/settings/test/s3",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
URL: "/api/settings/test/s3",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"filesystem":{`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (invalid filesystem)",
Name: "authorized as superuser (invalid filesystem)",
Method: http.MethodPost,
Url: "/api/settings/test/s3",
URL: "/api/settings/test/s3",
Body: strings.NewReader(`{"filesystem":"invalid"}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{`,
`"filesystem":{`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (valid filesystem and no s3)",
Name: "authorized as superuser (valid filesystem and no s3)",
Method: http.MethodPost,
Url: "/api/settings/test/s3",
URL: "/api/settings/test/s3",
Body: strings.NewReader(`{"filesystem":"storage"}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"data":{}`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}
@ -362,156 +263,199 @@ func TestSettingsTestEmail(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodPost,
Url: "/api/settings/test/email",
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "verification",
"email": "test@example.com"
}`),
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodPost,
Url: "/api/settings/test/email",
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "verification",
"email": "test@example.com"
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (invalid body)",
Name: "authorized as superuser (invalid body)",
Method: http.MethodPost,
Url: "/api/settings/test/email",
URL: "/api/settings/test/email",
Body: strings.NewReader(`{`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (empty json)",
Name: "authorized as superuser (empty json)",
Method: http.MethodPost,
Url: "/api/settings/test/email",
URL: "/api/settings/test/email",
Body: strings.NewReader(`{}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{
`"email":{"code":"validation_required"`,
`"template":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (verifiation template)",
Name: "authorized as superuser (verifiation template)",
Method: http.MethodPost,
Url: "/api/settings/test/email",
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "verification",
"email": "test@example.com"
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend != 1 {
t.Fatalf("[verification] Expected 1 sent email, got %d", app.TestMailer.TotalSend)
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[verification] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
}
if len(app.TestMailer.LastMessage.To) != 1 {
t.Fatalf("[verification] Expected 1 recipient, got %v", app.TestMailer.LastMessage.To)
if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[verification] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
}
if app.TestMailer.LastMessage.To[0].Address != "test@example.com" {
t.Fatalf("[verification] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage.To[0].Address)
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[verification] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
}
if !strings.Contains(app.TestMailer.LastMessage.HTML, "Verify") {
t.Fatalf("[verification] Expected to sent a verification email, got \n%v\n%v", app.TestMailer.LastMessage.Subject, app.TestMailer.LastMessage.HTML)
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Verify") {
t.Fatalf("[verification] Expected to sent a verification email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
}
},
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"OnMailerBeforeRecordVerificationSend": 1,
"OnMailerAfterRecordVerificationSend": 1,
"*": 0,
"OnMailerSend": 1,
"OnMailerRecordVerificationSend": 1,
},
},
{
Name: "authorized as admin (password reset template)",
Name: "authorized as superuser (password reset template)",
Method: http.MethodPost,
Url: "/api/settings/test/email",
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "password-reset",
"email": "test@example.com"
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend != 1 {
t.Fatalf("[password-reset] Expected 1 sent email, got %d", app.TestMailer.TotalSend)
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[password-reset] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
}
if len(app.TestMailer.LastMessage.To) != 1 {
t.Fatalf("[password-reset] Expected 1 recipient, got %v", app.TestMailer.LastMessage.To)
if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[password-reset] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
}
if app.TestMailer.LastMessage.To[0].Address != "test@example.com" {
t.Fatalf("[password-reset] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage.To[0].Address)
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[password-reset] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
}
if !strings.Contains(app.TestMailer.LastMessage.HTML, "Reset password") {
t.Fatalf("[password-reset] Expected to sent a password-reset email, got \n%v\n%v", app.TestMailer.LastMessage.Subject, app.TestMailer.LastMessage.HTML)
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Reset password") {
t.Fatalf("[password-reset] Expected to sent a password-reset email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
}
},
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"OnMailerBeforeRecordResetPasswordSend": 1,
"OnMailerAfterRecordResetPasswordSend": 1,
"*": 0,
"OnMailerSend": 1,
"OnMailerRecordPasswordResetSend": 1,
},
},
{
Name: "authorized as admin (email change)",
Name: "authorized as superuser (email change)",
Method: http.MethodPost,
Url: "/api/settings/test/email",
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "email-change",
"email": "test@example.com"
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend != 1 {
t.Fatalf("[email-change] Expected 1 sent email, got %d", app.TestMailer.TotalSend)
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[email-change] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
}
if len(app.TestMailer.LastMessage.To) != 1 {
t.Fatalf("[email-change] Expected 1 recipient, got %v", app.TestMailer.LastMessage.To)
if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[email-change] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
}
if app.TestMailer.LastMessage.To[0].Address != "test@example.com" {
t.Fatalf("[email-change] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage.To[0].Address)
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[email-change] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
}
if !strings.Contains(app.TestMailer.LastMessage.HTML, "Confirm new email") {
t.Fatalf("[email-change] Expected to sent a confirm new email email, got \n%v\n%v", app.TestMailer.LastMessage.Subject, app.TestMailer.LastMessage.HTML)
if !strings.Contains(app.TestMailer.LastMessage().HTML, "Confirm new email") {
t.Fatalf("[email-change] Expected to sent a confirm new email email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
}
},
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"OnMailerBeforeRecordChangeEmailSend": 1,
"OnMailerAfterRecordChangeEmailSend": 1,
"*": 0,
"OnMailerSend": 1,
"OnMailerRecordEmailChangeSend": 1,
},
},
{
Name: "authorized as superuser (otp)",
Method: http.MethodPost,
URL: "/api/settings/test/email",
Body: strings.NewReader(`{
"template": "otp",
"email": "test@example.com"
}`),
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
AfterTestFunc: func(t testing.TB, app *tests.TestApp, res *http.Response) {
if app.TestMailer.TotalSend() != 1 {
t.Fatalf("[otp] Expected 1 sent email, got %d", app.TestMailer.TotalSend())
}
if len(app.TestMailer.LastMessage().To) != 1 {
t.Fatalf("[otp] Expected 1 recipient, got %v", app.TestMailer.LastMessage().To)
}
if app.TestMailer.LastMessage().To[0].Address != "test@example.com" {
t.Fatalf("[otp] Expected the email to be sent to %s, got %s", "test@example.com", app.TestMailer.LastMessage().To[0].Address)
}
if !strings.Contains(app.TestMailer.LastMessage().HTML, "one-time password") {
t.Fatalf("[otp] Expected to sent OTP email, got \n%v\n%v", app.TestMailer.LastMessage().Subject, app.TestMailer.LastMessage().HTML)
}
},
ExpectedStatus: 204,
ExpectedContent: []string{},
ExpectedEvents: map[string]int{
"*": 0,
"OnMailerSend": 1,
"OnMailerRecordOTPSend": 1,
},
},
}
@ -545,38 +489,41 @@ func TestGenerateAppleClientSecret(t *testing.T) {
{
Name: "unauthorized",
Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret",
URL: "/api/settings/apple/generate-client-secret",
ExpectedStatus: 401,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as auth record",
Name: "authorized as regular user",
Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret",
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc",
URL: "/api/settings/apple/generate-client-secret",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.ZT3F0Z3iM-xbGgSG3LEKiEzHrPHr8t8IuHLZGGNuxLo",
},
ExpectedStatus: 401,
ExpectedStatus: 403,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (invalid body)",
Name: "authorized as superuser (invalid body)",
Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret",
URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(`{`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (empty json)",
Name: "authorized as superuser (empty json)",
Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret",
URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(`{}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{
@ -586,11 +533,12 @@ func TestGenerateAppleClientSecret(t *testing.T) {
`"privateKey":{"code":"validation_required"`,
`"duration":{"code":"validation_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (invalid data)",
Name: "authorized as superuser (invalid data)",
Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret",
URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(`{
"clientId": "",
"teamId": "123456789",
@ -598,8 +546,8 @@ func TestGenerateAppleClientSecret(t *testing.T) {
"privateKey": "invalid",
"duration": -1
}`),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 400,
ExpectedContent: []string{
@ -609,11 +557,12 @@ func TestGenerateAppleClientSecret(t *testing.T) {
`"privateKey":{"code":"validation_match_invalid"`,
`"duration":{"code":"validation_min_greater_equal_than_required"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
{
Name: "authorized as admin (valid data)",
Name: "authorized as superuser (valid data)",
Method: http.MethodPost,
Url: "/api/settings/apple/generate-client-secret",
URL: "/api/settings/apple/generate-client-secret",
Body: strings.NewReader(fmt.Sprintf(`{
"clientId": "123",
"teamId": "1234567890",
@ -621,13 +570,14 @@ func TestGenerateAppleClientSecret(t *testing.T) {
"privateKey": %q,
"duration": 1
}`, privatePem)),
RequestHeaders: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8",
Headers: map[string]string{
"Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhdXRoIiwiY29sbGVjdGlvbklkIjoiX3BiY18zMzIzODY2MzM5IiwiZXhwIjoyNTI0NjA0NDYxLCJyZWZyZXNoYWJsZSI6dHJ1ZX0.v_bMAygr6hXPwD2DpPrFpNQ7dd68Q3pGstmYAsvNBJg",
},
ExpectedStatus: 200,
ExpectedContent: []string{
`"secret":"`,
},
ExpectedEvents: map[string]int{"*": 0},
},
}

View File

@ -1,141 +0,0 @@
package cmd
import (
"errors"
"fmt"
"github.com/fatih/color"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/models"
"github.com/spf13/cobra"
)
// NewAdminCommand creates and returns new command for managing
// admin accounts (create, update, delete).
func NewAdminCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "admin",
Short: "Manages admin accounts",
}
command.AddCommand(adminCreateCommand(app))
command.AddCommand(adminUpdateCommand(app))
command.AddCommand(adminDeleteCommand(app))
return command
}
func adminCreateCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "create",
Example: "admin create test@example.com 1234567890",
Short: "Creates a new admin account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) != 2 {
return errors.New("Missing email and password arguments.")
}
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Missing or invalid email address.")
}
if len(args[1]) < 8 {
return errors.New("The password must be at least 8 chars long.")
}
admin := &models.Admin{}
admin.Email = args[0]
admin.SetPassword(args[1])
if !app.Dao().HasTable(admin.TableName()) {
return errors.New("Migration are not initialized yet. Please run 'migrate up' and try again.")
}
if err := app.Dao().SaveAdmin(admin); err != nil {
return fmt.Errorf("Failed to create new admin account: %v", err)
}
color.Green("Successfully created new admin %s!", admin.Email)
return nil
},
}
return command
}
func adminUpdateCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "update",
Example: "admin update test@example.com 1234567890",
Short: "Changes the password of a single admin account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) != 2 {
return errors.New("Missing email and password arguments.")
}
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Missing or invalid email address.")
}
if len(args[1]) < 8 {
return errors.New("The new password must be at least 8 chars long.")
}
if !app.Dao().HasTable((&models.Admin{}).TableName()) {
return errors.New("Migration are not initialized yet. Please run 'migrate up' and try again.")
}
admin, err := app.Dao().FindAdminByEmail(args[0])
if err != nil {
return fmt.Errorf("Admin with email %s doesn't exist.", args[0])
}
admin.SetPassword(args[1])
if err := app.Dao().SaveAdmin(admin); err != nil {
return fmt.Errorf("Failed to change admin %s password: %v", admin.Email, err)
}
color.Green("Successfully changed admin %s password!", admin.Email)
return nil
},
}
return command
}
func adminDeleteCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "delete",
Example: "admin delete test@example.com",
Short: "Deletes an existing admin account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) == 0 || args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Invalid or missing email address.")
}
if !app.Dao().HasTable((&models.Admin{}).TableName()) {
return errors.New("Migration are not initialized yet. Please run 'migrate up' and try again.")
}
admin, err := app.Dao().FindAdminByEmail(args[0])
if err != nil {
color.Yellow("Admin %s is already deleted.", args[0])
return nil
}
if err := app.Dao().DeleteAdmin(admin); err != nil {
return fmt.Errorf("Failed to delete admin %s: %v", admin.Email, err)
}
color.Green("Successfully deleted admin %s!", admin.Email)
return nil
},
}
return command
}

View File

@ -1,221 +0,0 @@
package cmd_test
import (
"testing"
"github.com/pocketbase/pocketbase/cmd"
"github.com/pocketbase/pocketbase/tests"
)
func TestAdminCreateCommand(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
password string
expectError bool
}{
{
"empty email and password",
"",
"",
true,
},
{
"empty email",
"",
"1234567890",
true,
},
{
"invalid email",
"invalid",
"1234567890",
true,
},
{
"duplicated email",
"test@example.com",
"1234567890",
true,
},
{
"empty password",
"test@example.com",
"",
true,
},
{
"short password",
"test_new@example.com",
"1234567",
true,
},
{
"valid email and password",
"test_new@example.com",
"12345678",
false,
},
}
for _, s := range scenarios {
command := cmd.NewAdminCommand(app)
command.SetArgs([]string{"create", s.email, s.password})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
}
if hasErr {
continue
}
// check whether the admin account was actually created
admin, err := app.Dao().FindAdminByEmail(s.email)
if err != nil {
t.Errorf("[%s] Failed to fetch created admin %s: %v", s.name, s.email, err)
} else if !admin.ValidatePassword(s.password) {
t.Errorf("[%s] Expected the admin password to match", s.name)
}
}
}
func TestAdminUpdateCommand(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
password string
expectError bool
}{
{
"empty email and password",
"",
"",
true,
},
{
"empty email",
"",
"1234567890",
true,
},
{
"invalid email",
"invalid",
"1234567890",
true,
},
{
"nonexisting admin",
"test_missing@example.com",
"1234567890",
true,
},
{
"empty password",
"test@example.com",
"",
true,
},
{
"short password",
"test_new@example.com",
"1234567",
true,
},
{
"valid email and password",
"test@example.com",
"12345678",
false,
},
}
for _, s := range scenarios {
command := cmd.NewAdminCommand(app)
command.SetArgs([]string{"update", s.email, s.password})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
}
if hasErr {
continue
}
// check whether the admin password was actually changed
admin, err := app.Dao().FindAdminByEmail(s.email)
if err != nil {
t.Errorf("[%s] Failed to fetch admin %s: %v", s.name, s.email, err)
} else if !admin.ValidatePassword(s.password) {
t.Errorf("[%s] Expected the admin password to match", s.name)
}
}
}
func TestAdminDeleteCommand(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
expectError bool
}{
{
"empty email",
"",
true,
},
{
"invalid email",
"invalid",
true,
},
{
"nonexisting admin",
"test_missing@example.com",
false,
},
{
"existing admin",
"test@example.com",
false,
},
}
for _, s := range scenarios {
command := cmd.NewAdminCommand(app)
command.SetArgs([]string{"delete", s.email})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
}
if hasErr {
continue
}
// check whether the admin account was actually deleted
if _, err := app.Dao().FindAdminByEmail(s.email); err == nil {
t.Errorf("[%s] Expected the admin account to be deleted", s.name)
}
}
}

View File

@ -15,6 +15,7 @@ func NewServeCommand(app core.App, showStartBanner bool) *cobra.Command {
var allowedOrigins []string
var httpAddr string
var httpsAddr string
var dashboardPath string
command := &cobra.Command{
Use: "serve [domain(s)]",
@ -36,9 +37,10 @@ func NewServeCommand(app core.App, showStartBanner bool) *cobra.Command {
}
}
_, err := apis.Serve(app, apis.ServeConfig{
err := apis.Serve(app, apis.ServeConfig{
HttpAddr: httpAddr,
HttpsAddr: httpsAddr,
DashboardPath: dashboardPath,
ShowStartBanner: showStartBanner,
AllowedOrigins: allowedOrigins,
CertificateDomains: args,
@ -73,5 +75,12 @@ func NewServeCommand(app core.App, showStartBanner bool) *cobra.Command {
"TCP address to listen for the HTTPS server\n(if domain args are specified - default to 0.0.0.0:443, otherwise - default to empty string, aka. no TLS)\nThe incoming HTTP traffic also will be auto redirected to the HTTPS version",
)
command.PersistentFlags().StringVar(
&dashboardPath,
"dashboard",
"/_/{path...}",
"The route path to the superusers dashboard; must include the '{path...}' wildcard parameter",
)
return command
}

166
cmd/superuser.go Normal file
View File

@ -0,0 +1,166 @@
package cmd
import (
"errors"
"fmt"
"github.com/fatih/color"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/core"
"github.com/spf13/cobra"
)
// NewSuperuserCommand creates and returns new command for managing
// superuser accounts (create, update, delete).
func NewSuperuserCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "superuser",
Short: "Manages superuser accounts",
}
command.AddCommand(superuserUpsertCommand(app))
command.AddCommand(superuserCreateCommand(app))
command.AddCommand(superuserUpdateCommand(app))
command.AddCommand(superuserDeleteCommand(app))
return command
}
func superuserUpsertCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "upsert",
Example: "superuser upsert test@example.com 1234567890",
Short: "Creates, or updates if email exists, a single superuser account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) != 2 {
return errors.New("Missing email and password arguments.")
}
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Missing or invalid email address.")
}
superusersCol, err := app.FindCachedCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
return fmt.Errorf("Failed to fetch %q collection: %w.", core.CollectionNameSuperusers, err)
}
superuser, err := app.FindAuthRecordByEmail(superusersCol, args[0])
if err != nil {
superuser = core.NewRecord(superusersCol)
}
superuser.SetEmail(args[0])
superuser.SetPassword(args[1])
if err := app.Save(superuser); err != nil {
return fmt.Errorf("Failed to upsert superuser account: %w.", err)
}
color.Green("Successfully saved superuser %q!", superuser.Email())
return nil
},
}
return command
}
func superuserCreateCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "create",
Example: "superuser create test@example.com 1234567890",
Short: "Creates a new superuser account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) != 2 {
return errors.New("Missing email and password arguments.")
}
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Missing or invalid email address.")
}
superusersCol, err := app.FindCachedCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
return fmt.Errorf("Failed to fetch %q collection: %w.", core.CollectionNameSuperusers, err)
}
superuser := core.NewRecord(superusersCol)
superuser.SetEmail(args[0])
superuser.SetPassword(args[1])
if err := app.Save(superuser); err != nil {
return fmt.Errorf("Failed to create new superuser account: %w.", err)
}
color.Green("Successfully created new superuser %q!", superuser.Email())
return nil
},
}
return command
}
func superuserUpdateCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "update",
Example: "superuser update test@example.com 1234567890",
Short: "Changes the password of a single superuser account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) != 2 {
return errors.New("Missing email and password arguments.")
}
if args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Missing or invalid email address.")
}
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, args[0])
if err != nil {
return fmt.Errorf("Superuser with email %q doesn't exist.", args[0])
}
superuser.SetPassword(args[1])
if err := app.Save(superuser); err != nil {
return fmt.Errorf("Failed to change superuser %q password: %w.", superuser.Email(), err)
}
color.Green("Successfully changed superuser %q password!", superuser.Email())
return nil
},
}
return command
}
func superuserDeleteCommand(app core.App) *cobra.Command {
command := &cobra.Command{
Use: "delete",
Example: "superuser delete test@example.com",
Short: "Deletes an existing superuser account",
SilenceUsage: true,
RunE: func(command *cobra.Command, args []string) error {
if len(args) == 0 || args[0] == "" || is.EmailFormat.Validate(args[0]) != nil {
return errors.New("Invalid or missing email address.")
}
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, args[0])
if err != nil {
color.Yellow("Superuser %q is missing or already deleted.", args[0])
return nil
}
if err := app.Delete(superuser); err != nil {
return fmt.Errorf("Failed to delete superuser %q: %w.", superuser.Email(), err)
}
color.Green("Successfully deleted superuser %q!", superuser.Email())
return nil
},
}
return command
}

310
cmd/superuser_test.go Normal file
View File

@ -0,0 +1,310 @@
package cmd_test
import (
"testing"
"github.com/pocketbase/pocketbase/cmd"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestSuperuserUpsertCommand(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
password string
expectError bool
}{
{
"empty email and password",
"",
"",
true,
},
{
"empty email",
"",
"1234567890",
true,
},
{
"invalid email",
"invalid",
"1234567890",
true,
},
{
"empty password",
"test@example.com",
"",
true,
},
{
"short password",
"test_new@example.com",
"1234567",
true,
},
{
"existing user",
"test@example.com",
"1234567890!",
false,
},
{
"new user",
"test_new@example.com",
"1234567890!",
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
command := cmd.NewSuperuserCommand(app)
command.SetArgs([]string{"upsert", s.email, s.password})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
// check whether the superuser account was actually upserted
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
if err != nil {
t.Fatalf("Failed to fetch superuser %s: %v", s.email, err)
} else if !superuser.ValidatePassword(s.password) {
t.Fatal("Expected the superuser password to match")
}
})
}
}
func TestSuperuserCreateCommand(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
password string
expectError bool
}{
{
"empty email and password",
"",
"",
true,
},
{
"empty email",
"",
"1234567890",
true,
},
{
"invalid email",
"invalid",
"1234567890",
true,
},
{
"duplicated email",
"test@example.com",
"1234567890",
true,
},
{
"empty password",
"test@example.com",
"",
true,
},
{
"short password",
"test_new@example.com",
"1234567",
true,
},
{
"valid email and password",
"test_new@example.com",
"12345678",
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
command := cmd.NewSuperuserCommand(app)
command.SetArgs([]string{"create", s.email, s.password})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
// check whether the superuser account was actually created
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
if err != nil {
t.Fatalf("Failed to fetch created superuser %s: %v", s.email, err)
} else if !superuser.ValidatePassword(s.password) {
t.Fatal("Expected the superuser password to match")
}
})
}
}
func TestSuperuserUpdateCommand(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
password string
expectError bool
}{
{
"empty email and password",
"",
"",
true,
},
{
"empty email",
"",
"1234567890",
true,
},
{
"invalid email",
"invalid",
"1234567890",
true,
},
{
"nonexisting superuser",
"test_missing@example.com",
"1234567890",
true,
},
{
"empty password",
"test@example.com",
"",
true,
},
{
"short password",
"test_new@example.com",
"1234567",
true,
},
{
"valid email and password",
"test@example.com",
"12345678",
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
command := cmd.NewSuperuserCommand(app)
command.SetArgs([]string{"update", s.email, s.password})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
// check whether the superuser password was actually changed
superuser, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email)
if err != nil {
t.Fatalf("Failed to fetch superuser %s: %v", s.email, err)
} else if !superuser.ValidatePassword(s.password) {
t.Fatal("Expected the superuser password to match")
}
})
}
}
func TestSuperuserDeleteCommand(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
email string
expectError bool
}{
{
"empty email",
"",
true,
},
{
"invalid email",
"invalid",
true,
},
{
"nonexisting superuser",
"test_missing@example.com",
false,
},
{
"existing superuser",
"test@example.com",
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
command := cmd.NewSuperuserCommand(app)
command.SetArgs([]string{"delete", s.email})
err := command.Execute()
hasErr := err != nil
if s.expectError != hasErr {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if _, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, s.email); err == nil {
t.Fatal("Expected the superuser account to be deleted")
}
})
}
}

File diff suppressed because it is too large Load Diff

239
core/auth_origin_model.go Normal file
View File

@ -0,0 +1,239 @@
package core
import (
"context"
"errors"
"slices"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/types"
)
const CollectionNameAuthOrigins = "_authOrigins"
var (
_ Model = (*AuthOrigin)(nil)
_ PreValidator = (*AuthOrigin)(nil)
_ RecordProxy = (*AuthOrigin)(nil)
)
// AuthOrigin defines a Record proxy for working with the authOrigins collection.
type AuthOrigin struct {
*Record
}
// NewAuthOrigin instantiates and returns a new blank *AuthOrigin model.
//
// Example usage:
//
// origin := core.NewOrigin(app)
// origin.SetRecordRef(user.Id)
// origin.SetCollectionRef(user.Collection().Id)
// origin.SetFingerprint("...")
// app.Save(origin)
func NewAuthOrigin(app App) *AuthOrigin {
m := &AuthOrigin{}
c, err := app.FindCachedCollectionByNameOrId(CollectionNameAuthOrigins)
if err != nil {
// this is just to make tests easier since authOrigins is a system collection and it is expected to be always accessible
// (note: the loaded record is further checked on AuthOrigin.PreValidate())
c = NewBaseCollection("@___invalid___")
}
m.Record = NewRecord(c)
return m
}
// PreValidate implements the [PreValidator] interface and checks
// whether the proxy is properly loaded.
func (m *AuthOrigin) PreValidate(ctx context.Context, app App) error {
if m.Record == nil || m.Record.Collection().Name != CollectionNameAuthOrigins {
return errors.New("missing or invalid AuthOrigin ProxyRecord")
}
return nil
}
// ProxyRecord returns the proxied Record model.
func (m *AuthOrigin) ProxyRecord() *Record {
return m.Record
}
// SetProxyRecord loads the specified record model into the current proxy.
func (m *AuthOrigin) SetProxyRecord(record *Record) {
m.Record = record
}
// CollectionRef returns the "collectionRef" field value.
func (m *AuthOrigin) CollectionRef() string {
return m.GetString("collectionRef")
}
// SetCollectionRef updates the "collectionRef" record field value.
func (m *AuthOrigin) SetCollectionRef(collectionId string) {
m.Set("collectionRef", collectionId)
}
// RecordRef returns the "recordRef" record field value.
func (m *AuthOrigin) RecordRef() string {
return m.GetString("recordRef")
}
// SetRecordRef updates the "recordRef" record field value.
func (m *AuthOrigin) SetRecordRef(recordId string) {
m.Set("recordRef", recordId)
}
// Fingerprint returns the "fingerprint" record field value.
func (m *AuthOrigin) Fingerprint() string {
return m.GetString("fingerprint")
}
// SetFingerprint updates the "fingerprint" record field value.
func (m *AuthOrigin) SetFingerprint(fingerprint string) {
m.Set("fingerprint", fingerprint)
}
// Created returns the "created" record field value.
func (m *AuthOrigin) Created() types.DateTime {
return m.GetDateTime("created")
}
// Updated returns the "updated" record field value.
func (m *AuthOrigin) Updated() types.DateTime {
return m.GetDateTime("updated")
}
func (app *BaseApp) registerAuthOriginHooks() {
recordRefHooks[*AuthOrigin](app, CollectionNameAuthOrigins, CollectionTypeAuth)
// delete existing auth origins on password change
app.OnRecordUpdate().Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
err := e.Next()
if err != nil || !e.Record.Collection().IsAuth() {
return err
}
old := e.Record.Original().GetString(FieldNamePassword + ":hash")
new := e.Record.GetString(FieldNamePassword + ":hash")
if old != new {
err = e.App.DeleteAllAuthOriginsByRecord(e.Record)
if err != nil {
e.App.Logger().Warn(
"Failed to delete all previous auth origin fingerprints",
"error", err,
"recordId", e.Record.Id,
"collectionId", e.Record.Collection().Id,
)
}
}
return nil
},
Priority: 99,
})
}
// -------------------------------------------------------------------
// recordRefHooks registers common hooks that are usually used with record proxies
// that have polymorphic record relations (aka. "collectionRef" and "recordRef" fields).
func recordRefHooks[T RecordProxy](app App, collectionName string, optCollectionTypes ...string) {
app.OnRecordValidate(collectionName).Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
collectionId := e.Record.GetString("collectionRef")
err := validation.Validate(collectionId, validation.Required, validation.By(validateCollectionId(e.App, optCollectionTypes...)))
if err != nil {
return validation.Errors{"collectionRef": err}
}
recordId := e.Record.GetString("recordRef")
err = validation.Validate(recordId, validation.Required, validation.By(validateRecordId(e.App, collectionId)))
if err != nil {
return validation.Errors{"recordRef": err}
}
return e.Next()
},
Priority: 99,
})
// delete on collection ref delete
app.OnCollectionDeleteExecute().Bind(&hook.Handler[*CollectionEvent]{
Func: func(e *CollectionEvent) error {
if e.Collection.Name == collectionName || (len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Collection.Type)) {
return e.Next()
}
originalApp := e.App
txErr := e.App.RunInTransaction(func(txApp App) error {
e.App = txApp
if err := e.Next(); err != nil {
return err
}
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{"collectionRef": e.Collection.Id})
if err != nil {
return err
}
for _, mfa := range rels {
if err := txApp.Delete(mfa); err != nil {
return err
}
}
return nil
})
e.App = originalApp
return txErr
},
Priority: 99,
})
// delete on record ref delete
app.OnRecordDeleteExecute().Bind(&hook.Handler[*RecordEvent]{
Func: func(e *RecordEvent) error {
if e.Record.Collection().Name == collectionName ||
(len(optCollectionTypes) > 0 && !slices.Contains(optCollectionTypes, e.Record.Collection().Type)) {
return e.Next()
}
originalApp := e.App
txErr := e.App.RunInTransaction(func(txApp App) error {
e.App = txApp
if err := e.Next(); err != nil {
return err
}
rels, err := txApp.FindAllRecords(collectionName, dbx.HashExp{
"collectionRef": e.Record.Collection().Id,
"recordRef": e.Record.Id,
})
if err != nil {
return err
}
for _, rel := range rels {
if err := txApp.Delete(rel); err != nil {
return err
}
}
return nil
})
e.App = originalApp
return txErr
},
Priority: 99,
})
}

View File

@ -0,0 +1,332 @@
package core_test
import (
"fmt"
"slices"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestNewAuthOrigin(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
if origin.Collection().Name != core.CollectionNameAuthOrigins {
t.Fatalf("Expected record with %q collection, got %q", core.CollectionNameAuthOrigins, origin.Collection().Name)
}
}
func TestAuthOriginProxyRecord(t *testing.T) {
t.Parallel()
record := core.NewRecord(core.NewBaseCollection("test"))
record.Id = "test_id"
origin := core.AuthOrigin{}
origin.SetProxyRecord(record)
if origin.ProxyRecord() == nil || origin.ProxyRecord().Id != record.Id {
t.Fatalf("Expected proxy record with id %q, got %v", record.Id, origin.ProxyRecord())
}
}
func TestAuthOriginRecordRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
origin.SetRecordRef(testValue)
if v := origin.RecordRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := origin.GetString("recordRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestAuthOriginCollectionRef(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
origin.SetCollectionRef(testValue)
if v := origin.CollectionRef(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := origin.GetString("collectionRef"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestAuthOriginFingerprint(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
testValues := []string{"test_1", "test2", ""}
for i, testValue := range testValues {
t.Run(fmt.Sprintf("%d_%q", i, testValue), func(t *testing.T) {
origin.SetFingerprint(testValue)
if v := origin.Fingerprint(); v != testValue {
t.Fatalf("Expected getter %q, got %q", testValue, v)
}
if v := origin.GetString("fingerprint"); v != testValue {
t.Fatalf("Expected field value %q, got %q", testValue, v)
}
})
}
}
func TestAuthOriginCreated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
if v := origin.Created().String(); v != "" {
t.Fatalf("Expected empty created, got %q", v)
}
now := types.NowDateTime()
origin.SetRaw("created", now)
if v := origin.Created().String(); v != now.String() {
t.Fatalf("Expected %q created, got %q", now.String(), v)
}
}
func TestAuthOriginUpdated(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
origin := core.NewAuthOrigin(app)
if v := origin.Updated().String(); v != "" {
t.Fatalf("Expected empty updated, got %q", v)
}
now := types.NowDateTime()
origin.SetRaw("updated", now)
if v := origin.Updated().String(); v != now.String() {
t.Fatalf("Expected %q updated, got %q", now.String(), v)
}
}
func TestAuthOriginPreValidate(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
originsCol, err := app.FindCollectionByNameOrId(core.CollectionNameAuthOrigins)
if err != nil {
t.Fatal(err)
}
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
t.Run("no proxy record", func(t *testing.T) {
origin := &core.AuthOrigin{}
if err := app.Validate(origin); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("non-AuthOrigin collection", func(t *testing.T) {
origin := &core.AuthOrigin{}
origin.SetProxyRecord(core.NewRecord(core.NewBaseCollection("invalid")))
origin.SetRecordRef(user.Id)
origin.SetCollectionRef(user.Collection().Id)
origin.SetFingerprint("abc")
if err := app.Validate(origin); err == nil {
t.Fatal("Expected collection validation error")
}
})
t.Run("AuthOrigin collection", func(t *testing.T) {
origin := &core.AuthOrigin{}
origin.SetProxyRecord(core.NewRecord(originsCol))
origin.SetRecordRef(user.Id)
origin.SetCollectionRef(user.Collection().Id)
origin.SetFingerprint("abc")
if err := app.Validate(origin); err != nil {
t.Fatalf("Expected nil validation error, got %v", err)
}
})
}
func TestAuthOriginValidateHook(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
user, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
origin func() *core.AuthOrigin
expectErrors []string
}{
{
"empty",
func() *core.AuthOrigin {
return core.NewAuthOrigin(app)
},
[]string{"collectionRef", "recordRef", "fingerprint"},
},
{
"non-auth collection",
func() *core.AuthOrigin {
origin := core.NewAuthOrigin(app)
origin.SetCollectionRef(demo1.Collection().Id)
origin.SetRecordRef(demo1.Id)
origin.SetFingerprint("abc")
return origin
},
[]string{"collectionRef"},
},
{
"missing record id",
func() *core.AuthOrigin {
origin := core.NewAuthOrigin(app)
origin.SetCollectionRef(user.Collection().Id)
origin.SetRecordRef("missing")
origin.SetFingerprint("abc")
return origin
},
[]string{"recordRef"},
},
{
"valid ref",
func() *core.AuthOrigin {
origin := core.NewAuthOrigin(app)
origin.SetCollectionRef(user.Collection().Id)
origin.SetRecordRef(user.Id)
origin.SetFingerprint("abc")
return origin
},
[]string{},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
errs := app.Validate(s.origin())
tests.TestValidationErrors(t, errs, s.expectErrors)
})
}
}
func TestAuthOriginPasswordChangeDeletion(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
// no auth origin associated with it
user1, err := testApp.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
deletedIds []string
}{
{user1, nil},
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
{client1, []string{"9r2j0m74260ur8i"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
deletedIds := []string{}
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
deletedIds = append(deletedIds, e.Record.Id)
return e.Next()
})
s.record.SetPassword("new_password")
err := app.Save(s.record)
if err != nil {
t.Fatal(err)
}
if len(deletedIds) != len(s.deletedIds) {
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
}
for _, id := range s.deletedIds {
if !slices.Contains(deletedIds, id) {
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
}
}
})
}
}

101
core/auth_origin_query.go Normal file
View File

@ -0,0 +1,101 @@
package core
import (
"errors"
"github.com/pocketbase/dbx"
)
// FindAllAuthOriginsByRecord returns all AuthOrigin models linked to the provided auth record (in DESC order).
func (app *BaseApp) FindAllAuthOriginsByRecord(authRecord *Record) ([]*AuthOrigin, error) {
result := []*AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{
"collectionRef": authRecord.Collection().Id,
"recordRef": authRecord.Id,
}).
OrderBy("created DESC").
All(&result)
if err != nil {
return nil, err
}
return result, nil
}
// FindAllAuthOriginsByCollection returns all AuthOrigin models linked to the provided collection (in DESC order).
func (app *BaseApp) FindAllAuthOriginsByCollection(collection *Collection) ([]*AuthOrigin, error) {
result := []*AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{"collectionRef": collection.Id}).
OrderBy("created DESC").
All(&result)
if err != nil {
return nil, err
}
return result, nil
}
// FindAuthOriginById returns a single AuthOrigin model by its id.
func (app *BaseApp) FindAuthOriginById(id string) (*AuthOrigin, error) {
result := &AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{"id": id}).
Limit(1).
One(result)
if err != nil {
return nil, err
}
return result, nil
}
// FindAuthOriginByRecordAndFingerprint returns a single AuthOrigin model
// by its authRecord relation and fingerprint.
func (app *BaseApp) FindAuthOriginByRecordAndFingerprint(authRecord *Record, fingerprint string) (*AuthOrigin, error) {
result := &AuthOrigin{}
err := app.RecordQuery(CollectionNameAuthOrigins).
AndWhere(dbx.HashExp{
"collectionRef": authRecord.Collection().Id,
"recordRef": authRecord.Id,
"fingerprint": fingerprint,
}).
Limit(1).
One(result)
if err != nil {
return nil, err
}
return result, nil
}
// DeleteAllAuthOriginsByRecord deletes all AuthOrigin models associated with the provided record.
//
// Returns a combined error with the failed deletes.
func (app *BaseApp) DeleteAllAuthOriginsByRecord(authRecord *Record) error {
models, err := app.FindAllAuthOriginsByRecord(authRecord)
if err != nil {
return err
}
var errs []error
for _, m := range models {
if err := app.Delete(m); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}

View File

@ -0,0 +1,268 @@
package core_test
import (
"fmt"
"slices"
"testing"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestFindAllAuthOriginsByRecord(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
superuser4, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
if err != nil {
t.Fatal(err)
}
client1, err := app.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
expected []string
}{
{demo1, nil},
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
{superuser4, nil},
{client1, []string{"9r2j0m74260ur8i"}},
}
for _, s := range scenarios {
t.Run(s.record.Collection().Name+"_"+s.record.Id, func(t *testing.T) {
result, err := app.FindAllAuthOriginsByRecord(s.record)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindAllAuthOriginsByCollection(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
superusers, err := app.FindCollectionByNameOrId(core.CollectionNameSuperusers)
if err != nil {
t.Fatal(err)
}
clients, err := app.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
collection *core.Collection
expected []string
}{
{demo1, nil},
{superusers, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib", "5f29jy38bf5zm3f"}},
{clients, []string{"9r2j0m74260ur8i"}},
}
for _, s := range scenarios {
t.Run(s.collection.Name, func(t *testing.T) {
result, err := app.FindAllAuthOriginsByCollection(s.collection)
if err != nil {
t.Fatal(err)
}
if len(result) != len(s.expected) {
t.Fatalf("Expected total origins %d, got %d", len(s.expected), len(result))
}
for i, id := range s.expected {
if result[i].Id != id {
t.Errorf("[%d] Expected id %q, got %q", i, id, result[i].Id)
}
}
})
}
}
func TestFindAuthOriginById(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
id string
expectError bool
}{
{"", true},
{"84nmscqy84lsi1t", true}, // non-origin id
{"9r2j0m74260ur8i", false},
}
for _, s := range scenarios {
t.Run(s.id, func(t *testing.T) {
result, err := app.FindAuthOriginById(s.id)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if result.Id != s.id {
t.Fatalf("Expected record with id %q, got %q", s.id, result.Id)
}
})
}
}
func TestFindAuthOriginByRecordAndFingerprint(t *testing.T) {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
demo1, err := app.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := app.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
fingerprint string
expectError bool
}{
{demo1, "6afbfe481c31c08c55a746cccb88ece0", true},
{superuser2, "", true},
{superuser2, "abc", true},
{superuser2, "22bbbcbed36e25321f384ccf99f60057", false}, // fingerprint from different origin
{superuser2, "6afbfe481c31c08c55a746cccb88ece0", false},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Id, s.fingerprint), func(t *testing.T) {
result, err := app.FindAuthOriginByRecordAndFingerprint(s.record, s.fingerprint)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
if hasErr {
return
}
if result.Fingerprint() != s.fingerprint {
t.Fatalf("Expected origin with fingerprint %q, got %q", s.fingerprint, result.Fingerprint())
}
if result.RecordRef() != s.record.Id || result.CollectionRef() != s.record.Collection().Id {
t.Fatalf("Expected record %q (%q), got %q (%q)", s.record.Id, s.record.Collection().Id, result.RecordRef(), result.CollectionRef())
}
})
}
}
func TestDeleteAllAuthOriginsByRecord(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
demo1, err := testApp.FindRecordById("demo1", "84nmscqy84lsi1t")
if err != nil {
t.Fatal(err)
}
superuser2, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test2@example.com")
if err != nil {
t.Fatal(err)
}
superuser4, err := testApp.FindAuthRecordByEmail(core.CollectionNameSuperusers, "test4@example.com")
if err != nil {
t.Fatal(err)
}
client1, err := testApp.FindAuthRecordByEmail("clients", "test@example.com")
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
record *core.Record
deletedIds []string
}{
{demo1, nil}, // non-auth record
{superuser2, []string{"5798yh833k6w6w0", "ic55o70g4f8pcl4", "dmy260k6ksjr4ib"}},
{superuser4, nil},
{client1, []string{"9r2j0m74260ur8i"}},
}
for i, s := range scenarios {
t.Run(fmt.Sprintf("%d_%s_%s", i, s.record.Collection().Name, s.record.Id), func(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
deletedIds := []string{}
app.OnRecordDelete().BindFunc(func(e *core.RecordEvent) error {
deletedIds = append(deletedIds, e.Record.Id)
return e.Next()
})
err := app.DeleteAllAuthOriginsByRecord(s.record)
if err != nil {
t.Fatal(err)
}
if len(deletedIds) != len(s.deletedIds) {
t.Fatalf("Expected deleted ids\n%v\ngot\n%v", s.deletedIds, deletedIds)
}
for _, id := range s.deletedIds {
if !slices.Contains(deletedIds, id) {
t.Errorf("Expected to find deleted id %q in %v", id, deletedIds)
}
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@ -12,20 +12,16 @@ import (
"sort"
"time"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tools/archive"
"github.com/pocketbase/pocketbase/tools/cron"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/osutils"
"github.com/pocketbase/pocketbase/tools/security"
)
// Deprecated: Replaced with StoreKeyActiveBackup.
const CacheKeyActiveBackup string = "@activeBackup"
const StoreKeyActiveBackup string = "@activeBackup"
const (
StoreKeyActiveBackup = "@activeBackup"
)
// CreateBackup creates a new backup of the current app pb_data directory.
//
@ -50,32 +46,37 @@ func (app *BaseApp) CreateBackup(ctx context.Context, name string) error {
return errors.New("try again later - another backup/restore operation has already been started")
}
if name == "" {
name = app.generateBackupName("pb_backup_")
}
app.Store().Set(StoreKeyActiveBackup, name)
defer app.Store().Remove(StoreKeyActiveBackup)
// root dir entries to exclude from the backup generation
exclude := []string{LocalBackupsDirName, LocalTempDirName}
event := new(BackupEvent)
event.App = app
event.Context = ctx
event.Name = name
// default root dir entries to exclude from the backup generation
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
return app.OnBackupCreate().Trigger(event, func(e *BackupEvent) error {
// generate a default name if missing
if e.Name == "" {
e.Name = generateBackupName(e.App, "pb_backup_")
}
// make sure that the special temp directory exists
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
localTempDir := filepath.Join(app.DataDir(), LocalTempDirName)
localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
return fmt.Errorf("failed to create a temp dir: %w", err)
}
// Archive pb_data in a temp directory, exluding the "backups" and the temp dirs.
// archive pb_data in a temp directory, exluding the "backups" and the temp dirs
//
// Run in transaction to temporary block other writes (transactions uses the NonconcurrentDB connection).
// ---
tempPath := filepath.Join(localTempDir, "pb_backup_"+security.PseudorandomString(4))
createErr := app.Dao().RunInTransaction(func(dataTXDao *daos.Dao) error {
return app.LogsDao().RunInTransaction(func(logsTXDao *daos.Dao) error {
// @todo consider experimenting with temp switching the readonly pragma after the db interface change
return archive.Create(app.DataDir(), tempPath, exclude...)
tempPath := filepath.Join(localTempDir, "pb_backup_"+security.PseudorandomString(6))
createErr := e.App.RunInTransaction(func(txApp App) error {
return txApp.AuxRunInTransaction(func(txApp App) error {
return archive.Create(txApp.DataDir(), tempPath, e.Exclude...)
})
})
if createErr != nil {
@ -83,21 +84,21 @@ func (app *BaseApp) CreateBackup(ctx context.Context, name string) error {
}
defer os.Remove(tempPath)
// Persist the backup in the backups filesystem.
// persist the backup in the backups filesystem
// ---
fsys, err := app.NewBackupsFilesystem()
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
fsys.SetContext(ctx)
fsys.SetContext(e.Context)
file, err := filesystem.NewFileFromPath(tempPath)
if err != nil {
return err
}
file.OriginalName = name
file.OriginalName = e.Name
file.Name = file.OriginalName
if err := fsys.UploadFile(file, file.Name); err != nil {
@ -105,6 +106,7 @@ func (app *BaseApp) CreateBackup(ctx context.Context, name string) error {
}
return nil
})
}
// RestoreBackup restores the backup with the specified name and restarts
@ -136,10 +138,6 @@ func (app *BaseApp) CreateBackup(ctx context.Context, name string) error {
// If a failure occure during the restore process the dir changes are reverted.
// If for whatever reason the revert is not possible, it panics.
func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
if runtime.GOOS == "windows" {
return errors.New("restore is not supported on windows")
}
if app.Store().Has(StoreKeyActiveBackup) {
return errors.New("try again later - another backup/restore operation has already been started")
}
@ -147,13 +145,25 @@ func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
app.Store().Set(StoreKeyActiveBackup, name)
defer app.Store().Remove(StoreKeyActiveBackup)
fsys, err := app.NewBackupsFilesystem()
event := new(BackupEvent)
event.App = app
event.Context = ctx
event.Name = name
// default root dir entries to exclude from the backup restore
event.Exclude = []string{LocalBackupsDirName, LocalTempDirName, LocalAutocertCacheDirName}
return app.OnBackupRestore().Trigger(event, func(e *BackupEvent) error {
if runtime.GOOS == "windows" {
return errors.New("restore is not supported on Windows")
}
fsys, err := e.App.NewBackupsFilesystem()
if err != nil {
return err
}
defer fsys.Close()
fsys.SetContext(ctx)
fsys.SetContext(e.Context)
// fetch the backup file in a temp location
br, err := fsys.GetFile(name)
@ -164,7 +174,7 @@ func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
// make sure that the special temp directory exists
// note: it needs to be inside the current pb_data to avoid "cross-device link" errors
localTempDir := filepath.Join(app.DataDir(), LocalTempDirName)
localTempDir := filepath.Join(e.App.DataDir(), LocalTempDirName)
if err := os.MkdirAll(localTempDir, os.ModePerm); err != nil {
return fmt.Errorf("failed to create a temp dir: %w", err)
}
@ -195,35 +205,32 @@ func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
// remove the extracted zip file since we no longer need it
// (this is in case the app restarts and the defer calls are not called)
if err := os.Remove(tempZip.Name()); err != nil {
app.Logger().Debug(
e.App.Logger().Debug(
"[RestoreBackup] Failed to remove the temp zip backup file",
slog.String("file", tempZip.Name()),
slog.String("error", err.Error()),
)
}
// root dir entries to exclude from the backup restore
exclude := []string{LocalBackupsDirName, LocalTempDirName}
// move the current pb_data content to a special temp location
// that will hold the old data between dirs replace
// (the temp dir will be automatically removed on the next app start)
oldTempDataDir := filepath.Join(localTempDir, "old_pb_data_"+security.PseudorandomString(4))
if err := osutils.MoveDirContent(app.DataDir(), oldTempDataDir, exclude...); err != nil {
if err := osutils.MoveDirContent(e.App.DataDir(), oldTempDataDir, e.Exclude...); err != nil {
return fmt.Errorf("failed to move the current pb_data content to a temp location: %w", err)
}
// move the extracted archive content to the app's pb_data
if err := osutils.MoveDirContent(extractedDataDir, app.DataDir(), exclude...); err != nil {
if err := osutils.MoveDirContent(extractedDataDir, e.App.DataDir(), e.Exclude...); err != nil {
return fmt.Errorf("failed to move the extracted archive content to pb_data: %w", err)
}
revertDataDirChanges := func() error {
if err := osutils.MoveDirContent(app.DataDir(), extractedDataDir, exclude...); err != nil {
if err := osutils.MoveDirContent(e.App.DataDir(), extractedDataDir, e.Exclude...); err != nil {
return fmt.Errorf("failed to revert the extracted dir change: %w", err)
}
if err := osutils.MoveDirContent(oldTempDataDir, app.DataDir(), exclude...); err != nil {
if err := osutils.MoveDirContent(oldTempDataDir, e.App.DataDir(), e.Exclude...); err != nil {
return fmt.Errorf("failed to revert old pb_data dir change: %w", err)
}
@ -231,7 +238,7 @@ func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
}
// restart the app
if err := app.Restart(); err != nil {
if err := e.App.Restart(); err != nil {
if revertErr := revertDataDirChanges(); revertErr != nil {
panic(revertErr)
}
@ -240,38 +247,27 @@ func (app *BaseApp) RestoreBackup(ctx context.Context, name string) error {
}
return nil
})
}
// initAutobackupHooks registers the autobackup app serve hooks.
func (app *BaseApp) initAutobackupHooks() error {
c := cron.New()
isServe := false
// registerAutobackupHooks registers the autobackup app serve hooks.
func (app *BaseApp) registerAutobackupHooks() {
const jobId = "__auto_pb_backup__"
loadJob := func() {
c.Stop()
// make sure that app.Settings() is always up to date
//
// @todo remove with the refactoring as core.App and daos.Dao will be one.
if err := app.RefreshSettings(); err != nil {
app.Logger().Debug(
"[Backup cron] Failed to get the latest app settings",
slog.String("error", err.Error()),
)
}
rawSchedule := app.Settings().Backups.Cron
if rawSchedule == "" || !isServe || !app.IsBootstrapped() {
if rawSchedule == "" {
app.Cron().Remove(jobId)
return
}
c.Add("@autobackup", rawSchedule, func() {
app.Cron().Add(jobId, rawSchedule, func() {
const autoPrefix = "@auto_pb_backup_"
name := app.generateBackupName(autoPrefix)
name := generateBackupName(app, autoPrefix)
if err := app.CreateBackup(context.Background(), name); err != nil {
app.Logger().Debug(
app.Logger().Error(
"[Backup cron] Failed to create backup",
slog.String("name", name),
slog.String("error", err.Error()),
@ -286,7 +282,7 @@ func (app *BaseApp) initAutobackupHooks() error {
fsys, err := app.NewBackupsFilesystem()
if err != nil {
app.Logger().Debug(
app.Logger().Error(
"[Backup cron] Failed to initialize the backup filesystem",
slog.String("error", err.Error()),
)
@ -296,7 +292,7 @@ func (app *BaseApp) initAutobackupHooks() error {
files, err := fsys.List(autoPrefix)
if err != nil {
app.Logger().Debug(
app.Logger().Error(
"[Backup cron] Failed to list autogenerated backups",
slog.String("error", err.Error()),
)
@ -317,7 +313,7 @@ func (app *BaseApp) initAutobackupHooks() error {
for _, f := range toRemove {
if err := fsys.Delete(f.Key); err != nil {
app.Logger().Debug(
app.Logger().Error(
"[Backup cron] Failed to remove old autogenerated backup",
slog.String("key", f.Key),
slog.String("error", err.Error()),
@ -325,29 +321,11 @@ func (app *BaseApp) initAutobackupHooks() error {
}
}
})
// restart the ticker
c.Start()
}
// load on app serve
app.OnBeforeServe().Add(func(e *ServeEvent) error {
isServe = true
loadJob()
return nil
})
// stop the ticker on app termination
app.OnTerminate().Add(func(e *TerminateEvent) error {
c.Stop()
return nil
})
// reload on app settings change
app.OnModelAfterUpdate((&models.Param{}).TableName()).Add(func(e *ModelEvent) error {
p := e.Model.(*models.Param)
if p == nil || p.Key != models.ParamAppSettings {
return nil
app.OnBootstrap().BindFunc(func(e *BootstrapEvent) error {
if err := e.Next(); err != nil {
return err
}
loadJob()
@ -355,10 +333,18 @@ func (app *BaseApp) initAutobackupHooks() error {
return nil
})
app.OnSettingsReload().BindFunc(func(e *SettingsReloadEvent) error {
if err := e.Next(); err != nil {
return err
}
loadJob()
return nil
})
}
func (app *BaseApp) generateBackupName(prefix string) string {
func generateBackupName(app App, prefix string) string {
appName := inflector.Snakecase(app.Settings().Meta.AppName)
if len(appName) > 50 {
appName = appName[:50]

View File

@ -128,9 +128,9 @@ func verifyBackupContent(app core.App, path string) error {
"data.db",
"data.db-shm",
"data.db-wal",
"logs.db",
"logs.db-shm",
"logs.db-wal",
"aux.db",
"aux.db-shm",
"aux.db-wal",
".gitignore",
}

View File

@ -1,63 +0,0 @@
package core_test
import (
"testing"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestBaseAppRefreshSettings(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
// cleanup all stored settings
if _, err := app.DB().NewQuery("DELETE from _params;").Execute(); err != nil {
t.Fatalf("Failed to delete all test settings: %v", err)
}
// check if the new settings are saved in the db
app.ResetEventCalls()
if err := app.RefreshSettings(); err != nil {
t.Fatalf("Failed to refresh the settings after delete: %v", err)
}
testEventCalls(t, app, map[string]int{
"OnModelBeforeCreate": 1,
"OnModelAfterCreate": 1,
})
param, err := app.Dao().FindParamByKey(models.ParamAppSettings)
if err != nil {
t.Fatalf("Expected new settings to be persisted, got %v", err)
}
// change the db entry and refresh the app settings (ensure that there was no db update)
param.Value = types.JsonRaw([]byte(`{"example": 123}`))
if err := app.Dao().SaveParam(param.Key, param.Value); err != nil {
t.Fatalf("Failed to update the test settings: %v", err)
}
app.ResetEventCalls()
if err := app.RefreshSettings(); err != nil {
t.Fatalf("Failed to refresh the app settings: %v", err)
}
testEventCalls(t, app, nil)
// try to refresh again without doing any changes
app.ResetEventCalls()
if err := app.RefreshSettings(); err != nil {
t.Fatalf("Failed to refresh the app settings without change: %v", err)
}
testEventCalls(t, app, nil)
}
func testEventCalls(t *testing.T, app *tests.TestApp, events map[string]int) {
if len(events) != len(app.EventCalls) {
t.Fatalf("Expected events doesn't match: \n%v, \ngot \n%v", events, app.EventCalls)
}
for name, total := range events {
if v, ok := app.EventCalls[name]; !ok || v != total {
t.Fatalf("Expected events doesn't exist or match: \n%v, \ngot \n%v", events, app.EventCalls)
}
}
}

View File

@ -1,59 +1,56 @@
package core
package core_test
import (
"context"
"database/sql"
"fmt"
"log/slog"
"os"
"strings"
"testing"
"time"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/migrations"
"github.com/pocketbase/pocketbase/migrations/logs"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tools/list"
_ "unsafe"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/logger"
"github.com/pocketbase/pocketbase/tools/mailer"
"github.com/pocketbase/pocketbase/tools/migrate"
"github.com/pocketbase/pocketbase/tools/types"
)
func TestNewBaseApp(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := NewBaseApp(BaseAppConfig{
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
EncryptionEnv: "test_env",
IsDev: true,
})
if app.dataDir != testDataDir {
t.Fatalf("expected dataDir %q, got %q", testDataDir, app.dataDir)
if app.DataDir() != testDataDir {
t.Fatalf("expected DataDir %q, got %q", testDataDir, app.DataDir())
}
if app.encryptionEnv != "test_env" {
t.Fatalf("expected encryptionEnv test_env, got %q", app.dataDir)
if app.EncryptionEnv() != "test_env" {
t.Fatalf("expected EncryptionEnv test_env, got %q", app.EncryptionEnv())
}
if !app.isDev {
t.Fatalf("expected isDev true, got %v", app.isDev)
if !app.IsDev() {
t.Fatalf("expected IsDev true, got %v", app.IsDev())
}
if app.store == nil {
t.Fatal("expected store to be set, got nil")
if app.Store() == nil {
t.Fatal("expected Store to be set, got nil")
}
if app.settings == nil {
t.Fatal("expected settings to be set, got nil")
if app.Settings() == nil {
t.Fatal("expected Settings to be set, got nil")
}
if app.subscriptionsBroker == nil {
t.Fatal("expected subscriptionsBroker to be set, got nil")
if app.SubscriptionsBroker() == nil {
t.Fatal("expected SubscriptionsBroker to be set, got nil")
}
if app.Cron() == nil {
t.Fatal("expected Cron to be set, got nil")
}
}
@ -61,9 +58,8 @@ func TestBaseAppBootstrap(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := NewBaseApp(BaseAppConfig{
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
EncryptionEnv: "pb_test_env",
})
defer app.ResetBootstrapState()
@ -83,72 +79,59 @@ func TestBaseAppBootstrap(t *testing.T) {
t.Fatal("Expected test data directory to be created.")
}
if app.dao == nil {
t.Fatal("Expected app.dao to be initialized, got nil.")
type nilCheck struct {
name string
value any
expectNil bool
}
if app.dao.BeforeCreateFunc == nil {
t.Fatal("Expected app.dao.BeforeCreateFunc to be set, got nil.")
runNilChecks := func(checks []nilCheck) {
for _, check := range checks {
t.Run(check.name, func(t *testing.T) {
isNil := check.value == nil
if isNil != check.expectNil {
t.Fatalf("Expected isNil %v, got %v", check.expectNil, isNil)
}
})
}
}
if app.dao.AfterCreateFunc == nil {
t.Fatal("Expected app.dao.AfterCreateFunc to be set, got nil.")
nilChecksBeforeReset := []nilCheck{
{"[before] concurrentDB", app.DB(), false},
{"[before] nonconcurrentDB", app.NonconcurrentDB(), false},
{"[before] auxConcurrentDB", app.AuxDB(), false},
{"[before] auxNonconcurrentDB", app.AuxNonconcurrentDB(), false},
{"[before] settings", app.Settings(), false},
{"[before] logger", app.Logger(), false},
{"[before] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
}
if app.dao.BeforeUpdateFunc == nil {
t.Fatal("Expected app.dao.BeforeUpdateFunc to be set, got nil.")
}
if app.dao.AfterUpdateFunc == nil {
t.Fatal("Expected app.dao.AfterUpdateFunc to be set, got nil.")
}
if app.dao.BeforeDeleteFunc == nil {
t.Fatal("Expected app.dao.BeforeDeleteFunc to be set, got nil.")
}
if app.dao.AfterDeleteFunc == nil {
t.Fatal("Expected app.dao.AfterDeleteFunc to be set, got nil.")
}
if app.logsDao == nil {
t.Fatal("Expected app.logsDao to be initialized, got nil.")
}
if app.settings == nil {
t.Fatal("Expected app.settings to be initialized, got nil.")
}
if app.logger == nil {
t.Fatal("Expected app.logger to be initialized, got nil.")
}
if _, ok := app.logger.Handler().(*logger.BatchHandler); !ok {
t.Fatal("Expected app.logger handler to be initialized.")
}
runNilChecks(nilChecksBeforeReset)
// reset
if err := app.ResetBootstrapState(); err != nil {
t.Fatal(err)
}
if app.dao != nil {
t.Fatalf("Expected app.dao to be nil, got %v.", app.dao)
nilChecksAfterReset := []nilCheck{
{"[after] concurrentDB", app.DB(), true},
{"[after] nonconcurrentDB", app.NonconcurrentDB(), true},
{"[after] auxConcurrentDB", app.AuxDB(), true},
{"[after] auxNonconcurrentDB", app.AuxNonconcurrentDB(), true},
{"[after] settings", app.Settings(), false},
{"[after] logger", app.Logger(), false},
{"[after] cached collections", app.Store().Get(core.StoreKeyCachedCollections), false},
}
if app.logsDao != nil {
t.Fatalf("Expected app.logsDao to be nil, got %v.", app.logsDao)
}
runNilChecks(nilChecksAfterReset)
}
func TestBaseAppGetters(t *testing.T) {
func TestNewBaseAppIsTransactional(t *testing.T) {
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := NewBaseApp(BaseAppConfig{
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
EncryptionEnv: "pb_test_env",
IsDev: true,
})
defer app.ResetBootstrapState()
@ -156,81 +139,58 @@ func TestBaseAppGetters(t *testing.T) {
t.Fatal(err)
}
if app.dao != app.Dao() {
t.Fatalf("Expected app.Dao %v, got %v", app.Dao(), app.dao)
if app.IsTransactional() {
t.Fatalf("Didn't expect the app to be transactional")
}
if app.dao.ConcurrentDB() != app.DB() {
t.Fatalf("Expected app.DB %v, got %v", app.DB(), app.dao.ConcurrentDB())
app.RunInTransaction(func(txApp core.App) error {
if !txApp.IsTransactional() {
t.Fatalf("Expected the app to be transactional")
}
if app.logsDao != app.LogsDao() {
t.Fatalf("Expected app.LogsDao %v, got %v", app.LogsDao(), app.logsDao)
}
if app.logsDao.ConcurrentDB() != app.LogsDB() {
t.Fatalf("Expected app.LogsDB %v, got %v", app.LogsDB(), app.logsDao.ConcurrentDB())
}
if app.dataDir != app.DataDir() {
t.Fatalf("Expected app.DataDir %v, got %v", app.DataDir(), app.dataDir)
}
if app.encryptionEnv != app.EncryptionEnv() {
t.Fatalf("Expected app.EncryptionEnv %v, got %v", app.EncryptionEnv(), app.encryptionEnv)
}
if app.isDev != app.IsDev() {
t.Fatalf("Expected app.IsDev %v, got %v", app.IsDev(), app.isDev)
}
if app.settings != app.Settings() {
t.Fatalf("Expected app.Settings %v, got %v", app.Settings(), app.settings)
}
if app.store != app.Store() {
t.Fatalf("Expected app.Store %v, got %v", app.Store(), app.store)
}
if app.logger != app.Logger() {
t.Fatalf("Expected app.Logger %v, got %v", app.Logger(), app.logger)
}
if app.subscriptionsBroker != app.SubscriptionsBroker() {
t.Fatalf("Expected app.SubscriptionsBroker %v, got %v", app.SubscriptionsBroker(), app.subscriptionsBroker)
}
if app.onBeforeServe != app.OnBeforeServe() || app.OnBeforeServe() == nil {
t.Fatalf("Getter app.OnBeforeServe does not match or nil (%v vs %v)", app.OnBeforeServe(), app.onBeforeServe)
}
return nil
})
}
func TestBaseAppNewMailClient(t *testing.T) {
app, cleanup, err := initTestBaseApp()
if err != nil {
t.Fatal(err)
}
defer cleanup()
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
EncryptionEnv: "pb_test_env",
})
defer app.ResetBootstrapState()
client1 := app.NewMailClient()
if val, ok := client1.(*mailer.Sendmail); !ok {
t.Fatalf("Expected mailer.Sendmail instance, got %v", val)
m1, ok := client1.(*mailer.Sendmail)
if !ok {
t.Fatalf("Expected mailer.Sendmail instance, got %v", m1)
}
if m1.OnSend() == nil || m1.OnSend().Length() == 0 {
t.Fatal("Expected OnSend hook to be registered")
}
app.Settings().Smtp.Enabled = true
app.Settings().SMTP.Enabled = true
client2 := app.NewMailClient()
if val, ok := client2.(*mailer.SmtpClient); !ok {
t.Fatalf("Expected mailer.SmtpClient instance, got %v", val)
m2, ok := client2.(*mailer.SMTPClient)
if !ok {
t.Fatalf("Expected mailer.SMTPClient instance, got %v", m2)
}
if m2.OnSend() == nil || m2.OnSend().Length() == 0 {
t.Fatal("Expected OnSend hook to be registered")
}
}
func TestBaseAppNewFilesystem(t *testing.T) {
app, cleanup, err := initTestBaseApp()
if err != nil {
t.Fatal(err)
}
defer cleanup()
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
})
defer app.ResetBootstrapState()
// local
local, localErr := app.NewFilesystem()
@ -253,11 +213,13 @@ func TestBaseAppNewFilesystem(t *testing.T) {
}
func TestBaseAppNewBackupsFilesystem(t *testing.T) {
app, cleanup, err := initTestBaseApp()
if err != nil {
t.Fatal(err)
}
defer cleanup()
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
})
defer app.ResetBootstrapState()
// local
local, localErr := app.NewBackupsFilesystem()
@ -280,18 +242,22 @@ func TestBaseAppNewBackupsFilesystem(t *testing.T) {
}
func TestBaseAppLoggerWrites(t *testing.T) {
app, cleanup, err := initTestBaseApp()
if err != nil {
t.Parallel()
app, _ := tests.NewTestApp()
defer app.Cleanup()
// reset
if err := app.DeleteOldLogs(time.Now()); err != nil {
t.Fatal(err)
}
defer cleanup()
const logsThreshold = 200
totalLogs := func(app App, t *testing.T) int {
totalLogs := func(app core.App, t *testing.T) int {
var total int
err := app.LogsDao().LogQuery().Select("count(*)").Row(&total)
err := app.LogQuery().Select("count(*)").Row(&total)
if err != nil {
t.Fatalf("Failed to fetch total logs: %v", err)
}
@ -338,106 +304,9 @@ func TestBaseAppLoggerWrites(t *testing.T) {
t.Fatalf("Expected %d logs, got %d", logsThreshold+1, total)
}
})
t.Run("test batch logs delete", func(t *testing.T) {
app.Settings().Logs.MaxDays = 2
deleteQueries := 0
// reset
app.Store().Set("lastLogsDeletedAt", time.Now())
if err := app.LogsDao().DeleteOldLogs(time.Now()); err != nil {
t.Fatal(err)
}
db := app.LogsDao().NonconcurrentDB().(*dbx.DB)
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
if strings.Contains(sql, "DELETE") {
deleteQueries++
}
}
// trigger batch write (A)
expectedLogs := logsThreshold
for i := 0; i < expectedLogs; i++ {
app.Logger().Error("testA")
}
if total := totalLogs(app, t); total != expectedLogs {
t.Fatalf("[batch write A] Expected %d logs, got %d", expectedLogs, total)
}
// mark the A inserted logs as 2-day expired
aExpiredDate, err := types.ParseDateTime(time.Now().AddDate(0, 0, -2))
if err != nil {
t.Fatal(err)
}
_, err = app.LogsDao().NonconcurrentDB().NewQuery("UPDATE _logs SET created={:date}, updated={:date}").Bind(dbx.Params{
"date": aExpiredDate.String(),
}).Execute()
if err != nil {
t.Fatalf("Failed to mock logs timestamp fields: %v", err)
}
// simulate recently deleted logs
app.Store().Set("lastLogsDeletedAt", time.Now().Add(-5*time.Hour))
// trigger batch write (B)
for i := 0; i < logsThreshold; i++ {
app.Logger().Error("testB")
}
expectedLogs = 2 * logsThreshold
// note: even though there are expired logs it shouldn't perform the delete operation because of the lastLogsDeledAt time
if total := totalLogs(app, t); total != expectedLogs {
t.Fatalf("[batch write B] Expected %d logs, got %d", expectedLogs, total)
}
// mark the B inserted logs as 1-day expired to ensure that they will not be deleted
bExpiredDate, err := types.ParseDateTime(time.Now().AddDate(0, 0, -1))
if err != nil {
t.Fatal(err)
}
_, err = app.LogsDao().NonconcurrentDB().NewQuery("UPDATE _logs SET created={:date}, updated={:date} where message='testB'").Bind(dbx.Params{
"date": bExpiredDate.String(),
}).Execute()
if err != nil {
t.Fatalf("Failed to mock logs timestamp fields: %v", err)
}
// should trigger delete on the next batch write
app.Store().Set("lastLogsDeletedAt", time.Now().Add(-6*time.Hour))
// trigger batch write (C)
for i := 0; i < logsThreshold; i++ {
app.Logger().Error("testC")
}
expectedLogs = 2 * logsThreshold // only B and C logs should remain
if total := totalLogs(app, t); total != expectedLogs {
t.Fatalf("[batch write C] Expected %d logs, got %d", expectedLogs, total)
}
if deleteQueries != 1 {
t.Fatalf("Expected DeleteOldLogs to be called %d, got %d", 1, deleteQueries)
}
})
}
func TestBaseAppRefreshSettingsLoggerMinLevelEnabled(t *testing.T) {
app, cleanup, err := initTestBaseApp()
if err != nil {
t.Fatal(err)
}
defer cleanup()
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
if !ok {
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
}
scenarios := []struct {
name string
isDev bool
@ -469,173 +338,35 @@ func TestBaseAppRefreshSettingsLoggerMinLevelEnabled(t *testing.T) {
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
app.isDev = s.isDev
const testDataDir = "./pb_base_app_test_data_dir/"
defer os.RemoveAll(testDataDir)
app := core.NewBaseApp(core.BaseAppConfig{
DataDir: testDataDir,
IsDev: s.isDev,
})
defer app.ResetBootstrapState()
if err := app.Bootstrap(); err != nil {
t.Fatal(err)
}
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
if !ok {
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
}
app.Settings().Logs.MinLevel = s.level
if err := app.Dao().SaveSettings(app.Settings()); err != nil {
if err := app.Save(app.Settings()); err != nil {
t.Fatalf("Failed to save settings: %v", err)
}
if err := app.RefreshSettings(); err != nil {
t.Fatalf("Failed to refresh app settings: %v", err)
}
for level, enabled := range s.expectations {
if v := handler.Enabled(nil, slog.Level(level)); v != enabled {
if v := handler.Enabled(context.Background(), slog.Level(level)); v != enabled {
t.Fatalf("Expected level %d Enabled() to be %v, got %v", level, enabled, v)
}
}
})
}
}
func TestBaseAppLoggerLevelDevPrint(t *testing.T) {
app, cleanup, err := initTestBaseApp()
if err != nil {
t.Fatal(err)
}
defer cleanup()
testLogLevel := 4
app.Settings().Logs.MinLevel = testLogLevel
if err := app.Dao().SaveSettings(app.Settings()); err != nil {
t.Fatal(err)
}
scenarios := []struct {
name string
isDev bool
levels []int
printedLevels []int
persistedLevels []int
}{
{
"dev mode",
true,
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
[]int{testLogLevel, testLogLevel + 1},
},
{
"nondev mode",
false,
[]int{testLogLevel - 1, testLogLevel, testLogLevel + 1},
[]int{},
[]int{testLogLevel, testLogLevel + 1},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
var printedLevels []int
var persistedLevels []int
app.isDev = s.isDev
// trigger slog handler min level refresh
if err := app.RefreshSettings(); err != nil {
t.Fatal(err)
}
// track printed logs
originalPrintLog := printLog
defer func() {
printLog = originalPrintLog
}()
printLog = func(log *logger.Log) {
printedLevels = append(printedLevels, int(log.Level))
}
// track persisted logs
app.LogsDao().AfterCreateFunc = func(eventDao *daos.Dao, m models.Model) error {
l, ok := m.(*models.Log)
if ok {
persistedLevels = append(persistedLevels, l.Level)
}
return nil
}
// write and persist logs
for _, l := range s.levels {
app.Logger().Log(nil, slog.Level(l), "test")
}
handler, ok := app.Logger().Handler().(*logger.BatchHandler)
if !ok {
t.Fatalf("Expected BatchHandler, got %v", app.Logger().Handler())
}
if err := handler.WriteAll(nil); err != nil {
t.Fatalf("Failed to write all logs: %v", err)
}
// check persisted log levels
if len(s.persistedLevels) != len(persistedLevels) {
t.Fatalf("Expected persisted levels \n%v\ngot\n%v", s.persistedLevels, persistedLevels)
}
for _, l := range persistedLevels {
if !list.ExistInSlice(l, s.persistedLevels) {
t.Fatalf("Missing expected persisted level %v in %v", l, persistedLevels)
}
}
// check printed log levels
if len(s.printedLevels) != len(printedLevels) {
t.Fatalf("Expected printed levels \n%v\ngot\n%v", s.printedLevels, printedLevels)
}
for _, l := range printedLevels {
if !list.ExistInSlice(l, s.printedLevels) {
t.Fatalf("Missing expected printed level %v in %v", l, printedLevels)
}
}
})
}
}
// -------------------------------------------------------------------
// note: make sure to call `defer cleanup()` when the app is no longer needed.
func initTestBaseApp() (app *BaseApp, cleanup func(), err error) {
testDataDir, err := os.MkdirTemp("", "test_base_app")
if err != nil {
return nil, nil, err
}
cleanup = func() {
os.RemoveAll(testDataDir)
}
app = NewBaseApp(BaseAppConfig{
DataDir: testDataDir,
})
initErr := func() error {
if err := app.Bootstrap(); err != nil {
return fmt.Errorf("bootstrap error: %w", err)
}
logsRunner, err := migrate.NewRunner(app.LogsDB(), logs.LogsMigrations)
if err != nil {
return fmt.Errorf("logsRunner error: %w", err)
}
if _, err := logsRunner.Up(); err != nil {
return fmt.Errorf("logsRunner migrations execution error: %w", err)
}
dataRunner, err := migrate.NewRunner(app.DB(), migrations.AppMigrations)
if err != nil {
return fmt.Errorf("logsRunner error: %w", err)
}
if _, err := dataRunner.Up(); err != nil {
return fmt.Errorf("dataRunner migrations execution error: %w", err)
}
return nil
}()
if initErr != nil {
cleanup()
return nil, nil, initErr
}
return app, cleanup, nil
}

194
core/collection_import.go Normal file
View File

@ -0,0 +1,194 @@
package core
import (
"cmp"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"slices"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/spf13/cast"
)
// ImportCollectionsByMarshaledJSON is the same as [ImportCollections]
// but accept marshaled json array as import data (usually used for the autogenerated snapshots).
func (app *BaseApp) ImportCollectionsByMarshaledJSON(rawSliceOfMaps []byte, deleteMissing bool) error {
data := []map[string]any{}
err := json.Unmarshal(rawSliceOfMaps, &data)
if err != nil {
return err
}
return app.ImportCollections(data, deleteMissing)
}
// ImportCollections imports the provided collections data in a single transaction.
//
// For existing matching collections, the imported data is unmarshaled on top of the existing model.
//
// NB! If deleteMissing is true, ALL NON-SYSTEM COLLECTIONS AND SCHEMA FIELDS,
// that are not present in the imported configuration, WILL BE DELETED
// (this includes their related records data).
func (app *BaseApp) ImportCollections(toImport []map[string]any, deleteMissing bool) error {
if len(toImport) == 0 {
// prevent accidentally deleting all collections
return errors.New("no collections to import")
}
importedCollections := make([]*Collection, len(toImport))
mappedImported := make(map[string]*Collection, len(toImport))
// normalize imported collections data to ensure that all
// collection fields are present and properly initialized
for i, data := range toImport {
var imported *Collection
identifier := cast.ToString(data["id"])
if identifier == "" {
identifier = cast.ToString(data["name"])
}
existing, err := app.FindCollectionByNameOrId(identifier)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
if existing != nil {
// refetch for deep copy
imported, err = app.FindCollectionByNameOrId(existing.Id)
if err != nil {
return err
}
// ensure that the fields will be cleared
if data["fields"] == nil && deleteMissing {
data["fields"] = []map[string]any{}
}
rawData, err := json.Marshal(data)
if err != nil {
return err
}
// load the imported data
err = json.Unmarshal(rawData, imported)
if err != nil {
return err
}
// extend with the existing fields if necessary
for _, f := range existing.Fields {
if !f.GetSystem() && deleteMissing {
continue
}
if imported.Fields.GetById(f.GetId()) == nil {
imported.Fields.Add(f)
}
}
} else {
imported = &Collection{}
rawData, err := json.Marshal(data)
if err != nil {
return err
}
// load the imported data
err = json.Unmarshal(rawData, imported)
if err != nil {
return err
}
}
imported.IntegrityChecks(false)
importedCollections[i] = imported
mappedImported[imported.Id] = imported
}
// reorder views last since the view query could depend on some of the other collections
slices.SortStableFunc(importedCollections, func(a, b *Collection) int {
cmpA := -1
if a.IsView() {
cmpA = 1
}
cmpB := -1
if b.IsView() {
cmpB = 1
}
res := cmp.Compare(cmpA, cmpB)
if res == 0 {
res = a.Created.Compare(b.Created)
if res == 0 {
res = a.Updated.Compare(b.Updated)
}
}
return res
})
return app.RunInTransaction(func(txApp App) error {
existingCollections := []*Collection{}
if err := txApp.CollectionQuery().OrderBy("updated ASC").All(&existingCollections); err != nil {
return err
}
mappedExisting := make(map[string]*Collection, len(existingCollections))
for _, existing := range existingCollections {
existing.IntegrityChecks(false)
mappedExisting[existing.Id] = existing
}
// delete old collections not available in the new configuration
// (before saving the imports in case a deleted collection name is being reused)
if deleteMissing {
for _, existing := range existingCollections {
if mappedImported[existing.Id] != nil || existing.System {
continue // exist or system
}
// delete collection
if err := txApp.Delete(existing); err != nil {
return err
}
}
}
// upsert imported collections
for _, imported := range importedCollections {
if err := txApp.SaveNoValidate(imported); err != nil {
return fmt.Errorf("failed to save collection %q: %w", imported.Name, err)
}
}
// run validations
for _, imported := range importedCollections {
original := mappedExisting[imported.Id]
if original == nil {
original = imported
}
validator := newCollectionValidator(
context.Background(),
txApp,
imported,
original,
)
if err := validator.run(); err != nil {
// serialize the validation error(s)
serializedErr, _ := json.MarshalIndent(err, "", " ")
return validation.Errors{"collections": validation.NewError(
"validation_collections_import_failure",
fmt.Sprintf("Data validations failed for collection %q (%s):\n%s", imported.Name, imported.Id, serializedErr),
)}
}
}
return nil
})
}

View File

@ -0,0 +1,476 @@
package core_test
import (
"encoding/json"
"strings"
"testing"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestImportCollections(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
var regularCollections []*core.Collection
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(&regularCollections)
if err != nil {
t.Fatal(err)
}
var systemCollections []*core.Collection
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
if err != nil {
t.Fatal(err)
}
totalRegularCollections := len(regularCollections)
totalSystemCollections := len(systemCollections)
totalCollections := totalRegularCollections + totalSystemCollections
scenarios := []struct {
name string
data []map[string]any
deleteMissing bool
expectError bool
expectCollectionsCount int
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
}{
{
name: "empty collections",
data: []map[string]any{},
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "minimal collection import (with missing system fields)",
data: []map[string]any{
{"name": "import_test1", "type": "auth"},
{
"name": "import_test2", "fields": []map[string]any{
{"name": "test", "type": "text"},
},
},
},
deleteMissing: false,
expectError: false,
expectCollectionsCount: totalCollections + 2,
},
{
name: "minimal collection import (trigger collection model validations)",
data: []map[string]any{
{"name": ""},
{
"name": "import_test2", "fields": []map[string]any{
{"name": "test", "type": "text"},
},
},
},
deleteMissing: false,
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "minimal collection import (trigger field settings validation)",
data: []map[string]any{
{"name": "import_test", "fields": []map[string]any{{"name": "test", "type": "text", "min": -1}}},
},
deleteMissing: false,
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "new + update + delete (system collections delete should be ignored)",
data: []map[string]any{
{
"id": "wsmn24bux7wo113",
"name": "demo",
"fields": []map[string]any{
{
"id": "_2hlxbmp",
"name": "title",
"type": "text",
"system": false,
"required": true,
"min": 3,
"max": nil,
"pattern": "",
},
},
"indexes": []string{},
},
{
"name": "import1",
"fields": []map[string]any{
{
"name": "active",
"type": "bool",
},
},
},
},
deleteMissing: true,
expectError: false,
expectCollectionsCount: totalSystemCollections + 2,
},
{
name: "test with deleteMissing: false",
data: []map[string]any{
{
// "id": "wsmn24bux7wo113", // test update with only name as identifier
"name": "demo1",
"fields": []map[string]any{
{
"id": "_2hlxbmp",
"name": "title",
"type": "text",
"system": false,
"required": true,
"min": 3,
"max": nil,
"pattern": "",
},
{
"id": "_2hlxbmp",
"name": "field_with_duplicate_id",
"type": "text",
"system": false,
"required": true,
"unique": false,
"min": 4,
"max": nil,
"pattern": "",
},
{
"id": "abcd_import",
"name": "new_field",
"type": "text",
},
},
},
{
"name": "new_import",
"fields": []map[string]any{
{
"id": "abcd_import",
"name": "active",
"type": "bool",
},
},
},
},
deleteMissing: false,
expectError: false,
expectCollectionsCount: totalCollections + 1,
afterTestFunc: func(testApp *tests.TestApp, resultCollections []*core.Collection) {
expectedCollectionFields := map[string]int{
core.CollectionNameAuthOrigins: 6,
"nologin": 10,
"demo1": 18,
"demo2": 5,
"demo3": 5,
"demo4": 16,
"demo5": 9,
"new_import": 2,
}
for name, expectedCount := range expectedCollectionFields {
collection, err := testApp.FindCollectionByNameOrId(name)
if err != nil {
t.Fatal(err)
}
if totalFields := len(collection.Fields); totalFields != expectedCount {
t.Errorf("Expected %d %q fields, got %d", expectedCount, collection.Name, totalFields)
}
}
},
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
err := testApp.ImportCollections(s.data, s.deleteMissing)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
}
// check collections count
collections := []*core.Collection{}
if err := testApp.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
if len(collections) != s.expectCollectionsCount {
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
}
if s.afterTestFunc != nil {
s.afterTestFunc(testApp, collections)
}
})
}
}
func TestImportCollectionsByMarshaledJSON(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
var regularCollections []*core.Collection
err := testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": false}).All(&regularCollections)
if err != nil {
t.Fatal(err)
}
var systemCollections []*core.Collection
err = testApp.CollectionQuery().AndWhere(dbx.HashExp{"system": true}).All(&systemCollections)
if err != nil {
t.Fatal(err)
}
totalRegularCollections := len(regularCollections)
totalSystemCollections := len(systemCollections)
totalCollections := totalRegularCollections + totalSystemCollections
scenarios := []struct {
name string
data string
deleteMissing bool
expectError bool
expectCollectionsCount int
afterTestFunc func(testApp *tests.TestApp, resultCollections []*core.Collection)
}{
{
name: "invalid json array",
data: `{"test":123}`,
expectError: true,
expectCollectionsCount: totalCollections,
},
{
name: "new + update + delete (system collections delete should be ignored)",
data: `[
{
"id": "wsmn24bux7wo113",
"name": "demo",
"fields": [
{
"id": "_2hlxbmp",
"name": "title",
"type": "text",
"system": false,
"required": true,
"min": 3,
"max": null,
"pattern": ""
}
],
"indexes": []
},
{
"name": "import1",
"fields": [
{
"name": "active",
"type": "bool"
}
]
}
]`,
deleteMissing: true,
expectError: false,
expectCollectionsCount: totalSystemCollections + 2,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
err := testApp.ImportCollectionsByMarshaledJSON([]byte(s.data), s.deleteMissing)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr to be %v, got %v (%v)", s.expectError, hasErr, err)
}
// check collections count
collections := []*core.Collection{}
if err := testApp.CollectionQuery().All(&collections); err != nil {
t.Fatal(err)
}
if len(collections) != s.expectCollectionsCount {
t.Fatalf("Expected %d collections, got %d", s.expectCollectionsCount, len(collections))
}
if s.afterTestFunc != nil {
s.afterTestFunc(testApp, collections)
}
})
}
}
func TestImportCollectionsUpdateRules(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
data map[string]any
deleteMissing bool
}{
{
"extend existing by name (without deleteMissing)",
map[string]any{"name": "clients", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
false,
},
{
"extend existing by id (without deleteMissing)",
map[string]any{"id": "v851q4r790rhknl", "authToken": map[string]any{"duration": 100}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
false,
},
{
"extend with delete missing",
map[string]any{
"id": "v851q4r790rhknl",
"authToken": map[string]any{"duration": 100},
"fields": []map[string]any{{"name": "test", "type": "text"}},
"passwordAuth": map[string]any{"identityFields": []string{"email"}},
"indexes": []string{
// min required system fields indexes
"CREATE UNIQUE INDEX `_v851q4r790rhknl_email_idx` ON `clients` (email) WHERE email != ''",
"CREATE UNIQUE INDEX `_v851q4r790rhknl_tokenKey_idx` ON `clients` (tokenKey)",
},
},
true,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
beforeCollection, err := testApp.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
err = testApp.ImportCollections([]map[string]any{s.data}, s.deleteMissing)
if err != nil {
t.Fatal(err)
}
afterCollection, err := testApp.FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
if afterCollection.AuthToken.Duration != 100 {
t.Fatalf("Expected AuthToken duration to be %d, got %d", 100, afterCollection.AuthToken.Duration)
}
if beforeCollection.AuthToken.Secret != afterCollection.AuthToken.Secret {
t.Fatalf("Expected AuthToken secrets to remain the same, got\n%q\nVS\n%q", beforeCollection.AuthToken.Secret, afterCollection.AuthToken.Secret)
}
if beforeCollection.Name != afterCollection.Name {
t.Fatalf("Expected Name to remain the same, got\n%q\nVS\n%q", beforeCollection.Name, afterCollection.Name)
}
if beforeCollection.Id != afterCollection.Id {
t.Fatalf("Expected Id to remain the same, got\n%q\nVS\n%q", beforeCollection.Id, afterCollection.Id)
}
if !s.deleteMissing {
totalExpectedFields := len(beforeCollection.Fields) + 1
if v := len(afterCollection.Fields); v != totalExpectedFields {
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
}
if afterCollection.Fields.GetByName("test") == nil {
t.Fatalf("Missing new field %q", "test")
}
// ensure that the old fields still exist
oldFields := beforeCollection.Fields.FieldNames()
for _, name := range oldFields {
if afterCollection.Fields.GetByName(name) == nil {
t.Fatalf("Missing expected old field %q", name)
}
}
} else {
totalExpectedFields := 1
for _, f := range beforeCollection.Fields {
if f.GetSystem() {
totalExpectedFields++
}
}
if v := len(afterCollection.Fields); v != totalExpectedFields {
t.Fatalf("Expected %d total fields, got %d", totalExpectedFields, v)
}
if afterCollection.Fields.GetByName("test") == nil {
t.Fatalf("Missing new field %q", "test")
}
// ensure that the old system fields still exist
for _, f := range beforeCollection.Fields {
if f.GetSystem() && afterCollection.Fields.GetByName(f.GetName()) == nil {
t.Fatalf("Missing expected old field %q", f.GetName())
}
}
}
})
}
}
func TestImportCollectionsCreateRules(t *testing.T) {
t.Parallel()
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
err := testApp.ImportCollections([]map[string]any{
{"name": "new_test", "type": "auth", "authToken": map[string]any{"duration": 123}, "fields": []map[string]any{{"name": "test", "type": "text"}}},
}, false)
if err != nil {
t.Fatal(err)
}
collection, err := testApp.FindCollectionByNameOrId("new_test")
if err != nil {
t.Fatal(err)
}
raw, err := json.Marshal(collection)
if err != nil {
t.Fatal(err)
}
rawStr := string(raw)
expectedParts := []string{
`"name":"new_test"`,
`"fields":[`,
`"name":"id"`,
`"name":"email"`,
`"name":"tokenKey"`,
`"name":"password"`,
`"name":"test"`,
`"indexes":[`,
`CREATE UNIQUE INDEX`,
`"duration":123`,
}
for _, part := range expectedParts {
if !strings.Contains(rawStr, part) {
t.Errorf("Missing %q in\n%s", part, rawStr)
}
}
}

949
core/collection_model.go Normal file
View File

@ -0,0 +1,949 @@
package core
import (
"encoding/json"
"fmt"
"strings"
"github.com/pocketbase/pocketbase/tools/dbutils"
"github.com/pocketbase/pocketbase/tools/hook"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
var (
_ Model = (*Collection)(nil)
_ DBExporter = (*Collection)(nil)
_ FilesManager = (*Collection)(nil)
)
const (
CollectionTypeBase = "base"
CollectionTypeAuth = "auth"
CollectionTypeView = "view"
)
const systemHookIdCollection = "__pbCollectionSystemHook__"
func (app *BaseApp) registerCollectionHooks() {
app.OnModelValidate().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionValidate().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelCreate().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionCreate().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelCreateExecute().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionCreateExecute().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterCreateSuccess().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionAfterCreateSuccess().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterCreateError().Bind(&hook.Handler[*ModelErrorEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelErrorEvent) error {
if ce, ok := newCollectionErrorEventFromModelErrorEvent(me); ok {
return me.App.OnCollectionAfterCreateError().Trigger(ce, func(ce *CollectionErrorEvent) error {
syncModelErrorEventWithCollectionErrorEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelUpdate().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionUpdate().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelUpdateExecute().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionUpdateExecute().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterUpdateSuccess().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionAfterUpdateSuccess().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterUpdateError().Bind(&hook.Handler[*ModelErrorEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelErrorEvent) error {
if ce, ok := newCollectionErrorEventFromModelErrorEvent(me); ok {
return me.App.OnCollectionAfterUpdateError().Trigger(ce, func(ce *CollectionErrorEvent) error {
syncModelErrorEventWithCollectionErrorEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelDelete().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionDelete().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelDeleteExecute().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionDeleteExecute().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterDeleteSuccess().Bind(&hook.Handler[*ModelEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelEvent) error {
if ce, ok := newCollectionEventFromModelEvent(me); ok {
return me.App.OnCollectionAfterDeleteSuccess().Trigger(ce, func(ce *CollectionEvent) error {
syncModelEventWithCollectionEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
app.OnModelAfterDeleteError().Bind(&hook.Handler[*ModelErrorEvent]{
Id: systemHookIdCollection,
Func: func(me *ModelErrorEvent) error {
if ce, ok := newCollectionErrorEventFromModelErrorEvent(me); ok {
return me.App.OnCollectionAfterDeleteError().Trigger(ce, func(ce *CollectionErrorEvent) error {
syncModelErrorEventWithCollectionErrorEvent(me, ce)
return me.Next()
})
}
return me.Next()
},
Priority: -99,
})
// --------------------------------------------------------------
app.OnCollectionValidate().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionValidate,
Priority: 99,
})
app.OnCollectionCreate().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionSave,
Priority: -99,
})
app.OnCollectionUpdate().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionSave,
Priority: -99,
})
app.OnCollectionCreateExecute().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionSaveExecute,
// execute as latest as possible, aka. closer to the db action to minimize the transactions lock time
Priority: 99,
})
app.OnCollectionUpdateExecute().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionSaveExecute,
Priority: 99, // execute as latest as possible, aka. closer to the db action to minimize the transactions lock time
})
app.OnCollectionDeleteExecute().Bind(&hook.Handler[*CollectionEvent]{
Id: systemHookIdCollection,
Func: onCollectionDeleteExecute,
Priority: 99, // execute as latest as possible, aka. closer to the db action to minimize the transactions lock time
})
// reload cache on failure
// ---
onErrorReloadCachedCollections := func(ce *CollectionErrorEvent) error {
if err := ce.App.ReloadCachedCollections(); err != nil {
ce.App.Logger().Warn("Failed to reload collections cache", "error", err)
}
return ce.Next()
}
app.OnCollectionAfterCreateError().Bind(&hook.Handler[*CollectionErrorEvent]{
Id: systemHookIdCollection,
Func: onErrorReloadCachedCollections,
Priority: -99,
})
app.OnCollectionAfterUpdateError().Bind(&hook.Handler[*CollectionErrorEvent]{
Id: systemHookIdCollection,
Func: onErrorReloadCachedCollections,
Priority: -99,
})
app.OnCollectionAfterDeleteError().Bind(&hook.Handler[*CollectionErrorEvent]{
Id: systemHookIdCollection,
Func: onErrorReloadCachedCollections,
Priority: -99,
})
// ---
app.OnBootstrap().Bind(&hook.Handler[*BootstrapEvent]{
Id: systemHookIdCollection,
Func: func(e *BootstrapEvent) error {
if err := e.Next(); err != nil {
return err
}
if err := e.App.ReloadCachedCollections(); err != nil {
return fmt.Errorf("failed to load collections cache: %w", err)
}
return nil
},
Priority: 99, // execute as latest as possible
})
}
// @todo experiment eventually replacing the rules *string with a struct?
type baseCollection struct {
BaseModel
disableIntegrityChecks bool
ListRule *string `db:"listRule" json:"listRule" form:"listRule"`
ViewRule *string `db:"viewRule" json:"viewRule" form:"viewRule"`
CreateRule *string `db:"createRule" json:"createRule" form:"createRule"`
UpdateRule *string `db:"updateRule" json:"updateRule" form:"updateRule"`
DeleteRule *string `db:"deleteRule" json:"deleteRule" form:"deleteRule"`
// RawOptions represents the raw serialized collection option loaded from the DB.
// NB! This field shouldn't be modified manually. It is automatically updated
// with the collection type specific option before save.
RawOptions types.JSONRaw `db:"options" json:"-" xml:"-" form:"-"`
Name string `db:"name" json:"name" form:"name"`
Type string `db:"type" json:"type" form:"type"`
Fields FieldsList `db:"fields" json:"fields" form:"fields"`
Indexes types.JSONArray[string] `db:"indexes" json:"indexes" form:"indexes"`
System bool `db:"system" json:"system" form:"system"`
Created types.DateTime `db:"created" json:"created"`
Updated types.DateTime `db:"updated" json:"updated"`
}
// Collection defines the table, fields and various options related to a set of records.
type Collection struct {
baseCollection
collectionAuthOptions
collectionViewOptions
}
// NewCollection initializes and returns a new Collection model with the specified type and name.
func NewCollection(typ, name string) *Collection {
switch typ {
case CollectionTypeAuth:
return NewAuthCollection(name)
case CollectionTypeView:
return NewViewCollection(name)
default:
return NewBaseCollection(name)
}
}
// NewBaseCollection initializes and returns a new "base" Collection model.
func NewBaseCollection(name string) *Collection {
m := &Collection{}
m.Name = name
m.Type = CollectionTypeBase
m.initDefaultId()
m.initDefaultFields()
return m
}
// NewViewCollection initializes and returns a new "view" Collection model.
func NewViewCollection(name string) *Collection {
m := &Collection{}
m.Name = name
m.Type = CollectionTypeView
m.initDefaultId()
m.initDefaultFields()
return m
}
// NewAuthCollection initializes and returns a new "auth" Collection model.
func NewAuthCollection(name string) *Collection {
m := &Collection{}
m.Name = name
m.Type = CollectionTypeAuth
m.initDefaultId()
m.initDefaultFields()
m.setDefaultAuthOptions()
return m
}
// TableName returns the Collection model SQL table name.
func (m *Collection) TableName() string {
return "_collections"
}
// BaseFilesPath returns the storage dir path used by the collection.
func (m *Collection) BaseFilesPath() string {
return m.Id
}
// IsBase checks if the current collection has "base" type.
func (m *Collection) IsBase() bool {
return m.Type == CollectionTypeBase
}
// IsAuth checks if the current collection has "auth" type.
func (m *Collection) IsAuth() bool {
return m.Type == CollectionTypeAuth
}
// IsView checks if the current collection has "view" type.
func (m *Collection) IsView() bool {
return m.Type == CollectionTypeView
}
// IntegrityChecks toggles the current collection integrity checks (ex. checking references on delete).
func (m *Collection) IntegrityChecks(enable bool) {
m.disableIntegrityChecks = !enable
}
// PostScan implements the [dbx.PostScanner] interface to auto unmarshal
// the raw serialized options into the concrete type specific fields.
func (m *Collection) PostScan() error {
if err := m.BaseModel.PostScan(); err != nil {
return err
}
return m.unmarshalRawOptions()
}
func (m *Collection) unmarshalRawOptions() error {
raw, err := m.RawOptions.MarshalJSON()
if err != nil {
return nil
}
switch m.Type {
case CollectionTypeView:
return json.Unmarshal(raw, &m.collectionViewOptions)
case CollectionTypeAuth:
return json.Unmarshal(raw, &m.collectionAuthOptions)
}
return nil
}
// UnmarshalJSON implements the [json.Unmarshaler] interface.
//
// For new/"blank" Collection models it replaces the model with a factory
// instance and then unmarshal the provided data one on top of it.
func (m *Collection) UnmarshalJSON(b []byte) error {
type alias *Collection
// initialize the default fields
// (e.g. in case the collection was NOT created using the designated factories)
if m.IsNew() && m.Type == "" {
minimal := &struct {
Type string `json:"type"`
Name string `json:"name"`
}{}
if err := json.Unmarshal(b, minimal); err != nil {
return err
}
blank := NewCollection(minimal.Type, minimal.Name)
*m = *blank
}
return json.Unmarshal(b, alias(m))
}
// MarshalJSON implements the [json.Marshaler] interface.
//
// Note that non-type related fields are ignored from the serialization
// (ex. for "view" colections the "auth" fields are skipped).
func (m Collection) MarshalJSON() ([]byte, error) {
switch m.Type {
case CollectionTypeView:
return json.Marshal(struct {
baseCollection
collectionViewOptions
}{m.baseCollection, m.collectionViewOptions})
case CollectionTypeAuth:
alias := struct {
baseCollection
collectionAuthOptions
}{m.baseCollection, m.collectionAuthOptions}
// ensure that it is always returned as array
if alias.OAuth2.Providers == nil {
alias.OAuth2.Providers = []OAuth2ProviderConfig{}
}
// hide secret keys from the serialization
alias.AuthToken.Secret = ""
alias.FileToken.Secret = ""
alias.PasswordResetToken.Secret = ""
alias.EmailChangeToken.Secret = ""
alias.VerificationToken.Secret = ""
for i := range alias.OAuth2.Providers {
alias.OAuth2.Providers[i].ClientSecret = ""
}
return json.Marshal(alias)
default:
return json.Marshal(m.baseCollection)
}
}
// String returns a string representation of the current collection.
func (m Collection) String() string {
raw, _ := json.Marshal(m)
return string(raw)
}
// DBExport prepares and exports the current collection data for db persistence.
func (m *Collection) DBExport(app App) (map[string]any, error) {
result := map[string]any{
"id": m.Id,
"type": m.Type,
"listRule": m.ListRule,
"viewRule": m.ViewRule,
"createRule": m.CreateRule,
"updateRule": m.UpdateRule,
"deleteRule": m.DeleteRule,
"name": m.Name,
"fields": m.Fields,
"indexes": m.Indexes,
"system": m.System,
"created": m.Created,
"updated": m.Updated,
"options": `{}`,
}
switch m.Type {
case CollectionTypeView:
if raw, err := types.ParseJSONRaw(m.collectionViewOptions); err == nil {
result["options"] = raw
} else {
return nil, err
}
case CollectionTypeAuth:
if raw, err := types.ParseJSONRaw(m.collectionAuthOptions); err == nil {
result["options"] = raw
} else {
return nil, err
}
}
return result, nil
}
// GetIndex returns s single Collection index expression by its name.
func (m *Collection) GetIndex(name string) string {
for _, idx := range m.Indexes {
if strings.EqualFold(dbutils.ParseIndex(idx).IndexName, name) {
return idx
}
}
return ""
}
// AddIndex adds a new index into the current collection.
//
// If the collection has an existing index matching the new name it will be replaced with the new one.
func (m *Collection) AddIndex(name string, unique bool, columnsExpr string, optWhereExpr string) {
m.RemoveIndex(name)
var idx strings.Builder
idx.WriteString("CREATE ")
if unique {
idx.WriteString("UNIQUE ")
}
idx.WriteString("INDEX `")
idx.WriteString(name)
idx.WriteString("` ")
idx.WriteString("ON `")
idx.WriteString(m.Name)
idx.WriteString("` (")
idx.WriteString(columnsExpr)
idx.WriteString(")")
if optWhereExpr != "" {
idx.WriteString(" WHERE ")
idx.WriteString(optWhereExpr)
}
m.Indexes = append(m.Indexes, idx.String())
}
// RemoveIndex removes a single index with the specified name from the current collection.
func (m *Collection) RemoveIndex(name string) {
for i, idx := range m.Indexes {
if strings.EqualFold(dbutils.ParseIndex(idx).IndexName, name) {
m.Indexes = append(m.Indexes[:i], m.Indexes[i+1:]...)
return
}
}
}
// delete hook
// -------------------------------------------------------------------
func onCollectionDeleteExecute(e *CollectionEvent) error {
if e.Collection.System {
return fmt.Errorf("[%s] system collections cannot be deleted", e.Collection.Name)
}
defer func() {
if err := e.App.ReloadCachedCollections(); err != nil {
e.App.Logger().Warn("Failed to reload collections cache", "error", err)
}
}()
if !e.Collection.disableIntegrityChecks {
// ensure that there aren't any existing references.
// note: the select is outside of the transaction to prevent SQLITE_LOCKED error when mixing read&write in a single transaction
references, err := e.App.FindCollectionReferences(e.Collection, e.Collection.Id)
if err != nil {
return fmt.Errorf("[%s] failed to check collection references: %w", e.Collection.Name, err)
}
if total := len(references); total > 0 {
names := make([]string, 0, len(references))
for ref := range references {
names = append(names, ref.Name)
}
return fmt.Errorf("[%s] failed to delete due to existing relation references: %s", e.Collection.Name, strings.Join(names, ", "))
}
}
originalApp := e.App
txErr := e.App.RunInTransaction(func(txApp App) error {
e.App = txApp
// delete the related view or records table
if e.Collection.IsView() {
if err := txApp.DeleteView(e.Collection.Name); err != nil {
return err
}
} else {
if err := txApp.DeleteTable(e.Collection.Name); err != nil {
return err
}
}
if !e.Collection.disableIntegrityChecks {
// trigger views resave to check for dependencies
if err := resaveViewsWithChangedFields(txApp, e.Collection.Id); err != nil {
return fmt.Errorf("[%s] failed to delete due to existing view dependency: %w", e.Collection.Name, err)
}
}
// delete
return e.Next()
})
e.App = originalApp
return txErr
}
// save hook
// -------------------------------------------------------------------
func (c *Collection) initDefaultId() {
if c.Id == "" && c.Name != "" {
c.Id = "_pbc_" + crc32Checksum(c.Name)
}
}
func (c *Collection) savePrepare() error {
if c.Type == "" {
c.Type = CollectionTypeBase
}
if c.IsNew() {
c.initDefaultId()
c.Created = types.NowDateTime()
}
c.Updated = types.NowDateTime()
// recreate the fields list to ensure that all normalizations
// like default field id are applied
c.Fields = NewFieldsList(c.Fields...)
c.initDefaultFields()
if c.IsAuth() {
c.unsetMissingOAuth2MappedFields()
}
return nil
}
func onCollectionSave(e *CollectionEvent) error {
if err := e.Collection.savePrepare(); err != nil {
return err
}
return e.Next()
}
func onCollectionSaveExecute(e *CollectionEvent) error {
defer func() {
if err := e.App.ReloadCachedCollections(); err != nil {
e.App.Logger().Warn("Failed to reload collections cache", "error", err)
}
}()
var oldCollection *Collection
if !e.Collection.IsNew() {
var err error
oldCollection, err = e.App.FindCachedCollectionByNameOrId(e.Collection.Id)
if err != nil {
return err
}
// invalidate previously issued auth tokens on auth rule change
if oldCollection.AuthRule != e.Collection.AuthRule &&
cast.ToString(oldCollection.AuthRule) != cast.ToString(e.Collection.AuthRule) {
e.Collection.AuthToken.Secret = security.RandomString(50)
}
}
originalApp := e.App
txErr := e.App.RunInTransaction(func(txApp App) error {
e.App = txApp
isView := e.Collection.IsView()
// ensures that the view collection shema is properly loaded
if isView {
query := e.Collection.ViewQuery
// generate collection fields list from the query
viewFields, err := e.App.CreateViewFields(query)
if err != nil {
return err
}
// delete old renamed view
if oldCollection != nil {
if err := e.App.DeleteView(oldCollection.Name); err != nil {
return err
}
}
// wrap view query if necessary
query, err = normalizeViewQueryId(e.App, query)
if err != nil {
return fmt.Errorf("failed to normalize view query id: %w", err)
}
// (re)create the view
if err := e.App.SaveView(e.Collection.Name, query); err != nil {
return err
}
// updates newCollection.Fields based on the generated view table info and query
e.Collection.Fields = viewFields
}
// save the Collection model
if err := e.Next(); err != nil {
return err
}
// sync the changes with the related records table
if !isView {
if err := e.App.SyncRecordTableSchema(e.Collection, oldCollection); err != nil {
// note: don't wrap to allow propagating indexes validation.Errors
return err
}
}
return nil
})
e.App = originalApp
if txErr != nil {
return txErr
}
// trigger an update for all views with changed fields as a result of the current collection save
// (ignoring view errors to allow users to update the query from the UI)
resaveViewsWithChangedFields(e.App, e.Collection.Id)
return nil
}
func (m *Collection) initDefaultFields() {
switch m.Type {
case CollectionTypeBase:
m.initIdField()
case CollectionTypeAuth:
m.initIdField()
m.initPasswordField()
m.initTokenKeyField()
m.initEmailField()
m.initEmailVisibilityField()
m.initVerifiedField()
case CollectionTypeView:
// view fields are autogenerated
}
}
func (m *Collection) initIdField() {
field, _ := m.Fields.GetByName(FieldNameId).(*TextField)
if field == nil {
// create default field
field = &TextField{
Name: FieldNameId,
System: true,
PrimaryKey: true,
Required: true,
Min: 15,
Max: 15,
Pattern: `^[a-z0-9]+$`,
AutogeneratePattern: `[a-z0-9]{15}`,
}
// prepend it
m.Fields = NewFieldsList(append([]Field{field}, m.Fields...)...)
} else {
// enforce system defaults
field.System = true
field.Required = true
field.PrimaryKey = true
field.Hidden = false
}
}
func (m *Collection) initPasswordField() {
field, _ := m.Fields.GetByName(FieldNamePassword).(*PasswordField)
if field == nil {
// load default field
m.Fields.Add(&PasswordField{
Name: FieldNamePassword,
System: true,
Hidden: true,
Required: true,
Min: 8,
})
} else {
// enforce system defaults
field.System = true
field.Hidden = true
field.Required = true
}
}
func (m *Collection) initTokenKeyField() {
field, _ := m.Fields.GetByName(FieldNameTokenKey).(*TextField)
if field == nil {
// load default field
m.Fields.Add(&TextField{
Name: FieldNameTokenKey,
System: true,
Hidden: true,
Min: 30,
Max: 60,
Required: true,
AutogeneratePattern: `[a-zA-Z0-9]{50}`,
})
} else {
// enforce system defaults
field.System = true
field.Hidden = true
field.Required = true
}
// ensure that there is a unique index for the field
if !dbutils.HasSingleColumnUniqueIndex(FieldNameTokenKey, m.Indexes) {
m.Indexes = append(m.Indexes, fmt.Sprintf(
"CREATE UNIQUE INDEX `%s` ON `%s` (`%s`)",
m.fieldIndexName(FieldNameTokenKey),
m.Name,
FieldNameTokenKey,
))
}
}
func (m *Collection) initEmailField() {
field, _ := m.Fields.GetByName(FieldNameEmail).(*EmailField)
if field == nil {
// load default field
m.Fields.Add(&EmailField{
Name: FieldNameEmail,
System: true,
Required: true,
})
} else {
// enforce system defaults
field.System = true
field.Hidden = false // managed by the emailVisibility flag
}
// ensure that there is a unique index for the email field
if !dbutils.HasSingleColumnUniqueIndex(FieldNameEmail, m.Indexes) {
m.Indexes = append(m.Indexes, fmt.Sprintf(
"CREATE UNIQUE INDEX `%s` ON `%s` (`%s`) WHERE `%s` != ''",
m.fieldIndexName(FieldNameEmail),
m.Name,
FieldNameEmail,
FieldNameEmail,
))
}
}
func (m *Collection) initEmailVisibilityField() {
field, _ := m.Fields.GetByName(FieldNameEmailVisibility).(*BoolField)
if field == nil {
// load default field
m.Fields.Add(&BoolField{
Name: FieldNameEmailVisibility,
System: true,
})
} else {
// enforce system defaults
field.System = true
}
}
func (m *Collection) initVerifiedField() {
field, _ := m.Fields.GetByName(FieldNameVerified).(*BoolField)
if field == nil {
// load default field
m.Fields.Add(&BoolField{
Name: FieldNameVerified,
System: true,
})
} else {
// enforce system defaults
field.System = true
}
}
func (m *Collection) fieldIndexName(field string) string {
name := "idx_" + field + "_"
if m.Id != "" {
name += m.Id
} else if m.Name != "" {
name += m.Name
} else {
name += security.PseudorandomString(10)
}
if len(name) > 64 {
return name[:64]
}
return name
}

View File

@ -0,0 +1,535 @@
package core
import (
"strconv"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/pocketbase/pocketbase/tools/auth"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
func (m *Collection) unsetMissingOAuth2MappedFields() {
if !m.IsAuth() {
return
}
if m.OAuth2.MappedFields.Id != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.Id) == nil {
m.OAuth2.MappedFields.Id = ""
}
}
if m.OAuth2.MappedFields.Name != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.Name) == nil {
m.OAuth2.MappedFields.Name = ""
}
}
if m.OAuth2.MappedFields.Username != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.Username) == nil {
m.OAuth2.MappedFields.Username = ""
}
}
if m.OAuth2.MappedFields.AvatarURL != "" {
if m.Fields.GetByName(m.OAuth2.MappedFields.AvatarURL) == nil {
m.OAuth2.MappedFields.AvatarURL = ""
}
}
}
func (m *Collection) setDefaultAuthOptions() {
m.collectionAuthOptions = collectionAuthOptions{
VerificationTemplate: defaultVerificationTemplate,
ResetPasswordTemplate: defaultResetPasswordTemplate,
ConfirmEmailChangeTemplate: defaultConfirmEmailChangeTemplate,
AuthRule: types.Pointer(""),
AuthAlert: AuthAlertConfig{
Enabled: true,
EmailTemplate: defaultAuthAlertTemplate,
},
PasswordAuth: PasswordAuthConfig{
Enabled: true,
IdentityFields: []string{FieldNameEmail},
},
MFA: MFAConfig{
Enabled: false,
Duration: 1800, // 30min
},
OTP: OTPConfig{
Enabled: false,
Duration: 180, // 3min
Length: 8,
EmailTemplate: defaultOTPTemplate,
},
AuthToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 604800, // 7 days
},
PasswordResetToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 1800, // 30min
},
EmailChangeToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 1800, // 30min
},
VerificationToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 259200, // 3days
},
FileToken: TokenConfig{
Secret: security.RandomString(50),
Duration: 180, // 3min
},
}
}
var _ optionsValidator = (*collectionAuthOptions)(nil)
// collectionAuthOptions defines the options for the "auth" type collection.
type collectionAuthOptions struct {
// AuthRule could be used to specify additional record constraints
// applied after record authentication and right before returning the
// auth token response to the client.
//
// For example, to allow only verified users you could set it to
// "verified = true".
//
// Set it to empty string to allow any Auth collection record to authenticate.
//
// Set it to nil to disallow authentication altogether for the collection
// (that includes password, OAuth2, etc.).
AuthRule *string `form:"authRule" json:"authRule"`
// ManageRule gives admin-like permissions to allow fully managing
// the auth record(s), eg. changing the password without requiring
// to enter the old one, directly updating the verified state and email, etc.
//
// This rule is executed in addition to the Create and Update API rules.
ManageRule *string `form:"manageRule" json:"manageRule"`
// AuthAlert defines options related to the auth alerts on new device login.
AuthAlert AuthAlertConfig `form:"authAlert" json:"authAlert"`
// OAuth2 specifies whether OAuth2 auth is enabled for the collection
// and which OAuth2 providers are allowed.
OAuth2 OAuth2Config `form:"oauth2" json:"oauth2"`
PasswordAuth PasswordAuthConfig `form:"passwordAuth" json:"passwordAuth"`
MFA MFAConfig `form:"mfa" json:"mfa"`
OTP OTPConfig `form:"otp" json:"otp"`
// Various token configurations
// ---
AuthToken TokenConfig `form:"authToken" json:"authToken"`
PasswordResetToken TokenConfig `form:"passwordResetToken" json:"passwordResetToken"`
EmailChangeToken TokenConfig `form:"emailChangeToken" json:"emailChangeToken"`
VerificationToken TokenConfig `form:"verificationToken" json:"verificationToken"`
FileToken TokenConfig `form:"fileToken" json:"fileToken"`
// default email templates
// ---
VerificationTemplate EmailTemplate `form:"verificationTemplate" json:"verificationTemplate"`
ResetPasswordTemplate EmailTemplate `form:"resetPasswordTemplate" json:"resetPasswordTemplate"`
ConfirmEmailChangeTemplate EmailTemplate `form:"confirmEmailChangeTemplate" json:"confirmEmailChangeTemplate"`
}
func (o *collectionAuthOptions) validate(cv *collectionValidator) error {
err := validation.ValidateStruct(o,
validation.Field(
&o.AuthRule,
validation.By(cv.checkRule),
validation.By(cv.ensureNoSystemRuleChange(cv.original.AuthRule)),
),
validation.Field(
&o.ManageRule,
validation.NilOrNotEmpty,
validation.By(cv.checkRule),
validation.By(cv.ensureNoSystemRuleChange(cv.original.ManageRule)),
),
validation.Field(&o.AuthAlert),
validation.Field(&o.PasswordAuth),
validation.Field(&o.OAuth2),
validation.Field(&o.OTP),
validation.Field(&o.MFA),
validation.Field(&o.AuthToken),
validation.Field(&o.PasswordResetToken),
validation.Field(&o.EmailChangeToken),
validation.Field(&o.VerificationToken),
validation.Field(&o.FileToken),
validation.Field(&o.VerificationTemplate, validation.Required),
validation.Field(&o.ResetPasswordTemplate, validation.Required),
validation.Field(&o.ConfirmEmailChangeTemplate, validation.Required),
)
if err != nil {
return err
}
if o.MFA.Enabled {
// if MFA is enabled require at least 2 auth methods
//
// @todo maybe consider disabling the check because if custom auth methods
// are registered it may fail since we don't have mechanism to detect them at the moment
authsEnabled := 0
if o.PasswordAuth.Enabled {
authsEnabled++
}
if o.OAuth2.Enabled {
authsEnabled++
}
if o.OTP.Enabled {
authsEnabled++
}
if authsEnabled < 2 {
return validation.Errors{
"mfa": validation.Errors{
"enabled": validation.NewError("validation_mfa_not_enough_auths", "MFA requires at least 2 auth methods to be enabled."),
},
}
}
if o.MFA.Rule != "" {
mfaRuleValidators := []validation.RuleFunc{
cv.checkRule,
cv.ensureNoSystemRuleChange(&cv.original.MFA.Rule),
}
for _, validator := range mfaRuleValidators {
err := validator(&o.MFA.Rule)
if err != nil {
return validation.Errors{
"mfa": validation.Errors{
"rule": err,
},
}
}
}
}
}
// extra check to ensure that only unique identity fields are used
if o.PasswordAuth.Enabled {
err = validation.Validate(o.PasswordAuth.IdentityFields, validation.By(cv.checkFieldsForUniqueIndex))
if err != nil {
return validation.Errors{
"passwordAuth": validation.Errors{
"identityFields": err,
},
}
}
}
return nil
}
// -------------------------------------------------------------------
type EmailTemplate struct {
Subject string `form:"subject" json:"subject"`
Body string `form:"body" json:"body"`
}
// Validate makes EmailTemplate validatable by implementing [validation.Validatable] interface.
func (t EmailTemplate) Validate() error {
return validation.ValidateStruct(&t,
validation.Field(&t.Subject, validation.Required),
validation.Field(&t.Body, validation.Required),
)
}
// Resolve replaces the placeholder parameters in the current email
// template and returns its components as ready-to-use strings.
func (t EmailTemplate) Resolve(placeholders map[string]any) (subject, body string) {
body = t.Body
subject = t.Subject
for k, v := range placeholders {
vStr := cast.ToString(v)
// replace subject placeholder params (if any)
subject = strings.ReplaceAll(subject, k, vStr)
// replace body placeholder params (if any)
body = strings.ReplaceAll(body, k, vStr)
}
return subject, body
}
// -------------------------------------------------------------------
type AuthAlertConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
}
// Validate makes AuthAlertConfig validatable by implementing [validation.Validatable] interface.
func (c AuthAlertConfig) Validate() error {
return validation.ValidateStruct(&c,
// note: for now always run the email template validations even
// if not enabled since it could be used separately
validation.Field(&c.EmailTemplate),
)
}
// -------------------------------------------------------------------
type TokenConfig struct {
Secret string `form:"secret" json:"secret,omitempty"`
// Duration specifies how long an issued token to be valid (in seconds)
Duration int64 `form:"duration" json:"duration"`
}
// Validate makes TokenConfig validatable by implementing [validation.Validatable] interface.
func (c TokenConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Secret, validation.Required, validation.Length(30, 255)),
validation.Field(&c.Duration, validation.Required, validation.Min(10), validation.Max(94670856)), // ~3y max
)
}
// DurationTime returns the current Duration as [time.Duration].
func (c TokenConfig) DurationTime() time.Duration {
return time.Duration(c.Duration) * time.Second
}
// -------------------------------------------------------------------
type OTPConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
// Duration specifies how long the OTP to be valid (in seconds)
Duration int64 `form:"duration" json:"duration"`
// Length specifies the auto generated password length.
Length int `form:"length" json:"length"`
// EmailTemplate is the default OTP email template that will be send to the auth record.
//
// In addition to the system placeholders you can also make use of
// [core.EmailPlaceholderOTPId] and [core.EmailPlaceholderOTP].
EmailTemplate EmailTemplate `form:"emailTemplate" json:"emailTemplate"`
}
// Validate makes OTPConfig validatable by implementing [validation.Validatable] interface.
func (c OTPConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
validation.Field(&c.Length, validation.When(c.Enabled, validation.Required, validation.Min(4))),
// note: for now always run the email template validations even
// if not enabled since it could be used separately
validation.Field(&c.EmailTemplate),
)
}
// DurationTime returns the current Duration as [time.Duration].
func (c OTPConfig) DurationTime() time.Duration {
return time.Duration(c.Duration) * time.Second
}
// -------------------------------------------------------------------
type MFAConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
// Duration specifies how long an issued MFA to be valid (in seconds)
Duration int64 `form:"duration" json:"duration"`
// Rule is an optional field to restrict MFA only for the records that satisfy the rule.
//
// Leave it empty to enable MFA for everyone.
Rule string `form:"rule" json:"rule"`
}
// Validate makes MFAConfig validatable by implementing [validation.Validatable] interface.
func (c MFAConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Duration, validation.When(c.Enabled, validation.Required, validation.Min(10), validation.Max(86400))),
)
}
// DurationTime returns the current Duration as [time.Duration].
func (c MFAConfig) DurationTime() time.Duration {
return time.Duration(c.Duration) * time.Second
}
// -------------------------------------------------------------------
type PasswordAuthConfig struct {
Enabled bool `form:"enabled" json:"enabled"`
// IdentityFields is a list of field names that could be used as
// identity during password authentication.
//
// Usually only fields that has single column UNIQUE index are accepted as values.
IdentityFields []string `form:"identityFields" json:"identityFields"`
}
// Validate makes PasswordAuthConfig validatable by implementing [validation.Validatable] interface.
func (c PasswordAuthConfig) Validate() error {
// strip duplicated values
c.IdentityFields = list.ToUniqueStringSlice(c.IdentityFields)
if !c.Enabled {
return nil // no need to validate
}
return validation.ValidateStruct(&c,
validation.Field(&c.IdentityFields, validation.Required),
)
}
// -------------------------------------------------------------------
type OAuth2KnownFields struct {
Id string `form:"id" json:"id"`
Name string `form:"name" json:"name"`
Username string `form:"username" json:"username"`
AvatarURL string `form:"avatarURL" json:"avatarURL"`
}
type OAuth2Config struct {
Providers []OAuth2ProviderConfig `form:"providers" json:"providers"`
MappedFields OAuth2KnownFields `form:"mappedFields" json:"mappedFields"`
Enabled bool `form:"enabled" json:"enabled"`
}
// GetProviderConfig returns the first OAuth2ProviderConfig that matches the specified name.
//
// Returns false and zero config if no such provider is available in c.Providers.
func (c OAuth2Config) GetProviderConfig(name string) (config OAuth2ProviderConfig, exists bool) {
for _, p := range c.Providers {
if p.Name == name {
return p, true
}
}
return
}
// Validate makes OAuth2Config validatable by implementing [validation.Validatable] interface.
func (c OAuth2Config) Validate() error {
if !c.Enabled {
return nil // no need to validate
}
return validation.ValidateStruct(&c,
// note: don't require providers for now as they could be externally registered/removed
validation.Field(&c.Providers, validation.By(checkForDuplicatedProviders)),
)
}
func checkForDuplicatedProviders(value any) error {
configs, _ := value.([]OAuth2ProviderConfig)
existing := map[string]struct{}{}
for i, c := range configs {
if c.Name == "" {
continue // the name nonempty state is validated separately
}
if _, ok := existing[c.Name]; ok {
return validation.Errors{
strconv.Itoa(i): validation.Errors{
"name": validation.NewError("validation_duplicated_provider", "The provider "+c.Name+" is already registered.").
SetParams(map[string]any{"name": c.Name}),
},
}
}
existing[c.Name] = struct{}{}
}
return nil
}
type OAuth2ProviderConfig struct {
// PKCE overwrites the default provider PKCE config option.
//
// This usually shouldn't be needed but some OAuth2 vendors, like the LinkedIn OIDC,
// may require manual adjustment due to returning error if extra parameters are added to the request
// (https://github.com/pocketbase/pocketbase/discussions/3799#discussioncomment-7640312)
PKCE *bool `form:"pkce" json:"pkce"`
Name string `form:"name" json:"name"`
ClientId string `form:"clientId" json:"clientId"`
ClientSecret string `form:"clientSecret" json:"clientSecret,omitempty"`
AuthURL string `form:"authURL" json:"authURL"`
TokenURL string `form:"tokenURL" json:"tokenURL"`
UserInfoURL string `form:"userInfoURL" json:"userInfoURL"`
DisplayName string `form:"displayName" json:"displayName"`
}
// Validate makes OAuth2ProviderConfig validatable by implementing [validation.Validatable] interface.
func (c OAuth2ProviderConfig) Validate() error {
return validation.ValidateStruct(&c,
validation.Field(&c.Name, validation.Required, validation.By(checkProviderName)),
validation.Field(&c.ClientId, validation.Required),
validation.Field(&c.ClientSecret, validation.Required),
validation.Field(&c.AuthURL, is.URL),
validation.Field(&c.TokenURL, is.URL),
validation.Field(&c.UserInfoURL, is.URL),
)
}
func checkProviderName(value any) error {
name, _ := value.(string)
if name == "" {
return nil // nothing to check
}
if _, err := auth.NewProviderByName(name); err != nil {
return validation.NewError("validation_missing_provider", "Invalid or missing provider with name "+name+".").
SetParams(map[string]any{"name": name})
}
return nil
}
// InitProvider returns a new auth.Provider instance loaded with the current OAuth2ProviderConfig options.
func (c OAuth2ProviderConfig) InitProvider() (auth.Provider, error) {
provider, err := auth.NewProviderByName(c.Name)
if err != nil {
return nil, err
}
if c.ClientId != "" {
provider.SetClientId(c.ClientId)
}
if c.ClientSecret != "" {
provider.SetClientSecret(c.ClientSecret)
}
if c.AuthURL != "" {
provider.SetAuthURL(c.AuthURL)
}
if c.UserInfoURL != "" {
provider.SetUserInfoURL(c.UserInfoURL)
}
if c.TokenURL != "" {
provider.SetTokenURL(c.TokenURL)
}
if c.DisplayName != "" {
provider.SetDisplayName(c.DisplayName)
}
if c.PKCE != nil {
provider.SetPKCE(*c.PKCE)
}
return provider, nil
}

Some files were not shown because too many files have changed in this diff Show More