1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-02-03 13:21:22 +02:00

Make OAuth2 implementation less shoddy.

- Add a new storer specifically for OAuth2 to enable clients to choose
  regular database storing OR Oauth2 but not have to have both.
- Stop storing OAuth2 credentials in a combined form inside username.
- Add new events to capture OAuth events just like auth.
- Have pass-through parameters for OAuth init urls, this allows us to
  pass additional behavior options (redirects and remember me) as well
  as other things that should be present on the page that is redirected
  to.
- Context.LoadUser is now OAuth aware.
- Remember's callbacks now include an OAuth check to see if a horribly
  packed state variable contains a flag to say that we want to be
  remembered.
- Change the OAuth2 Callback to use Attributes instead of that custom
  struct to allow people to append whatever attributes they want into
  the user that will be saved.
This commit is contained in:
Aaron L 2015-03-13 16:23:43 -07:00
parent 082caf88b3
commit 06edd2e615
16 changed files with 349 additions and 95 deletions

View File

@ -13,7 +13,9 @@ type Event int
const (
EventRegister Event = iota
EventAuth
EventOAuth
EventAuthFail
EventOAuthFail
EventRecoverStart
EventRecoverEnd
EventGet
@ -21,14 +23,15 @@ const (
EventPasswordReset
)
const eventNames = "EventRegisterEventAuthEventAuthFailEventRecoverStartEventRecoverEndEventGetEventGetUserSessionEventPasswordReset"
const eventNames = "EventRegisterEventAuthEventOAuthEventAuthFailEventOAuthFailEventRecoverStartEventRecoverEndEventGetEventGetUserSessionEventPasswordReset"
var eventIndexes = [...]uint8{0, 13, 22, 35, 52, 67, 75, 94, 112}
var eventIndexes = [...]uint8{0, 13, 22, 32, 45, 59, 76, 91, 99, 118, 136}
func (i Event) String() string {
if i < 0 || i+1 >= Event(len(eventIndexes)) {
return fmt.Sprintf("Event(%d)", i)
}
return eventNames[eventIndexes[i]:eventIndexes[i+1]]
}

View File

@ -160,7 +160,9 @@ func TestEventString(t *testing.T) {
}{
{EventRegister, "EventRegister"},
{EventAuth, "EventAuth"},
{EventOAuth, "EventOAuth"},
{EventAuthFail, "EventAuthFail"},
{EventOAuthFail, "EventOAuthFail"},
{EventRecoverStart, "EventRecoverStart"},
{EventRecoverEnd, "EventRecoverEnd"},
{EventGet, "EventGet"},

View File

@ -68,6 +68,7 @@ type Config struct {
XSRFMaker XSRF
Storer Storer
OAuth2Storer OAuth2Storer
CookieStoreMaker CookieStoreMaker
SessionStoreMaker SessionStoreMaker
LogWriter io.Writer

View File

@ -9,6 +9,11 @@ import (
"time"
)
var (
FormValueRedirect = "redir"
FormValueOAuth2State = "state"
)
// Context provides context for module operations and callbacks. One obvious
// need for context is a request's session store. It is not safe for use by
// multiple goroutines.
@ -100,12 +105,19 @@ func (c *Context) LoadUser(key string) error {
return nil
}
intf, err := Cfg.Storer.Get(key, ModuleAttrMeta)
var user interface{}
var err error
if index := strings.IndexByte(key, ';'); index > 0 {
user, err = Cfg.OAuth2Storer.GetOAuth(key[:index], key[index+1:], ModuleAttrMeta)
} else {
user, err = Cfg.Storer.Get(key, ModuleAttrMeta)
}
if err != nil {
return err
}
c.User = Unbind(intf)
c.User = Unbind(user)
return nil
}

View File

@ -95,22 +95,51 @@ func TestContext_SaveUser(t *testing.T) {
func TestContext_LoadUser(t *testing.T) {
Cfg = NewConfig()
ctx := NewContext()
attr := Attributes{
"email": "hello@joe.com",
"password": "mysticalhash",
"uid": "what",
"provider": "google",
}
storer := mockStorer{
"joe": Attributes{"email": "hello@joe.com", "password": "mysticalhash"},
"joe": attr,
"whatgoogle": attr,
}
Cfg.Storer = storer
Cfg.OAuth2Storer = storer
err := ctx.LoadUser("joe")
if err != nil {
ctx.User = nil
if err := ctx.LoadUser("joe"); err != nil {
t.Error("Unexpected error:", err)
}
attr := storer["joe"]
if email, err := ctx.User.StringErr("email"); err != nil {
t.Error(err)
} else if email != attr["email"] {
t.Error("Email wrong:", email)
}
if password, err := ctx.User.StringErr("password"); err != nil {
t.Error(err)
} else if password != attr["password"] {
t.Error("Password wrong:", password)
}
for k, v := range attr {
if v != ctx.User[k] {
t.Error(v, "not equal to", ctx.User[k])
}
ctx.User = nil
if err := ctx.LoadUser("what;google"); err != nil {
t.Error("Unexpected error:", err)
}
if email, err := ctx.User.StringErr("email"); err != nil {
t.Error(err)
} else if email != attr["email"] {
t.Error("Email wrong:", email)
}
if password, err := ctx.User.StringErr("password"); err != nil {
t.Error(err)
} else if password != attr["password"] {
t.Error("Password wrong:", password)
}
}

View File

@ -91,6 +91,39 @@ func (m *MockStorer) Get(key string, attrMeta authboss.AttributeMeta) (result in
return u, nil
}
func (m *MockStorer) PutOAuth(uid, provider string, attr authboss.Attributes) error {
if len(m.PutErr) > 0 {
return errors.New(m.PutErr)
}
if _, ok := m.Users[uid+provider]; !ok {
m.Users[uid+provider] = attr
return nil
}
for k, v := range attr {
m.Users[uid+provider][k] = v
}
return nil
}
func (m *MockStorer) GetOAuth(uid, provider string, attrMeta authboss.AttributeMeta) (result interface{}, err error) {
if len(m.GetErr) > 0 {
return nil, errors.New(m.GetErr)
}
userAttrs, ok := m.Users[uid+provider]
if !ok {
return nil, authboss.ErrUserNotFound
}
u := &MockUser{}
if err := userAttrs.Bind(u, true); err != nil {
panic(err)
}
return u, nil
}
func (m *MockStorer) AddToken(key, token string) error {
if len(m.AddTokenErr) > 0 {
return errors.New(m.AddTokenErr)

View File

@ -29,6 +29,17 @@ func (m mockStorer) Get(key string, attrMeta AttributeMeta) (result interface{},
}, nil
}
func (m mockStorer) PutOAuth(uid, provider string, attr Attributes) error {
m[uid+provider] = attr
return nil
}
func (m mockStorer) GetOAuth(uid, provider string, attrMeta AttributeMeta) (result interface{}, err error) {
return &mockUser{
m[uid+provider]["email"].(string), m[uid+provider]["password"].(string),
}, nil
}
type mockClientStore map[string]string
func (m mockClientStore) Get(key string) (string, bool) {

View File

@ -6,20 +6,28 @@ import (
"golang.org/x/oauth2"
)
// OAuth2Provider is the entire configuration
// required to authenticate with this provider.
//
// The OAuth2Config does not need a redirect URL because it will
// be automatically created by the
/*
OAuth2Provider is the entire configuration
required to authenticate with this provider.
The OAuth2Config does not need a redirect URL because it will
be automatically created by the route registration in oauth2 module.
AdditionalParams can be used to specify extra parameters to tack on to the
end of the initial request, this allows for provider specific oauth options
like access_type=offline to be passed to the provider.
Callback gives the config and the token allowing an http client using the
authenticated token to be created. Because each OAuth2 implementation has a different
API this must be handled for each provider separately. It is used to return two things
specifically: UID (the ID according to the provider) and the Email address.
The UID must be passed back or there will be an error as it is the means of identifying the
user in the system, e-mail is optional but should be returned in systems using
emailing. The keys authboss.StoreOAuth2UID and authboss.StoreEmail can be used to set
these values in the authboss.Attributes map returned by the callback.
*/
type OAuthProvider struct {
OAuth2Config *oauth2.Config
AdditionalParams url.Values
Callback func(oauth2.Config, *oauth2.Token) (OAuth2Credentials, error)
}
// OAuth2Credentials are used to store in the database.
// Email is optional
type OAuth2Credentials struct {
UID string
Email string
Callback func(oauth2.Config, *oauth2.Token) (Attributes, error)
}

View File

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
@ -18,15 +19,6 @@ var (
errOAuthStateValidation = errors.New("Could not validate oauth2 state param")
)
// OAuth2Storer is required to do OAuth2 storing.
type OAuth2Storer interface {
authboss.Storer
// NewOrUpdate should retrieve the user if he already exists, or create
// a new one. The key is composed of the provider:UID together and is stored
// in the authboss.StoreUsername field.
OAuth2NewOrUpdate(key string, attr authboss.Attributes) error
}
type OAuth2 struct{}
func init() {
@ -34,7 +26,7 @@ func init() {
}
func (o *OAuth2) Initialize() error {
if _, ok := authboss.Cfg.Storer.(OAuth2Storer); !ok {
if authboss.Cfg.OAuth2Storer == nil {
return errors.New("oauth2: need an OAuth2Storer")
}
return nil
@ -65,11 +57,12 @@ func (o *OAuth2) Routes() authboss.RouteTable {
func (o *OAuth2) Storage() authboss.StorageOptions {
return authboss.StorageOptions{
authboss.StoreUsername: authboss.String,
authboss.StoreEmail: authboss.String,
authboss.StoreOAuth2Token: authboss.String,
authboss.StoreOAuth2Refresh: authboss.String,
authboss.StoreOAuth2Expiry: authboss.DateTime,
authboss.StoreEmail: authboss.String,
authboss.StoreOAuth2UID: authboss.String,
authboss.StoreOAuth2Provider: authboss.String,
authboss.StoreOAuth2Token: authboss.String,
authboss.StoreOAuth2Refresh: authboss.String,
authboss.StoreOAuth2Expiry: authboss.DateTime,
}
}
@ -89,6 +82,16 @@ func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) er
state := base64.URLEncoding.EncodeToString(random)
ctx.SessionStorer.Put(authboss.SessionOAuth2State, state)
var passAlongs []string
for k, vals := range r.URL.Query() {
for _, val := range vals {
passAlongs = append(passAlongs, fmt.Sprintf("%s=%s", k, val))
}
}
if len(passAlongs) > 0 {
state += ";" + strings.Join(passAlongs, ";")
}
url := cfg.OAuth2Config.AuthCodeURL(state)
extraParams := cfg.AdditionalParams.Encode()
@ -108,6 +111,10 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
hasErr := r.FormValue("error")
if len(hasErr) > 0 {
if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
return err
}
return authboss.ErrAndRedirect{
Err: errors.New(r.FormValue("error_reason")),
Location: authboss.Cfg.AuthLoginFailPath,
@ -126,8 +133,9 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
}
// Ensure request is genuine
state := r.FormValue("state")
if state != sessState {
state := r.FormValue(authboss.FormValueOAuth2State)
splState := strings.Split(state, ";")
if len(splState) == 0 || splState[0] != sessState {
return errOAuthStateValidation
}
@ -138,32 +146,55 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
return fmt.Errorf("Could not validate oauth2 code: %v", err)
}
credentials, err := cfg.Callback(*cfg.OAuth2Config, token)
user, err := cfg.Callback(*cfg.OAuth2Config, token)
if err != nil {
return err
}
// User is authenticated
key := fmt.Sprintf("%s:%s", provider, credentials.UID)
user := make(authboss.Attributes)
user[authboss.StoreUsername] = key
// OAuth2UID is required.
uid, err := user.StringErr(authboss.StoreOAuth2UID)
if err != nil {
return err
}
user[authboss.StoreOAuth2UID] = uid
user[authboss.StoreOAuth2Provider] = provider
user[authboss.StoreOAuth2Expiry] = token.Expiry
user[authboss.StoreOAuth2Token] = token.AccessToken
if len(token.RefreshToken) != 0 {
user[authboss.StoreOAuth2Refresh] = token.RefreshToken
}
if len(credentials.Email) > 0 {
user[authboss.StoreEmail] = credentials.Email
}
// Log user in
ctx.SessionStorer.Put(authboss.SessionKey, key)
storer := authboss.Cfg.Storer.(OAuth2Storer)
if err = storer.OAuth2NewOrUpdate(key, user); err != nil {
if err = authboss.Cfg.OAuth2Storer.PutOAuth(uid, provider, user); err != nil {
return err
}
http.Redirect(w, r, authboss.Cfg.AuthLoginOKPath, http.StatusFound)
// Log user in
ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider))
if err = authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuth, ctx); err != nil {
return nil
}
redirect := authboss.Cfg.AuthLoginOKPath
values := make(url.Values)
if len(splState) > 0 {
for _, arg := range splState[1:] {
spl := strings.Split(arg, "=")
switch spl[0] {
case authboss.CookieRemember:
case authboss.FormValueRedirect:
redirect = spl[1]
default:
values.Set(spl[0], spl[1])
}
}
}
if len(values) > 0 {
redirect = fmt.Sprintf("%s?%s", redirect, values.Encode())
}
http.Redirect(w, r, redirect, http.StatusFound)
return nil
}

View File

@ -1,10 +1,12 @@
package oauth2
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"path"
"sort"
"strings"
"testing"
"time"
@ -30,7 +32,7 @@ var testProviders = map[string]authboss.OAuthProvider{
func TestInitialize(t *testing.T) {
authboss.Cfg = authboss.NewConfig()
authboss.Cfg.Storer = mocks.NewMockStorer()
authboss.Cfg.OAuth2Storer = mocks.NewMockStorer()
o := OAuth2{}
if err := o.Initialize(); err != nil {
t.Error(err)
@ -76,7 +78,7 @@ func TestOAuth2Init(t *testing.T) {
cfg.OAuth2Providers = testProviders
authboss.Cfg = cfg
r, _ := http.NewRequest("GET", "/oauth2/google", nil)
r, _ := http.NewRequest("GET", "/oauth2/google?r=/my/redirect&rm=true", nil)
w := httptest.NewRecorder()
ctx := authboss.NewContext()
ctx.SessionStorer = session
@ -100,6 +102,30 @@ func TestOAuth2Init(t *testing.T) {
if query["include_requested_scopes"][0] != "true" {
t.Error("Missing extra parameters:", loc)
}
state := query[authboss.FormValueOAuth2State][0]
if len(state) == 0 {
t.Error("It should have had some state:", loc)
}
splits := strings.Split(state, ";")
if len(splits[0]) != 44 {
t.Error("The xsrf token was wrong size:", len(splits[0]), splits[0])
}
// Maps are fun
sort.Strings(splits[1:])
if v, err := url.QueryUnescape(splits[1]); err != nil {
t.Error(err)
} else if v != "r=/my/redirect" {
t.Error("Redirect parameter not saved:", splits[1])
}
if v, err := url.QueryUnescape(splits[2]); err != nil {
t.Error(err)
} else if v != "rm=true" {
t.Error("Remember parameter not saved:", splits[2])
}
}
func TestOAuthSuccess(t *testing.T) {
@ -113,10 +139,10 @@ func TestOAuthSuccess(t *testing.T) {
Expiry: expiry,
}
fakeCallback := func(_ oauth2.Config, _ *oauth2.Token) (authboss.OAuth2Credentials, error) {
return authboss.OAuth2Credentials{
UID: "uid",
Email: "email",
fakeCallback := func(_ oauth2.Config, _ *oauth2.Token) (authboss.Attributes, error) {
return authboss.Attributes{
authboss.StoreOAuth2UID: "uid",
authboss.StoreEmail: "email",
}, nil
}
@ -142,29 +168,27 @@ func TestOAuthSuccess(t *testing.T) {
}
authboss.Cfg = cfg
r, _ := http.NewRequest("GET", "/oauth2/fake?code=code&state=state", nil)
url := fmt.Sprintf("/oauth2/fake?code=code&state=%s", url.QueryEscape("state;redir=/myurl;rm=true;myparam=5"))
r, _ := http.NewRequest("GET", url, nil)
w := httptest.NewRecorder()
ctx := authboss.NewContext()
session := mocks.NewMockClientStorer()
session.Put(authboss.SessionOAuth2State, "state")
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
storer := mocks.NewMockStorer()
ctx.SessionStorer = session
cfg.Storer = storer
cfg.OAuth2Storer = storer
cfg.AuthLoginOKPath = "/fakeloginok"
if err := oauthCallback(ctx, w, r); err != nil {
t.Error(err)
}
key := "fake:uid"
key := "uidfake"
user, ok := storer.Users[key]
if !ok {
t.Error("Couldn't find user.")
}
if val, _ := user.String(authboss.StoreUsername); val != key {
t.Error("Username was wrong:", val)
}
if val, _ := user.String(authboss.StoreEmail); val != "email" {
t.Error("Email was wrong:", val)
}
@ -178,13 +202,13 @@ func TestOAuthSuccess(t *testing.T) {
t.Error("Expiry was wrong:", val)
}
if val, _ := session.Get(authboss.SessionKey); val != key {
if val, _ := session.Get(authboss.SessionKey); val != "uid;fake" {
t.Error("User was not logged in:", val)
}
if w.Code != http.StatusFound {
t.Error("It should redirect")
} else if loc := w.Header().Get("Location"); loc != authboss.Cfg.AuthLoginOKPath {
} else if loc := w.Header().Get("Location"); loc != "/myurl?myparam=5" {
t.Error("Redirect is wrong:", loc)
}
}
@ -193,13 +217,13 @@ func TestOAuthXSRFFailure(t *testing.T) {
cfg := authboss.NewConfig()
session := mocks.NewMockClientStorer()
session.Put(authboss.SessionOAuth2State, "state")
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
cfg.OAuth2Providers = testProviders
authboss.Cfg = cfg
values := url.Values{}
values.Set("state", "notstate")
values.Set(authboss.FormValueOAuth2State, "notstate")
values.Set("code", "code")
r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil)

View File

@ -26,21 +26,22 @@ type googleMeResponse struct {
var clientGet = (*http.Client).Get
// Google is a callback appropriate for use with Google's OAuth2 configuration.
func Google(cfg oauth2.Config, token *oauth2.Token) (cred authboss.OAuth2Credentials, err error) {
func Google(cfg oauth2.Config, token *oauth2.Token) (authboss.Attributes, error) {
client := cfg.Client(oauth2.NoContext, token)
resp, err := clientGet(client, googleInfoEndpoint)
if err != nil {
return cred, err
return nil, err
}
defer resp.Body.Close()
dec := json.NewDecoder(resp.Body)
var jsonResp googleMeResponse
if err = dec.Decode(&jsonResp); err != nil {
return cred, err
return nil, err
}
cred.UID = jsonResp.ID
cred.Email = jsonResp.Email
return cred, nil
return authboss.Attributes{
authboss.StoreOAuth2UID: jsonResp.ID,
authboss.StoreEmail: jsonResp.Email,
}, nil
}

View File

@ -8,6 +8,7 @@ import (
"time"
"golang.org/x/oauth2"
"gopkg.in/authboss.v0"
)
func TestGoogle(t *testing.T) {
@ -30,15 +31,15 @@ func TestGoogle(t *testing.T) {
Expiry: time.Now().Add(60 * time.Minute),
}
cred, err := Google(cfg, tok)
user, err := Google(cfg, tok)
if err != nil {
t.Error(err)
}
if cred.UID != "id" {
t.Error("UID wrong:", cred.UID)
if uid, ok := user[authboss.StoreOAuth2UID]; !ok || uid != "id" {
t.Error("UID wrong:", uid)
}
if cred.Email != "email" {
t.Error("Email wrong:", cred.Email)
if email, ok := user[authboss.StoreEmail]; !ok || email != "email" {
t.Error("Email wrong:", email)
}
}

View File

@ -251,7 +251,7 @@ func TestRecover_sendRecoverEmail(t *testing.T) {
mailer := mocks.NewMockMailer()
authboss.Cfg.EmailSubjectPrefix = "foo "
authboss.Cfg.HostName = "bar"
authboss.Cfg.RootURL = "bar"
authboss.Cfg.Mailer = mailer
a.sendRecoverEmail("a@b.c", "abc=")
@ -265,7 +265,7 @@ func TestRecover_sendRecoverEmail(t *testing.T) {
t.Error("Unexpected subject:", mailer.Last.Subject)
}
url := fmt.Sprintf("%s/recover/complete?token=abc=", authboss.Cfg.HostName)
url := fmt.Sprintf("%s/recover/complete?token=abc=", authboss.Cfg.RootURL)
if !strings.Contains(mailer.Last.HTMLBody, url) {
t.Error("Expected HTMLBody to contain url:", url)
}
@ -409,11 +409,11 @@ func TestRecover_completeHanlderFunc_POST(t *testing.T) {
}
if !cbCalled {
t.Error("Expected EventPasswordReste callback to have been fired")
t.Error("Expected EventPasswordReset callback to have been fired")
}
if val, ok := sessionStorer.Get(authboss.SessionKey); !ok || val != "john" {
t.Errorf("Ecxpected SessionKey to be:", "john")
t.Error("Expected SessionKey to be:", "john")
}
if w.Code != http.StatusFound {

View File

@ -10,6 +10,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"strings"
"gopkg.in/authboss.v0"
)
@ -50,11 +51,12 @@ func (r *Remember) Initialize() error {
}
if _, ok := authboss.Cfg.Storer.(TokenStorer); !ok {
return errors.New("remember: TokenStorer required for remember me functionality")
return errors.New("remember: TokenStorer required for remember functionality")
}
authboss.Cfg.Callbacks.Before(authboss.EventGetUserSession, r.auth)
authboss.Cfg.Callbacks.After(authboss.EventAuth, r.afterAuth)
authboss.Cfg.Callbacks.After(authboss.EventOAuth, r.afterOAuth)
authboss.Cfg.Callbacks.After(authboss.EventPasswordReset, r.afterPassword)
return nil
@ -90,6 +92,53 @@ func (r *Remember) afterAuth(ctx *authboss.Context) error {
return nil
}
// afterOAuth is called after oauth authentication is successful.
// Has to pander to horrible state variable packing to figure out if we want
// to be remembered.
func (r *Remember) afterOAuth(ctx *authboss.Context) error {
state, ok := ctx.FirstFormValue(authboss.FormValueOAuth2State)
if !ok {
return nil
}
splState := strings.Split(state, ";")
if len(splState) < 0 {
return nil
}
should := false
for _, arg := range splState[1:] {
spl := strings.Split(arg, "=")
if spl[0] == authboss.CookieRemember {
should = spl[1] == "true"
break
}
}
if !should {
return nil
}
if ctx.User == nil {
return errUserMissing
}
uid, err := ctx.User.StringErr(authboss.StoreOAuth2Provider)
if err != nil {
return err
}
provider, err := ctx.User.StringErr(authboss.StoreOAuth2Provider)
if err != nil {
return err
}
if _, err := r.new(ctx.CookieStorer, uid+";"+provider); err != nil {
return fmt.Errorf("remember: Failed to create remember token: %v", err)
}
return nil
}
// afterPassword is called after the password has been reset.
func (r *Remember) afterPassword(ctx *authboss.Context) error {
if ctx.User == nil {
@ -157,7 +206,7 @@ func (r *Remember) auth(ctx *authboss.Context) (authboss.Interrupt, error) {
index := bytes.IndexByte(token, ';')
if index < 0 {
return authboss.InterruptNone, errors.New("remember: Invalid remember me token.")
return authboss.InterruptNone, errors.New("remember: Invalid remember token.")
}
// Get the key.

View File

@ -2,7 +2,9 @@ package remember
import (
"bytes"
"fmt"
"net/http"
"net/url"
"testing"
"gopkg.in/authboss.v0"
@ -64,6 +66,42 @@ func TestAfterAuth(t *testing.T) {
}
}
func TestAfterOAuth(t *testing.T) {
r := Remember{}
authboss.NewConfig()
storer := mocks.NewMockStorer()
authboss.Cfg.Storer = storer
cookies := mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer()
uri := fmt.Sprintf("%s?state=%s", "localhost/oauthed", url.QueryEscape("xsrf;rm=true"))
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
t.Error("Unexpected Error:", err)
}
ctx, err := authboss.ContextFromRequest(req)
if err != nil {
t.Error("Unexpected error:", err)
}
ctx.SessionStorer = session
ctx.CookieStorer = cookies
ctx.User = authboss.Attributes{
authboss.StoreOAuth2UID: "uid",
authboss.StoreOAuth2Provider: "google",
}
if err := r.afterOAuth(ctx); err != nil {
t.Error(err)
}
if _, ok := cookies.Values[authboss.CookieRemember]; !ok {
t.Error("Expected a cookie to have been set.")
}
}
func TestAfterPasswordReset(t *testing.T) {
r := Remember{}
authboss.NewConfig()

View File

@ -20,9 +20,11 @@ const (
// Data store constants for OAuth2 attribute names.
const (
StoreOAuth2Token = "oauth2_token"
StoreOAuth2Refresh = "oauth2_refresh"
StoreOAuth2Expiry = "oauth2_expiry"
StoreOAuth2UID = "oauth2_uid"
StoreOAuth2Provider = "oauth2_provider"
StoreOAuth2Token = "oauth2_token"
StoreOAuth2Refresh = "oauth2_refresh"
StoreOAuth2Expiry = "oauth2_expiry"
)
var (
@ -39,8 +41,8 @@ type StorageOptions map[string]DataType
// The type of store is up to the developer implementing it, and all it has to
// do is be able to store several simple types.
type Storer interface {
// Put is for storing the attributes passed in. The type information can
// help serialization without using type assertions.
// Put is for storing the attributes passed in using the key. This is an
// update only method and should not store if it does not find the key.
Put(key string, attr Attributes) error
// Get is for retrieving attributes for a given key. The return value
// must be a struct that contains all fields with the correct types as shown
@ -49,6 +51,15 @@ type Storer interface {
Get(key string, attrMeta AttributeMeta) (interface{}, error)
}
// OAuth2Storer is a replacement (or addition) to the Storer interface.
// It allows users to be stored and fetched via a uid/provider combination.
type OAuth2Storer interface {
// PutOAuth creates or updates an existing record (unlike Storer.Put)
// because in the OAuth flow there is no separate create/update.
PutOAuth(uid, provider string, attr Attributes) error
GetOAuth(uid, provider string, attrMeta AttributeMeta) (interface{}, error)
}
// DataType represents the various types that clients must be able to store.
type DataType int