You've already forked pocketbase
mirror of
https://github.com/pocketbase/pocketbase.git
synced 2025-11-06 09:29:19 +02:00
[#4177] added graceful OAuth2 redirect error handling
This commit is contained in:
@@ -656,29 +656,42 @@ func (api *recordAuthApi) unlinkExternalAuth(c echo.Context) error {
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
const oauth2SubscriptionTopic = "@oauth2"
|
||||
const (
|
||||
oauth2SubscriptionTopic string = "@oauth2"
|
||||
oauth2FailureRedirectPath string = "../_/#/auth/oauth2-redirect-failure"
|
||||
oauth2SuccessRedirectPath string = "../_/#/auth/oauth2-redirect-success"
|
||||
)
|
||||
|
||||
type oauth2EventMessage struct {
|
||||
State string `json:"state"`
|
||||
Code string `json:"code"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (api *recordAuthApi) oauth2SubscriptionRedirect(c echo.Context) error {
|
||||
state := c.QueryParam("state")
|
||||
code := c.QueryParam("code")
|
||||
|
||||
if code == "" || state == "" {
|
||||
return NewBadRequestError("Invalid OAuth2 redirect parameters.", nil)
|
||||
if state == "" {
|
||||
api.app.Logger().Debug("Missing OAuth2 state parameter")
|
||||
return c.Redirect(http.StatusTemporaryRedirect, oauth2FailureRedirectPath)
|
||||
}
|
||||
|
||||
client, err := api.app.SubscriptionsBroker().ClientById(state)
|
||||
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscriptionTopic) {
|
||||
return NewNotFoundError("Missing or invalid OAuth2 subscription client.", err)
|
||||
api.app.Logger().Debug("Missing or invalid OAuth2 subscription client", "error", err, "clientId", state)
|
||||
return c.Redirect(http.StatusTemporaryRedirect, oauth2FailureRedirectPath)
|
||||
}
|
||||
defer client.Unsubscribe(oauth2SubscriptionTopic)
|
||||
|
||||
data := map[string]string{
|
||||
"state": state,
|
||||
"code": code,
|
||||
data := oauth2EventMessage{
|
||||
State: state,
|
||||
Code: c.QueryParam("code"),
|
||||
Error: c.QueryParam("error"),
|
||||
}
|
||||
|
||||
encodedData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return NewBadRequestError("Failed to marshalize OAuth2 redirect data.", err)
|
||||
api.app.Logger().Debug("Failed to marshalize OAuth2 redirect data", "error", err)
|
||||
return c.Redirect(http.StatusTemporaryRedirect, oauth2FailureRedirectPath)
|
||||
}
|
||||
|
||||
msg := subscriptions.Message{
|
||||
@@ -688,5 +701,10 @@ func (api *recordAuthApi) oauth2SubscriptionRedirect(c echo.Context) error {
|
||||
|
||||
client.Send(msg)
|
||||
|
||||
return c.Redirect(http.StatusTemporaryRedirect, "../_/#/auth/oauth2-redirect")
|
||||
if data.Error != "" || data.Code == "" {
|
||||
api.app.Logger().Debug("Failed OAuth2 redirect due to an error or missing code parameter", "error", data.Error, "clientId", data.State)
|
||||
return c.Redirect(http.StatusTemporaryRedirect, oauth2FailureRedirectPath)
|
||||
}
|
||||
|
||||
return c.Redirect(http.StatusTemporaryRedirect, oauth2SuccessRedirectPath)
|
||||
}
|
||||
|
||||
@@ -1503,114 +1503,205 @@ func TestRecordAuthUnlinkExternalsAuth(t *testing.T) {
|
||||
func TestRecordAuthOAuth2Redirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c1 := subscriptions.NewDefaultClient()
|
||||
clientStubs := make([]map[string]subscriptions.Client, 0, 10)
|
||||
|
||||
c2 := subscriptions.NewDefaultClient()
|
||||
c2.Subscribe("@oauth2")
|
||||
for i := 0; i < 10; i++ {
|
||||
c1 := subscriptions.NewDefaultClient()
|
||||
|
||||
c3 := subscriptions.NewDefaultClient()
|
||||
c3.Subscribe("test1", "@oauth2")
|
||||
c2 := subscriptions.NewDefaultClient()
|
||||
c2.Subscribe("@oauth2")
|
||||
|
||||
c4 := subscriptions.NewDefaultClient()
|
||||
c4.Subscribe("test1", "test2")
|
||||
c3 := subscriptions.NewDefaultClient()
|
||||
c3.Subscribe("test1", "@oauth2")
|
||||
|
||||
c5 := subscriptions.NewDefaultClient()
|
||||
c5.Subscribe("@oauth2")
|
||||
c5.Discard()
|
||||
c4 := subscriptions.NewDefaultClient()
|
||||
c4.Subscribe("test1", "test2")
|
||||
|
||||
beforeTestFunc := func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
app.SubscriptionsBroker().Register(c1)
|
||||
app.SubscriptionsBroker().Register(c2)
|
||||
app.SubscriptionsBroker().Register(c3)
|
||||
app.SubscriptionsBroker().Register(c4)
|
||||
app.SubscriptionsBroker().Register(c5)
|
||||
c5 := subscriptions.NewDefaultClient()
|
||||
c5.Subscribe("@oauth2")
|
||||
c5.Discard()
|
||||
|
||||
clientStubs = append(clientStubs, map[string]subscriptions.Client{
|
||||
"c1": c1,
|
||||
"c2": c2,
|
||||
"c3": c3,
|
||||
"c4": c4,
|
||||
"c5": c5,
|
||||
})
|
||||
}
|
||||
|
||||
checkFailureRedirect := func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "/oauth2-redirect-failure") {
|
||||
t.Fatalf("Expected failure redirect, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
checkSuccessRedirect := func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
loc := res.Header.Get("Location")
|
||||
if !strings.Contains(loc, "/oauth2-redirect-success") {
|
||||
t.Fatalf("Expected success redirect, got %q", loc)
|
||||
}
|
||||
}
|
||||
|
||||
checkClientMessages := func(t *testing.T, clientId string, msg subscriptions.Message, expectedMessages map[string][]string) {
|
||||
if len(expectedMessages[clientId]) == 0 {
|
||||
t.Fatalf("Unexpected client %q message, got %s:\n%s", clientId, msg.Name, msg.Data)
|
||||
}
|
||||
|
||||
if msg.Name != "@oauth2" {
|
||||
t.Fatalf("Expected @oauth2 msg.Name, got %q", msg.Name)
|
||||
}
|
||||
|
||||
for _, txt := range expectedMessages[clientId] {
|
||||
if !strings.Contains(string(msg.Data), txt) {
|
||||
t.Fatalf("Failed to find %q in \n%s", txt, msg.Data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
beforeTestFunc := func(
|
||||
clients map[string]subscriptions.Client,
|
||||
expectedMessages map[string][]string,
|
||||
) func(*testing.T, *tests.TestApp, *echo.Echo) {
|
||||
return func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
for _, client := range clients {
|
||||
app.SubscriptionsBroker().Register(client)
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
|
||||
// add to the app store so that it can be cancelled manually after test completion
|
||||
app.Store().Set("cancelFunc", cancelFunc)
|
||||
|
||||
go func() {
|
||||
defer cancelFunc()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-clients["c1"].Channel():
|
||||
checkClientMessages(t, "c1", msg, expectedMessages)
|
||||
case msg := <-clients["c2"].Channel():
|
||||
checkClientMessages(t, "c2", msg, expectedMessages)
|
||||
case msg := <-clients["c3"].Channel():
|
||||
checkClientMessages(t, "c3", msg, expectedMessages)
|
||||
case msg := <-clients["c4"].Channel():
|
||||
checkClientMessages(t, "c4", msg, expectedMessages)
|
||||
case msg := <-clients["c5"].Channel():
|
||||
checkClientMessages(t, "c5", msg, expectedMessages)
|
||||
case <-ctx.Done():
|
||||
for _, c := range clients {
|
||||
close(c.Channel())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
scenarios := []tests.ApiScenario{
|
||||
{
|
||||
Name: "no state query param",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?code=123",
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
Name: "no state query param",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?code=123",
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[0], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "no code query param",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?state=" + c3.Id(),
|
||||
ExpectedStatus: 400,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
Name: "invalid or missing client",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?code=123&state=missing",
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[1], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "missing client",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?code=123&state=missing",
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
Name: "no code query param",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?state=" + clientStubs[2]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[2], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[2]["c3"].Id(), `"code":""`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
|
||||
if clientStubs[2]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "discarded client with @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?code=123&state=" + c5.Id(),
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
Name: "error query param",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?error=example&code=123&state=" + clientStubs[3]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[3], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[3]["c3"].Id(), `"code":"123"`, `"error":"example"`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
|
||||
if clientStubs[3]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "client without @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?code=123&state=" + c4.Id(),
|
||||
BeforeTestFunc: beforeTestFunc,
|
||||
ExpectedStatus: 404,
|
||||
ExpectedContent: []string{`"data":{}`},
|
||||
Name: "discarded client with @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c5"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[4], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "client without @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c4"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[5], nil),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkFailureRedirect(t, app, res)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "client with @oauth2 subscription",
|
||||
Method: http.MethodGet,
|
||||
Url: "/api/oauth2-redirect?code=123&state=" + c3.Id(),
|
||||
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
|
||||
beforeTestFunc(t, app, e)
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
|
||||
go func() {
|
||||
defer cancelFunc()
|
||||
L:
|
||||
for {
|
||||
select {
|
||||
case <-c1.Channel():
|
||||
t.Error("Unexpected c1 message")
|
||||
break L
|
||||
case <-c2.Channel():
|
||||
t.Error("Unexpected c2 message")
|
||||
break L
|
||||
case msg := <-c3.Channel():
|
||||
if msg.Name != "@oauth2" {
|
||||
t.Errorf("Expected @oauth2 msg.Name, got %q", msg.Name)
|
||||
}
|
||||
|
||||
expectedParams := []string{`"state"`, `"code"`}
|
||||
for _, p := range expectedParams {
|
||||
if !strings.Contains(string(msg.Data), p) {
|
||||
t.Errorf("Couldn't find %s in \n%v", p, msg.Data)
|
||||
}
|
||||
}
|
||||
|
||||
break L
|
||||
case <-c4.Channel():
|
||||
t.Error("Unexpected c4 message")
|
||||
break L
|
||||
case <-c5.Channel():
|
||||
t.Error("Unexpected c5 message")
|
||||
break L
|
||||
case <-ctx.Done():
|
||||
t.Error("Context timeout reached")
|
||||
break L
|
||||
}
|
||||
}
|
||||
}()
|
||||
},
|
||||
Url: "/api/oauth2-redirect?code=123&state=" + clientStubs[6]["c3"].Id(),
|
||||
BeforeTestFunc: beforeTestFunc(clientStubs[6], map[string][]string{
|
||||
"c3": {`"state":"` + clientStubs[6]["c3"].Id(), `"code":"123"`},
|
||||
}),
|
||||
ExpectedStatus: http.StatusTemporaryRedirect,
|
||||
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
|
||||
app.Store().Get("cancelFunc").(context.CancelFunc)()
|
||||
|
||||
checkSuccessRedirect(t, app, res)
|
||||
|
||||
if clientStubs[6]["c3"].HasSubscription("@oauth2") {
|
||||
t.Fatalf("Expected oauth2 subscription to be removed")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user