mirror of
https://github.com/volatiletech/authboss.git
synced 2025-09-16 09:06:20 +02:00
Rewrite oauth module
- Tried to be clear about OAuth2 vs OAuth in all places. - Allow users to be locked from OAuth logins (if done manually for some reason other than failed logins) - Cleaned up some docs and wording around the previously very confusing (now hopefully only somewhat confusing) oauth2 module.
This commit is contained in:
@@ -27,6 +27,11 @@ type Config struct {
|
||||
// LogoutOK is the redirect path after a log out.
|
||||
LogoutOK string
|
||||
|
||||
// OAuth2LoginOK is the redirect path after a successful oauth2 login
|
||||
OAuth2LoginOK string
|
||||
// OAuth2LoginNotOK is the redirect path after an unsuccessful oauth2 login
|
||||
OAuth2LoginNotOK string
|
||||
|
||||
// RecoverOK is the redirect path after a successful recovery of a password.
|
||||
RecoverOK string
|
||||
|
||||
|
@@ -13,9 +13,9 @@ type Event int
|
||||
const (
|
||||
EventRegister Event = iota
|
||||
EventAuth
|
||||
EventOAuth
|
||||
EventOAuth2
|
||||
EventAuthFail
|
||||
EventOAuthFail
|
||||
EventOAuth2Fail
|
||||
EventRecoverStart
|
||||
EventRecoverEnd
|
||||
EventGetUser
|
||||
|
@@ -132,9 +132,9 @@ func TestEventString(t *testing.T) {
|
||||
}{
|
||||
{EventRegister, "EventRegister"},
|
||||
{EventAuth, "EventAuth"},
|
||||
{EventOAuth, "EventOAuth"},
|
||||
{EventOAuth2, "EventOAuth2"},
|
||||
{EventAuthFail, "EventAuthFail"},
|
||||
{EventOAuthFail, "EventOAuthFail"},
|
||||
{EventOAuth2Fail, "EventOAuth2Fail"},
|
||||
{EventRecoverStart, "EventRecoverStart"},
|
||||
{EventRecoverEnd, "EventRecoverEnd"},
|
||||
{EventGetUser, "EventGetUser"},
|
||||
|
@@ -25,9 +25,12 @@ type User struct {
|
||||
AttemptCount int
|
||||
LastAttempt time.Time
|
||||
Locked time.Time
|
||||
OAuthToken string
|
||||
OAuthRefresh string
|
||||
OAuthExpiry time.Time
|
||||
|
||||
OAuth2UID string
|
||||
OAuth2Provider string
|
||||
OAuth2Token string
|
||||
OAuth2Refresh string
|
||||
OAuth2Expiry time.Time
|
||||
|
||||
Arbitrary map[string]string
|
||||
}
|
||||
@@ -43,9 +46,12 @@ func (m User) GetConfirmed() bool { return m.Confirmed }
|
||||
func (m User) GetAttemptCount() int { return m.AttemptCount }
|
||||
func (m User) GetLastAttempt() time.Time { return m.LastAttempt }
|
||||
func (m User) GetLocked() time.Time { return m.Locked }
|
||||
func (m User) GetOAuthToken() string { return m.OAuthToken }
|
||||
func (m User) GetOAuthRefresh() string { return m.OAuthRefresh }
|
||||
func (m User) GetOAuthExpiry() time.Time { return m.OAuthExpiry }
|
||||
func (m User) IsOAuth2User() bool { return len(m.OAuth2Provider) != 0 }
|
||||
func (m User) GetOAuth2UID() string { return m.OAuth2UID }
|
||||
func (m User) GetOAuth2Provider() string { return m.OAuth2Provider }
|
||||
func (m User) GetOAuth2AccessToken() string { return m.OAuth2Token }
|
||||
func (m User) GetOAuth2RefreshToken() string { return m.OAuth2Refresh }
|
||||
func (m User) GetOAuth2Expiry() time.Time { return m.OAuth2Expiry }
|
||||
func (m User) GetArbitrary() map[string]string { return m.Arbitrary }
|
||||
|
||||
func (m *User) PutPID(email string) { m.Email = email }
|
||||
@@ -61,9 +67,11 @@ func (m *User) PutConfirmed(confirmed bool) { m.Confirmed = confirmed }
|
||||
func (m *User) PutAttemptCount(attemptCount int) { m.AttemptCount = attemptCount }
|
||||
func (m *User) PutLastAttempt(attemptTime time.Time) { m.LastAttempt = attemptTime }
|
||||
func (m *User) PutLocked(locked time.Time) { m.Locked = locked }
|
||||
func (m *User) PutOAuthToken(oAuthToken string) { m.OAuthToken = oAuthToken }
|
||||
func (m *User) PutOAuthRefresh(oAuthRefresh string) { m.OAuthRefresh = oAuthRefresh }
|
||||
func (m *User) PutOAuthExpiry(oAuthExpiry time.Time) { m.OAuthExpiry = oAuthExpiry }
|
||||
func (m *User) PutOAuth2UID(uid string) { m.OAuth2UID = uid }
|
||||
func (m *User) PutOAuth2Provider(provider string) { m.OAuth2Provider = provider }
|
||||
func (m *User) PutOAuth2AccessToken(token string) { m.OAuth2Token = token }
|
||||
func (m *User) PutOAuth2RefreshToken(refresh string) { m.OAuth2Refresh = refresh }
|
||||
func (m *User) PutOAuth2Expiry(expiry time.Time) { m.OAuth2Expiry = expiry }
|
||||
func (m *User) PutArbitrary(arb map[string]string) { m.Arbitrary = arb }
|
||||
|
||||
// ServerStorer should be valid for any module storer defined in authboss.
|
||||
@@ -115,6 +123,38 @@ func (s *ServerStorer) Save(ctx context.Context, user authboss.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewFromOAuth2 finds a user with the given details, or returns a new one
|
||||
func (s *ServerStorer) NewFromOAuth2(ctx context.Context, provider string, details map[string]string) (authboss.OAuth2User, error) {
|
||||
uid := details["uid"]
|
||||
email := details["email"]
|
||||
name := details["name"]
|
||||
pid := authboss.MakeOAuth2PID(provider, uid)
|
||||
|
||||
u, ok := s.Users[pid]
|
||||
if ok {
|
||||
u.Username = name
|
||||
u.Email = email
|
||||
return u, nil
|
||||
}
|
||||
|
||||
return &User{
|
||||
OAuth2UID: uid,
|
||||
OAuth2Provider: provider,
|
||||
Email: email,
|
||||
Username: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SaveOAuth2 creates a user if not found, or updates one that exists.
|
||||
func (s *ServerStorer) SaveOAuth2(ctx context.Context, user authboss.OAuth2User) error {
|
||||
u := user.(*User)
|
||||
|
||||
pid := authboss.MakeOAuth2PID(u.OAuth2Provider, u.OAuth2UID)
|
||||
// Since we don't have to differentiate between insert/update in a map, we just overwrite
|
||||
s.Users[pid] = u
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadByConfirmToken finds a user by his confirm token
|
||||
func (s *ServerStorer) LoadByConfirmToken(ctx context.Context, token string) (authboss.ConfirmableUser, error) {
|
||||
for _, v := range s.Users {
|
||||
|
@@ -35,8 +35,8 @@ type Lock struct {
|
||||
func (l *Lock) Init(ab *authboss.Authboss) error {
|
||||
l.Authboss = ab
|
||||
|
||||
// Events
|
||||
l.Events.Before(authboss.EventAuth, l.BeforeAuth)
|
||||
l.Events.Before(authboss.EventOAuth, l.BeforeAuth)
|
||||
l.Events.After(authboss.EventAuth, l.AfterAuthSuccess)
|
||||
l.Events.After(authboss.EventAuthFail, l.AfterAuthFail)
|
||||
|
||||
|
25
oauth2.go
25
oauth2.go
@@ -7,11 +7,6 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// FormValue constants
|
||||
const (
|
||||
FormValueOAuth2State = "state"
|
||||
)
|
||||
|
||||
/*
|
||||
OAuth2Provider is the entire configuration
|
||||
required to authenticate with this provider.
|
||||
@@ -23,22 +18,14 @@ AdditionalParams can be used to specify extra parameters to tack on to the
|
||||
end of the initial request, this allows for provider specific oauth options
|
||||
like access_type=offline to be passed to the provider.
|
||||
|
||||
Callback gives the config and the token allowing an http client using the
|
||||
authenticated token to be created. Because each OAuth2 implementation has a different
|
||||
API this must be handled for each provider separately. It is used to return two things
|
||||
specifically: UID (the ID according to the provider) and the Email address.
|
||||
The UID must be passed back or there will be an error as it is the means of identifying the
|
||||
user in the system, e-mail is optional but should be returned in systems using
|
||||
emailing. The keys authboss.StoreOAuth2UID and authboss.StoreEmail can be used to set
|
||||
these values in the map returned by the callback.
|
||||
|
||||
In addition to the required values mentioned above any additional
|
||||
values that you wish to have in your user struct can be included here, such as the
|
||||
Name of the user at the endpoint. This will be passed back in the Arbitrary()
|
||||
function if it exists.
|
||||
FindUserDetails gives the config and the token allowing an http client using the
|
||||
authenticated token to be created, a call is then made to a known endpoint that will
|
||||
return details about the user we've retrieved the token for. Those details are returned
|
||||
as a map[string]string and subsequently passed into OAuth2ServerStorer.NewFromOAuth2.
|
||||
API this must be handled for each provider separately.
|
||||
*/
|
||||
type OAuth2Provider struct {
|
||||
OAuth2Config *oauth2.Config
|
||||
AdditionalParams url.Values
|
||||
Callback func(context.Context, oauth2.Config, *oauth2.Token) (map[string]string, error)
|
||||
FindUserDetails func(context.Context, oauth2.Config, *oauth2.Token) (map[string]string, error)
|
||||
}
|
||||
|
309
oauth2/oauth2.go
309
oauth2/oauth2.go
@@ -1,6 +1,36 @@
|
||||
// Package oauth2 allows users to be created and authenticated
|
||||
// via oauth2 services like facebook, google etc. Currently
|
||||
// only the web server flow is supported.
|
||||
//
|
||||
// The general flow looks like this:
|
||||
// 1. User goes to Start handler and has his session packed with goodies
|
||||
// then redirects to the OAuth service.
|
||||
// 2. OAuth service returns to OAuthCallback which extracts state and parameters
|
||||
// and generally checks that everything is ok. It uses the token received to
|
||||
// get an access token from the oauth2 library
|
||||
// 3. Calls the OAuth2Provider.FindUserDetails which should return the user's details
|
||||
// in a generic form.
|
||||
// 4. Passes the user details into the OAuth2ServerStorer.NewFromOAuth2 in order
|
||||
// to create a user object we can work with.
|
||||
// 5. Saves the user in the database, logs them in, redirects.
|
||||
//
|
||||
// In order to do this there are a number of parts:
|
||||
// 1. The configuration of a provider (handled by authboss.Config.Modules.OAuth2Providers)
|
||||
// 2. The flow of redirection of client, parameter passing etc (handled by this package)
|
||||
// 3. The HTTP call to the service once a token has been retrieved to get user details
|
||||
// (handled by OAuth2Provider.FindUserDetails)
|
||||
// 4. The creation of a user from the user details returned from the FindUserDetails
|
||||
// (authboss.OAuth2ServerStorer)
|
||||
//
|
||||
// Of these parts, the responsibility of the authboss library consumer is on 1, 3, and 4.
|
||||
// Configuration of providers that should be used is totally up to the consumer. The FindUserDetails
|
||||
// function is typically up to the user, but we have some basic ones included in this package too.
|
||||
// The creation of users from the FindUserDetail's map[string]string return is handled as part
|
||||
// of the implementation of the OAuth2ServerStorer.
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
@@ -9,13 +39,19 @@ import (
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/volatiletech/authboss"
|
||||
"github.com/volatiletech/authboss/internal/response"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// FormValue constants
|
||||
const (
|
||||
FormValueOAuth2State = "state"
|
||||
FormValueOAuth2Redir = "redir"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -31,68 +67,60 @@ func init() {
|
||||
authboss.RegisterModule("oauth2", &OAuth2{})
|
||||
}
|
||||
|
||||
// Initialize module
|
||||
func (o *OAuth2) Initialize(ab *authboss.Authboss) error {
|
||||
// Init module
|
||||
func (o *OAuth2) Init(ab *authboss.Authboss) error {
|
||||
o.Authboss = ab
|
||||
if o.OAuth2Storer == nil && o.OAuth2StoreMaker == nil {
|
||||
return errors.New("need an oauth2Storer")
|
||||
|
||||
// Do annoying sorting on keys so we can have predictible
|
||||
// route registration (both for consistency inside the router but also for tests -_-)
|
||||
var keys []string
|
||||
for k := range o.Authboss.Config.Modules.OAuth2Providers {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, provider := range keys {
|
||||
cfg := o.Authboss.Config.Modules.OAuth2Providers[provider]
|
||||
provider = strings.ToLower(provider)
|
||||
|
||||
init := fmt.Sprintf("/oauth2/%s", provider)
|
||||
callback := fmt.Sprintf("/oauth2/callback/%s", provider)
|
||||
|
||||
o.Authboss.Config.Core.Router.Get(init, o.Authboss.Core.ErrorHandler.Wrap(o.Start))
|
||||
o.Authboss.Config.Core.Router.Get(callback, o.Authboss.Core.ErrorHandler.Wrap(o.End))
|
||||
|
||||
if mount := o.Authboss.Config.Paths.Mount; len(mount) > 0 {
|
||||
callback = path.Join(mount, callback)
|
||||
}
|
||||
|
||||
cfg.OAuth2Config.RedirectURL = o.Authboss.Config.Paths.RootURL + callback
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Routes for module
|
||||
func (o *OAuth2) Routes() authboss.RouteTable {
|
||||
routes := make(authboss.RouteTable)
|
||||
// Start the oauth2 process
|
||||
func (o *OAuth2) Start(w http.ResponseWriter, r *http.Request) error {
|
||||
logger := o.Authboss.RequestLogger(r)
|
||||
|
||||
for prov, cfg := range o.OAuth2Providers {
|
||||
prov = strings.ToLower(prov)
|
||||
|
||||
init := fmt.Sprintf("/oauth2/%s", prov)
|
||||
callback := fmt.Sprintf("/oauth2/callback/%s", prov)
|
||||
|
||||
routes[init] = o.oauthInit
|
||||
routes[callback] = o.oauthCallback
|
||||
|
||||
if len(o.MountPath) > 0 {
|
||||
callback = path.Join(o.MountPath, callback)
|
||||
}
|
||||
|
||||
cfg.OAuth2Config.RedirectURL = o.RootURL + callback
|
||||
}
|
||||
|
||||
routes["/oauth2/logout"] = o.logout
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
// Storage requirements
|
||||
func (o *OAuth2) Storage() authboss.StorageOptions {
|
||||
return authboss.StorageOptions{
|
||||
authboss.StoreEmail: authboss.String,
|
||||
authboss.StoreOAuth2UID: authboss.String,
|
||||
authboss.StoreOAuth2Provider: authboss.String,
|
||||
authboss.StoreOAuth2Token: authboss.String,
|
||||
authboss.StoreOAuth2Refresh: authboss.String,
|
||||
authboss.StoreOAuth2Expiry: authboss.DateTime,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OAuth2) oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
||||
provider := strings.ToLower(filepath.Base(r.URL.Path))
|
||||
cfg, ok := o.OAuth2Providers[provider]
|
||||
logger.Infof("started oauth2 flow for provider: %s", provider)
|
||||
cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider]
|
||||
if !ok {
|
||||
return errors.Errorf("OAuth2 provider %q not found", provider)
|
||||
return errors.Errorf("oauth2 provider %q not found", provider)
|
||||
}
|
||||
|
||||
random := make([]byte, 32)
|
||||
_, err := rand.Read(random)
|
||||
if err != nil {
|
||||
return err
|
||||
// Create nonce
|
||||
nonce := make([]byte, 32)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return errors.Wrap(err, "failed to create nonce")
|
||||
}
|
||||
|
||||
state := base64.URLEncoding.EncodeToString(random)
|
||||
ctx.SessionStorer.Put(authboss.SessionOAuth2State, state)
|
||||
state := base64.URLEncoding.EncodeToString(nonce)
|
||||
authboss.PutSession(w, authboss.SessionOAuth2State, state)
|
||||
|
||||
// This clearly ignores the fact that query parameters can have multiple
|
||||
// values but I guess we're ignoring that
|
||||
passAlongs := make(map[string]string)
|
||||
for k, vals := range r.URL.Query() {
|
||||
for _, val := range vals {
|
||||
@@ -101,13 +129,13 @@ func (o *OAuth2) oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http
|
||||
}
|
||||
|
||||
if len(passAlongs) > 0 {
|
||||
str, err := json.Marshal(passAlongs)
|
||||
byt, err := json.Marshal(passAlongs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.SessionStorer.Put(authboss.SessionOAuth2Params, string(str))
|
||||
authboss.PutSession(w, authboss.SessionOAuth2Params, string(byt))
|
||||
} else {
|
||||
ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
|
||||
authboss.DelSession(w, authboss.SessionOAuth2Params)
|
||||
}
|
||||
|
||||
url := cfg.OAuth2Config.AuthCodeURL(state)
|
||||
@@ -117,128 +145,153 @@ func (o *OAuth2) oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http
|
||||
url = fmt.Sprintf("%s&%s", url, extraParams)
|
||||
}
|
||||
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
return nil
|
||||
ro := authboss.RedirectOptions{
|
||||
Code: http.StatusTemporaryRedirect,
|
||||
RedirectPath: url,
|
||||
}
|
||||
return o.Authboss.Core.Redirector.Redirect(w, r, ro)
|
||||
}
|
||||
|
||||
// for testing
|
||||
// for testing, mocked out at the beginning
|
||||
var exchanger = (*oauth2.Config).Exchange
|
||||
|
||||
func (o *OAuth2) oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
||||
// End the oauth2 process, this is the handler for the oauth2 callback
|
||||
// that the third party will redirect to.
|
||||
func (o *OAuth2) End(w http.ResponseWriter, r *http.Request) error {
|
||||
logger := o.Authboss.RequestLogger(r)
|
||||
provider := strings.ToLower(filepath.Base(r.URL.Path))
|
||||
logger.Infof("finishing oauth2 flow for provider: %s", provider)
|
||||
|
||||
sessState, err := ctx.SessionStorer.GetErr(authboss.SessionOAuth2State)
|
||||
ctx.SessionStorer.Del(authboss.SessionOAuth2State)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessValues, ok := ctx.SessionStorer.Get(authboss.SessionOAuth2Params)
|
||||
// Don't delete this value from session immediately, Events use this too
|
||||
var values map[string]string
|
||||
if ok {
|
||||
if err := json.Unmarshal([]byte(sessValues), &values); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
hasErr := r.FormValue("error")
|
||||
if len(hasErr) > 0 {
|
||||
if err := o.Events.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return authboss.ErrAndRedirect{
|
||||
Err: errors.New(r.FormValue("error_reason")),
|
||||
Location: o.AuthLoginFailPath,
|
||||
FlashError: fmt.Sprintf("%s login cancelled or failed.", strings.Title(provider)),
|
||||
}
|
||||
}
|
||||
|
||||
cfg, ok := o.OAuth2Providers[provider]
|
||||
// This shouldn't happen because the router should 404 first, but just in case
|
||||
cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider]
|
||||
if !ok {
|
||||
return errors.Errorf("oauth2 provider %q not found", provider)
|
||||
}
|
||||
|
||||
// Ensure request is genuine
|
||||
state := r.FormValue(authboss.FormValueOAuth2State)
|
||||
splState := strings.Split(state, ";")
|
||||
if len(splState) == 0 || splState[0] != sessState {
|
||||
wantState, ok := authboss.GetSession(r, authboss.SessionOAuth2State)
|
||||
if !ok {
|
||||
return errors.New("oauth2 endpoint hit without session state")
|
||||
}
|
||||
|
||||
// Verify we got the same state in the session as was passed to us in the
|
||||
// query parameter.
|
||||
state := r.FormValue(FormValueOAuth2State)
|
||||
if state != wantState {
|
||||
return errOAuthStateValidation
|
||||
}
|
||||
|
||||
// Get the code
|
||||
rawParams, ok := authboss.GetSession(r, authboss.SessionOAuth2Params)
|
||||
var params map[string]string
|
||||
if ok {
|
||||
if err := json.Unmarshal([]byte(rawParams), ¶ms); err != nil {
|
||||
return errors.Wrap(err, "failed to decode oauth2 params")
|
||||
}
|
||||
}
|
||||
|
||||
authboss.DelSession(w, authboss.SessionOAuth2State)
|
||||
authboss.DelSession(w, authboss.SessionOAuth2Params)
|
||||
|
||||
hasErr := r.FormValue("error")
|
||||
if len(hasErr) > 0 {
|
||||
reason := r.FormValue("error_reason")
|
||||
logger.Infof("oauth2 login failed: %s, reason: %s", hasErr, reason)
|
||||
|
||||
handled, err := o.Authboss.Events.FireAfter(authboss.EventOAuth2Fail, w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if handled {
|
||||
return nil
|
||||
}
|
||||
|
||||
ro := authboss.RedirectOptions{
|
||||
Code: http.StatusTemporaryRedirect,
|
||||
RedirectPath: o.Authboss.Config.Paths.OAuth2LoginNotOK,
|
||||
Failure: fmt.Sprintf("%s login cancelled or failed", strings.Title(provider)),
|
||||
}
|
||||
return o.Authboss.Core.Redirector.Redirect(w, r, ro)
|
||||
}
|
||||
|
||||
// Get the code which we can use to make an access token
|
||||
code := r.FormValue("code")
|
||||
token, err := exchanger(cfg.OAuth2Config, o.Config.ContextProvider(r), code)
|
||||
token, err := exchanger(cfg.OAuth2Config, r.Context(), code)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not validate oauth2 code")
|
||||
}
|
||||
|
||||
user, err := cfg.Callback(o.Config.ContextProvider(r), *cfg.OAuth2Config, token)
|
||||
details, err := cfg.FindUserDetails(r.Context(), *cfg.OAuth2Config, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// OAuth2UID is required.
|
||||
uid, err := user.StringErr(authboss.StoreOAuth2UID)
|
||||
storer := authboss.EnsureCanOAuth2(o.Authboss.Config.Storage.Server)
|
||||
user, err := storer.NewFromOAuth2(r.Context(), provider, details)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, "failed to create oauth2 user from values")
|
||||
}
|
||||
|
||||
user[authboss.StoreOAuth2UID] = uid
|
||||
user[authboss.StoreOAuth2Provider] = provider
|
||||
user[authboss.StoreOAuth2Expiry] = token.Expiry
|
||||
user[authboss.StoreOAuth2Token] = token.AccessToken
|
||||
user.PutOAuth2Provider(provider)
|
||||
user.PutOAuth2AccessToken(token.AccessToken)
|
||||
user.PutOAuth2Expiry(token.Expiry)
|
||||
if len(token.RefreshToken) != 0 {
|
||||
user[authboss.StoreOAuth2Refresh] = token.RefreshToken
|
||||
user.PutOAuth2RefreshToken(token.RefreshToken)
|
||||
}
|
||||
|
||||
if err = ctx.OAuth2Storer.PutOAuth(uid, provider, user); err != nil {
|
||||
if err := storer.SaveOAuth2(r.Context(), user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Fully log user in
|
||||
ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider))
|
||||
ctx.SessionStorer.Del(authboss.SessionHalfAuthKey)
|
||||
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user))
|
||||
|
||||
if err = o.Events.FireAfter(authboss.EventOAuth, ctx); err != nil {
|
||||
handled, err := o.Authboss.Events.FireBefore(authboss.EventOAuth2, w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if handled {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
|
||||
// Fully log user in
|
||||
authboss.PutSession(w, authboss.SessionKey, authboss.MakeOAuth2PID(provider, user.GetOAuth2UID()))
|
||||
authboss.DelSession(w, authboss.SessionHalfAuthKey)
|
||||
|
||||
redirect := o.AuthLoginOKPath
|
||||
// Create a query string from all the pieces we've received
|
||||
// as passthru from the original request.
|
||||
redirect := o.Authboss.Config.Paths.OAuth2LoginOK
|
||||
query := make(url.Values)
|
||||
for k, v := range values {
|
||||
for k, v := range params {
|
||||
switch k {
|
||||
case authboss.CookieRemember:
|
||||
case authboss.FormValueRedirect:
|
||||
if v == "true" {
|
||||
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyValues, RMTrue{}))
|
||||
}
|
||||
case FormValueOAuth2Redir:
|
||||
redirect = v
|
||||
default:
|
||||
query.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
handled, err = o.Authboss.Events.FireAfter(authboss.EventOAuth2, w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if handled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(query) > 0 {
|
||||
redirect = fmt.Sprintf("%s?%s", redirect, query.Encode())
|
||||
}
|
||||
|
||||
sf := fmt.Sprintf("Logged in successfully with %s.", strings.Title(provider))
|
||||
response.Redirect(ctx, w, r, redirect, sf, "", false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OAuth2) logout(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
ctx.SessionStorer.Del(authboss.SessionKey)
|
||||
ctx.CookieStorer.Del(authboss.CookieRemember)
|
||||
ctx.SessionStorer.Del(authboss.SessionLastAction)
|
||||
|
||||
response.Redirect(ctx, w, r, o.AuthLogoutOKPath, "You have logged out", "", true)
|
||||
default:
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
ro := authboss.RedirectOptions{
|
||||
Code: http.StatusTemporaryRedirect,
|
||||
RedirectPath: redirect,
|
||||
Success: fmt.Sprintf("Logged in successfully with %s.", strings.Title(provider)),
|
||||
}
|
||||
|
||||
return nil
|
||||
return o.Authboss.Config.Core.Redirector.Redirect(w, r, ro)
|
||||
}
|
||||
|
||||
// RMTrue is a dummy struct implementing authboss.RememberValuer
|
||||
// in order to tell the remember me module to remember them.
|
||||
type RMTrue struct{}
|
||||
|
||||
// GetShouldRemember always returns true
|
||||
func (RMTrue) GetShouldRemember() bool { return true }
|
||||
|
@@ -2,11 +2,9 @@ package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -18,6 +16,12 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
func init() {
|
||||
exchanger = func(_ *oauth2.Config, _ context.Context, _ string) (*oauth2.Token, error) {
|
||||
return testToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
var testProviders = map[string]authboss.OAuth2Provider{
|
||||
"google": authboss.OAuth2Provider{
|
||||
OAuth2Config: &oauth2.Config{
|
||||
@@ -25,8 +29,10 @@ var testProviders = map[string]authboss.OAuth2Provider{
|
||||
ClientSecret: `hands`,
|
||||
Scopes: []string{`profile`, `email`},
|
||||
Endpoint: google.Endpoint,
|
||||
// This is typically set by Init() but some tests rely on it's existence
|
||||
RedirectURL: "https://www.example.com/auth/oauth2/callback/google",
|
||||
},
|
||||
Callback: Google,
|
||||
FindUserDetails: GoogleUserDetails,
|
||||
AdditionalParams: url.Values{"include_requested_scopes": []string{"true"}},
|
||||
},
|
||||
"facebook": authboss.OAuth2Provider{
|
||||
@@ -35,11 +41,341 @@ var testProviders = map[string]authboss.OAuth2Provider{
|
||||
ClientSecret: `hands`,
|
||||
Scopes: []string{`email`},
|
||||
Endpoint: facebook.Endpoint,
|
||||
// This is typically set by Init() but some tests rely on it's existence
|
||||
RedirectURL: "https://www.example.com/auth/oauth2/callback/facebook",
|
||||
},
|
||||
Callback: Facebook,
|
||||
FindUserDetails: FacebookUserDetails,
|
||||
},
|
||||
}
|
||||
|
||||
var testToken = &oauth2.Token{
|
||||
AccessToken: "token",
|
||||
TokenType: "Bearer",
|
||||
RefreshToken: "refresh",
|
||||
Expiry: time.Now().AddDate(0, 0, 1),
|
||||
}
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
// No t.Parallel() since the cfg.RedirectURL is set in Init()
|
||||
|
||||
ab := authboss.New()
|
||||
oauth := &OAuth2{}
|
||||
|
||||
router := &mocks.Router{}
|
||||
ab.Config.Modules.OAuth2Providers = testProviders
|
||||
ab.Config.Core.Router = router
|
||||
ab.Config.Core.ErrorHandler = &mocks.ErrorHandler{}
|
||||
|
||||
ab.Config.Paths.Mount = "/auth"
|
||||
ab.Config.Paths.RootURL = "https://www.example.com"
|
||||
|
||||
if err := oauth.Init(ab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
gets := []string{
|
||||
"/oauth2/facebook", "/oauth2/callback/facebook",
|
||||
"/oauth2/google", "/oauth2/callback/google",
|
||||
}
|
||||
if err := router.HasGets(gets...); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
type testHarness struct {
|
||||
oauth *OAuth2
|
||||
ab *authboss.Authboss
|
||||
|
||||
bodyReader *mocks.BodyReader
|
||||
responder *mocks.Responder
|
||||
redirector *mocks.Redirector
|
||||
session *mocks.ClientStateRW
|
||||
storer *mocks.ServerStorer
|
||||
}
|
||||
|
||||
func testSetup() *testHarness {
|
||||
harness := &testHarness{}
|
||||
|
||||
harness.ab = authboss.New()
|
||||
harness.redirector = &mocks.Redirector{}
|
||||
harness.session = mocks.NewClientRW()
|
||||
harness.storer = mocks.NewServerStorer()
|
||||
|
||||
harness.ab.Modules.OAuth2Providers = testProviders
|
||||
|
||||
harness.ab.Paths.OAuth2LoginOK = "/auth/oauth2/ok"
|
||||
harness.ab.Paths.OAuth2LoginNotOK = "/auth/oauth2/not/ok"
|
||||
|
||||
harness.ab.Config.Core.Logger = mocks.Logger{}
|
||||
harness.ab.Config.Core.Redirector = harness.redirector
|
||||
harness.ab.Config.Storage.SessionState = harness.session
|
||||
harness.ab.Config.Storage.Server = harness.storer
|
||||
|
||||
harness.oauth = &OAuth2{harness.ab}
|
||||
|
||||
return harness
|
||||
}
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec)
|
||||
r := httptest.NewRequest("GET", "/oauth2/google?cake=yes&death=no", nil)
|
||||
|
||||
if err := h.oauth.Start(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if h.redirector.Options.Code != http.StatusTemporaryRedirect {
|
||||
t.Error("code was wrong:", h.redirector.Options.Code)
|
||||
}
|
||||
|
||||
url, err := url.Parse(h.redirector.Options.RedirectPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
query := url.Query()
|
||||
if state := query.Get("state"); len(state) == 0 {
|
||||
t.Error("our nonce should have been here")
|
||||
}
|
||||
if callback := query.Get("redirect_uri"); callback != "https://www.example.com/auth/oauth2/callback/google" {
|
||||
t.Error("callback was wrong:", callback)
|
||||
}
|
||||
if clientID := query.Get("client_id"); clientID != "jazz" {
|
||||
t.Error("clientID was wrong:", clientID)
|
||||
}
|
||||
if url.Host != "accounts.google.com" {
|
||||
t.Error("host was wrong:", url.Host)
|
||||
}
|
||||
|
||||
if h.session.ClientValues[authboss.SessionOAuth2State] != query.Get("state") {
|
||||
t.Error("the state should have been saved in the session")
|
||||
}
|
||||
if v := h.session.ClientValues[authboss.SessionOAuth2Params]; v != `{"cake":"yes","death":"no"}` {
|
||||
t.Error("oauth2 session params are wrong:", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartBadProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec)
|
||||
r := httptest.NewRequest("GET", "/oauth2/test", nil)
|
||||
|
||||
err := h.oauth.Start(w, r)
|
||||
if e := err.Error(); !strings.Contains(e, `provider "test" not found`) {
|
||||
t.Error("it should have errored:", e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec)
|
||||
|
||||
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
|
||||
r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state", nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := h.oauth.End(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK) // Flush headers
|
||||
|
||||
opts := h.redirector.Options
|
||||
if opts.Code != http.StatusTemporaryRedirect {
|
||||
t.Error("it should have redirected")
|
||||
}
|
||||
if opts.RedirectPath != "/auth/oauth2/ok" {
|
||||
t.Error("redir path was wrong:", opts.RedirectPath)
|
||||
}
|
||||
if s := h.session.ClientValues[authboss.SessionKey]; s != "oauth2;;google;;id" {
|
||||
t.Error("session id should have been set:", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndBadProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec)
|
||||
r := httptest.NewRequest("GET", "/oauth2/callback/test", nil)
|
||||
|
||||
err := h.oauth.End(w, r)
|
||||
if e := err.Error(); !strings.Contains(e, `provider "test" not found`) {
|
||||
t.Error("it should have errored:", e)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndBadState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec)
|
||||
r := httptest.NewRequest("GET", "/oauth2/callback/google", nil)
|
||||
|
||||
err := h.oauth.End(w, r)
|
||||
if e := err.Error(); !strings.Contains(e, `oauth2 endpoint hit without session state`) {
|
||||
t.Error("it should have errored:", e)
|
||||
}
|
||||
|
||||
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
|
||||
r, err = h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=x", nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := h.oauth.End(w, r); err != errOAuthStateValidation {
|
||||
t.Error("error was wrong:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec)
|
||||
|
||||
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
|
||||
r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state&error=badtimes&error_reason=reason", nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := h.oauth.End(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
opts := h.redirector.Options
|
||||
if opts.Code != http.StatusTemporaryRedirect {
|
||||
t.Error("code was wrong:", opts.Code)
|
||||
}
|
||||
if opts.RedirectPath != "/auth/oauth2/not/ok" {
|
||||
t.Error("path was wrong:", opts.RedirectPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndHandling(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AfterOAuth2Fail", func(t *testing.T) {
|
||||
h := testSetup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec)
|
||||
|
||||
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
|
||||
r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state&error=badtimes&error_reason=reason", nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
called := false
|
||||
h.ab.Events.After(authboss.EventOAuth2Fail, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
|
||||
called = true
|
||||
return true, nil
|
||||
})
|
||||
|
||||
if err := h.oauth.End(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !called {
|
||||
t.Error("it should have been called")
|
||||
}
|
||||
if h.redirector.Options.Code != 0 {
|
||||
t.Error("it should not have tried to redirect")
|
||||
}
|
||||
})
|
||||
t.Run("BeforeOAuth2", func(t *testing.T) {
|
||||
h := testSetup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec)
|
||||
|
||||
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
|
||||
r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state", nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
called := false
|
||||
h.ab.Events.Before(authboss.EventOAuth2, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
|
||||
called = true
|
||||
return true, nil
|
||||
})
|
||||
|
||||
if err := h.oauth.End(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK) // Flush headers
|
||||
|
||||
if !called {
|
||||
t.Error("it should have been called")
|
||||
}
|
||||
if h.redirector.Options.Code != 0 {
|
||||
t.Error("it should not have tried to redirect")
|
||||
}
|
||||
if len(h.session.ClientValues[authboss.SessionKey]) != 0 {
|
||||
t.Error("should have not logged the user in")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AfterOAuth2", func(t *testing.T) {
|
||||
h := testSetup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec)
|
||||
|
||||
h.session.ClientValues[authboss.SessionOAuth2State] = "state"
|
||||
r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state", nil))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
called := false
|
||||
h.ab.Events.After(authboss.EventOAuth2, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
|
||||
called = true
|
||||
return true, nil
|
||||
})
|
||||
|
||||
if err := h.oauth.End(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK) // Flush headers
|
||||
|
||||
if !called {
|
||||
t.Error("it should have been called")
|
||||
}
|
||||
if h.redirector.Options.Code != 0 {
|
||||
t.Error("it should not have tried to redirect")
|
||||
}
|
||||
if s := h.session.ClientValues[authboss.SessionKey]; s != "oauth2;;google;;id" {
|
||||
t.Error("session id should have been set:", s)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
func TestInitialize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -281,45 +617,4 @@ func TestOAuthFailure(t *testing.T) {
|
||||
t.Error("It should record the failure.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := authboss.New()
|
||||
oauth := OAuth2{ab}
|
||||
ab.AuthLogoutOKPath = "/dashboard"
|
||||
|
||||
r, _ := http.NewRequest("GET", "/oauth2/google?", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ctx := ab.NewContext()
|
||||
session := mocks.NewMockClientStorer(authboss.SessionKey, "asdf", authboss.SessionLastAction, "1234")
|
||||
cookies := mocks.NewMockClientStorer(authboss.CookieRemember, "qwert")
|
||||
ctx.SessionStorer = session
|
||||
ctx.CookieStorer = cookies
|
||||
|
||||
if err := oauth.logout(ctx, w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if val, ok := session.Get(authboss.SessionKey); ok {
|
||||
t.Error("Unexpected session key:", val)
|
||||
}
|
||||
|
||||
if val, ok := session.Get(authboss.SessionLastAction); ok {
|
||||
t.Error("Unexpected last action:", val)
|
||||
}
|
||||
|
||||
if val, ok := cookies.Get(authboss.CookieRemember); ok {
|
||||
t.Error("Unexpected rm cookie:", val)
|
||||
}
|
||||
|
||||
if http.StatusFound != w.Code {
|
||||
t.Errorf("Expected status code %d, got %d", http.StatusFound, w.Code)
|
||||
}
|
||||
|
||||
location := w.Header().Get("Location")
|
||||
if location != ab.AuthLogoutOKPath {
|
||||
t.Error("Redirect wrong:", location)
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
@@ -3,12 +3,20 @@ package oauth2
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/volatiletech/authboss"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Constants for returning in the FindUserDetails call
|
||||
const (
|
||||
OAuth2UID = "uid"
|
||||
OAuth2Email = "email"
|
||||
OAuth2Name = "name"
|
||||
)
|
||||
|
||||
const (
|
||||
googleInfoEndpoint = `https://www.googleapis.com/userinfo/v2/me`
|
||||
facebookInfoEndpoint = `https://graph.facebook.com/me?fields=name,email`
|
||||
@@ -22,8 +30,8 @@ type googleMeResponse struct {
|
||||
// testing
|
||||
var clientGet = (*http.Client).Get
|
||||
|
||||
// Google is a callback appropriate for use with Google's OAuth2 configuration.
|
||||
func Google(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) {
|
||||
// GoogleUserDetails can be used as a FindUserDetails function for an authboss.OAuth2Provider
|
||||
func GoogleUserDetails(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) {
|
||||
client := cfg.Client(ctx, token)
|
||||
resp, err := clientGet(client, googleInfoEndpoint)
|
||||
if err != nil {
|
||||
@@ -31,15 +39,19 @@ func Google(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[st
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
dec := json.NewDecoder(resp.Body)
|
||||
var jsonResp googleMeResponse
|
||||
if err = dec.Decode(&jsonResp); err != nil {
|
||||
byt, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read body from google oauth2 endpoint")
|
||||
}
|
||||
|
||||
var response googleMeResponse
|
||||
if err = json.Unmarshal(byt, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]string{
|
||||
authboss.StoreOAuth2UID: jsonResp.ID,
|
||||
authboss.StoreEmail: jsonResp.Email,
|
||||
OAuth2UID: response.ID,
|
||||
OAuth2Email: response.Email,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -49,8 +61,8 @@ type facebookMeResponse struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Facebook is a callback appropriate for use with Facebook's OAuth2 configuration.
|
||||
func Facebook(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) {
|
||||
// FacebookUserDetails can be used as a FindUserDetails function for an authboss.OAuth2Provider
|
||||
func FacebookUserDetails(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) {
|
||||
client := cfg.Client(ctx, token)
|
||||
resp, err := clientGet(client, facebookInfoEndpoint)
|
||||
if err != nil {
|
||||
@@ -58,15 +70,19 @@ func Facebook(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
dec := json.NewDecoder(resp.Body)
|
||||
var jsonResp facebookMeResponse
|
||||
if err = dec.Decode(&jsonResp); err != nil {
|
||||
return nil, err
|
||||
byt, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read body from facebook oauth2 endpoint")
|
||||
}
|
||||
|
||||
var response facebookMeResponse
|
||||
if err = json.Unmarshal(byt, &response); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse json from facebook oauth2 endpoint")
|
||||
}
|
||||
|
||||
return map[string]string{
|
||||
"name": jsonResp.Name,
|
||||
authboss.StoreOAuth2UID: jsonResp.ID,
|
||||
authboss.StoreEmail: jsonResp.Email,
|
||||
OAuth2UID: response.ID,
|
||||
OAuth2Email: response.Email,
|
||||
OAuth2Name: response.Name,
|
||||
}, nil
|
||||
}
|
||||
|
@@ -8,21 +8,21 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/volatiletech/authboss"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestGoogle(t *testing.T) {
|
||||
saveClientGet := clientGet
|
||||
defer func() {
|
||||
clientGet = saveClientGet
|
||||
}()
|
||||
|
||||
func init() {
|
||||
// This has an extra parameter that the Google client wouldn't normally get, but it'll safely be
|
||||
// ignored.
|
||||
clientGet = func(_ *http.Client, url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"id":"id", "email":"email"}`)),
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"id":"id", "email":"email", "name": "name"}`)),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := *testProviders["google"].OAuth2Config
|
||||
tok := &oauth2.Token{
|
||||
@@ -32,30 +32,21 @@ func TestGoogle(t *testing.T) {
|
||||
Expiry: time.Now().Add(60 * time.Minute),
|
||||
}
|
||||
|
||||
user, err := Google(context.TODO(), cfg, tok)
|
||||
details, err := GoogleUserDetails(context.Background(), cfg, tok)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if uid, ok := user[authboss.StoreOAuth2UID]; !ok || uid != "id" {
|
||||
if uid, ok := details[OAuth2UID]; !ok || uid != "id" {
|
||||
t.Error("UID wrong:", uid)
|
||||
}
|
||||
if email, ok := user[authboss.StoreEmail]; !ok || email != "email" {
|
||||
if email, ok := details[OAuth2Email]; !ok || email != "email" {
|
||||
t.Error("Email wrong:", email)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFacebook(t *testing.T) {
|
||||
saveClientGet := clientGet
|
||||
defer func() {
|
||||
clientGet = saveClientGet
|
||||
}()
|
||||
|
||||
clientGet = func(_ *http.Client, url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: ioutil.NopCloser(strings.NewReader(`{"id":"id", "email":"email", "name":"name"}`)),
|
||||
}, nil
|
||||
}
|
||||
t.Parallel()
|
||||
|
||||
cfg := *testProviders["facebook"].OAuth2Config
|
||||
tok := &oauth2.Token{
|
||||
@@ -65,18 +56,18 @@ func TestFacebook(t *testing.T) {
|
||||
Expiry: time.Now().Add(60 * time.Minute),
|
||||
}
|
||||
|
||||
user, err := Facebook(context.TODO(), cfg, tok)
|
||||
details, err := FacebookUserDetails(context.Background(), cfg, tok)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if uid, ok := user[authboss.StoreOAuth2UID]; !ok || uid != "id" {
|
||||
if uid, ok := details[OAuth2UID]; !ok || uid != "id" {
|
||||
t.Error("UID wrong:", uid)
|
||||
}
|
||||
if email, ok := user[authboss.StoreEmail]; !ok || email != "email" {
|
||||
if email, ok := details[OAuth2Email]; !ok || email != "email" {
|
||||
t.Error("Email wrong:", email)
|
||||
}
|
||||
if name, ok := user["name"]; !ok || name != "name" {
|
||||
if name, ok := details[OAuth2Name]; !ok || name != "name" {
|
||||
t.Error("Name wrong:", name)
|
||||
}
|
||||
}
|
||||
|
@@ -35,8 +35,7 @@ func (r *Remember) Init(ab *authboss.Authboss) error {
|
||||
r.Authboss = ab
|
||||
|
||||
r.Events.After(authboss.EventAuth, r.RememberAfterAuth)
|
||||
//TODO(aarondl): Rectify this once oauth2 is done
|
||||
// r.Events.After(authboss.EventOAuth, r.RememberAfterAuth)
|
||||
r.Events.After(authboss.EventOAuth, r.RememberAfterAuth)
|
||||
r.Events.After(authboss.EventPasswordReset, r.AfterPasswordReset)
|
||||
|
||||
return nil
|
||||
|
58
storage.go
58
storage.go
@@ -13,22 +13,6 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Data store constants for attribute names.
|
||||
const (
|
||||
StoreEmail = "email"
|
||||
StoreUsername = "username"
|
||||
StorePassword = "password"
|
||||
)
|
||||
|
||||
// Data store constants for OAuth2 attribute names.
|
||||
const (
|
||||
StoreOAuth2UID = "oauth2_uid"
|
||||
StoreOAuth2Provider = "oauth2_provider"
|
||||
StoreOAuth2Token = "oauth2_token"
|
||||
StoreOAuth2Refresh = "oauth2_refresh"
|
||||
StoreOAuth2Expiry = "oauth2_expiry"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUserFound should be returned from Create (see ConfirmUser) when the primaryID
|
||||
// of the record is found.
|
||||
@@ -52,7 +36,7 @@ type ServerStorer interface {
|
||||
}
|
||||
|
||||
// CreatingServerStorer is used for creating new users
|
||||
// like when Registration is being done.
|
||||
// like when Registration or OAuth2 is being done.
|
||||
type CreatingServerStorer interface {
|
||||
ServerStorer
|
||||
|
||||
@@ -64,6 +48,34 @@ type CreatingServerStorer interface {
|
||||
Create(ctx context.Context, user User) error
|
||||
}
|
||||
|
||||
// OAuth2ServerStorer has the ability to create users from data from the provider.
|
||||
type OAuth2ServerStorer interface {
|
||||
ServerStorer
|
||||
|
||||
// NewFromOAuth2 should return an OAuth2User from a set
|
||||
// of details returned from OAuth2Provider.FindUserDetails
|
||||
// A more in-depth explanation is that once we've got an access token
|
||||
// for the service in question (say a service that rhymes with book)
|
||||
// the FindUserDetails function does an http request to a known endpoint that
|
||||
// provides details about the user, those details are captured in a generic
|
||||
// way as map[string]string and passed into this function to be turned
|
||||
// into a real user.
|
||||
//
|
||||
// It's possible that the user exists in the database already, and so
|
||||
// an attempt should be made to look that user up using the details.
|
||||
// Any details that have changed should be updated. Do not save the user
|
||||
// since that will be done by a later call to OAuth2ServerStorer.SaveOAuth2()
|
||||
NewFromOAuth2(ctx context.Context, provider string, details map[string]string) (OAuth2User, error)
|
||||
|
||||
// SaveOAuth2 has different semantics from the typical ServerStorer.Save, in this case
|
||||
// we want to insert a user if they do not exist. The difference must be made clear because
|
||||
// in the non-oauth2 case, we know exactly when we want to Create vs Update. However
|
||||
// since we're simply trying to persist a user that may have been in our database, but if not
|
||||
// should already be (since you can think of the operation as a caching of what's on the oauth2 provider's
|
||||
// servers).
|
||||
SaveOAuth2(ctx context.Context, user OAuth2User) error
|
||||
}
|
||||
|
||||
// ConfirmingServerStorer can find a user by a confirm token
|
||||
type ConfirmingServerStorer interface {
|
||||
ServerStorer
|
||||
@@ -84,6 +96,8 @@ type RecoveringServerStorer interface {
|
||||
|
||||
// RememberingServerStorer allows users to be remembered across sessions
|
||||
type RememberingServerStorer interface {
|
||||
ServerStorer
|
||||
|
||||
// AddRememberToken to a user
|
||||
AddRememberToken(pid, token string) error
|
||||
// DelRememberTokens removes all tokens for the given pid
|
||||
@@ -132,3 +146,13 @@ func EnsureCanRemember(storer ServerStorer) RememberingServerStorer {
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// EnsureCanOAuth2 makes sure the server storer supports oauth2 creation and lookup
|
||||
func EnsureCanOAuth2(storer ServerStorer) OAuth2ServerStorer {
|
||||
s, ok := storer.(OAuth2ServerStorer)
|
||||
if !ok {
|
||||
panic("could not upgrade ServerStorer to OAuth2ServerStorer, check your struct")
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
@@ -4,9 +4,9 @@ package authboss
|
||||
|
||||
import "strconv"
|
||||
|
||||
const _Event_name = "EventRegisterEventAuthEventOAuthEventAuthFailEventOAuthFailEventRecoverStartEventRecoverEndEventGetUserEventGetUserSessionEventPasswordReset"
|
||||
const _Event_name = "EventRegisterEventAuthEventOAuth2EventAuthFailEventOAuth2FailEventRecoverStartEventRecoverEndEventGetUserEventGetUserSessionEventPasswordReset"
|
||||
|
||||
var _Event_index = [...]uint8{0, 13, 22, 32, 45, 59, 76, 91, 103, 122, 140}
|
||||
var _Event_index = [...]uint8{0, 13, 22, 33, 46, 61, 78, 93, 105, 124, 142}
|
||||
|
||||
func (i Event) String() string {
|
||||
if i < 0 || i >= Event(len(_Event_index)-1) {
|
||||
|
53
user.go
53
user.go
@@ -2,6 +2,7 @@ package authboss
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -81,6 +82,8 @@ type ArbitraryUser interface {
|
||||
}
|
||||
|
||||
// OAuth2User allows reading and writing values relating to OAuth2
|
||||
// Also see MakeOAuthPID/ParseOAuthPID for helpers to fullfill the User
|
||||
// part of the interface.
|
||||
type OAuth2User interface {
|
||||
User
|
||||
|
||||
@@ -88,17 +91,17 @@ type OAuth2User interface {
|
||||
// oauth2 user.
|
||||
IsOAuth2User() bool
|
||||
|
||||
GetUID() (uid string)
|
||||
GetProvider() (provider string)
|
||||
GetToken() (token string)
|
||||
GetRefreshToken() (refreshToken string)
|
||||
GetExpiry() (expiry time.Duration)
|
||||
GetOAuth2UID() (uid string)
|
||||
GetOAuth2Provider() (provider string)
|
||||
GetOAuth2AccessToken() (token string)
|
||||
GetOAuth2RefreshToken() (refreshToken string)
|
||||
GetOAuth2Expiry() (expiry time.Time)
|
||||
|
||||
PutUID(uid string)
|
||||
PutProvider(provider string)
|
||||
PutToken(token string)
|
||||
PutRefreshToken(refreshToken string)
|
||||
PutExpiry(expiry time.Duration)
|
||||
PutOAuth2UID(uid string)
|
||||
PutOAuth2Provider(provider string)
|
||||
PutOAuth2AccessToken(token string)
|
||||
PutOAuth2RefreshToken(refreshToken string)
|
||||
PutOAuth2Expiry(expiry time.Time)
|
||||
}
|
||||
|
||||
// MustBeAuthable forces an upgrade to an AuthableUser or panic.
|
||||
@@ -132,3 +135,33 @@ func MustBeRecoverable(u User) RecoverableUser {
|
||||
}
|
||||
panic(fmt.Sprintf("could not upgrade user to a recoverable user, given type: %T", u))
|
||||
}
|
||||
|
||||
// MustBeOAuthable forces an upgrade to an OAuth2User or panic.
|
||||
func MustBeOAuthable(u User) OAuth2User {
|
||||
if ou, ok := u.(OAuth2User); ok {
|
||||
return ou
|
||||
}
|
||||
panic(fmt.Sprintf("could not upgrade user to an oauthable user, given type: %T", u))
|
||||
}
|
||||
|
||||
// MakeOAuth2PID is used to create a pid for users that don't have
|
||||
// an e-mail address or username in the normal system. This allows
|
||||
// all the modules to continue to working as intended without having
|
||||
// a true primary id. As well as not having to divide the regular and oauth
|
||||
// stuff all down the middle.
|
||||
func MakeOAuth2PID(provider, uid string) string {
|
||||
return fmt.Sprintf("oauth2;;%s;;%s", provider, uid)
|
||||
}
|
||||
|
||||
// ParseOAuth2PID returns the uid and provider for a given OAuth2 pid
|
||||
func ParseOAuth2PID(pid string) (provider, uid string) {
|
||||
splits := strings.Split(pid, ";;")
|
||||
if len(splits) != 3 {
|
||||
panic(fmt.Sprintf("failed to parse oauth2 pid, too many segments: %s", pid))
|
||||
}
|
||||
if splits[0] != "oauth2" {
|
||||
panic(fmt.Sprintf("invalid oauth2 pid, did not start with oauth2: %s", pid))
|
||||
}
|
||||
|
||||
return splits[1], splits[2]
|
||||
}
|
||||
|
23
user_test.go
Normal file
23
user_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package authboss
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestOAuth2PIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := "provider"
|
||||
uid := "uid"
|
||||
pid := MakeOAuth2PID(provider, uid)
|
||||
|
||||
if pid != "oauth2;;provider;;uid" {
|
||||
t.Error("pid was wrong:", pid)
|
||||
}
|
||||
|
||||
gotProvider, gotUID := ParseOAuth2PID(pid)
|
||||
if gotUID != uid {
|
||||
t.Error("uid was wrong:", gotUID)
|
||||
}
|
||||
if gotProvider != provider {
|
||||
t.Error("provider was wrong:", gotProvider)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user