package apis import ( "fmt" "log/slog" "net/http" "net/url" "slices" "strings" "time" "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/tools/hook" "github.com/pocketbase/pocketbase/tools/list" "github.com/pocketbase/pocketbase/tools/router" "github.com/pocketbase/pocketbase/tools/routine" "github.com/spf13/cast" ) // Common request event store keys used by the middlewares and api handlers. const ( RequestEventKeyLogMeta = "pbLogMeta" // extra data to store with the request activity log requestEventKeyExecStart = "__execStart" // the value must be time.Time requestEventKeySkipSuccessActivityLog = "__skipSuccessActivityLogger" // the value must be bool ) const ( DefaultWWWRedirectMiddlewarePriority = -99999 DefaultWWWRedirectMiddlewareId = "pbWWWRedirect" DefaultActivityLoggerMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 30 DefaultActivityLoggerMiddlewareId = "pbActivityLogger" DefaultSkipSuccessActivityLogMiddlewareId = "pbSkipSuccessActivityLog" DefaultEnableAuthIdActivityLog = "pbEnableAuthIdActivityLog" DefaultLoadAuthTokenMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 20 DefaultLoadAuthTokenMiddlewareId = "pbLoadAuthToken" DefaultSecurityHeadersMiddlewarePriority = DefaultRateLimitMiddlewarePriority - 10 DefaultSecurityHeadersMiddlewareId = "pbSecurityHeaders" DefaultRequireGuestOnlyMiddlewareId = "pbRequireGuestOnly" DefaultRequireAuthMiddlewareId = "pbRequireAuth" DefaultRequireSuperuserAuthMiddlewareId = "pbRequireSuperuserAuth" DefaultRequireSuperuserAuthOnlyIfAnyMiddlewareId = "pbRequireSuperuserAuthOnlyIfAny" DefaultRequireSuperuserOrOwnerAuthMiddlewareId = "pbRequireSuperuserOrOwnerAuth" DefaultRequireSameCollectionContextAuthMiddlewareId = "pbRequireSameCollectionContextAuth" ) // RequireGuestOnly middleware requires a request to NOT have a valid // Authorization header. // // This middleware is the opposite of [apis.RequireAuth()]. func RequireGuestOnly() *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultRequireGuestOnlyMiddlewareId, Func: func(e *core.RequestEvent) error { if e.Auth != nil { return router.NewBadRequestError("The request can be accessed only by guests.", nil) } return e.Next() }, } } // RequireAuth middleware requires a request to have a valid record Authorization header. // // The auth record could be from any collection. // You can further filter the allowed record auth collections by specifying their names. // // Example: // // apis.RequireAuth() // any auth collection // apis.RequireAuth("_superusers", "users") // only the listed auth collections func RequireAuth(optCollectionNames ...string) *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultRequireAuthMiddlewareId, Func: requireAuth(optCollectionNames...), } } func requireAuth(optCollectionNames ...string) hook.HandlerFunc[*core.RequestEvent] { return func(e *core.RequestEvent) error { if e.Auth == nil { return e.UnauthorizedError("The request requires valid record authorization token.", nil) } // check record collection name if len(optCollectionNames) > 0 && !slices.Contains(optCollectionNames, e.Auth.Collection().Name) { return e.ForbiddenError("The authorized record is not allowed to perform this action.", nil) } return e.Next() } } // RequireSuperuserAuth middleware requires a request to have // a valid superuser Authorization header. func RequireSuperuserAuth() *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultRequireSuperuserAuthMiddlewareId, Func: requireAuth(core.CollectionNameSuperusers), } } // RequireSuperuserAuthOnlyIfAny middleware requires a request to have // a valid superuser Authorization header ONLY if the application has // at least 1 existing superuser. func RequireSuperuserAuthOnlyIfAny() *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultRequireSuperuserAuthOnlyIfAnyMiddlewareId, Func: func(e *core.RequestEvent) error { if e.HasSuperuserAuth() { return e.Next() } totalSuperusers, err := e.App.CountRecords(core.CollectionNameSuperusers) if err != nil { return e.InternalServerError("Failed to fetch superusers info.", err) } if totalSuperusers == 0 { return e.Next() } return requireAuth(core.CollectionNameSuperusers)(e) }, } } // RequireSuperuserOrOwnerAuth middleware requires a request to have // a valid superuser or regular record owner Authorization header set. // // This middleware is similar to [apis.RequireAuth()] but // for the auth record token expects to have the same id as the path // parameter ownerIdPathParam (default to "id" if empty). func RequireSuperuserOrOwnerAuth(ownerIdPathParam string) *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultRequireSuperuserOrOwnerAuthMiddlewareId, Func: func(e *core.RequestEvent) error { if e.Auth == nil { return e.UnauthorizedError("The request requires superuser or record authorization token.", nil) } if e.Auth.IsSuperuser() { return e.Next() } if ownerIdPathParam == "" { ownerIdPathParam = "id" } ownerId := e.Request.PathValue(ownerIdPathParam) // note: it is considered "safe" to compare only the record id // since the auth record ids are treated as unique across all auth collections if e.Auth.Id != ownerId { return e.ForbiddenError("You are not allowed to perform this request.", nil) } return e.Next() }, } } // RequireSameCollectionContextAuth middleware requires a request to have // a valid record Authorization header and the auth record's collection to // match the one from the route path parameter (default to "collection" if collectionParam is empty). func RequireSameCollectionContextAuth(collectionPathParam string) *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultRequireSameCollectionContextAuthMiddlewareId, Func: func(e *core.RequestEvent) error { if e.Auth == nil { return e.UnauthorizedError("The request requires valid record authorization token.", nil) } if collectionPathParam == "" { collectionPathParam = "collection" } collection, _ := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam)) if collection == nil || e.Auth.Collection().Id != collection.Id { return e.ForbiddenError(fmt.Sprintf("The request requires auth record from %s collection.", e.Auth.Collection().Name), nil) } return e.Next() }, } } // loadAuthToken attempts to load the auth context based on the "Authorization: TOKEN" header value. // // This middleware does nothing in case of missing, invalid or expired token. // // This middleware is registered by default for all routes. // // Note: We don't throw an error on invalid or expired token to allow // users to extend with their own custom handling in external middleware(s). func loadAuthToken() *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultLoadAuthTokenMiddlewareId, Priority: DefaultLoadAuthTokenMiddlewarePriority, Func: func(e *core.RequestEvent) error { token := getAuthTokenFromRequest(e) if token == "" { return e.Next() } record, err := e.App.FindAuthRecordByToken(token, core.TokenTypeAuth) if err != nil { e.App.Logger().Debug("loadAuthToken failure", "error", err) } else if record != nil { e.Auth = record } return e.Next() }, } } func getAuthTokenFromRequest(e *core.RequestEvent) string { token := e.Request.Header.Get("Authorization") if token != "" { // the schema prefix is not required and it is only for // compatibility with the defaults of some HTTP clients token = strings.TrimPrefix(token, "Bearer ") } return token } // wwwRedirect performs www->non-www redirect(s) if the request host // matches with one of the values in redirectHosts. // // This middleware is registered by default on Serve for all routes. func wwwRedirect(redirectHosts []string) *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultWWWRedirectMiddlewareId, Priority: DefaultWWWRedirectMiddlewarePriority, Func: func(e *core.RequestEvent) error { host := e.Request.Host if strings.HasPrefix(host, "www.") && list.ExistInSlice(host, redirectHosts) { return e.Redirect( http.StatusTemporaryRedirect, (e.Request.URL.Scheme + "://" + host[4:] + e.Request.RequestURI), ) } return e.Next() }, } } // securityHeaders middleware adds common security headers to the response. // // This middleware is registered by default for all routes. func securityHeaders() *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultSecurityHeadersMiddlewareId, Priority: DefaultSecurityHeadersMiddlewarePriority, Func: func(e *core.RequestEvent) error { e.Response.Header().Set("X-XSS-Protection", "1; mode=block") e.Response.Header().Set("X-Content-Type-Options", "nosniff") e.Response.Header().Set("X-Frame-Options", "SAMEORIGIN") // @todo consider a default HSTS? // (see also https://webkit.org/blog/8146/protecting-against-hsts-abuse/) return e.Next() }, } } // SkipSuccessActivityLog is a helper middleware that instructs the global // activity logger to log only requests that have failed/returned an error. func SkipSuccessActivityLog() *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultSkipSuccessActivityLogMiddlewareId, Func: func(e *core.RequestEvent) error { e.Set(requestEventKeySkipSuccessActivityLog, true) return e.Next() }, } } // activityLogger middleware takes care to save the request information // into the logs database. // // This middleware is registered by default for all routes. // // The middleware does nothing if the app logs retention period is zero // (aka. app.Settings().Logs.MaxDays = 0). // // Users can attach the [apis.SkipSuccessActivityLog()] middleware if // you want to log only the failed requests. func activityLogger() *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultActivityLoggerMiddlewareId, Priority: DefaultActivityLoggerMiddlewarePriority, Func: func(e *core.RequestEvent) error { e.Set(requestEventKeyExecStart, time.Now()) err := e.Next() logRequest(e, err) return err }, } } func logRequest(event *core.RequestEvent, err error) { // no logs retention if event.App.Settings().Logs.MaxDays == 0 { return } // the non-error route has explicitly disabled the activity logger if err == nil && event.Get(requestEventKeySkipSuccessActivityLog) != nil { return } attrs := make([]any, 0, 15) attrs = append(attrs, slog.String("type", "request")) started := cast.ToTime(event.Get(requestEventKeyExecStart)) if !started.IsZero() { attrs = append(attrs, slog.Float64("execTime", float64(time.Since(started))/float64(time.Millisecond))) } if meta := event.Get(RequestEventKeyLogMeta); meta != nil { attrs = append(attrs, slog.Any("meta", meta)) } status := event.Status() method := cutStr(strings.ToUpper(event.Request.Method), 50) requestUri := cutStr(event.Request.URL.RequestURI(), 3000) // parse the request error if err != nil { if apiErr, ok := err.(*router.ApiError); ok { status = apiErr.Status attrs = append( attrs, slog.String("error", apiErr.Message), slog.Any("details", apiErr.RawData()), ) } else { attrs = append(attrs, slog.String("error", err.Error())) } } attrs = append( attrs, slog.String("url", requestUri), slog.String("method", method), slog.Int("status", status), slog.String("referer", cutStr(event.Request.Referer(), 2000)), slog.String("userAgent", cutStr(event.Request.UserAgent(), 2000)), ) if event.Auth != nil { attrs = append(attrs, slog.String("auth", event.Auth.Collection().Name)) if event.App.Settings().Logs.LogAuthId { attrs = append(attrs, slog.String("authId", event.Auth.Id)) } } else { attrs = append(attrs, slog.String("auth", "")) } if event.App.Settings().Logs.LogIP { var userIP string if len(event.App.Settings().TrustedProxy.Headers) > 0 { userIP = event.RealIP() } else { // fallback to the legacy behavior (it is "safe" since it is only for log purposes) userIP = cutStr(event.UnsafeRealIP(), 50) } attrs = append( attrs, slog.String("userIP", userIP), slog.String("remoteIP", event.RemoteIP()), ) } // don't block on logs write routine.FireAndForget(func() { message := method + " " if escaped, unescapeErr := url.PathUnescape(requestUri); unescapeErr == nil { message += escaped } else { message += requestUri } if err != nil { event.App.Logger().Error(message, attrs...) } else { event.App.Logger().Info(message, attrs...) } }) } func cutStr(str string, max int) string { if len(str) > max { return str[:max] + "..." } return str }