mirror of
https://github.com/volatiletech/authboss.git
synced 2025-01-06 03:54:17 +02:00
Make remember and oauth2 work better together.
- Change OAuth2 extra params to not use state, but session instead.
This commit is contained in:
parent
e83110ee31
commit
07cbd6016f
@ -13,6 +13,8 @@ const (
|
||||
SessionLastAction = "last_action"
|
||||
// SessionOAuth2State is the xsrf protection key for oauth.
|
||||
SessionOAuth2State = "oauth2_state"
|
||||
// SessionOAuth2Params is the additional settings for oauth like redirection/remember.
|
||||
SessionOAuth2Params = "oauth2_params"
|
||||
|
||||
// CookieRemember is used for cookies and form input names.
|
||||
CookieRemember = "rm"
|
||||
|
@ -3,6 +3,7 @@ package oauth2
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -87,14 +88,21 @@ func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) er
|
||||
state := base64.URLEncoding.EncodeToString(random)
|
||||
ctx.SessionStorer.Put(authboss.SessionOAuth2State, state)
|
||||
|
||||
var passAlongs []string
|
||||
passAlongs := make(map[string]string)
|
||||
for k, vals := range r.URL.Query() {
|
||||
for _, val := range vals {
|
||||
passAlongs = append(passAlongs, fmt.Sprintf("%s=%s", k, val))
|
||||
passAlongs[k] = val
|
||||
}
|
||||
}
|
||||
|
||||
if len(passAlongs) > 0 {
|
||||
state += ";" + strings.Join(passAlongs, ";")
|
||||
str, err := json.Marshal(passAlongs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.SessionStorer.Put(authboss.SessionOAuth2Params, string(str))
|
||||
} else {
|
||||
ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
|
||||
}
|
||||
|
||||
url := cfg.OAuth2Config.AuthCodeURL(state)
|
||||
@ -114,6 +122,21 @@ var exchanger = (*oauth2.Config).Exchange
|
||||
func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
||||
provider := strings.ToLower(filepath.Base(r.URL.Path))
|
||||
|
||||
sessState, err := ctx.SessionStorer.GetErr(authboss.SessionOAuth2State)
|
||||
ctx.SessionStorer.Del(authboss.SessionOAuth2State)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessValues, ok := ctx.SessionStorer.Get(authboss.SessionOAuth2Params)
|
||||
// Don't delete this value from session immediately, callbacks use this too
|
||||
var values map[string]string
|
||||
if ok {
|
||||
if err := json.Unmarshal([]byte(sessValues), &values); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
hasErr := r.FormValue("error")
|
||||
if len(hasErr) > 0 {
|
||||
if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
|
||||
@ -127,12 +150,6 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
}
|
||||
|
||||
sessState, err := ctx.SessionStorer.GetErr(authboss.SessionOAuth2State)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.SessionStorer.Del(authboss.SessionOAuth2State)
|
||||
|
||||
cfg, ok := authboss.Cfg.OAuth2Providers[provider]
|
||||
if !ok {
|
||||
return fmt.Errorf("OAuth2 provider %q not found", provider)
|
||||
@ -183,23 +200,22 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
|
||||
|
||||
redirect := authboss.Cfg.AuthLoginOKPath
|
||||
values := make(url.Values)
|
||||
if len(splState) > 0 {
|
||||
for _, arg := range splState[1:] {
|
||||
spl := strings.Split(arg, "=")
|
||||
switch spl[0] {
|
||||
case authboss.CookieRemember:
|
||||
case authboss.FormValueRedirect:
|
||||
redirect = spl[1]
|
||||
default:
|
||||
values.Set(spl[0], spl[1])
|
||||
}
|
||||
query := make(url.Values)
|
||||
for k, v := range values {
|
||||
switch k {
|
||||
case authboss.CookieRemember:
|
||||
case authboss.FormValueRedirect:
|
||||
redirect = v
|
||||
default:
|
||||
query.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
redirect = fmt.Sprintf("%s?%s", redirect, values.Encode())
|
||||
if len(query) > 0 {
|
||||
redirect = fmt.Sprintf("%s?%s", redirect, query.Encode())
|
||||
}
|
||||
|
||||
http.Redirect(w, r, redirect, http.StatusFound)
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -78,7 +77,7 @@ func TestOAuth2Init(t *testing.T) {
|
||||
cfg.OAuth2Providers = testProviders
|
||||
authboss.Cfg = cfg
|
||||
|
||||
r, _ := http.NewRequest("GET", "/oauth2/google?r=/my/redirect&rm=true", nil)
|
||||
r, _ := http.NewRequest("GET", "/oauth2/google?redir=/my/redirect%23lol&rm=true", nil)
|
||||
w := httptest.NewRecorder()
|
||||
ctx := authboss.NewContext()
|
||||
ctx.SessionStorer = session
|
||||
@ -107,24 +106,8 @@ func TestOAuth2Init(t *testing.T) {
|
||||
t.Error("It should have had some state:", loc)
|
||||
}
|
||||
|
||||
splits := strings.Split(state, ";")
|
||||
if len(splits[0]) != 44 {
|
||||
t.Error("The xsrf token was wrong size:", len(splits[0]), splits[0])
|
||||
}
|
||||
|
||||
// Maps are fun
|
||||
sort.Strings(splits[1:])
|
||||
|
||||
if v, err := url.QueryUnescape(splits[1]); err != nil {
|
||||
t.Error(err)
|
||||
} else if v != "r=/my/redirect" {
|
||||
t.Error("Redirect parameter not saved:", splits[1])
|
||||
}
|
||||
|
||||
if v, err := url.QueryUnescape(splits[2]); err != nil {
|
||||
t.Error(err)
|
||||
} else if v != "rm=true" {
|
||||
t.Error("Remember parameter not saved:", splits[2])
|
||||
if params := session.Values[authboss.SessionOAuth2Params]; params != `{"redir":"/my/redirect#lol","rm":"true"}` {
|
||||
t.Error("The params were wrong:", params)
|
||||
}
|
||||
}
|
||||
|
||||
@ -171,12 +154,18 @@ func TestOAuthSuccess(t *testing.T) {
|
||||
}
|
||||
authboss.Cfg = cfg
|
||||
|
||||
url := fmt.Sprintf("/oauth2/fake?code=code&state=%s", url.QueryEscape("state;redir=/myurl;rm=true;myparam=5"))
|
||||
values := make(url.Values)
|
||||
values.Set("code", "code")
|
||||
values.Set("state", "state")
|
||||
|
||||
url := fmt.Sprintf("/oauth2/fake?%s", values.Encode())
|
||||
r, _ := http.NewRequest("GET", url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
ctx := authboss.NewContext()
|
||||
session := mocks.NewMockClientStorer()
|
||||
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
|
||||
session.Put(authboss.SessionOAuth2Params, `{"redir":"/myurl?myparam=5","rm":"true"}`)
|
||||
|
||||
storer := mocks.NewMockStorer()
|
||||
ctx.SessionStorer = session
|
||||
cfg.OAuth2Storer = storer
|
||||
@ -232,9 +221,9 @@ func TestOAuthXSRFFailure(t *testing.T) {
|
||||
values.Set(authboss.FormValueOAuth2State, "notstate")
|
||||
values.Set("code", "code")
|
||||
|
||||
r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil)
|
||||
ctx := authboss.NewContext()
|
||||
ctx.SessionStorer = session
|
||||
r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil)
|
||||
|
||||
err := oauthCallback(ctx, nil, r)
|
||||
if err != errOAuthStateValidation {
|
||||
@ -253,9 +242,13 @@ func TestOAuthFailure(t *testing.T) {
|
||||
values.Set("error_reason", "auth_failure")
|
||||
values.Set("error_description", "Failed to auth.")
|
||||
|
||||
ctx := authboss.NewContext()
|
||||
session := mocks.NewMockClientStorer()
|
||||
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
|
||||
ctx.SessionStorer = session
|
||||
r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil)
|
||||
|
||||
err := oauthCallback(nil, nil, r)
|
||||
err := oauthCallback(ctx, nil, r)
|
||||
if red, ok := err.(authboss.ErrAndRedirect); !ok {
|
||||
t.Error("Should be a redirect error")
|
||||
} else if len(red.FlashError) == 0 {
|
||||
|
@ -6,9 +6,9 @@ import (
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/authboss.v0"
|
||||
)
|
||||
@ -97,24 +97,18 @@ func (r *Remember) afterAuth(ctx *authboss.Context) error {
|
||||
// Has to pander to horrible state variable packing to figure out if we want
|
||||
// to be remembered.
|
||||
func (r *Remember) afterOAuth(ctx *authboss.Context) error {
|
||||
state, ok := ctx.FirstFormValue(authboss.FormValueOAuth2State)
|
||||
sessValues, ok := ctx.SessionStorer.Get(authboss.SessionOAuth2Params)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
splState := strings.Split(state, ";")
|
||||
if len(splState) < 0 {
|
||||
return nil
|
||||
var values map[string]string
|
||||
if err := json.Unmarshal([]byte(sessValues), &values); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
should := false
|
||||
for _, arg := range splState[1:] {
|
||||
spl := strings.Split(arg, "=")
|
||||
if spl[0] == authboss.CookieRemember {
|
||||
should = spl[1] == "true"
|
||||
break
|
||||
}
|
||||
}
|
||||
val, ok := values[authboss.CookieRemember]
|
||||
should := ok && val == "true"
|
||||
|
||||
if !should {
|
||||
return nil
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/authboss.v0"
|
||||
@ -73,9 +72,9 @@ func TestAfterOAuth(t *testing.T) {
|
||||
authboss.Cfg.Storer = storer
|
||||
|
||||
cookies := mocks.NewMockClientStorer()
|
||||
session := mocks.NewMockClientStorer()
|
||||
session := mocks.NewMockClientStorer(authboss.SessionOAuth2Params, `{"rm":"true"}`)
|
||||
|
||||
uri := fmt.Sprintf("%s?state=%s", "localhost/oauthed", url.QueryEscape("xsrf;rm=true"))
|
||||
uri := fmt.Sprintf("%s?state=%s", "localhost/oauthed", "xsrf")
|
||||
req, err := http.NewRequest("GET", uri, nil)
|
||||
if err != nil {
|
||||
t.Error("Unexpected Error:", err)
|
||||
|
Loading…
Reference in New Issue
Block a user