1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-09-16 09:06:20 +02:00

Rewrite oauth module

- Tried to be clear about OAuth2 vs OAuth in all places.
- Allow users to be locked from OAuth logins (if done manually for some
  reason other than failed logins)
- Cleaned up some docs and wording around the previously very confusing
  (now hopefully only somewhat confusing) oauth2 module.
This commit is contained in:
Aaron L
2018-03-08 18:39:51 -08:00
parent 634892e29c
commit 1112987bce
15 changed files with 746 additions and 280 deletions

View File

@@ -27,6 +27,11 @@ type Config struct {
// LogoutOK is the redirect path after a log out.
LogoutOK string
// OAuth2LoginOK is the redirect path after a successful oauth2 login
OAuth2LoginOK string
// OAuth2LoginNotOK is the redirect path after an unsuccessful oauth2 login
OAuth2LoginNotOK string
// RecoverOK is the redirect path after a successful recovery of a password.
RecoverOK string

View File

@@ -13,9 +13,9 @@ type Event int
const (
EventRegister Event = iota
EventAuth
EventOAuth
EventOAuth2
EventAuthFail
EventOAuthFail
EventOAuth2Fail
EventRecoverStart
EventRecoverEnd
EventGetUser

View File

@@ -132,9 +132,9 @@ func TestEventString(t *testing.T) {
}{
{EventRegister, "EventRegister"},
{EventAuth, "EventAuth"},
{EventOAuth, "EventOAuth"},
{EventOAuth2, "EventOAuth2"},
{EventAuthFail, "EventAuthFail"},
{EventOAuthFail, "EventOAuthFail"},
{EventOAuth2Fail, "EventOAuth2Fail"},
{EventRecoverStart, "EventRecoverStart"},
{EventRecoverEnd, "EventRecoverEnd"},
{EventGetUser, "EventGetUser"},

View File

@@ -25,9 +25,12 @@ type User struct {
AttemptCount int
LastAttempt time.Time
Locked time.Time
OAuthToken string
OAuthRefresh string
OAuthExpiry time.Time
OAuth2UID string
OAuth2Provider string
OAuth2Token string
OAuth2Refresh string
OAuth2Expiry time.Time
Arbitrary map[string]string
}
@@ -43,9 +46,12 @@ func (m User) GetConfirmed() bool { return m.Confirmed }
func (m User) GetAttemptCount() int { return m.AttemptCount }
func (m User) GetLastAttempt() time.Time { return m.LastAttempt }
func (m User) GetLocked() time.Time { return m.Locked }
func (m User) GetOAuthToken() string { return m.OAuthToken }
func (m User) GetOAuthRefresh() string { return m.OAuthRefresh }
func (m User) GetOAuthExpiry() time.Time { return m.OAuthExpiry }
func (m User) IsOAuth2User() bool { return len(m.OAuth2Provider) != 0 }
func (m User) GetOAuth2UID() string { return m.OAuth2UID }
func (m User) GetOAuth2Provider() string { return m.OAuth2Provider }
func (m User) GetOAuth2AccessToken() string { return m.OAuth2Token }
func (m User) GetOAuth2RefreshToken() string { return m.OAuth2Refresh }
func (m User) GetOAuth2Expiry() time.Time { return m.OAuth2Expiry }
func (m User) GetArbitrary() map[string]string { return m.Arbitrary }
func (m *User) PutPID(email string) { m.Email = email }
@@ -61,9 +67,11 @@ func (m *User) PutConfirmed(confirmed bool) { m.Confirmed = confirmed }
func (m *User) PutAttemptCount(attemptCount int) { m.AttemptCount = attemptCount }
func (m *User) PutLastAttempt(attemptTime time.Time) { m.LastAttempt = attemptTime }
func (m *User) PutLocked(locked time.Time) { m.Locked = locked }
func (m *User) PutOAuthToken(oAuthToken string) { m.OAuthToken = oAuthToken }
func (m *User) PutOAuthRefresh(oAuthRefresh string) { m.OAuthRefresh = oAuthRefresh }
func (m *User) PutOAuthExpiry(oAuthExpiry time.Time) { m.OAuthExpiry = oAuthExpiry }
func (m *User) PutOAuth2UID(uid string) { m.OAuth2UID = uid }
func (m *User) PutOAuth2Provider(provider string) { m.OAuth2Provider = provider }
func (m *User) PutOAuth2AccessToken(token string) { m.OAuth2Token = token }
func (m *User) PutOAuth2RefreshToken(refresh string) { m.OAuth2Refresh = refresh }
func (m *User) PutOAuth2Expiry(expiry time.Time) { m.OAuth2Expiry = expiry }
func (m *User) PutArbitrary(arb map[string]string) { m.Arbitrary = arb }
// ServerStorer should be valid for any module storer defined in authboss.
@@ -115,6 +123,38 @@ func (s *ServerStorer) Save(ctx context.Context, user authboss.User) error {
return nil
}
// NewFromOAuth2 finds a user with the given details, or returns a new one
func (s *ServerStorer) NewFromOAuth2(ctx context.Context, provider string, details map[string]string) (authboss.OAuth2User, error) {
uid := details["uid"]
email := details["email"]
name := details["name"]
pid := authboss.MakeOAuth2PID(provider, uid)
u, ok := s.Users[pid]
if ok {
u.Username = name
u.Email = email
return u, nil
}
return &User{
OAuth2UID: uid,
OAuth2Provider: provider,
Email: email,
Username: name,
}, nil
}
// SaveOAuth2 creates a user if not found, or updates one that exists.
func (s *ServerStorer) SaveOAuth2(ctx context.Context, user authboss.OAuth2User) error {
u := user.(*User)
pid := authboss.MakeOAuth2PID(u.OAuth2Provider, u.OAuth2UID)
// Since we don't have to differentiate between insert/update in a map, we just overwrite
s.Users[pid] = u
return nil
}
// LoadByConfirmToken finds a user by his confirm token
func (s *ServerStorer) LoadByConfirmToken(ctx context.Context, token string) (authboss.ConfirmableUser, error) {
for _, v := range s.Users {

View File

@@ -35,8 +35,8 @@ type Lock struct {
func (l *Lock) Init(ab *authboss.Authboss) error {
l.Authboss = ab
// Events
l.Events.Before(authboss.EventAuth, l.BeforeAuth)
l.Events.Before(authboss.EventOAuth, l.BeforeAuth)
l.Events.After(authboss.EventAuth, l.AfterAuthSuccess)
l.Events.After(authboss.EventAuthFail, l.AfterAuthFail)

View File

@@ -7,11 +7,6 @@ import (
"golang.org/x/oauth2"
)
// FormValue constants
const (
FormValueOAuth2State = "state"
)
/*
OAuth2Provider is the entire configuration
required to authenticate with this provider.
@@ -23,22 +18,14 @@ AdditionalParams can be used to specify extra parameters to tack on to the
end of the initial request, this allows for provider specific oauth options
like access_type=offline to be passed to the provider.
Callback gives the config and the token allowing an http client using the
authenticated token to be created. Because each OAuth2 implementation has a different
API this must be handled for each provider separately. It is used to return two things
specifically: UID (the ID according to the provider) and the Email address.
The UID must be passed back or there will be an error as it is the means of identifying the
user in the system, e-mail is optional but should be returned in systems using
emailing. The keys authboss.StoreOAuth2UID and authboss.StoreEmail can be used to set
these values in the map returned by the callback.
In addition to the required values mentioned above any additional
values that you wish to have in your user struct can be included here, such as the
Name of the user at the endpoint. This will be passed back in the Arbitrary()
function if it exists.
FindUserDetails gives the config and the token allowing an http client using the
authenticated token to be created, a call is then made to a known endpoint that will
return details about the user we've retrieved the token for. Those details are returned
as a map[string]string and subsequently passed into OAuth2ServerStorer.NewFromOAuth2.
API this must be handled for each provider separately.
*/
type OAuth2Provider struct {
OAuth2Config *oauth2.Config
AdditionalParams url.Values
Callback func(context.Context, oauth2.Config, *oauth2.Token) (map[string]string, error)
FindUserDetails func(context.Context, oauth2.Config, *oauth2.Token) (map[string]string, error)
}

View File

@@ -1,6 +1,36 @@
// Package oauth2 allows users to be created and authenticated
// via oauth2 services like facebook, google etc. Currently
// only the web server flow is supported.
//
// The general flow looks like this:
// 1. User goes to Start handler and has his session packed with goodies
// then redirects to the OAuth service.
// 2. OAuth service returns to OAuthCallback which extracts state and parameters
// and generally checks that everything is ok. It uses the token received to
// get an access token from the oauth2 library
// 3. Calls the OAuth2Provider.FindUserDetails which should return the user's details
// in a generic form.
// 4. Passes the user details into the OAuth2ServerStorer.NewFromOAuth2 in order
// to create a user object we can work with.
// 5. Saves the user in the database, logs them in, redirects.
//
// In order to do this there are a number of parts:
// 1. The configuration of a provider (handled by authboss.Config.Modules.OAuth2Providers)
// 2. The flow of redirection of client, parameter passing etc (handled by this package)
// 3. The HTTP call to the service once a token has been retrieved to get user details
// (handled by OAuth2Provider.FindUserDetails)
// 4. The creation of a user from the user details returned from the FindUserDetails
// (authboss.OAuth2ServerStorer)
//
// Of these parts, the responsibility of the authboss library consumer is on 1, 3, and 4.
// Configuration of providers that should be used is totally up to the consumer. The FindUserDetails
// function is typically up to the user, but we have some basic ones included in this package too.
// The creation of users from the FindUserDetail's map[string]string return is handled as part
// of the implementation of the OAuth2ServerStorer.
package oauth2
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
@@ -9,13 +39,19 @@ import (
"net/url"
"path"
"path/filepath"
"sort"
"strings"
"github.com/pkg/errors"
"golang.org/x/oauth2"
"github.com/volatiletech/authboss"
"github.com/volatiletech/authboss/internal/response"
"golang.org/x/oauth2"
)
// FormValue constants
const (
FormValueOAuth2State = "state"
FormValueOAuth2Redir = "redir"
)
var (
@@ -31,68 +67,60 @@ func init() {
authboss.RegisterModule("oauth2", &OAuth2{})
}
// Initialize module
func (o *OAuth2) Initialize(ab *authboss.Authboss) error {
// Init module
func (o *OAuth2) Init(ab *authboss.Authboss) error {
o.Authboss = ab
if o.OAuth2Storer == nil && o.OAuth2StoreMaker == nil {
return errors.New("need an oauth2Storer")
// Do annoying sorting on keys so we can have predictible
// route registration (both for consistency inside the router but also for tests -_-)
var keys []string
for k := range o.Authboss.Config.Modules.OAuth2Providers {
keys = append(keys, k)
}
sort.Strings(keys)
for _, provider := range keys {
cfg := o.Authboss.Config.Modules.OAuth2Providers[provider]
provider = strings.ToLower(provider)
init := fmt.Sprintf("/oauth2/%s", provider)
callback := fmt.Sprintf("/oauth2/callback/%s", provider)
o.Authboss.Config.Core.Router.Get(init, o.Authboss.Core.ErrorHandler.Wrap(o.Start))
o.Authboss.Config.Core.Router.Get(callback, o.Authboss.Core.ErrorHandler.Wrap(o.End))
if mount := o.Authboss.Config.Paths.Mount; len(mount) > 0 {
callback = path.Join(mount, callback)
}
cfg.OAuth2Config.RedirectURL = o.Authboss.Config.Paths.RootURL + callback
}
return nil
}
// Routes for module
func (o *OAuth2) Routes() authboss.RouteTable {
routes := make(authboss.RouteTable)
// Start the oauth2 process
func (o *OAuth2) Start(w http.ResponseWriter, r *http.Request) error {
logger := o.Authboss.RequestLogger(r)
for prov, cfg := range o.OAuth2Providers {
prov = strings.ToLower(prov)
init := fmt.Sprintf("/oauth2/%s", prov)
callback := fmt.Sprintf("/oauth2/callback/%s", prov)
routes[init] = o.oauthInit
routes[callback] = o.oauthCallback
if len(o.MountPath) > 0 {
callback = path.Join(o.MountPath, callback)
}
cfg.OAuth2Config.RedirectURL = o.RootURL + callback
}
routes["/oauth2/logout"] = o.logout
return routes
}
// Storage requirements
func (o *OAuth2) Storage() authboss.StorageOptions {
return authboss.StorageOptions{
authboss.StoreEmail: authboss.String,
authboss.StoreOAuth2UID: authboss.String,
authboss.StoreOAuth2Provider: authboss.String,
authboss.StoreOAuth2Token: authboss.String,
authboss.StoreOAuth2Refresh: authboss.String,
authboss.StoreOAuth2Expiry: authboss.DateTime,
}
}
func (o *OAuth2) oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
provider := strings.ToLower(filepath.Base(r.URL.Path))
cfg, ok := o.OAuth2Providers[provider]
logger.Infof("started oauth2 flow for provider: %s", provider)
cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider]
if !ok {
return errors.Errorf("OAuth2 provider %q not found", provider)
return errors.Errorf("oauth2 provider %q not found", provider)
}
random := make([]byte, 32)
_, err := rand.Read(random)
if err != nil {
return err
// Create nonce
nonce := make([]byte, 32)
if _, err := rand.Read(nonce); err != nil {
return errors.Wrap(err, "failed to create nonce")
}
state := base64.URLEncoding.EncodeToString(random)
ctx.SessionStorer.Put(authboss.SessionOAuth2State, state)
state := base64.URLEncoding.EncodeToString(nonce)
authboss.PutSession(w, authboss.SessionOAuth2State, state)
// This clearly ignores the fact that query parameters can have multiple
// values but I guess we're ignoring that
passAlongs := make(map[string]string)
for k, vals := range r.URL.Query() {
for _, val := range vals {
@@ -101,13 +129,13 @@ func (o *OAuth2) oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http
}
if len(passAlongs) > 0 {
str, err := json.Marshal(passAlongs)
byt, err := json.Marshal(passAlongs)
if err != nil {
return err
}
ctx.SessionStorer.Put(authboss.SessionOAuth2Params, string(str))
authboss.PutSession(w, authboss.SessionOAuth2Params, string(byt))
} else {
ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
authboss.DelSession(w, authboss.SessionOAuth2Params)
}
url := cfg.OAuth2Config.AuthCodeURL(state)
@@ -117,128 +145,153 @@ func (o *OAuth2) oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http
url = fmt.Sprintf("%s&%s", url, extraParams)
}
http.Redirect(w, r, url, http.StatusFound)
return nil
ro := authboss.RedirectOptions{
Code: http.StatusTemporaryRedirect,
RedirectPath: url,
}
return o.Authboss.Core.Redirector.Redirect(w, r, ro)
}
// for testing
// for testing, mocked out at the beginning
var exchanger = (*oauth2.Config).Exchange
func (o *OAuth2) oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
// End the oauth2 process, this is the handler for the oauth2 callback
// that the third party will redirect to.
func (o *OAuth2) End(w http.ResponseWriter, r *http.Request) error {
logger := o.Authboss.RequestLogger(r)
provider := strings.ToLower(filepath.Base(r.URL.Path))
logger.Infof("finishing oauth2 flow for provider: %s", provider)
sessState, err := ctx.SessionStorer.GetErr(authboss.SessionOAuth2State)
ctx.SessionStorer.Del(authboss.SessionOAuth2State)
if err != nil {
return err
}
sessValues, ok := ctx.SessionStorer.Get(authboss.SessionOAuth2Params)
// Don't delete this value from session immediately, Events use this too
var values map[string]string
if ok {
if err := json.Unmarshal([]byte(sessValues), &values); err != nil {
return err
}
}
hasErr := r.FormValue("error")
if len(hasErr) > 0 {
if err := o.Events.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
return err
}
return authboss.ErrAndRedirect{
Err: errors.New(r.FormValue("error_reason")),
Location: o.AuthLoginFailPath,
FlashError: fmt.Sprintf("%s login cancelled or failed.", strings.Title(provider)),
}
}
cfg, ok := o.OAuth2Providers[provider]
// This shouldn't happen because the router should 404 first, but just in case
cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider]
if !ok {
return errors.Errorf("oauth2 provider %q not found", provider)
}
// Ensure request is genuine
state := r.FormValue(authboss.FormValueOAuth2State)
splState := strings.Split(state, ";")
if len(splState) == 0 || splState[0] != sessState {
wantState, ok := authboss.GetSession(r, authboss.SessionOAuth2State)
if !ok {
return errors.New("oauth2 endpoint hit without session state")
}
// Verify we got the same state in the session as was passed to us in the
// query parameter.
state := r.FormValue(FormValueOAuth2State)
if state != wantState {
return errOAuthStateValidation
}
// Get the code
rawParams, ok := authboss.GetSession(r, authboss.SessionOAuth2Params)
var params map[string]string
if ok {
if err := json.Unmarshal([]byte(rawParams), &params); err != nil {
return errors.Wrap(err, "failed to decode oauth2 params")
}
}
authboss.DelSession(w, authboss.SessionOAuth2State)
authboss.DelSession(w, authboss.SessionOAuth2Params)
hasErr := r.FormValue("error")
if len(hasErr) > 0 {
reason := r.FormValue("error_reason")
logger.Infof("oauth2 login failed: %s, reason: %s", hasErr, reason)
handled, err := o.Authboss.Events.FireAfter(authboss.EventOAuth2Fail, w, r)
if err != nil {
return err
} else if handled {
return nil
}
ro := authboss.RedirectOptions{
Code: http.StatusTemporaryRedirect,
RedirectPath: o.Authboss.Config.Paths.OAuth2LoginNotOK,
Failure: fmt.Sprintf("%s login cancelled or failed", strings.Title(provider)),
}
return o.Authboss.Core.Redirector.Redirect(w, r, ro)
}
// Get the code which we can use to make an access token
code := r.FormValue("code")
token, err := exchanger(cfg.OAuth2Config, o.Config.ContextProvider(r), code)
token, err := exchanger(cfg.OAuth2Config, r.Context(), code)
if err != nil {
return errors.Wrap(err, "could not validate oauth2 code")
}
user, err := cfg.Callback(o.Config.ContextProvider(r), *cfg.OAuth2Config, token)
details, err := cfg.FindUserDetails(r.Context(), *cfg.OAuth2Config, token)
if err != nil {
return err
}
// OAuth2UID is required.
uid, err := user.StringErr(authboss.StoreOAuth2UID)
storer := authboss.EnsureCanOAuth2(o.Authboss.Config.Storage.Server)
user, err := storer.NewFromOAuth2(r.Context(), provider, details)
if err != nil {
return err
return errors.Wrap(err, "failed to create oauth2 user from values")
}
user[authboss.StoreOAuth2UID] = uid
user[authboss.StoreOAuth2Provider] = provider
user[authboss.StoreOAuth2Expiry] = token.Expiry
user[authboss.StoreOAuth2Token] = token.AccessToken
user.PutOAuth2Provider(provider)
user.PutOAuth2AccessToken(token.AccessToken)
user.PutOAuth2Expiry(token.Expiry)
if len(token.RefreshToken) != 0 {
user[authboss.StoreOAuth2Refresh] = token.RefreshToken
user.PutOAuth2RefreshToken(token.RefreshToken)
}
if err = ctx.OAuth2Storer.PutOAuth(uid, provider, user); err != nil {
if err := storer.SaveOAuth2(r.Context(), user); err != nil {
return err
}
// Fully log user in
ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider))
ctx.SessionStorer.Del(authboss.SessionHalfAuthKey)
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user))
if err = o.Events.FireAfter(authboss.EventOAuth, ctx); err != nil {
handled, err := o.Authboss.Events.FireBefore(authboss.EventOAuth2, w, r)
if err != nil {
return err
} else if handled {
return nil
}
ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
// Fully log user in
authboss.PutSession(w, authboss.SessionKey, authboss.MakeOAuth2PID(provider, user.GetOAuth2UID()))
authboss.DelSession(w, authboss.SessionHalfAuthKey)
redirect := o.AuthLoginOKPath
// Create a query string from all the pieces we've received
// as passthru from the original request.
redirect := o.Authboss.Config.Paths.OAuth2LoginOK
query := make(url.Values)
for k, v := range values {
for k, v := range params {
switch k {
case authboss.CookieRemember:
case authboss.FormValueRedirect:
if v == "true" {
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyValues, RMTrue{}))
}
case FormValueOAuth2Redir:
redirect = v
default:
query.Set(k, v)
}
}
handled, err = o.Authboss.Events.FireAfter(authboss.EventOAuth2, w, r)
if err != nil {
return err
} else if handled {
return nil
}
if len(query) > 0 {
redirect = fmt.Sprintf("%s?%s", redirect, query.Encode())
}
sf := fmt.Sprintf("Logged in successfully with %s.", strings.Title(provider))
response.Redirect(ctx, w, r, redirect, sf, "", false)
return nil
}
func (o *OAuth2) logout(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
switch r.Method {
case "GET":
ctx.SessionStorer.Del(authboss.SessionKey)
ctx.CookieStorer.Del(authboss.CookieRemember)
ctx.SessionStorer.Del(authboss.SessionLastAction)
response.Redirect(ctx, w, r, o.AuthLogoutOKPath, "You have logged out", "", true)
default:
w.WriteHeader(http.StatusMethodNotAllowed)
ro := authboss.RedirectOptions{
Code: http.StatusTemporaryRedirect,
RedirectPath: redirect,
Success: fmt.Sprintf("Logged in successfully with %s.", strings.Title(provider)),
}
return nil
return o.Authboss.Config.Core.Redirector.Redirect(w, r, ro)
}
// RMTrue is a dummy struct implementing authboss.RememberValuer
// in order to tell the remember me module to remember them.
type RMTrue struct{}
// GetShouldRemember always returns true
func (RMTrue) GetShouldRemember() bool { return true }

View File

@@ -2,11 +2,9 @@ package oauth2
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"path"
"strings"
"testing"
"time"
@@ -18,6 +16,12 @@ import (
"golang.org/x/oauth2/google"
)
func init() {
exchanger = func(_ *oauth2.Config, _ context.Context, _ string) (*oauth2.Token, error) {
return testToken, nil
}
}
var testProviders = map[string]authboss.OAuth2Provider{
"google": authboss.OAuth2Provider{
OAuth2Config: &oauth2.Config{
@@ -25,8 +29,10 @@ var testProviders = map[string]authboss.OAuth2Provider{
ClientSecret: `hands`,
Scopes: []string{`profile`, `email`},
Endpoint: google.Endpoint,
// This is typically set by Init() but some tests rely on it's existence
RedirectURL: "https://www.example.com/auth/oauth2/callback/google",
},
Callback: Google,
FindUserDetails: GoogleUserDetails,
AdditionalParams: url.Values{"include_requested_scopes": []string{"true"}},
},
"facebook": authboss.OAuth2Provider{
@@ -35,11 +41,341 @@ var testProviders = map[string]authboss.OAuth2Provider{
ClientSecret: `hands`,
Scopes: []string{`email`},
Endpoint: facebook.Endpoint,
// This is typically set by Init() but some tests rely on it's existence
RedirectURL: "https://www.example.com/auth/oauth2/callback/facebook",
},
Callback: Facebook,
FindUserDetails: FacebookUserDetails,
},
}
var testToken = &oauth2.Token{
AccessToken: "token",
TokenType: "Bearer",
RefreshToken: "refresh",
Expiry: time.Now().AddDate(0, 0, 1),
}
func TestInit(t *testing.T) {
// No t.Parallel() since the cfg.RedirectURL is set in Init()
ab := authboss.New()
oauth := &OAuth2{}
router := &mocks.Router{}
ab.Config.Modules.OAuth2Providers = testProviders
ab.Config.Core.Router = router
ab.Config.Core.ErrorHandler = &mocks.ErrorHandler{}
ab.Config.Paths.Mount = "/auth"
ab.Config.Paths.RootURL = "https://www.example.com"
if err := oauth.Init(ab); err != nil {
t.Fatal(err)
}
gets := []string{
"/oauth2/facebook", "/oauth2/callback/facebook",
"/oauth2/google", "/oauth2/callback/google",
}
if err := router.HasGets(gets...); err != nil {
t.Error(err)
}
}
type testHarness struct {
oauth *OAuth2
ab *authboss.Authboss
bodyReader *mocks.BodyReader
responder *mocks.Responder
redirector *mocks.Redirector
session *mocks.ClientStateRW
storer *mocks.ServerStorer
}
func testSetup() *testHarness {
harness := &testHarness{}
harness.ab = authboss.New()
harness.redirector = &mocks.Redirector{}
harness.session = mocks.NewClientRW()
harness.storer = mocks.NewServerStorer()
harness.ab.Modules.OAuth2Providers = testProviders
harness.ab.Paths.OAuth2LoginOK = "/auth/oauth2/ok"
harness.ab.Paths.OAuth2LoginNotOK = "/auth/oauth2/not/ok"
harness.ab.Config.Core.Logger = mocks.Logger{}
harness.ab.Config.Core.Redirector = harness.redirector
harness.ab.Config.Storage.SessionState = harness.session
harness.ab.Config.Storage.Server = harness.storer
harness.oauth = &OAuth2{harness.ab}
return harness
}
func TestStart(t *testing.T) {
t.Parallel()
h := testSetup()
rec := httptest.NewRecorder()
w := h.ab.NewResponse(rec)
r := httptest.NewRequest("GET", "/oauth2/google?cake=yes&death=no", nil)
if err := h.oauth.Start(w, r); err != nil {
t.Error(err)
}
if h.redirector.Options.Code != http.StatusTemporaryRedirect {
t.Error("code was wrong:", h.redirector.Options.Code)
}
url, err := url.Parse(h.redirector.Options.RedirectPath)
if err != nil {
t.Fatal(err)
}
query := url.Query()
if state := query.Get("state"); len(state) == 0 {
t.Error("our nonce should have been here")
}
if callback := query.Get("redirect_uri"); callback != "https://www.example.com/auth/oauth2/callback/google" {
t.Error("callback was wrong:", callback)
}
if clientID := query.Get("client_id"); clientID != "jazz" {
t.Error("clientID was wrong:", clientID)
}
if url.Host != "accounts.google.com" {
t.Error("host was wrong:", url.Host)
}
if h.session.ClientValues[authboss.SessionOAuth2State] != query.Get("state") {
t.Error("the state should have been saved in the session")
}
if v := h.session.ClientValues[authboss.SessionOAuth2Params]; v != `{"cake":"yes","death":"no"}` {
t.Error("oauth2 session params are wrong:", v)
}
}
func TestStartBadProvider(t *testing.T) {
t.Parallel()
h := testSetup()
rec := httptest.NewRecorder()
w := h.ab.NewResponse(rec)
r := httptest.NewRequest("GET", "/oauth2/test", nil)
err := h.oauth.Start(w, r)
if e := err.Error(); !strings.Contains(e, `provider "test" not found`) {
t.Error("it should have errored:", e)
}
}
func TestEnd(t *testing.T) {
t.Parallel()
h := testSetup()
rec := httptest.NewRecorder()
w := h.ab.NewResponse(rec)
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state", nil))
if err != nil {
t.Fatal(err)
}
if err := h.oauth.End(w, r); err != nil {
t.Error(err)
}
w.WriteHeader(http.StatusOK) // Flush headers
opts := h.redirector.Options
if opts.Code != http.StatusTemporaryRedirect {
t.Error("it should have redirected")
}
if opts.RedirectPath != "/auth/oauth2/ok" {
t.Error("redir path was wrong:", opts.RedirectPath)
}
if s := h.session.ClientValues[authboss.SessionKey]; s != "oauth2;;google;;id" {
t.Error("session id should have been set:", s)
}
}
func TestEndBadProvider(t *testing.T) {
t.Parallel()
h := testSetup()
rec := httptest.NewRecorder()
w := h.ab.NewResponse(rec)
r := httptest.NewRequest("GET", "/oauth2/callback/test", nil)
err := h.oauth.End(w, r)
if e := err.Error(); !strings.Contains(e, `provider "test" not found`) {
t.Error("it should have errored:", e)
}
}
func TestEndBadState(t *testing.T) {
t.Parallel()
h := testSetup()
rec := httptest.NewRecorder()
w := h.ab.NewResponse(rec)
r := httptest.NewRequest("GET", "/oauth2/callback/google", nil)
err := h.oauth.End(w, r)
if e := err.Error(); !strings.Contains(e, `oauth2 endpoint hit without session state`) {
t.Error("it should have errored:", e)
}
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
r, err = h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=x", nil))
if err != nil {
t.Fatal(err)
}
if err := h.oauth.End(w, r); err != errOAuthStateValidation {
t.Error("error was wrong:", err)
}
}
func TestEndErrors(t *testing.T) {
t.Parallel()
h := testSetup()
rec := httptest.NewRecorder()
w := h.ab.NewResponse(rec)
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state&error=badtimes&error_reason=reason", nil))
if err != nil {
t.Fatal(err)
}
if err := h.oauth.End(w, r); err != nil {
t.Error(err)
}
opts := h.redirector.Options
if opts.Code != http.StatusTemporaryRedirect {
t.Error("code was wrong:", opts.Code)
}
if opts.RedirectPath != "/auth/oauth2/not/ok" {
t.Error("path was wrong:", opts.RedirectPath)
}
}
func TestEndHandling(t *testing.T) {
t.Parallel()
t.Run("AfterOAuth2Fail", func(t *testing.T) {
h := testSetup()
rec := httptest.NewRecorder()
w := h.ab.NewResponse(rec)
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state&error=badtimes&error_reason=reason", nil))
if err != nil {
t.Fatal(err)
}
called := false
h.ab.Events.After(authboss.EventOAuth2Fail, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
called = true
return true, nil
})
if err := h.oauth.End(w, r); err != nil {
t.Error(err)
}
if !called {
t.Error("it should have been called")
}
if h.redirector.Options.Code != 0 {
t.Error("it should not have tried to redirect")
}
})
t.Run("BeforeOAuth2", func(t *testing.T) {
h := testSetup()
rec := httptest.NewRecorder()
w := h.ab.NewResponse(rec)
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state", nil))
if err != nil {
t.Fatal(err)
}
called := false
h.ab.Events.Before(authboss.EventOAuth2, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
called = true
return true, nil
})
if err := h.oauth.End(w, r); err != nil {
t.Error(err)
}
w.WriteHeader(http.StatusOK) // Flush headers
if !called {
t.Error("it should have been called")
}
if h.redirector.Options.Code != 0 {
t.Error("it should not have tried to redirect")
}
if len(h.session.ClientValues[authboss.SessionKey]) != 0 {
t.Error("should have not logged the user in")
}
})
t.Run("AfterOAuth2", func(t *testing.T) {
h := testSetup()
rec := httptest.NewRecorder()
w := h.ab.NewResponse(rec)
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state", nil))
if err != nil {
t.Fatal(err)
}
called := false
h.ab.Events.After(authboss.EventOAuth2, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
called = true
return true, nil
})
if err := h.oauth.End(w, r); err != nil {
t.Error(err)
}
w.WriteHeader(http.StatusOK) // Flush headers
if !called {
t.Error("it should have been called")
}
if h.redirector.Options.Code != 0 {
t.Error("it should not have tried to redirect")
}
if s := h.session.ClientValues[authboss.SessionKey]; s != "oauth2;;google;;id" {
t.Error("session id should have been set:", s)
}
})
}
/*
func TestInitialize(t *testing.T) {
t.Parallel()
@@ -281,45 +617,4 @@ func TestOAuthFailure(t *testing.T) {
t.Error("It should record the failure.")
}
}
func TestLogout(t *testing.T) {
t.Parallel()
ab := authboss.New()
oauth := OAuth2{ab}
ab.AuthLogoutOKPath = "/dashboard"
r, _ := http.NewRequest("GET", "/oauth2/google?", nil)
w := httptest.NewRecorder()
ctx := ab.NewContext()
session := mocks.NewMockClientStorer(authboss.SessionKey, "asdf", authboss.SessionLastAction, "1234")
cookies := mocks.NewMockClientStorer(authboss.CookieRemember, "qwert")
ctx.SessionStorer = session
ctx.CookieStorer = cookies
if err := oauth.logout(ctx, w, r); err != nil {
t.Error(err)
}
if val, ok := session.Get(authboss.SessionKey); ok {
t.Error("Unexpected session key:", val)
}
if val, ok := session.Get(authboss.SessionLastAction); ok {
t.Error("Unexpected last action:", val)
}
if val, ok := cookies.Get(authboss.CookieRemember); ok {
t.Error("Unexpected rm cookie:", val)
}
if http.StatusFound != w.Code {
t.Errorf("Expected status code %d, got %d", http.StatusFound, w.Code)
}
location := w.Header().Get("Location")
if location != ab.AuthLogoutOKPath {
t.Error("Redirect wrong:", location)
}
}
*/

View File

@@ -3,12 +3,20 @@ package oauth2
import (
"context"
"encoding/json"
"io/ioutil"
"net/http"
"github.com/volatiletech/authboss"
"github.com/pkg/errors"
"golang.org/x/oauth2"
)
// Constants for returning in the FindUserDetails call
const (
OAuth2UID = "uid"
OAuth2Email = "email"
OAuth2Name = "name"
)
const (
googleInfoEndpoint = `https://www.googleapis.com/userinfo/v2/me`
facebookInfoEndpoint = `https://graph.facebook.com/me?fields=name,email`
@@ -22,8 +30,8 @@ type googleMeResponse struct {
// testing
var clientGet = (*http.Client).Get
// Google is a callback appropriate for use with Google's OAuth2 configuration.
func Google(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) {
// GoogleUserDetails can be used as a FindUserDetails function for an authboss.OAuth2Provider
func GoogleUserDetails(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) {
client := cfg.Client(ctx, token)
resp, err := clientGet(client, googleInfoEndpoint)
if err != nil {
@@ -31,15 +39,19 @@ func Google(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[st
}
defer resp.Body.Close()
dec := json.NewDecoder(resp.Body)
var jsonResp googleMeResponse
if err = dec.Decode(&jsonResp); err != nil {
byt, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read body from google oauth2 endpoint")
}
var response googleMeResponse
if err = json.Unmarshal(byt, &response); err != nil {
return nil, err
}
return map[string]string{
authboss.StoreOAuth2UID: jsonResp.ID,
authboss.StoreEmail: jsonResp.Email,
OAuth2UID: response.ID,
OAuth2Email: response.Email,
}, nil
}
@@ -49,8 +61,8 @@ type facebookMeResponse struct {
Name string `json:"name"`
}
// Facebook is a callback appropriate for use with Facebook's OAuth2 configuration.
func Facebook(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) {
// FacebookUserDetails can be used as a FindUserDetails function for an authboss.OAuth2Provider
func FacebookUserDetails(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) {
client := cfg.Client(ctx, token)
resp, err := clientGet(client, facebookInfoEndpoint)
if err != nil {
@@ -58,15 +70,19 @@ func Facebook(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[
}
defer resp.Body.Close()
dec := json.NewDecoder(resp.Body)
var jsonResp facebookMeResponse
if err = dec.Decode(&jsonResp); err != nil {
return nil, err
byt, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read body from facebook oauth2 endpoint")
}
var response facebookMeResponse
if err = json.Unmarshal(byt, &response); err != nil {
return nil, errors.Wrap(err, "failed to parse json from facebook oauth2 endpoint")
}
return map[string]string{
"name": jsonResp.Name,
authboss.StoreOAuth2UID: jsonResp.ID,
authboss.StoreEmail: jsonResp.Email,
OAuth2UID: response.ID,
OAuth2Email: response.Email,
OAuth2Name: response.Name,
}, nil
}

View File

@@ -8,21 +8,21 @@ import (
"testing"
"time"
"github.com/volatiletech/authboss"
"golang.org/x/oauth2"
)
func TestGoogle(t *testing.T) {
saveClientGet := clientGet
defer func() {
clientGet = saveClientGet
}()
func init() {
// This has an extra parameter that the Google client wouldn't normally get, but it'll safely be
// ignored.
clientGet = func(_ *http.Client, url string) (*http.Response, error) {
return &http.Response{
Body: ioutil.NopCloser(strings.NewReader(`{"id":"id", "email":"email"}`)),
Body: ioutil.NopCloser(strings.NewReader(`{"id":"id", "email":"email", "name": "name"}`)),
}, nil
}
}
func TestGoogle(t *testing.T) {
t.Parallel()
cfg := *testProviders["google"].OAuth2Config
tok := &oauth2.Token{
@@ -32,30 +32,21 @@ func TestGoogle(t *testing.T) {
Expiry: time.Now().Add(60 * time.Minute),
}
user, err := Google(context.TODO(), cfg, tok)
details, err := GoogleUserDetails(context.Background(), cfg, tok)
if err != nil {
t.Error(err)
}
if uid, ok := user[authboss.StoreOAuth2UID]; !ok || uid != "id" {
if uid, ok := details[OAuth2UID]; !ok || uid != "id" {
t.Error("UID wrong:", uid)
}
if email, ok := user[authboss.StoreEmail]; !ok || email != "email" {
if email, ok := details[OAuth2Email]; !ok || email != "email" {
t.Error("Email wrong:", email)
}
}
func TestFacebook(t *testing.T) {
saveClientGet := clientGet
defer func() {
clientGet = saveClientGet
}()
clientGet = func(_ *http.Client, url string) (*http.Response, error) {
return &http.Response{
Body: ioutil.NopCloser(strings.NewReader(`{"id":"id", "email":"email", "name":"name"}`)),
}, nil
}
t.Parallel()
cfg := *testProviders["facebook"].OAuth2Config
tok := &oauth2.Token{
@@ -65,18 +56,18 @@ func TestFacebook(t *testing.T) {
Expiry: time.Now().Add(60 * time.Minute),
}
user, err := Facebook(context.TODO(), cfg, tok)
details, err := FacebookUserDetails(context.Background(), cfg, tok)
if err != nil {
t.Error(err)
}
if uid, ok := user[authboss.StoreOAuth2UID]; !ok || uid != "id" {
if uid, ok := details[OAuth2UID]; !ok || uid != "id" {
t.Error("UID wrong:", uid)
}
if email, ok := user[authboss.StoreEmail]; !ok || email != "email" {
if email, ok := details[OAuth2Email]; !ok || email != "email" {
t.Error("Email wrong:", email)
}
if name, ok := user["name"]; !ok || name != "name" {
if name, ok := details[OAuth2Name]; !ok || name != "name" {
t.Error("Name wrong:", name)
}
}

View File

@@ -35,8 +35,7 @@ func (r *Remember) Init(ab *authboss.Authboss) error {
r.Authboss = ab
r.Events.After(authboss.EventAuth, r.RememberAfterAuth)
//TODO(aarondl): Rectify this once oauth2 is done
// r.Events.After(authboss.EventOAuth, r.RememberAfterAuth)
r.Events.After(authboss.EventOAuth, r.RememberAfterAuth)
r.Events.After(authboss.EventPasswordReset, r.AfterPasswordReset)
return nil

View File

@@ -13,22 +13,6 @@ import (
"github.com/pkg/errors"
)
// Data store constants for attribute names.
const (
StoreEmail = "email"
StoreUsername = "username"
StorePassword = "password"
)
// Data store constants for OAuth2 attribute names.
const (
StoreOAuth2UID = "oauth2_uid"
StoreOAuth2Provider = "oauth2_provider"
StoreOAuth2Token = "oauth2_token"
StoreOAuth2Refresh = "oauth2_refresh"
StoreOAuth2Expiry = "oauth2_expiry"
)
var (
// ErrUserFound should be returned from Create (see ConfirmUser) when the primaryID
// of the record is found.
@@ -52,7 +36,7 @@ type ServerStorer interface {
}
// CreatingServerStorer is used for creating new users
// like when Registration is being done.
// like when Registration or OAuth2 is being done.
type CreatingServerStorer interface {
ServerStorer
@@ -64,6 +48,34 @@ type CreatingServerStorer interface {
Create(ctx context.Context, user User) error
}
// OAuth2ServerStorer has the ability to create users from data from the provider.
type OAuth2ServerStorer interface {
ServerStorer
// NewFromOAuth2 should return an OAuth2User from a set
// of details returned from OAuth2Provider.FindUserDetails
// A more in-depth explanation is that once we've got an access token
// for the service in question (say a service that rhymes with book)
// the FindUserDetails function does an http request to a known endpoint that
// provides details about the user, those details are captured in a generic
// way as map[string]string and passed into this function to be turned
// into a real user.
//
// It's possible that the user exists in the database already, and so
// an attempt should be made to look that user up using the details.
// Any details that have changed should be updated. Do not save the user
// since that will be done by a later call to OAuth2ServerStorer.SaveOAuth2()
NewFromOAuth2(ctx context.Context, provider string, details map[string]string) (OAuth2User, error)
// SaveOAuth2 has different semantics from the typical ServerStorer.Save, in this case
// we want to insert a user if they do not exist. The difference must be made clear because
// in the non-oauth2 case, we know exactly when we want to Create vs Update. However
// since we're simply trying to persist a user that may have been in our database, but if not
// should already be (since you can think of the operation as a caching of what's on the oauth2 provider's
// servers).
SaveOAuth2(ctx context.Context, user OAuth2User) error
}
// ConfirmingServerStorer can find a user by a confirm token
type ConfirmingServerStorer interface {
ServerStorer
@@ -84,6 +96,8 @@ type RecoveringServerStorer interface {
// RememberingServerStorer allows users to be remembered across sessions
type RememberingServerStorer interface {
ServerStorer
// AddRememberToken to a user
AddRememberToken(pid, token string) error
// DelRememberTokens removes all tokens for the given pid
@@ -132,3 +146,13 @@ func EnsureCanRemember(storer ServerStorer) RememberingServerStorer {
return s
}
// EnsureCanOAuth2 makes sure the server storer supports oauth2 creation and lookup
func EnsureCanOAuth2(storer ServerStorer) OAuth2ServerStorer {
s, ok := storer.(OAuth2ServerStorer)
if !ok {
panic("could not upgrade ServerStorer to OAuth2ServerStorer, check your struct")
}
return s
}

View File

@@ -4,9 +4,9 @@ package authboss
import "strconv"
const _Event_name = "EventRegisterEventAuthEventOAuthEventAuthFailEventOAuthFailEventRecoverStartEventRecoverEndEventGetUserEventGetUserSessionEventPasswordReset"
const _Event_name = "EventRegisterEventAuthEventOAuth2EventAuthFailEventOAuth2FailEventRecoverStartEventRecoverEndEventGetUserEventGetUserSessionEventPasswordReset"
var _Event_index = [...]uint8{0, 13, 22, 32, 45, 59, 76, 91, 103, 122, 140}
var _Event_index = [...]uint8{0, 13, 22, 33, 46, 61, 78, 93, 105, 124, 142}
func (i Event) String() string {
if i < 0 || i >= Event(len(_Event_index)-1) {

53
user.go
View File

@@ -2,6 +2,7 @@ package authboss
import (
"fmt"
"strings"
"time"
)
@@ -81,6 +82,8 @@ type ArbitraryUser interface {
}
// OAuth2User allows reading and writing values relating to OAuth2
// Also see MakeOAuthPID/ParseOAuthPID for helpers to fullfill the User
// part of the interface.
type OAuth2User interface {
User
@@ -88,17 +91,17 @@ type OAuth2User interface {
// oauth2 user.
IsOAuth2User() bool
GetUID() (uid string)
GetProvider() (provider string)
GetToken() (token string)
GetRefreshToken() (refreshToken string)
GetExpiry() (expiry time.Duration)
GetOAuth2UID() (uid string)
GetOAuth2Provider() (provider string)
GetOAuth2AccessToken() (token string)
GetOAuth2RefreshToken() (refreshToken string)
GetOAuth2Expiry() (expiry time.Time)
PutUID(uid string)
PutProvider(provider string)
PutToken(token string)
PutRefreshToken(refreshToken string)
PutExpiry(expiry time.Duration)
PutOAuth2UID(uid string)
PutOAuth2Provider(provider string)
PutOAuth2AccessToken(token string)
PutOAuth2RefreshToken(refreshToken string)
PutOAuth2Expiry(expiry time.Time)
}
// MustBeAuthable forces an upgrade to an AuthableUser or panic.
@@ -132,3 +135,33 @@ func MustBeRecoverable(u User) RecoverableUser {
}
panic(fmt.Sprintf("could not upgrade user to a recoverable user, given type: %T", u))
}
// MustBeOAuthable forces an upgrade to an OAuth2User or panic.
func MustBeOAuthable(u User) OAuth2User {
if ou, ok := u.(OAuth2User); ok {
return ou
}
panic(fmt.Sprintf("could not upgrade user to an oauthable user, given type: %T", u))
}
// MakeOAuth2PID is used to create a pid for users that don't have
// an e-mail address or username in the normal system. This allows
// all the modules to continue to working as intended without having
// a true primary id. As well as not having to divide the regular and oauth
// stuff all down the middle.
func MakeOAuth2PID(provider, uid string) string {
return fmt.Sprintf("oauth2;;%s;;%s", provider, uid)
}
// ParseOAuth2PID returns the uid and provider for a given OAuth2 pid
func ParseOAuth2PID(pid string) (provider, uid string) {
splits := strings.Split(pid, ";;")
if len(splits) != 3 {
panic(fmt.Sprintf("failed to parse oauth2 pid, too many segments: %s", pid))
}
if splits[0] != "oauth2" {
panic(fmt.Sprintf("invalid oauth2 pid, did not start with oauth2: %s", pid))
}
return splits[1], splits[2]
}

23
user_test.go Normal file
View File

@@ -0,0 +1,23 @@
package authboss
import "testing"
func TestOAuth2PIDs(t *testing.T) {
t.Parallel()
provider := "provider"
uid := "uid"
pid := MakeOAuth2PID(provider, uid)
if pid != "oauth2;;provider;;uid" {
t.Error("pid was wrong:", pid)
}
gotProvider, gotUID := ParseOAuth2PID(pid)
if gotUID != uid {
t.Error("uid was wrong:", gotUID)
}
if gotProvider != provider {
t.Error("provider was wrong:", gotProvider)
}
}