You've already forked pocketbase
mirror of
https://github.com/pocketbase/pocketbase.git
synced 2025-11-06 17:39:57 +02:00
[#55] added OAuth2 subscription redirect handler
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package apis
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"github.com/pocketbase/pocketbase/tools/routine"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@@ -25,12 +27,15 @@ import (
|
||||
func bindRecordAuthApi(app core.App, rg *echo.Group) {
|
||||
api := recordAuthApi{app: app}
|
||||
|
||||
// global oauth2 subscription redirect handler
|
||||
rg.GET("/oauth2-redirect", api.oauth2SubscriptionRedirect)
|
||||
|
||||
// common collection record related routes
|
||||
subGroup := rg.Group(
|
||||
"/collections/:collection",
|
||||
ActivityLogger(app),
|
||||
LoadCollectionContext(app, models.CollectionTypeAuth),
|
||||
)
|
||||
|
||||
subGroup.GET("/auth-methods", api.authMethods)
|
||||
subGroup.POST("/auth-refresh", api.authRefresh, RequireSameContextRecordAuth())
|
||||
subGroup.POST("/auth-with-oauth2", api.authWithOAuth2)
|
||||
@@ -628,3 +633,36 @@ func (api *recordAuthApi) unlinkExternalAuth(c echo.Context) error {
|
||||
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const oauth2SubscribeTopic = "@oauth2"
|
||||
|
||||
func (api *recordAuthApi) oauth2SubscriptionRedirect(c echo.Context) error {
|
||||
state := c.QueryParam("state")
|
||||
code := c.QueryParam("code")
|
||||
|
||||
client, err := api.app.SubscriptionsBroker().ClientById(state)
|
||||
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscribeTopic) {
|
||||
return NewNotFoundError("Missing or invalid oauth2 subscription client", err)
|
||||
}
|
||||
|
||||
data := map[string]string{
|
||||
"state": state,
|
||||
"code": code,
|
||||
}
|
||||
|
||||
encodedData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to marshalize oauth2 redirect data", err)
|
||||
}
|
||||
|
||||
msg := subscriptions.Message{
|
||||
Name: oauth2SubscribeTopic,
|
||||
Data: string(encodedData),
|
||||
}
|
||||
|
||||
client.Channel() <- msg
|
||||
|
||||
return c.Redirect(http.StatusTemporaryRedirect, "/_/#/auth/oauth2-redirect")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package apis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
@@ -1144,3 +1146,139 @@ func TestRecordAuthUnlinkExternalsAuth(t *testing.T) {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordAuthOAuth2Redirect(t *testing.T) {
|
||||
c1 := subscriptions.NewDefaultClient()
|
||||
|
||||
c2 := subscriptions.NewDefaultClient()
|
||||
c2.Subscribe("@oauth2")
|
||||
|
||||
c3 := subscriptions.NewDefaultClient()
|
||||
c3.Subscribe("test1", "@oauth2")
|
||||
|
||||
c4 := subscriptions.NewDefaultClient()
|
||||
c4.Subscribe("test1", "test2")
|
||||
|
||||
c5 := subscriptions.NewDefaultClient()
|
||||
c5.Subscribe("@oauth2")
|
||||
c5.Discard()
|
||||
|
||||
baseBeforeTestFunc := func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.SubscriptionsBroker().Register(c1)
|
||||
app.SubscriptionsBroker().Register(c2)
|
||||
app.SubscriptionsBroker().Register(c3)
|
||||
app.SubscriptionsBroker().Register(c4)
|
||||
app.SubscriptionsBroker().Register(c5)
|
||||
}
|
||||
|
||||
noMessagesBeforeTestFunc := func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
baseBeforeTestFunc(t, app, e)
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
|
||||
go func() {
|
||||
defer cancelFunc()
|
||||
L:
|
||||
for {
|
||||
select {
|
||||
case <-c1.Channel():
|
||||
t.Error("Unexpected c1 message")
|
||||
break L
|
||||
case <-c2.Channel():
|
||||
t.Error("Unexpected c2 message")
|
||||
break L
|
||||
case <-c3.Channel():
|
||||
t.Error("Unexpected c3 message")
|
||||
break L
|
||||
case <-c4.Channel():
|
||||
t.Error("Unexpected c4 message")
|
||||
break L
|
||||
case <-c5.Channel():
|
||||
t.Error("Unexpected c5 message")
|
||||
break L
|
||||
case <-ctx.Done():
|
||||
t.Error("Context timeout reached")
|
||||
break L
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "no clients",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "discarded client with @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?state=" + c5.Id(),
|
||||
BeforeTestFunc: noMessagesBeforeTestFunc,
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "client without @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?state=" + c4.Id(),
|
||||
BeforeTestFunc: noMessagesBeforeTestFunc,
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
},
|
||||
{
|
||||
Name: "client without @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?state=" + c3.Id(),
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
baseBeforeTestFunc(t, app, e)
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
|
||||
go func() {
|
||||
defer cancelFunc()
|
||||
L:
|
||||
for {
|
||||
select {
|
||||
case <-c1.Channel():
|
||||
t.Error("Unexpected c1 message")
|
||||
break L
|
||||
case <-c2.Channel():
|
||||
t.Error("Unexpected c2 message")
|
||||
break L
|
||||
case msg := <-c3.Channel():
|
||||
if msg.Name != "@oauth2" {
|
||||
t.Errorf("Expected @oauth2 msg.Name, got %q", msg.Name)
|
||||
}
|
||||
|
||||
expectedParams := []string{`"state"`, `"code"`}
|
||||
for _, p := range expectedParams {
|
||||
if !strings.Contains(msg.Data, p) {
|
||||
t.Errorf("Couldn't find %s in \n%v", p, msg.Data)
|
||||
}
|
||||
}
|
||||
|
||||
break L
|
||||
case <-c4.Channel():
|
||||
t.Error("Unexpected c4 message")
|
||||
break L
|
||||
case <-c5.Channel():
|
||||
t.Error("Unexpected c5 message")
|
||||
break L
|
||||
case <-ctx.Done():
|
||||
t.Error("Context timeout reached")
|
||||
break L
|
||||
}
|
||||
}
|
||||
}()
|
||||
},
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
scenario.Test(t)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user