1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-11-06 17:39:57 +02:00

[#55] added OAuth2 subscription redirect handler

This commit is contained in:
Gani Georgiev
2023-04-10 22:27:00 +03:00
parent c826514eca
commit dc72d5adee
34 changed files with 336 additions and 111 deletions

View File

@@ -1,6 +1,7 @@
package apis
import (
"encoding/json"
"errors"
"fmt"
"log"
@@ -17,6 +18,7 @@ import (
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/subscriptions"
"golang.org/x/oauth2"
)
@@ -25,12 +27,15 @@ import (
func bindRecordAuthApi(app core.App, rg *echo.Group) {
api := recordAuthApi{app: app}
// global oauth2 subscription redirect handler
rg.GET("/oauth2-redirect", api.oauth2SubscriptionRedirect)
// common collection record related routes
subGroup := rg.Group(
"/collections/:collection",
ActivityLogger(app),
LoadCollectionContext(app, models.CollectionTypeAuth),
)
subGroup.GET("/auth-methods", api.authMethods)
subGroup.POST("/auth-refresh", api.authRefresh, RequireSameContextRecordAuth())
subGroup.POST("/auth-with-oauth2", api.authWithOAuth2)
@@ -628,3 +633,36 @@ func (api *recordAuthApi) unlinkExternalAuth(c echo.Context) error {
return handlerErr
}
// -------------------------------------------------------------------
const oauth2SubscribeTopic = "@oauth2"
func (api *recordAuthApi) oauth2SubscriptionRedirect(c echo.Context) error {
state := c.QueryParam("state")
code := c.QueryParam("code")
client, err := api.app.SubscriptionsBroker().ClientById(state)
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscribeTopic) {
return NewNotFoundError("Missing or invalid oauth2 subscription client", err)
}
data := map[string]string{
"state": state,
"code": code,
}
encodedData, err := json.Marshal(data)
if err != nil {
return NewBadRequestError("Failed to marshalize oauth2 redirect data", err)
}
msg := subscriptions.Message{
Name: oauth2SubscribeTopic,
Data: string(encodedData),
}
client.Channel() <- msg
return c.Redirect(http.StatusTemporaryRedirect, "/_/#/auth/oauth2-redirect")
}

View File

@@ -1,6 +1,7 @@
package apis_test
import (
"context"
"net/http"
"strings"
"testing"
@@ -9,6 +10,7 @@ import (
"github.com/labstack/echo/v5"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/subscriptions"
"github.com/pocketbase/pocketbase/tools/types"
)
@@ -1144,3 +1146,139 @@ func TestRecordAuthUnlinkExternalsAuth(t *testing.T) {
scenario.Test(t)
}
}
func TestRecordAuthOAuth2Redirect(t *testing.T) {
c1 := subscriptions.NewDefaultClient()
c2 := subscriptions.NewDefaultClient()
c2.Subscribe("@oauth2")
c3 := subscriptions.NewDefaultClient()
c3.Subscribe("test1", "@oauth2")
c4 := subscriptions.NewDefaultClient()
c4.Subscribe("test1", "test2")
c5 := subscriptions.NewDefaultClient()
c5.Subscribe("@oauth2")
c5.Discard()
baseBeforeTestFunc := func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.SubscriptionsBroker().Register(c1)
app.SubscriptionsBroker().Register(c2)
app.SubscriptionsBroker().Register(c3)
app.SubscriptionsBroker().Register(c4)
app.SubscriptionsBroker().Register(c5)
}
noMessagesBeforeTestFunc := func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
baseBeforeTestFunc(t, app, e)
ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
go func() {
defer cancelFunc()
L:
for {
select {
case <-c1.Channel():
t.Error("Unexpected c1 message")
break L
case <-c2.Channel():
t.Error("Unexpected c2 message")
break L
case <-c3.Channel():
t.Error("Unexpected c3 message")
break L
case <-c4.Channel():
t.Error("Unexpected c4 message")
break L
case <-c5.Channel():
t.Error("Unexpected c5 message")
break L
case <-ctx.Done():
t.Error("Context timeout reached")
break L
}
}
}()
}
scenarios := []tests.ApiScenario{
{
Name: "no clients",
Method: http.MethodGet,
Url: "/api/oauth2-redirect",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "discarded client with @oauth2 subscription",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?state=" + c5.Id(),
BeforeTestFunc: noMessagesBeforeTestFunc,
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "client without @oauth2 subscription",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?state=" + c4.Id(),
BeforeTestFunc: noMessagesBeforeTestFunc,
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
},
{
Name: "client without @oauth2 subscription",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?state=" + c3.Id(),
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
baseBeforeTestFunc(t, app, e)
ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
go func() {
defer cancelFunc()
L:
for {
select {
case <-c1.Channel():
t.Error("Unexpected c1 message")
break L
case <-c2.Channel():
t.Error("Unexpected c2 message")
break L
case msg := <-c3.Channel():
if msg.Name != "@oauth2" {
t.Errorf("Expected @oauth2 msg.Name, got %q", msg.Name)
}
expectedParams := []string{`"state"`, `"code"`}
for _, p := range expectedParams {
if !strings.Contains(msg.Data, p) {
t.Errorf("Couldn't find %s in \n%v", p, msg.Data)
}
}
break L
case <-c4.Channel():
t.Error("Unexpected c4 message")
break L
case <-c5.Channel():
t.Error("Unexpected c5 message")
break L
case <-ctx.Done():
t.Error("Context timeout reached")
break L
}
}
}()
},
ExpectedStatus: http.StatusTemporaryRedirect,
},
}
for _, scenario := range scenarios {
scenario.Test(t)
}
}