mirror of
https://github.com/volatiletech/authboss.git
synced 2025-02-03 13:21:22 +02:00
Make OAuth2 implementation less shoddy.
- Add a new storer specifically for OAuth2 to enable clients to choose regular database storing OR Oauth2 but not have to have both. - Stop storing OAuth2 credentials in a combined form inside username. - Add new events to capture OAuth events just like auth. - Have pass-through parameters for OAuth init urls, this allows us to pass additional behavior options (redirects and remember me) as well as other things that should be present on the page that is redirected to. - Context.LoadUser is now OAuth aware. - Remember's callbacks now include an OAuth check to see if a horribly packed state variable contains a flag to say that we want to be remembered. - Change the OAuth2 Callback to use Attributes instead of that custom struct to allow people to append whatever attributes they want into the user that will be saved.
This commit is contained in:
parent
082caf88b3
commit
06edd2e615
@ -13,7 +13,9 @@ type Event int
|
||||
const (
|
||||
EventRegister Event = iota
|
||||
EventAuth
|
||||
EventOAuth
|
||||
EventAuthFail
|
||||
EventOAuthFail
|
||||
EventRecoverStart
|
||||
EventRecoverEnd
|
||||
EventGet
|
||||
@ -21,14 +23,15 @@ const (
|
||||
EventPasswordReset
|
||||
)
|
||||
|
||||
const eventNames = "EventRegisterEventAuthEventAuthFailEventRecoverStartEventRecoverEndEventGetEventGetUserSessionEventPasswordReset"
|
||||
const eventNames = "EventRegisterEventAuthEventOAuthEventAuthFailEventOAuthFailEventRecoverStartEventRecoverEndEventGetEventGetUserSessionEventPasswordReset"
|
||||
|
||||
var eventIndexes = [...]uint8{0, 13, 22, 35, 52, 67, 75, 94, 112}
|
||||
var eventIndexes = [...]uint8{0, 13, 22, 32, 45, 59, 76, 91, 99, 118, 136}
|
||||
|
||||
func (i Event) String() string {
|
||||
if i < 0 || i+1 >= Event(len(eventIndexes)) {
|
||||
return fmt.Sprintf("Event(%d)", i)
|
||||
}
|
||||
|
||||
return eventNames[eventIndexes[i]:eventIndexes[i+1]]
|
||||
}
|
||||
|
||||
|
@ -160,7 +160,9 @@ func TestEventString(t *testing.T) {
|
||||
}{
|
||||
{EventRegister, "EventRegister"},
|
||||
{EventAuth, "EventAuth"},
|
||||
{EventOAuth, "EventOAuth"},
|
||||
{EventAuthFail, "EventAuthFail"},
|
||||
{EventOAuthFail, "EventOAuthFail"},
|
||||
{EventRecoverStart, "EventRecoverStart"},
|
||||
{EventRecoverEnd, "EventRecoverEnd"},
|
||||
{EventGet, "EventGet"},
|
||||
|
@ -68,6 +68,7 @@ type Config struct {
|
||||
XSRFMaker XSRF
|
||||
|
||||
Storer Storer
|
||||
OAuth2Storer OAuth2Storer
|
||||
CookieStoreMaker CookieStoreMaker
|
||||
SessionStoreMaker SessionStoreMaker
|
||||
LogWriter io.Writer
|
||||
|
16
context.go
16
context.go
@ -9,6 +9,11 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
FormValueRedirect = "redir"
|
||||
FormValueOAuth2State = "state"
|
||||
)
|
||||
|
||||
// Context provides context for module operations and callbacks. One obvious
|
||||
// need for context is a request's session store. It is not safe for use by
|
||||
// multiple goroutines.
|
||||
@ -100,12 +105,19 @@ func (c *Context) LoadUser(key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
intf, err := Cfg.Storer.Get(key, ModuleAttrMeta)
|
||||
var user interface{}
|
||||
var err error
|
||||
|
||||
if index := strings.IndexByte(key, ';'); index > 0 {
|
||||
user, err = Cfg.OAuth2Storer.GetOAuth(key[:index], key[index+1:], ModuleAttrMeta)
|
||||
} else {
|
||||
user, err = Cfg.Storer.Get(key, ModuleAttrMeta)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.User = Unbind(intf)
|
||||
c.User = Unbind(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -95,22 +95,51 @@ func TestContext_SaveUser(t *testing.T) {
|
||||
func TestContext_LoadUser(t *testing.T) {
|
||||
Cfg = NewConfig()
|
||||
ctx := NewContext()
|
||||
|
||||
attr := Attributes{
|
||||
"email": "hello@joe.com",
|
||||
"password": "mysticalhash",
|
||||
"uid": "what",
|
||||
"provider": "google",
|
||||
}
|
||||
|
||||
storer := mockStorer{
|
||||
"joe": Attributes{"email": "hello@joe.com", "password": "mysticalhash"},
|
||||
"joe": attr,
|
||||
"whatgoogle": attr,
|
||||
}
|
||||
Cfg.Storer = storer
|
||||
Cfg.OAuth2Storer = storer
|
||||
|
||||
err := ctx.LoadUser("joe")
|
||||
if err != nil {
|
||||
ctx.User = nil
|
||||
if err := ctx.LoadUser("joe"); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
attr := storer["joe"]
|
||||
if email, err := ctx.User.StringErr("email"); err != nil {
|
||||
t.Error(err)
|
||||
} else if email != attr["email"] {
|
||||
t.Error("Email wrong:", email)
|
||||
}
|
||||
if password, err := ctx.User.StringErr("password"); err != nil {
|
||||
t.Error(err)
|
||||
} else if password != attr["password"] {
|
||||
t.Error("Password wrong:", password)
|
||||
}
|
||||
|
||||
for k, v := range attr {
|
||||
if v != ctx.User[k] {
|
||||
t.Error(v, "not equal to", ctx.User[k])
|
||||
}
|
||||
ctx.User = nil
|
||||
if err := ctx.LoadUser("what;google"); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
if email, err := ctx.User.StringErr("email"); err != nil {
|
||||
t.Error(err)
|
||||
} else if email != attr["email"] {
|
||||
t.Error("Email wrong:", email)
|
||||
}
|
||||
if password, err := ctx.User.StringErr("password"); err != nil {
|
||||
t.Error(err)
|
||||
} else if password != attr["password"] {
|
||||
t.Error("Password wrong:", password)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,6 +91,39 @@ func (m *MockStorer) Get(key string, attrMeta authboss.AttributeMeta) (result in
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (m *MockStorer) PutOAuth(uid, provider string, attr authboss.Attributes) error {
|
||||
if len(m.PutErr) > 0 {
|
||||
return errors.New(m.PutErr)
|
||||
}
|
||||
|
||||
if _, ok := m.Users[uid+provider]; !ok {
|
||||
m.Users[uid+provider] = attr
|
||||
return nil
|
||||
}
|
||||
for k, v := range attr {
|
||||
m.Users[uid+provider][k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockStorer) GetOAuth(uid, provider string, attrMeta authboss.AttributeMeta) (result interface{}, err error) {
|
||||
if len(m.GetErr) > 0 {
|
||||
return nil, errors.New(m.GetErr)
|
||||
}
|
||||
|
||||
userAttrs, ok := m.Users[uid+provider]
|
||||
if !ok {
|
||||
return nil, authboss.ErrUserNotFound
|
||||
}
|
||||
|
||||
u := &MockUser{}
|
||||
if err := userAttrs.Bind(u, true); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (m *MockStorer) AddToken(key, token string) error {
|
||||
if len(m.AddTokenErr) > 0 {
|
||||
return errors.New(m.AddTokenErr)
|
||||
|
@ -29,6 +29,17 @@ func (m mockStorer) Get(key string, attrMeta AttributeMeta) (result interface{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m mockStorer) PutOAuth(uid, provider string, attr Attributes) error {
|
||||
m[uid+provider] = attr
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockStorer) GetOAuth(uid, provider string, attrMeta AttributeMeta) (result interface{}, err error) {
|
||||
return &mockUser{
|
||||
m[uid+provider]["email"].(string), m[uid+provider]["password"].(string),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mockClientStore map[string]string
|
||||
|
||||
func (m mockClientStore) Get(key string) (string, bool) {
|
||||
|
34
oauth2.go
34
oauth2.go
@ -6,20 +6,28 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OAuth2Provider is the entire configuration
|
||||
// required to authenticate with this provider.
|
||||
//
|
||||
// The OAuth2Config does not need a redirect URL because it will
|
||||
// be automatically created by the
|
||||
/*
|
||||
OAuth2Provider is the entire configuration
|
||||
required to authenticate with this provider.
|
||||
|
||||
The OAuth2Config does not need a redirect URL because it will
|
||||
be automatically created by the route registration in oauth2 module.
|
||||
|
||||
AdditionalParams can be used to specify extra parameters to tack on to the
|
||||
end of the initial request, this allows for provider specific oauth options
|
||||
like access_type=offline to be passed to the provider.
|
||||
|
||||
Callback gives the config and the token allowing an http client using the
|
||||
authenticated token to be created. Because each OAuth2 implementation has a different
|
||||
API this must be handled for each provider separately. It is used to return two things
|
||||
specifically: UID (the ID according to the provider) and the Email address.
|
||||
The UID must be passed back or there will be an error as it is the means of identifying the
|
||||
user in the system, e-mail is optional but should be returned in systems using
|
||||
emailing. The keys authboss.StoreOAuth2UID and authboss.StoreEmail can be used to set
|
||||
these values in the authboss.Attributes map returned by the callback.
|
||||
*/
|
||||
type OAuthProvider struct {
|
||||
OAuth2Config *oauth2.Config
|
||||
AdditionalParams url.Values
|
||||
Callback func(oauth2.Config, *oauth2.Token) (OAuth2Credentials, error)
|
||||
}
|
||||
|
||||
// OAuth2Credentials are used to store in the database.
|
||||
// Email is optional
|
||||
type OAuth2Credentials struct {
|
||||
UID string
|
||||
Email string
|
||||
Callback func(oauth2.Config, *oauth2.Token) (Attributes, error)
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -18,15 +19,6 @@ var (
|
||||
errOAuthStateValidation = errors.New("Could not validate oauth2 state param")
|
||||
)
|
||||
|
||||
// OAuth2Storer is required to do OAuth2 storing.
|
||||
type OAuth2Storer interface {
|
||||
authboss.Storer
|
||||
// NewOrUpdate should retrieve the user if he already exists, or create
|
||||
// a new one. The key is composed of the provider:UID together and is stored
|
||||
// in the authboss.StoreUsername field.
|
||||
OAuth2NewOrUpdate(key string, attr authboss.Attributes) error
|
||||
}
|
||||
|
||||
type OAuth2 struct{}
|
||||
|
||||
func init() {
|
||||
@ -34,7 +26,7 @@ func init() {
|
||||
}
|
||||
|
||||
func (o *OAuth2) Initialize() error {
|
||||
if _, ok := authboss.Cfg.Storer.(OAuth2Storer); !ok {
|
||||
if authboss.Cfg.OAuth2Storer == nil {
|
||||
return errors.New("oauth2: need an OAuth2Storer")
|
||||
}
|
||||
return nil
|
||||
@ -65,11 +57,12 @@ func (o *OAuth2) Routes() authboss.RouteTable {
|
||||
|
||||
func (o *OAuth2) Storage() authboss.StorageOptions {
|
||||
return authboss.StorageOptions{
|
||||
authboss.StoreUsername: authboss.String,
|
||||
authboss.StoreEmail: authboss.String,
|
||||
authboss.StoreOAuth2Token: authboss.String,
|
||||
authboss.StoreOAuth2Refresh: authboss.String,
|
||||
authboss.StoreOAuth2Expiry: authboss.DateTime,
|
||||
authboss.StoreEmail: authboss.String,
|
||||
authboss.StoreOAuth2UID: authboss.String,
|
||||
authboss.StoreOAuth2Provider: authboss.String,
|
||||
authboss.StoreOAuth2Token: authboss.String,
|
||||
authboss.StoreOAuth2Refresh: authboss.String,
|
||||
authboss.StoreOAuth2Expiry: authboss.DateTime,
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,6 +82,16 @@ func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) er
|
||||
state := base64.URLEncoding.EncodeToString(random)
|
||||
ctx.SessionStorer.Put(authboss.SessionOAuth2State, state)
|
||||
|
||||
var passAlongs []string
|
||||
for k, vals := range r.URL.Query() {
|
||||
for _, val := range vals {
|
||||
passAlongs = append(passAlongs, fmt.Sprintf("%s=%s", k, val))
|
||||
}
|
||||
}
|
||||
if len(passAlongs) > 0 {
|
||||
state += ";" + strings.Join(passAlongs, ";")
|
||||
}
|
||||
|
||||
url := cfg.OAuth2Config.AuthCodeURL(state)
|
||||
|
||||
extraParams := cfg.AdditionalParams.Encode()
|
||||
@ -108,6 +111,10 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
|
||||
|
||||
hasErr := r.FormValue("error")
|
||||
if len(hasErr) > 0 {
|
||||
if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return authboss.ErrAndRedirect{
|
||||
Err: errors.New(r.FormValue("error_reason")),
|
||||
Location: authboss.Cfg.AuthLoginFailPath,
|
||||
@ -126,8 +133,9 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// Ensure request is genuine
|
||||
state := r.FormValue("state")
|
||||
if state != sessState {
|
||||
state := r.FormValue(authboss.FormValueOAuth2State)
|
||||
splState := strings.Split(state, ";")
|
||||
if len(splState) == 0 || splState[0] != sessState {
|
||||
return errOAuthStateValidation
|
||||
}
|
||||
|
||||
@ -138,32 +146,55 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
|
||||
return fmt.Errorf("Could not validate oauth2 code: %v", err)
|
||||
}
|
||||
|
||||
credentials, err := cfg.Callback(*cfg.OAuth2Config, token)
|
||||
user, err := cfg.Callback(*cfg.OAuth2Config, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// User is authenticated
|
||||
key := fmt.Sprintf("%s:%s", provider, credentials.UID)
|
||||
user := make(authboss.Attributes)
|
||||
user[authboss.StoreUsername] = key
|
||||
// OAuth2UID is required.
|
||||
uid, err := user.StringErr(authboss.StoreOAuth2UID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user[authboss.StoreOAuth2UID] = uid
|
||||
user[authboss.StoreOAuth2Provider] = provider
|
||||
user[authboss.StoreOAuth2Expiry] = token.Expiry
|
||||
user[authboss.StoreOAuth2Token] = token.AccessToken
|
||||
if len(token.RefreshToken) != 0 {
|
||||
user[authboss.StoreOAuth2Refresh] = token.RefreshToken
|
||||
}
|
||||
if len(credentials.Email) > 0 {
|
||||
user[authboss.StoreEmail] = credentials.Email
|
||||
}
|
||||
|
||||
// Log user in
|
||||
ctx.SessionStorer.Put(authboss.SessionKey, key)
|
||||
|
||||
storer := authboss.Cfg.Storer.(OAuth2Storer)
|
||||
if err = storer.OAuth2NewOrUpdate(key, user); err != nil {
|
||||
if err = authboss.Cfg.OAuth2Storer.PutOAuth(uid, provider, user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
http.Redirect(w, r, authboss.Cfg.AuthLoginOKPath, http.StatusFound)
|
||||
// Log user in
|
||||
ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider))
|
||||
|
||||
if err = authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuth, ctx); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
redirect := authboss.Cfg.AuthLoginOKPath
|
||||
values := make(url.Values)
|
||||
if len(splState) > 0 {
|
||||
for _, arg := range splState[1:] {
|
||||
spl := strings.Split(arg, "=")
|
||||
switch spl[0] {
|
||||
case authboss.CookieRemember:
|
||||
case authboss.FormValueRedirect:
|
||||
redirect = spl[1]
|
||||
default:
|
||||
values.Set(spl[0], spl[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
redirect = fmt.Sprintf("%s?%s", redirect, values.Encode())
|
||||
}
|
||||
|
||||
http.Redirect(w, r, redirect, http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
@ -1,10 +1,12 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -30,7 +32,7 @@ var testProviders = map[string]authboss.OAuthProvider{
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
authboss.Cfg = authboss.NewConfig()
|
||||
authboss.Cfg.Storer = mocks.NewMockStorer()
|
||||
authboss.Cfg.OAuth2Storer = mocks.NewMockStorer()
|
||||
o := OAuth2{}
|
||||
if err := o.Initialize(); err != nil {
|
||||
t.Error(err)
|
||||
@ -76,7 +78,7 @@ func TestOAuth2Init(t *testing.T) {
|
||||
cfg.OAuth2Providers = testProviders
|
||||
authboss.Cfg = cfg
|
||||
|
||||
r, _ := http.NewRequest("GET", "/oauth2/google", nil)
|
||||
r, _ := http.NewRequest("GET", "/oauth2/google?r=/my/redirect&rm=true", nil)
|
||||
w := httptest.NewRecorder()
|
||||
ctx := authboss.NewContext()
|
||||
ctx.SessionStorer = session
|
||||
@ -100,6 +102,30 @@ func TestOAuth2Init(t *testing.T) {
|
||||
if query["include_requested_scopes"][0] != "true" {
|
||||
t.Error("Missing extra parameters:", loc)
|
||||
}
|
||||
state := query[authboss.FormValueOAuth2State][0]
|
||||
if len(state) == 0 {
|
||||
t.Error("It should have had some state:", loc)
|
||||
}
|
||||
|
||||
splits := strings.Split(state, ";")
|
||||
if len(splits[0]) != 44 {
|
||||
t.Error("The xsrf token was wrong size:", len(splits[0]), splits[0])
|
||||
}
|
||||
|
||||
// Maps are fun
|
||||
sort.Strings(splits[1:])
|
||||
|
||||
if v, err := url.QueryUnescape(splits[1]); err != nil {
|
||||
t.Error(err)
|
||||
} else if v != "r=/my/redirect" {
|
||||
t.Error("Redirect parameter not saved:", splits[1])
|
||||
}
|
||||
|
||||
if v, err := url.QueryUnescape(splits[2]); err != nil {
|
||||
t.Error(err)
|
||||
} else if v != "rm=true" {
|
||||
t.Error("Remember parameter not saved:", splits[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthSuccess(t *testing.T) {
|
||||
@ -113,10 +139,10 @@ func TestOAuthSuccess(t *testing.T) {
|
||||
Expiry: expiry,
|
||||
}
|
||||
|
||||
fakeCallback := func(_ oauth2.Config, _ *oauth2.Token) (authboss.OAuth2Credentials, error) {
|
||||
return authboss.OAuth2Credentials{
|
||||
UID: "uid",
|
||||
Email: "email",
|
||||
fakeCallback := func(_ oauth2.Config, _ *oauth2.Token) (authboss.Attributes, error) {
|
||||
return authboss.Attributes{
|
||||
authboss.StoreOAuth2UID: "uid",
|
||||
authboss.StoreEmail: "email",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -142,29 +168,27 @@ func TestOAuthSuccess(t *testing.T) {
|
||||
}
|
||||
authboss.Cfg = cfg
|
||||
|
||||
r, _ := http.NewRequest("GET", "/oauth2/fake?code=code&state=state", nil)
|
||||
url := fmt.Sprintf("/oauth2/fake?code=code&state=%s", url.QueryEscape("state;redir=/myurl;rm=true;myparam=5"))
|
||||
r, _ := http.NewRequest("GET", url, nil)
|
||||
w := httptest.NewRecorder()
|
||||
ctx := authboss.NewContext()
|
||||
session := mocks.NewMockClientStorer()
|
||||
session.Put(authboss.SessionOAuth2State, "state")
|
||||
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
|
||||
storer := mocks.NewMockStorer()
|
||||
ctx.SessionStorer = session
|
||||
cfg.Storer = storer
|
||||
cfg.OAuth2Storer = storer
|
||||
cfg.AuthLoginOKPath = "/fakeloginok"
|
||||
|
||||
if err := oauthCallback(ctx, w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
key := "fake:uid"
|
||||
key := "uidfake"
|
||||
user, ok := storer.Users[key]
|
||||
if !ok {
|
||||
t.Error("Couldn't find user.")
|
||||
}
|
||||
|
||||
if val, _ := user.String(authboss.StoreUsername); val != key {
|
||||
t.Error("Username was wrong:", val)
|
||||
}
|
||||
if val, _ := user.String(authboss.StoreEmail); val != "email" {
|
||||
t.Error("Email was wrong:", val)
|
||||
}
|
||||
@ -178,13 +202,13 @@ func TestOAuthSuccess(t *testing.T) {
|
||||
t.Error("Expiry was wrong:", val)
|
||||
}
|
||||
|
||||
if val, _ := session.Get(authboss.SessionKey); val != key {
|
||||
if val, _ := session.Get(authboss.SessionKey); val != "uid;fake" {
|
||||
t.Error("User was not logged in:", val)
|
||||
}
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Error("It should redirect")
|
||||
} else if loc := w.Header().Get("Location"); loc != authboss.Cfg.AuthLoginOKPath {
|
||||
} else if loc := w.Header().Get("Location"); loc != "/myurl?myparam=5" {
|
||||
t.Error("Redirect is wrong:", loc)
|
||||
}
|
||||
}
|
||||
@ -193,13 +217,13 @@ func TestOAuthXSRFFailure(t *testing.T) {
|
||||
cfg := authboss.NewConfig()
|
||||
|
||||
session := mocks.NewMockClientStorer()
|
||||
session.Put(authboss.SessionOAuth2State, "state")
|
||||
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
|
||||
|
||||
cfg.OAuth2Providers = testProviders
|
||||
authboss.Cfg = cfg
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("state", "notstate")
|
||||
values.Set(authboss.FormValueOAuth2State, "notstate")
|
||||
values.Set("code", "code")
|
||||
|
||||
r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil)
|
||||
|
@ -26,21 +26,22 @@ type googleMeResponse struct {
|
||||
var clientGet = (*http.Client).Get
|
||||
|
||||
// Google is a callback appropriate for use with Google's OAuth2 configuration.
|
||||
func Google(cfg oauth2.Config, token *oauth2.Token) (cred authboss.OAuth2Credentials, err error) {
|
||||
func Google(cfg oauth2.Config, token *oauth2.Token) (authboss.Attributes, error) {
|
||||
client := cfg.Client(oauth2.NoContext, token)
|
||||
resp, err := clientGet(client, googleInfoEndpoint)
|
||||
if err != nil {
|
||||
return cred, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
dec := json.NewDecoder(resp.Body)
|
||||
var jsonResp googleMeResponse
|
||||
if err = dec.Decode(&jsonResp); err != nil {
|
||||
return cred, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cred.UID = jsonResp.ID
|
||||
cred.Email = jsonResp.Email
|
||||
return cred, nil
|
||||
return authboss.Attributes{
|
||||
authboss.StoreOAuth2UID: jsonResp.ID,
|
||||
authboss.StoreEmail: jsonResp.Email,
|
||||
}, nil
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/authboss.v0"
|
||||
)
|
||||
|
||||
func TestGoogle(t *testing.T) {
|
||||
@ -30,15 +31,15 @@ func TestGoogle(t *testing.T) {
|
||||
Expiry: time.Now().Add(60 * time.Minute),
|
||||
}
|
||||
|
||||
cred, err := Google(cfg, tok)
|
||||
user, err := Google(cfg, tok)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if cred.UID != "id" {
|
||||
t.Error("UID wrong:", cred.UID)
|
||||
if uid, ok := user[authboss.StoreOAuth2UID]; !ok || uid != "id" {
|
||||
t.Error("UID wrong:", uid)
|
||||
}
|
||||
if cred.Email != "email" {
|
||||
t.Error("Email wrong:", cred.Email)
|
||||
if email, ok := user[authboss.StoreEmail]; !ok || email != "email" {
|
||||
t.Error("Email wrong:", email)
|
||||
}
|
||||
}
|
||||
|
@ -251,7 +251,7 @@ func TestRecover_sendRecoverEmail(t *testing.T) {
|
||||
|
||||
mailer := mocks.NewMockMailer()
|
||||
authboss.Cfg.EmailSubjectPrefix = "foo "
|
||||
authboss.Cfg.HostName = "bar"
|
||||
authboss.Cfg.RootURL = "bar"
|
||||
authboss.Cfg.Mailer = mailer
|
||||
|
||||
a.sendRecoverEmail("a@b.c", "abc=")
|
||||
@ -265,7 +265,7 @@ func TestRecover_sendRecoverEmail(t *testing.T) {
|
||||
t.Error("Unexpected subject:", mailer.Last.Subject)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/recover/complete?token=abc=", authboss.Cfg.HostName)
|
||||
url := fmt.Sprintf("%s/recover/complete?token=abc=", authboss.Cfg.RootURL)
|
||||
if !strings.Contains(mailer.Last.HTMLBody, url) {
|
||||
t.Error("Expected HTMLBody to contain url:", url)
|
||||
}
|
||||
@ -409,11 +409,11 @@ func TestRecover_completeHanlderFunc_POST(t *testing.T) {
|
||||
}
|
||||
|
||||
if !cbCalled {
|
||||
t.Error("Expected EventPasswordReste callback to have been fired")
|
||||
t.Error("Expected EventPasswordReset callback to have been fired")
|
||||
}
|
||||
|
||||
if val, ok := sessionStorer.Get(authboss.SessionKey); !ok || val != "john" {
|
||||
t.Errorf("Ecxpected SessionKey to be:", "john")
|
||||
t.Error("Expected SessionKey to be:", "john")
|
||||
}
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/authboss.v0"
|
||||
)
|
||||
@ -50,11 +51,12 @@ func (r *Remember) Initialize() error {
|
||||
}
|
||||
|
||||
if _, ok := authboss.Cfg.Storer.(TokenStorer); !ok {
|
||||
return errors.New("remember: TokenStorer required for remember me functionality")
|
||||
return errors.New("remember: TokenStorer required for remember functionality")
|
||||
}
|
||||
|
||||
authboss.Cfg.Callbacks.Before(authboss.EventGetUserSession, r.auth)
|
||||
authboss.Cfg.Callbacks.After(authboss.EventAuth, r.afterAuth)
|
||||
authboss.Cfg.Callbacks.After(authboss.EventOAuth, r.afterOAuth)
|
||||
authboss.Cfg.Callbacks.After(authboss.EventPasswordReset, r.afterPassword)
|
||||
|
||||
return nil
|
||||
@ -90,6 +92,53 @@ func (r *Remember) afterAuth(ctx *authboss.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// afterOAuth is called after oauth authentication is successful.
|
||||
// Has to pander to horrible state variable packing to figure out if we want
|
||||
// to be remembered.
|
||||
func (r *Remember) afterOAuth(ctx *authboss.Context) error {
|
||||
state, ok := ctx.FirstFormValue(authboss.FormValueOAuth2State)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
splState := strings.Split(state, ";")
|
||||
if len(splState) < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
should := false
|
||||
for _, arg := range splState[1:] {
|
||||
spl := strings.Split(arg, "=")
|
||||
if spl[0] == authboss.CookieRemember {
|
||||
should = spl[1] == "true"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !should {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ctx.User == nil {
|
||||
return errUserMissing
|
||||
}
|
||||
|
||||
uid, err := ctx.User.StringErr(authboss.StoreOAuth2Provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
provider, err := ctx.User.StringErr(authboss.StoreOAuth2Provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := r.new(ctx.CookieStorer, uid+";"+provider); err != nil {
|
||||
return fmt.Errorf("remember: Failed to create remember token: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// afterPassword is called after the password has been reset.
|
||||
func (r *Remember) afterPassword(ctx *authboss.Context) error {
|
||||
if ctx.User == nil {
|
||||
@ -157,7 +206,7 @@ func (r *Remember) auth(ctx *authboss.Context) (authboss.Interrupt, error) {
|
||||
|
||||
index := bytes.IndexByte(token, ';')
|
||||
if index < 0 {
|
||||
return authboss.InterruptNone, errors.New("remember: Invalid remember me token.")
|
||||
return authboss.InterruptNone, errors.New("remember: Invalid remember token.")
|
||||
}
|
||||
|
||||
// Get the key.
|
||||
|
@ -2,7 +2,9 @@ package remember
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/authboss.v0"
|
||||
@ -64,6 +66,42 @@ func TestAfterAuth(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAfterOAuth(t *testing.T) {
|
||||
r := Remember{}
|
||||
authboss.NewConfig()
|
||||
storer := mocks.NewMockStorer()
|
||||
authboss.Cfg.Storer = storer
|
||||
|
||||
cookies := mocks.NewMockClientStorer()
|
||||
session := mocks.NewMockClientStorer()
|
||||
|
||||
uri := fmt.Sprintf("%s?state=%s", "localhost/oauthed", url.QueryEscape("xsrf;rm=true"))
|
||||
req, err := http.NewRequest("GET", uri, nil)
|
||||
if err != nil {
|
||||
t.Error("Unexpected Error:", err)
|
||||
}
|
||||
|
||||
ctx, err := authboss.ContextFromRequest(req)
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
ctx.SessionStorer = session
|
||||
ctx.CookieStorer = cookies
|
||||
ctx.User = authboss.Attributes{
|
||||
authboss.StoreOAuth2UID: "uid",
|
||||
authboss.StoreOAuth2Provider: "google",
|
||||
}
|
||||
|
||||
if err := r.afterOAuth(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, ok := cookies.Values[authboss.CookieRemember]; !ok {
|
||||
t.Error("Expected a cookie to have been set.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAfterPasswordReset(t *testing.T) {
|
||||
r := Remember{}
|
||||
authboss.NewConfig()
|
||||
|
21
storer.go
21
storer.go
@ -20,9 +20,11 @@ const (
|
||||
|
||||
// Data store constants for OAuth2 attribute names.
|
||||
const (
|
||||
StoreOAuth2Token = "oauth2_token"
|
||||
StoreOAuth2Refresh = "oauth2_refresh"
|
||||
StoreOAuth2Expiry = "oauth2_expiry"
|
||||
StoreOAuth2UID = "oauth2_uid"
|
||||
StoreOAuth2Provider = "oauth2_provider"
|
||||
StoreOAuth2Token = "oauth2_token"
|
||||
StoreOAuth2Refresh = "oauth2_refresh"
|
||||
StoreOAuth2Expiry = "oauth2_expiry"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -39,8 +41,8 @@ type StorageOptions map[string]DataType
|
||||
// The type of store is up to the developer implementing it, and all it has to
|
||||
// do is be able to store several simple types.
|
||||
type Storer interface {
|
||||
// Put is for storing the attributes passed in. The type information can
|
||||
// help serialization without using type assertions.
|
||||
// Put is for storing the attributes passed in using the key. This is an
|
||||
// update only method and should not store if it does not find the key.
|
||||
Put(key string, attr Attributes) error
|
||||
// Get is for retrieving attributes for a given key. The return value
|
||||
// must be a struct that contains all fields with the correct types as shown
|
||||
@ -49,6 +51,15 @@ type Storer interface {
|
||||
Get(key string, attrMeta AttributeMeta) (interface{}, error)
|
||||
}
|
||||
|
||||
// OAuth2Storer is a replacement (or addition) to the Storer interface.
|
||||
// It allows users to be stored and fetched via a uid/provider combination.
|
||||
type OAuth2Storer interface {
|
||||
// PutOAuth creates or updates an existing record (unlike Storer.Put)
|
||||
// because in the OAuth flow there is no separate create/update.
|
||||
PutOAuth(uid, provider string, attr Attributes) error
|
||||
GetOAuth(uid, provider string, attrMeta AttributeMeta) (interface{}, error)
|
||||
}
|
||||
|
||||
// DataType represents the various types that clients must be able to store.
|
||||
type DataType int
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user