1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-01-06 03:54:17 +02:00

Make remember and oauth2 work better together.

- Change OAuth2 extra params to not use state, but session instead.
This commit is contained in:
Aaron 2015-03-24 19:39:20 -07:00
parent e83110ee31
commit 07cbd6016f
5 changed files with 65 additions and 61 deletions

View File

@ -13,6 +13,8 @@ const (
SessionLastAction = "last_action"
// SessionOAuth2State is the xsrf protection key for oauth.
SessionOAuth2State = "oauth2_state"
// SessionOAuth2Params is the additional settings for oauth like redirection/remember.
SessionOAuth2Params = "oauth2_params"
// CookieRemember is used for cookies and form input names.
CookieRemember = "rm"

View File

@ -3,6 +3,7 @@ package oauth2
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
@ -87,14 +88,21 @@ func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) er
state := base64.URLEncoding.EncodeToString(random)
ctx.SessionStorer.Put(authboss.SessionOAuth2State, state)
var passAlongs []string
passAlongs := make(map[string]string)
for k, vals := range r.URL.Query() {
for _, val := range vals {
passAlongs = append(passAlongs, fmt.Sprintf("%s=%s", k, val))
passAlongs[k] = val
}
}
if len(passAlongs) > 0 {
state += ";" + strings.Join(passAlongs, ";")
str, err := json.Marshal(passAlongs)
if err != nil {
return err
}
ctx.SessionStorer.Put(authboss.SessionOAuth2Params, string(str))
} else {
ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
}
url := cfg.OAuth2Config.AuthCodeURL(state)
@ -114,6 +122,21 @@ var exchanger = (*oauth2.Config).Exchange
func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
provider := strings.ToLower(filepath.Base(r.URL.Path))
sessState, err := ctx.SessionStorer.GetErr(authboss.SessionOAuth2State)
ctx.SessionStorer.Del(authboss.SessionOAuth2State)
if err != nil {
return err
}
sessValues, ok := ctx.SessionStorer.Get(authboss.SessionOAuth2Params)
// Don't delete this value from session immediately, callbacks use this too
var values map[string]string
if ok {
if err := json.Unmarshal([]byte(sessValues), &values); err != nil {
return err
}
}
hasErr := r.FormValue("error")
if len(hasErr) > 0 {
if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
@ -127,12 +150,6 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
}
}
sessState, err := ctx.SessionStorer.GetErr(authboss.SessionOAuth2State)
if err != nil {
return err
}
ctx.SessionStorer.Del(authboss.SessionOAuth2State)
cfg, ok := authboss.Cfg.OAuth2Providers[provider]
if !ok {
return fmt.Errorf("OAuth2 provider %q not found", provider)
@ -183,23 +200,22 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
return nil
}
ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
redirect := authboss.Cfg.AuthLoginOKPath
values := make(url.Values)
if len(splState) > 0 {
for _, arg := range splState[1:] {
spl := strings.Split(arg, "=")
switch spl[0] {
case authboss.CookieRemember:
case authboss.FormValueRedirect:
redirect = spl[1]
default:
values.Set(spl[0], spl[1])
}
query := make(url.Values)
for k, v := range values {
switch k {
case authboss.CookieRemember:
case authboss.FormValueRedirect:
redirect = v
default:
query.Set(k, v)
}
}
if len(values) > 0 {
redirect = fmt.Sprintf("%s?%s", redirect, values.Encode())
if len(query) > 0 {
redirect = fmt.Sprintf("%s?%s", redirect, query.Encode())
}
http.Redirect(w, r, redirect, http.StatusFound)

View File

@ -6,7 +6,6 @@ import (
"net/http/httptest"
"net/url"
"path"
"sort"
"strings"
"testing"
"time"
@ -78,7 +77,7 @@ func TestOAuth2Init(t *testing.T) {
cfg.OAuth2Providers = testProviders
authboss.Cfg = cfg
r, _ := http.NewRequest("GET", "/oauth2/google?r=/my/redirect&rm=true", nil)
r, _ := http.NewRequest("GET", "/oauth2/google?redir=/my/redirect%23lol&rm=true", nil)
w := httptest.NewRecorder()
ctx := authboss.NewContext()
ctx.SessionStorer = session
@ -107,24 +106,8 @@ func TestOAuth2Init(t *testing.T) {
t.Error("It should have had some state:", loc)
}
splits := strings.Split(state, ";")
if len(splits[0]) != 44 {
t.Error("The xsrf token was wrong size:", len(splits[0]), splits[0])
}
// Maps are fun
sort.Strings(splits[1:])
if v, err := url.QueryUnescape(splits[1]); err != nil {
t.Error(err)
} else if v != "r=/my/redirect" {
t.Error("Redirect parameter not saved:", splits[1])
}
if v, err := url.QueryUnescape(splits[2]); err != nil {
t.Error(err)
} else if v != "rm=true" {
t.Error("Remember parameter not saved:", splits[2])
if params := session.Values[authboss.SessionOAuth2Params]; params != `{"redir":"/my/redirect#lol","rm":"true"}` {
t.Error("The params were wrong:", params)
}
}
@ -171,12 +154,18 @@ func TestOAuthSuccess(t *testing.T) {
}
authboss.Cfg = cfg
url := fmt.Sprintf("/oauth2/fake?code=code&state=%s", url.QueryEscape("state;redir=/myurl;rm=true;myparam=5"))
values := make(url.Values)
values.Set("code", "code")
values.Set("state", "state")
url := fmt.Sprintf("/oauth2/fake?%s", values.Encode())
r, _ := http.NewRequest("GET", url, nil)
w := httptest.NewRecorder()
ctx := authboss.NewContext()
session := mocks.NewMockClientStorer()
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
session.Put(authboss.SessionOAuth2Params, `{"redir":"/myurl?myparam=5","rm":"true"}`)
storer := mocks.NewMockStorer()
ctx.SessionStorer = session
cfg.OAuth2Storer = storer
@ -232,9 +221,9 @@ func TestOAuthXSRFFailure(t *testing.T) {
values.Set(authboss.FormValueOAuth2State, "notstate")
values.Set("code", "code")
r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil)
ctx := authboss.NewContext()
ctx.SessionStorer = session
r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil)
err := oauthCallback(ctx, nil, r)
if err != errOAuthStateValidation {
@ -253,9 +242,13 @@ func TestOAuthFailure(t *testing.T) {
values.Set("error_reason", "auth_failure")
values.Set("error_description", "Failed to auth.")
ctx := authboss.NewContext()
session := mocks.NewMockClientStorer()
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
ctx.SessionStorer = session
r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil)
err := oauthCallback(nil, nil, r)
err := oauthCallback(ctx, nil, r)
if red, ok := err.(authboss.ErrAndRedirect); !ok {
t.Error("Should be a redirect error")
} else if len(red.FlashError) == 0 {

View File

@ -6,9 +6,9 @@ import (
"crypto/md5"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"gopkg.in/authboss.v0"
)
@ -97,24 +97,18 @@ func (r *Remember) afterAuth(ctx *authboss.Context) error {
// Has to pander to horrible state variable packing to figure out if we want
// to be remembered.
func (r *Remember) afterOAuth(ctx *authboss.Context) error {
state, ok := ctx.FirstFormValue(authboss.FormValueOAuth2State)
sessValues, ok := ctx.SessionStorer.Get(authboss.SessionOAuth2Params)
if !ok {
return nil
}
splState := strings.Split(state, ";")
if len(splState) < 0 {
return nil
var values map[string]string
if err := json.Unmarshal([]byte(sessValues), &values); err != nil {
return err
}
should := false
for _, arg := range splState[1:] {
spl := strings.Split(arg, "=")
if spl[0] == authboss.CookieRemember {
should = spl[1] == "true"
break
}
}
val, ok := values[authboss.CookieRemember]
should := ok && val == "true"
if !should {
return nil

View File

@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
"net/http"
"net/url"
"testing"
"gopkg.in/authboss.v0"
@ -73,9 +72,9 @@ func TestAfterOAuth(t *testing.T) {
authboss.Cfg.Storer = storer
cookies := mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer(authboss.SessionOAuth2Params, `{"rm":"true"}`)
uri := fmt.Sprintf("%s?state=%s", "localhost/oauthed", url.QueryEscape("xsrf;rm=true"))
uri := fmt.Sprintf("%s?state=%s", "localhost/oauthed", "xsrf")
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
t.Error("Unexpected Error:", err)