1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-11-06 09:29:19 +02:00

[#4177] added graceful OAuth2 redirect error handling

This commit is contained in:
Gani Georgiev
2024-01-19 19:14:52 +02:00
parent fc18e69183
commit b2b792b763
38 changed files with 293 additions and 161 deletions

View File

@@ -656,29 +656,42 @@ func (api *recordAuthApi) unlinkExternalAuth(c echo.Context) error {
// -------------------------------------------------------------------
const oauth2SubscriptionTopic = "@oauth2"
const (
oauth2SubscriptionTopic string = "@oauth2"
oauth2FailureRedirectPath string = "../_/#/auth/oauth2-redirect-failure"
oauth2SuccessRedirectPath string = "../_/#/auth/oauth2-redirect-success"
)
type oauth2EventMessage struct {
State string `json:"state"`
Code string `json:"code"`
Error string `json:"error,omitempty"`
}
func (api *recordAuthApi) oauth2SubscriptionRedirect(c echo.Context) error {
state := c.QueryParam("state")
code := c.QueryParam("code")
if code == "" || state == "" {
return NewBadRequestError("Invalid OAuth2 redirect parameters.", nil)
if state == "" {
api.app.Logger().Debug("Missing OAuth2 state parameter")
return c.Redirect(http.StatusTemporaryRedirect, oauth2FailureRedirectPath)
}
client, err := api.app.SubscriptionsBroker().ClientById(state)
if err != nil || client.IsDiscarded() || !client.HasSubscription(oauth2SubscriptionTopic) {
return NewNotFoundError("Missing or invalid OAuth2 subscription client.", err)
api.app.Logger().Debug("Missing or invalid OAuth2 subscription client", "error", err, "clientId", state)
return c.Redirect(http.StatusTemporaryRedirect, oauth2FailureRedirectPath)
}
defer client.Unsubscribe(oauth2SubscriptionTopic)
data := map[string]string{
"state": state,
"code": code,
data := oauth2EventMessage{
State: state,
Code: c.QueryParam("code"),
Error: c.QueryParam("error"),
}
encodedData, err := json.Marshal(data)
if err != nil {
return NewBadRequestError("Failed to marshalize OAuth2 redirect data.", err)
api.app.Logger().Debug("Failed to marshalize OAuth2 redirect data", "error", err)
return c.Redirect(http.StatusTemporaryRedirect, oauth2FailureRedirectPath)
}
msg := subscriptions.Message{
@@ -688,5 +701,10 @@ func (api *recordAuthApi) oauth2SubscriptionRedirect(c echo.Context) error {
client.Send(msg)
return c.Redirect(http.StatusTemporaryRedirect, "../_/#/auth/oauth2-redirect")
if data.Error != "" || data.Code == "" {
api.app.Logger().Debug("Failed OAuth2 redirect due to an error or missing code parameter", "error", data.Error, "clientId", data.State)
return c.Redirect(http.StatusTemporaryRedirect, oauth2FailureRedirectPath)
}
return c.Redirect(http.StatusTemporaryRedirect, oauth2SuccessRedirectPath)
}

View File

@@ -1503,114 +1503,205 @@ func TestRecordAuthUnlinkExternalsAuth(t *testing.T) {
func TestRecordAuthOAuth2Redirect(t *testing.T) {
t.Parallel()
c1 := subscriptions.NewDefaultClient()
clientStubs := make([]map[string]subscriptions.Client, 0, 10)
c2 := subscriptions.NewDefaultClient()
c2.Subscribe("@oauth2")
for i := 0; i < 10; i++ {
c1 := subscriptions.NewDefaultClient()
c3 := subscriptions.NewDefaultClient()
c3.Subscribe("test1", "@oauth2")
c2 := subscriptions.NewDefaultClient()
c2.Subscribe("@oauth2")
c4 := subscriptions.NewDefaultClient()
c4.Subscribe("test1", "test2")
c3 := subscriptions.NewDefaultClient()
c3.Subscribe("test1", "@oauth2")
c5 := subscriptions.NewDefaultClient()
c5.Subscribe("@oauth2")
c5.Discard()
c4 := subscriptions.NewDefaultClient()
c4.Subscribe("test1", "test2")
beforeTestFunc := func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
app.SubscriptionsBroker().Register(c1)
app.SubscriptionsBroker().Register(c2)
app.SubscriptionsBroker().Register(c3)
app.SubscriptionsBroker().Register(c4)
app.SubscriptionsBroker().Register(c5)
c5 := subscriptions.NewDefaultClient()
c5.Subscribe("@oauth2")
c5.Discard()
clientStubs = append(clientStubs, map[string]subscriptions.Client{
"c1": c1,
"c2": c2,
"c3": c3,
"c4": c4,
"c5": c5,
})
}
checkFailureRedirect := func(t *testing.T, app *tests.TestApp, res *http.Response) {
loc := res.Header.Get("Location")
if !strings.Contains(loc, "/oauth2-redirect-failure") {
t.Fatalf("Expected failure redirect, got %q", loc)
}
}
checkSuccessRedirect := func(t *testing.T, app *tests.TestApp, res *http.Response) {
loc := res.Header.Get("Location")
if !strings.Contains(loc, "/oauth2-redirect-success") {
t.Fatalf("Expected success redirect, got %q", loc)
}
}
checkClientMessages := func(t *testing.T, clientId string, msg subscriptions.Message, expectedMessages map[string][]string) {
if len(expectedMessages[clientId]) == 0 {
t.Fatalf("Unexpected client %q message, got %s:\n%s", clientId, msg.Name, msg.Data)
}
if msg.Name != "@oauth2" {
t.Fatalf("Expected @oauth2 msg.Name, got %q", msg.Name)
}
for _, txt := range expectedMessages[clientId] {
if !strings.Contains(string(msg.Data), txt) {
t.Fatalf("Failed to find %q in \n%s", txt, msg.Data)
}
}
}
beforeTestFunc := func(
clients map[string]subscriptions.Client,
expectedMessages map[string][]string,
) func(*testing.T, *tests.TestApp, *echo.Echo) {
return func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
for _, client := range clients {
app.SubscriptionsBroker().Register(client)
}
ctx, cancelFunc := context.WithTimeout(context.Background(), 100*time.Millisecond)
// add to the app store so that it can be cancelled manually after test completion
app.Store().Set("cancelFunc", cancelFunc)
go func() {
defer cancelFunc()
for {
select {
case msg := <-clients["c1"].Channel():
checkClientMessages(t, "c1", msg, expectedMessages)
case msg := <-clients["c2"].Channel():
checkClientMessages(t, "c2", msg, expectedMessages)
case msg := <-clients["c3"].Channel():
checkClientMessages(t, "c3", msg, expectedMessages)
case msg := <-clients["c4"].Channel():
checkClientMessages(t, "c4", msg, expectedMessages)
case msg := <-clients["c5"].Channel():
checkClientMessages(t, "c5", msg, expectedMessages)
case <-ctx.Done():
for _, c := range clients {
close(c.Channel())
}
return
}
}
}()
}
}
scenarios := []tests.ApiScenario{
{
Name: "no state query param",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?code=123",
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
Name: "no state query param",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?code=123",
BeforeTestFunc: beforeTestFunc(clientStubs[0], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "no code query param",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?state=" + c3.Id(),
ExpectedStatus: 400,
ExpectedContent: []string{`"data":{}`},
Name: "invalid or missing client",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?code=123&state=missing",
BeforeTestFunc: beforeTestFunc(clientStubs[1], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "missing client",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?code=123&state=missing",
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
Name: "no code query param",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?state=" + clientStubs[2]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[2], map[string][]string{
"c3": {`"state":"` + clientStubs[2]["c3"].Id(), `"code":""`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
if clientStubs[2]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
{
Name: "discarded client with @oauth2 subscription",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?code=123&state=" + c5.Id(),
BeforeTestFunc: beforeTestFunc,
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
Name: "error query param",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?error=example&code=123&state=" + clientStubs[3]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[3], map[string][]string{
"c3": {`"state":"` + clientStubs[3]["c3"].Id(), `"code":"123"`, `"error":"example"`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
if clientStubs[3]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
{
Name: "client without @oauth2 subscription",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?code=123&state=" + c4.Id(),
BeforeTestFunc: beforeTestFunc,
ExpectedStatus: 404,
ExpectedContent: []string{`"data":{}`},
Name: "discarded client with @oauth2 subscription",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c5"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[4], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "client without @oauth2 subscription",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?code=123&state=" + clientStubs[4]["c4"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[5], nil),
ExpectedStatus: http.StatusTemporaryRedirect,
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkFailureRedirect(t, app, res)
},
},
{
Name: "client with @oauth2 subscription",
Method: http.MethodGet,
Url: "/api/oauth2-redirect?code=123&state=" + c3.Id(),
BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) {
beforeTestFunc(t, app, e)
ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
go func() {
defer cancelFunc()
L:
for {
select {
case <-c1.Channel():
t.Error("Unexpected c1 message")
break L
case <-c2.Channel():
t.Error("Unexpected c2 message")
break L
case msg := <-c3.Channel():
if msg.Name != "@oauth2" {
t.Errorf("Expected @oauth2 msg.Name, got %q", msg.Name)
}
expectedParams := []string{`"state"`, `"code"`}
for _, p := range expectedParams {
if !strings.Contains(string(msg.Data), p) {
t.Errorf("Couldn't find %s in \n%v", p, msg.Data)
}
}
break L
case <-c4.Channel():
t.Error("Unexpected c4 message")
break L
case <-c5.Channel():
t.Error("Unexpected c5 message")
break L
case <-ctx.Done():
t.Error("Context timeout reached")
break L
}
}
}()
},
Url: "/api/oauth2-redirect?code=123&state=" + clientStubs[6]["c3"].Id(),
BeforeTestFunc: beforeTestFunc(clientStubs[6], map[string][]string{
"c3": {`"state":"` + clientStubs[6]["c3"].Id(), `"code":"123"`},
}),
ExpectedStatus: http.StatusTemporaryRedirect,
AfterTestFunc: func(t *testing.T, app *tests.TestApp, res *http.Response) {
app.Store().Get("cancelFunc").(context.CancelFunc)()
checkSuccessRedirect(t, app, res)
if clientStubs[6]["c3"].HasSubscription("@oauth2") {
t.Fatalf("Expected oauth2 subscription to be removed")
}
},
},
}