From 1112987bce41221c240ed3d3bda01f02be564ab8 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Thu, 8 Mar 2018 18:39:51 -0800 Subject: [PATCH] Rewrite oauth module - Tried to be clear about OAuth2 vs OAuth in all places. - Allow users to be locked from OAuth logins (if done manually for some reason other than failed logins) - Cleaned up some docs and wording around the previously very confusing (now hopefully only somewhat confusing) oauth2 module. --- config.go | 5 + events.go | 4 +- events_test.go | 4 +- internal/mocks/mocks.go | 58 +++++- lock/lock.go | 2 +- oauth2.go | 25 +-- oauth2/oauth2.go | 309 ++++++++++++++++++------------- oauth2/oauth2_test.go | 387 ++++++++++++++++++++++++++++++++++----- oauth2/providers.go | 50 +++-- oauth2/providers_test.go | 41 ++--- remember/remember.go | 3 +- storage.go | 58 ++++-- stringers.go | 4 +- user.go | 53 +++++- user_test.go | 23 +++ 15 files changed, 746 insertions(+), 280 deletions(-) create mode 100644 user_test.go diff --git a/config.go b/config.go index 2361dcf..db3fef0 100644 --- a/config.go +++ b/config.go @@ -27,6 +27,11 @@ type Config struct { // LogoutOK is the redirect path after a log out. LogoutOK string + // OAuth2LoginOK is the redirect path after a successful oauth2 login + OAuth2LoginOK string + // OAuth2LoginNotOK is the redirect path after an unsuccessful oauth2 login + OAuth2LoginNotOK string + // RecoverOK is the redirect path after a successful recovery of a password. RecoverOK string diff --git a/events.go b/events.go index d692132..18bdb8b 100644 --- a/events.go +++ b/events.go @@ -13,9 +13,9 @@ type Event int const ( EventRegister Event = iota EventAuth - EventOAuth + EventOAuth2 EventAuthFail - EventOAuthFail + EventOAuth2Fail EventRecoverStart EventRecoverEnd EventGetUser diff --git a/events_test.go b/events_test.go index 4d90c22..96e3167 100644 --- a/events_test.go +++ b/events_test.go @@ -132,9 +132,9 @@ func TestEventString(t *testing.T) { }{ {EventRegister, "EventRegister"}, {EventAuth, "EventAuth"}, - {EventOAuth, "EventOAuth"}, + {EventOAuth2, "EventOAuth2"}, {EventAuthFail, "EventAuthFail"}, - {EventOAuthFail, "EventOAuthFail"}, + {EventOAuth2Fail, "EventOAuth2Fail"}, {EventRecoverStart, "EventRecoverStart"}, {EventRecoverEnd, "EventRecoverEnd"}, {EventGetUser, "EventGetUser"}, diff --git a/internal/mocks/mocks.go b/internal/mocks/mocks.go index cd9877d..aff508d 100644 --- a/internal/mocks/mocks.go +++ b/internal/mocks/mocks.go @@ -25,9 +25,12 @@ type User struct { AttemptCount int LastAttempt time.Time Locked time.Time - OAuthToken string - OAuthRefresh string - OAuthExpiry time.Time + + OAuth2UID string + OAuth2Provider string + OAuth2Token string + OAuth2Refresh string + OAuth2Expiry time.Time Arbitrary map[string]string } @@ -43,9 +46,12 @@ func (m User) GetConfirmed() bool { return m.Confirmed } func (m User) GetAttemptCount() int { return m.AttemptCount } func (m User) GetLastAttempt() time.Time { return m.LastAttempt } func (m User) GetLocked() time.Time { return m.Locked } -func (m User) GetOAuthToken() string { return m.OAuthToken } -func (m User) GetOAuthRefresh() string { return m.OAuthRefresh } -func (m User) GetOAuthExpiry() time.Time { return m.OAuthExpiry } +func (m User) IsOAuth2User() bool { return len(m.OAuth2Provider) != 0 } +func (m User) GetOAuth2UID() string { return m.OAuth2UID } +func (m User) GetOAuth2Provider() string { return m.OAuth2Provider } +func (m User) GetOAuth2AccessToken() string { return m.OAuth2Token } +func (m User) GetOAuth2RefreshToken() string { return m.OAuth2Refresh } +func (m User) GetOAuth2Expiry() time.Time { return m.OAuth2Expiry } func (m User) GetArbitrary() map[string]string { return m.Arbitrary } func (m *User) PutPID(email string) { m.Email = email } @@ -61,9 +67,11 @@ func (m *User) PutConfirmed(confirmed bool) { m.Confirmed = confirmed } func (m *User) PutAttemptCount(attemptCount int) { m.AttemptCount = attemptCount } func (m *User) PutLastAttempt(attemptTime time.Time) { m.LastAttempt = attemptTime } func (m *User) PutLocked(locked time.Time) { m.Locked = locked } -func (m *User) PutOAuthToken(oAuthToken string) { m.OAuthToken = oAuthToken } -func (m *User) PutOAuthRefresh(oAuthRefresh string) { m.OAuthRefresh = oAuthRefresh } -func (m *User) PutOAuthExpiry(oAuthExpiry time.Time) { m.OAuthExpiry = oAuthExpiry } +func (m *User) PutOAuth2UID(uid string) { m.OAuth2UID = uid } +func (m *User) PutOAuth2Provider(provider string) { m.OAuth2Provider = provider } +func (m *User) PutOAuth2AccessToken(token string) { m.OAuth2Token = token } +func (m *User) PutOAuth2RefreshToken(refresh string) { m.OAuth2Refresh = refresh } +func (m *User) PutOAuth2Expiry(expiry time.Time) { m.OAuth2Expiry = expiry } func (m *User) PutArbitrary(arb map[string]string) { m.Arbitrary = arb } // ServerStorer should be valid for any module storer defined in authboss. @@ -115,6 +123,38 @@ func (s *ServerStorer) Save(ctx context.Context, user authboss.User) error { return nil } +// NewFromOAuth2 finds a user with the given details, or returns a new one +func (s *ServerStorer) NewFromOAuth2(ctx context.Context, provider string, details map[string]string) (authboss.OAuth2User, error) { + uid := details["uid"] + email := details["email"] + name := details["name"] + pid := authboss.MakeOAuth2PID(provider, uid) + + u, ok := s.Users[pid] + if ok { + u.Username = name + u.Email = email + return u, nil + } + + return &User{ + OAuth2UID: uid, + OAuth2Provider: provider, + Email: email, + Username: name, + }, nil +} + +// SaveOAuth2 creates a user if not found, or updates one that exists. +func (s *ServerStorer) SaveOAuth2(ctx context.Context, user authboss.OAuth2User) error { + u := user.(*User) + + pid := authboss.MakeOAuth2PID(u.OAuth2Provider, u.OAuth2UID) + // Since we don't have to differentiate between insert/update in a map, we just overwrite + s.Users[pid] = u + return nil +} + // LoadByConfirmToken finds a user by his confirm token func (s *ServerStorer) LoadByConfirmToken(ctx context.Context, token string) (authboss.ConfirmableUser, error) { for _, v := range s.Users { diff --git a/lock/lock.go b/lock/lock.go index bb99c72..d4f2769 100644 --- a/lock/lock.go +++ b/lock/lock.go @@ -35,8 +35,8 @@ type Lock struct { func (l *Lock) Init(ab *authboss.Authboss) error { l.Authboss = ab - // Events l.Events.Before(authboss.EventAuth, l.BeforeAuth) + l.Events.Before(authboss.EventOAuth, l.BeforeAuth) l.Events.After(authboss.EventAuth, l.AfterAuthSuccess) l.Events.After(authboss.EventAuthFail, l.AfterAuthFail) diff --git a/oauth2.go b/oauth2.go index e26dd03..d633bc0 100644 --- a/oauth2.go +++ b/oauth2.go @@ -7,11 +7,6 @@ import ( "golang.org/x/oauth2" ) -// FormValue constants -const ( - FormValueOAuth2State = "state" -) - /* OAuth2Provider is the entire configuration required to authenticate with this provider. @@ -23,22 +18,14 @@ AdditionalParams can be used to specify extra parameters to tack on to the end of the initial request, this allows for provider specific oauth options like access_type=offline to be passed to the provider. -Callback gives the config and the token allowing an http client using the -authenticated token to be created. Because each OAuth2 implementation has a different -API this must be handled for each provider separately. It is used to return two things -specifically: UID (the ID according to the provider) and the Email address. -The UID must be passed back or there will be an error as it is the means of identifying the -user in the system, e-mail is optional but should be returned in systems using -emailing. The keys authboss.StoreOAuth2UID and authboss.StoreEmail can be used to set -these values in the map returned by the callback. - -In addition to the required values mentioned above any additional -values that you wish to have in your user struct can be included here, such as the -Name of the user at the endpoint. This will be passed back in the Arbitrary() -function if it exists. +FindUserDetails gives the config and the token allowing an http client using the +authenticated token to be created, a call is then made to a known endpoint that will +return details about the user we've retrieved the token for. Those details are returned +as a map[string]string and subsequently passed into OAuth2ServerStorer.NewFromOAuth2. +API this must be handled for each provider separately. */ type OAuth2Provider struct { OAuth2Config *oauth2.Config AdditionalParams url.Values - Callback func(context.Context, oauth2.Config, *oauth2.Token) (map[string]string, error) + FindUserDetails func(context.Context, oauth2.Config, *oauth2.Token) (map[string]string, error) } diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index beb1d08..0c92ef0 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -1,6 +1,36 @@ +// Package oauth2 allows users to be created and authenticated +// via oauth2 services like facebook, google etc. Currently +// only the web server flow is supported. +// +// The general flow looks like this: +// 1. User goes to Start handler and has his session packed with goodies +// then redirects to the OAuth service. +// 2. OAuth service returns to OAuthCallback which extracts state and parameters +// and generally checks that everything is ok. It uses the token received to +// get an access token from the oauth2 library +// 3. Calls the OAuth2Provider.FindUserDetails which should return the user's details +// in a generic form. +// 4. Passes the user details into the OAuth2ServerStorer.NewFromOAuth2 in order +// to create a user object we can work with. +// 5. Saves the user in the database, logs them in, redirects. +// +// In order to do this there are a number of parts: +// 1. The configuration of a provider (handled by authboss.Config.Modules.OAuth2Providers) +// 2. The flow of redirection of client, parameter passing etc (handled by this package) +// 3. The HTTP call to the service once a token has been retrieved to get user details +// (handled by OAuth2Provider.FindUserDetails) +// 4. The creation of a user from the user details returned from the FindUserDetails +// (authboss.OAuth2ServerStorer) +// +// Of these parts, the responsibility of the authboss library consumer is on 1, 3, and 4. +// Configuration of providers that should be used is totally up to the consumer. The FindUserDetails +// function is typically up to the user, but we have some basic ones included in this package too. +// The creation of users from the FindUserDetail's map[string]string return is handled as part +// of the implementation of the OAuth2ServerStorer. package oauth2 import ( + "context" "crypto/rand" "encoding/base64" "encoding/json" @@ -9,13 +39,19 @@ import ( "net/url" "path" "path/filepath" + "sort" "strings" "github.com/pkg/errors" + "golang.org/x/oauth2" "github.com/volatiletech/authboss" - "github.com/volatiletech/authboss/internal/response" - "golang.org/x/oauth2" +) + +// FormValue constants +const ( + FormValueOAuth2State = "state" + FormValueOAuth2Redir = "redir" ) var ( @@ -31,68 +67,60 @@ func init() { authboss.RegisterModule("oauth2", &OAuth2{}) } -// Initialize module -func (o *OAuth2) Initialize(ab *authboss.Authboss) error { +// Init module +func (o *OAuth2) Init(ab *authboss.Authboss) error { o.Authboss = ab - if o.OAuth2Storer == nil && o.OAuth2StoreMaker == nil { - return errors.New("need an oauth2Storer") + + // Do annoying sorting on keys so we can have predictible + // route registration (both for consistency inside the router but also for tests -_-) + var keys []string + for k := range o.Authboss.Config.Modules.OAuth2Providers { + keys = append(keys, k) } + sort.Strings(keys) + + for _, provider := range keys { + cfg := o.Authboss.Config.Modules.OAuth2Providers[provider] + provider = strings.ToLower(provider) + + init := fmt.Sprintf("/oauth2/%s", provider) + callback := fmt.Sprintf("/oauth2/callback/%s", provider) + + o.Authboss.Config.Core.Router.Get(init, o.Authboss.Core.ErrorHandler.Wrap(o.Start)) + o.Authboss.Config.Core.Router.Get(callback, o.Authboss.Core.ErrorHandler.Wrap(o.End)) + + if mount := o.Authboss.Config.Paths.Mount; len(mount) > 0 { + callback = path.Join(mount, callback) + } + + cfg.OAuth2Config.RedirectURL = o.Authboss.Config.Paths.RootURL + callback + } + return nil } -// Routes for module -func (o *OAuth2) Routes() authboss.RouteTable { - routes := make(authboss.RouteTable) +// Start the oauth2 process +func (o *OAuth2) Start(w http.ResponseWriter, r *http.Request) error { + logger := o.Authboss.RequestLogger(r) - for prov, cfg := range o.OAuth2Providers { - prov = strings.ToLower(prov) - - init := fmt.Sprintf("/oauth2/%s", prov) - callback := fmt.Sprintf("/oauth2/callback/%s", prov) - - routes[init] = o.oauthInit - routes[callback] = o.oauthCallback - - if len(o.MountPath) > 0 { - callback = path.Join(o.MountPath, callback) - } - - cfg.OAuth2Config.RedirectURL = o.RootURL + callback - } - - routes["/oauth2/logout"] = o.logout - - return routes -} - -// Storage requirements -func (o *OAuth2) Storage() authboss.StorageOptions { - return authboss.StorageOptions{ - authboss.StoreEmail: authboss.String, - authboss.StoreOAuth2UID: authboss.String, - authboss.StoreOAuth2Provider: authboss.String, - authboss.StoreOAuth2Token: authboss.String, - authboss.StoreOAuth2Refresh: authboss.String, - authboss.StoreOAuth2Expiry: authboss.DateTime, - } -} - -func (o *OAuth2) oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error { provider := strings.ToLower(filepath.Base(r.URL.Path)) - cfg, ok := o.OAuth2Providers[provider] + logger.Infof("started oauth2 flow for provider: %s", provider) + cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider] if !ok { - return errors.Errorf("OAuth2 provider %q not found", provider) + return errors.Errorf("oauth2 provider %q not found", provider) } - random := make([]byte, 32) - _, err := rand.Read(random) - if err != nil { - return err + // Create nonce + nonce := make([]byte, 32) + if _, err := rand.Read(nonce); err != nil { + return errors.Wrap(err, "failed to create nonce") } - state := base64.URLEncoding.EncodeToString(random) - ctx.SessionStorer.Put(authboss.SessionOAuth2State, state) + state := base64.URLEncoding.EncodeToString(nonce) + authboss.PutSession(w, authboss.SessionOAuth2State, state) + // This clearly ignores the fact that query parameters can have multiple + // values but I guess we're ignoring that passAlongs := make(map[string]string) for k, vals := range r.URL.Query() { for _, val := range vals { @@ -101,13 +129,13 @@ func (o *OAuth2) oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http } if len(passAlongs) > 0 { - str, err := json.Marshal(passAlongs) + byt, err := json.Marshal(passAlongs) if err != nil { return err } - ctx.SessionStorer.Put(authboss.SessionOAuth2Params, string(str)) + authboss.PutSession(w, authboss.SessionOAuth2Params, string(byt)) } else { - ctx.SessionStorer.Del(authboss.SessionOAuth2Params) + authboss.DelSession(w, authboss.SessionOAuth2Params) } url := cfg.OAuth2Config.AuthCodeURL(state) @@ -117,128 +145,153 @@ func (o *OAuth2) oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http url = fmt.Sprintf("%s&%s", url, extraParams) } - http.Redirect(w, r, url, http.StatusFound) - return nil + ro := authboss.RedirectOptions{ + Code: http.StatusTemporaryRedirect, + RedirectPath: url, + } + return o.Authboss.Core.Redirector.Redirect(w, r, ro) } -// for testing +// for testing, mocked out at the beginning var exchanger = (*oauth2.Config).Exchange -func (o *OAuth2) oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error { +// End the oauth2 process, this is the handler for the oauth2 callback +// that the third party will redirect to. +func (o *OAuth2) End(w http.ResponseWriter, r *http.Request) error { + logger := o.Authboss.RequestLogger(r) provider := strings.ToLower(filepath.Base(r.URL.Path)) + logger.Infof("finishing oauth2 flow for provider: %s", provider) - sessState, err := ctx.SessionStorer.GetErr(authboss.SessionOAuth2State) - ctx.SessionStorer.Del(authboss.SessionOAuth2State) - if err != nil { - return err - } - - sessValues, ok := ctx.SessionStorer.Get(authboss.SessionOAuth2Params) - // Don't delete this value from session immediately, Events use this too - var values map[string]string - if ok { - if err := json.Unmarshal([]byte(sessValues), &values); err != nil { - return err - } - } - - hasErr := r.FormValue("error") - if len(hasErr) > 0 { - if err := o.Events.FireAfter(authboss.EventOAuthFail, ctx); err != nil { - return err - } - - return authboss.ErrAndRedirect{ - Err: errors.New(r.FormValue("error_reason")), - Location: o.AuthLoginFailPath, - FlashError: fmt.Sprintf("%s login cancelled or failed.", strings.Title(provider)), - } - } - - cfg, ok := o.OAuth2Providers[provider] + // This shouldn't happen because the router should 404 first, but just in case + cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider] if !ok { return errors.Errorf("oauth2 provider %q not found", provider) } - // Ensure request is genuine - state := r.FormValue(authboss.FormValueOAuth2State) - splState := strings.Split(state, ";") - if len(splState) == 0 || splState[0] != sessState { + wantState, ok := authboss.GetSession(r, authboss.SessionOAuth2State) + if !ok { + return errors.New("oauth2 endpoint hit without session state") + } + + // Verify we got the same state in the session as was passed to us in the + // query parameter. + state := r.FormValue(FormValueOAuth2State) + if state != wantState { return errOAuthStateValidation } - // Get the code + rawParams, ok := authboss.GetSession(r, authboss.SessionOAuth2Params) + var params map[string]string + if ok { + if err := json.Unmarshal([]byte(rawParams), ¶ms); err != nil { + return errors.Wrap(err, "failed to decode oauth2 params") + } + } + + authboss.DelSession(w, authboss.SessionOAuth2State) + authboss.DelSession(w, authboss.SessionOAuth2Params) + + hasErr := r.FormValue("error") + if len(hasErr) > 0 { + reason := r.FormValue("error_reason") + logger.Infof("oauth2 login failed: %s, reason: %s", hasErr, reason) + + handled, err := o.Authboss.Events.FireAfter(authboss.EventOAuth2Fail, w, r) + if err != nil { + return err + } else if handled { + return nil + } + + ro := authboss.RedirectOptions{ + Code: http.StatusTemporaryRedirect, + RedirectPath: o.Authboss.Config.Paths.OAuth2LoginNotOK, + Failure: fmt.Sprintf("%s login cancelled or failed", strings.Title(provider)), + } + return o.Authboss.Core.Redirector.Redirect(w, r, ro) + } + + // Get the code which we can use to make an access token code := r.FormValue("code") - token, err := exchanger(cfg.OAuth2Config, o.Config.ContextProvider(r), code) + token, err := exchanger(cfg.OAuth2Config, r.Context(), code) if err != nil { return errors.Wrap(err, "could not validate oauth2 code") } - user, err := cfg.Callback(o.Config.ContextProvider(r), *cfg.OAuth2Config, token) + details, err := cfg.FindUserDetails(r.Context(), *cfg.OAuth2Config, token) if err != nil { return err } - // OAuth2UID is required. - uid, err := user.StringErr(authboss.StoreOAuth2UID) + storer := authboss.EnsureCanOAuth2(o.Authboss.Config.Storage.Server) + user, err := storer.NewFromOAuth2(r.Context(), provider, details) if err != nil { - return err + return errors.Wrap(err, "failed to create oauth2 user from values") } - user[authboss.StoreOAuth2UID] = uid - user[authboss.StoreOAuth2Provider] = provider - user[authboss.StoreOAuth2Expiry] = token.Expiry - user[authboss.StoreOAuth2Token] = token.AccessToken + user.PutOAuth2Provider(provider) + user.PutOAuth2AccessToken(token.AccessToken) + user.PutOAuth2Expiry(token.Expiry) if len(token.RefreshToken) != 0 { - user[authboss.StoreOAuth2Refresh] = token.RefreshToken + user.PutOAuth2RefreshToken(token.RefreshToken) } - if err = ctx.OAuth2Storer.PutOAuth(uid, provider, user); err != nil { + if err := storer.SaveOAuth2(r.Context(), user); err != nil { return err } - // Fully log user in - ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider)) - ctx.SessionStorer.Del(authboss.SessionHalfAuthKey) + r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user)) - if err = o.Events.FireAfter(authboss.EventOAuth, ctx); err != nil { + handled, err := o.Authboss.Events.FireBefore(authboss.EventOAuth2, w, r) + if err != nil { + return err + } else if handled { return nil } - ctx.SessionStorer.Del(authboss.SessionOAuth2Params) + // Fully log user in + authboss.PutSession(w, authboss.SessionKey, authboss.MakeOAuth2PID(provider, user.GetOAuth2UID())) + authboss.DelSession(w, authboss.SessionHalfAuthKey) - redirect := o.AuthLoginOKPath + // Create a query string from all the pieces we've received + // as passthru from the original request. + redirect := o.Authboss.Config.Paths.OAuth2LoginOK query := make(url.Values) - for k, v := range values { + for k, v := range params { switch k { case authboss.CookieRemember: - case authboss.FormValueRedirect: + if v == "true" { + r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyValues, RMTrue{})) + } + case FormValueOAuth2Redir: redirect = v default: query.Set(k, v) } } + handled, err = o.Authboss.Events.FireAfter(authboss.EventOAuth2, w, r) + if err != nil { + return err + } else if handled { + return nil + } + if len(query) > 0 { redirect = fmt.Sprintf("%s?%s", redirect, query.Encode()) } - sf := fmt.Sprintf("Logged in successfully with %s.", strings.Title(provider)) - response.Redirect(ctx, w, r, redirect, sf, "", false) - return nil -} - -func (o *OAuth2) logout(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error { - switch r.Method { - case "GET": - ctx.SessionStorer.Del(authboss.SessionKey) - ctx.CookieStorer.Del(authboss.CookieRemember) - ctx.SessionStorer.Del(authboss.SessionLastAction) - - response.Redirect(ctx, w, r, o.AuthLogoutOKPath, "You have logged out", "", true) - default: - w.WriteHeader(http.StatusMethodNotAllowed) + ro := authboss.RedirectOptions{ + Code: http.StatusTemporaryRedirect, + RedirectPath: redirect, + Success: fmt.Sprintf("Logged in successfully with %s.", strings.Title(provider)), } - - return nil + return o.Authboss.Config.Core.Redirector.Redirect(w, r, ro) } + +// RMTrue is a dummy struct implementing authboss.RememberValuer +// in order to tell the remember me module to remember them. +type RMTrue struct{} + +// GetShouldRemember always returns true +func (RMTrue) GetShouldRemember() bool { return true } diff --git a/oauth2/oauth2_test.go b/oauth2/oauth2_test.go index 27f6e06..f7b4cd8 100644 --- a/oauth2/oauth2_test.go +++ b/oauth2/oauth2_test.go @@ -2,11 +2,9 @@ package oauth2 import ( "context" - "fmt" "net/http" "net/http/httptest" "net/url" - "path" "strings" "testing" "time" @@ -18,6 +16,12 @@ import ( "golang.org/x/oauth2/google" ) +func init() { + exchanger = func(_ *oauth2.Config, _ context.Context, _ string) (*oauth2.Token, error) { + return testToken, nil + } +} + var testProviders = map[string]authboss.OAuth2Provider{ "google": authboss.OAuth2Provider{ OAuth2Config: &oauth2.Config{ @@ -25,8 +29,10 @@ var testProviders = map[string]authboss.OAuth2Provider{ ClientSecret: `hands`, Scopes: []string{`profile`, `email`}, Endpoint: google.Endpoint, + // This is typically set by Init() but some tests rely on it's existence + RedirectURL: "https://www.example.com/auth/oauth2/callback/google", }, - Callback: Google, + FindUserDetails: GoogleUserDetails, AdditionalParams: url.Values{"include_requested_scopes": []string{"true"}}, }, "facebook": authboss.OAuth2Provider{ @@ -35,11 +41,341 @@ var testProviders = map[string]authboss.OAuth2Provider{ ClientSecret: `hands`, Scopes: []string{`email`}, Endpoint: facebook.Endpoint, + // This is typically set by Init() but some tests rely on it's existence + RedirectURL: "https://www.example.com/auth/oauth2/callback/facebook", }, - Callback: Facebook, + FindUserDetails: FacebookUserDetails, }, } +var testToken = &oauth2.Token{ + AccessToken: "token", + TokenType: "Bearer", + RefreshToken: "refresh", + Expiry: time.Now().AddDate(0, 0, 1), +} + +func TestInit(t *testing.T) { + // No t.Parallel() since the cfg.RedirectURL is set in Init() + + ab := authboss.New() + oauth := &OAuth2{} + + router := &mocks.Router{} + ab.Config.Modules.OAuth2Providers = testProviders + ab.Config.Core.Router = router + ab.Config.Core.ErrorHandler = &mocks.ErrorHandler{} + + ab.Config.Paths.Mount = "/auth" + ab.Config.Paths.RootURL = "https://www.example.com" + + if err := oauth.Init(ab); err != nil { + t.Fatal(err) + } + + gets := []string{ + "/oauth2/facebook", "/oauth2/callback/facebook", + "/oauth2/google", "/oauth2/callback/google", + } + if err := router.HasGets(gets...); err != nil { + t.Error(err) + } +} + +type testHarness struct { + oauth *OAuth2 + ab *authboss.Authboss + + bodyReader *mocks.BodyReader + responder *mocks.Responder + redirector *mocks.Redirector + session *mocks.ClientStateRW + storer *mocks.ServerStorer +} + +func testSetup() *testHarness { + harness := &testHarness{} + + harness.ab = authboss.New() + harness.redirector = &mocks.Redirector{} + harness.session = mocks.NewClientRW() + harness.storer = mocks.NewServerStorer() + + harness.ab.Modules.OAuth2Providers = testProviders + + harness.ab.Paths.OAuth2LoginOK = "/auth/oauth2/ok" + harness.ab.Paths.OAuth2LoginNotOK = "/auth/oauth2/not/ok" + + harness.ab.Config.Core.Logger = mocks.Logger{} + harness.ab.Config.Core.Redirector = harness.redirector + harness.ab.Config.Storage.SessionState = harness.session + harness.ab.Config.Storage.Server = harness.storer + + harness.oauth = &OAuth2{harness.ab} + + return harness +} + +func TestStart(t *testing.T) { + t.Parallel() + + h := testSetup() + + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec) + r := httptest.NewRequest("GET", "/oauth2/google?cake=yes&death=no", nil) + + if err := h.oauth.Start(w, r); err != nil { + t.Error(err) + } + + if h.redirector.Options.Code != http.StatusTemporaryRedirect { + t.Error("code was wrong:", h.redirector.Options.Code) + } + + url, err := url.Parse(h.redirector.Options.RedirectPath) + if err != nil { + t.Fatal(err) + } + query := url.Query() + if state := query.Get("state"); len(state) == 0 { + t.Error("our nonce should have been here") + } + if callback := query.Get("redirect_uri"); callback != "https://www.example.com/auth/oauth2/callback/google" { + t.Error("callback was wrong:", callback) + } + if clientID := query.Get("client_id"); clientID != "jazz" { + t.Error("clientID was wrong:", clientID) + } + if url.Host != "accounts.google.com" { + t.Error("host was wrong:", url.Host) + } + + if h.session.ClientValues[authboss.SessionOAuth2State] != query.Get("state") { + t.Error("the state should have been saved in the session") + } + if v := h.session.ClientValues[authboss.SessionOAuth2Params]; v != `{"cake":"yes","death":"no"}` { + t.Error("oauth2 session params are wrong:", v) + } +} + +func TestStartBadProvider(t *testing.T) { + t.Parallel() + + h := testSetup() + + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec) + r := httptest.NewRequest("GET", "/oauth2/test", nil) + + err := h.oauth.Start(w, r) + if e := err.Error(); !strings.Contains(e, `provider "test" not found`) { + t.Error("it should have errored:", e) + } +} + +func TestEnd(t *testing.T) { + t.Parallel() + + h := testSetup() + + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec) + + h.session.ClientValues[authboss.SessionOAuth2State] = "state" + r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state", nil)) + if err != nil { + t.Fatal(err) + } + + if err := h.oauth.End(w, r); err != nil { + t.Error(err) + } + + w.WriteHeader(http.StatusOK) // Flush headers + + opts := h.redirector.Options + if opts.Code != http.StatusTemporaryRedirect { + t.Error("it should have redirected") + } + if opts.RedirectPath != "/auth/oauth2/ok" { + t.Error("redir path was wrong:", opts.RedirectPath) + } + if s := h.session.ClientValues[authboss.SessionKey]; s != "oauth2;;google;;id" { + t.Error("session id should have been set:", s) + } +} + +func TestEndBadProvider(t *testing.T) { + t.Parallel() + + h := testSetup() + + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec) + r := httptest.NewRequest("GET", "/oauth2/callback/test", nil) + + err := h.oauth.End(w, r) + if e := err.Error(); !strings.Contains(e, `provider "test" not found`) { + t.Error("it should have errored:", e) + } +} + +func TestEndBadState(t *testing.T) { + t.Parallel() + + h := testSetup() + + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec) + r := httptest.NewRequest("GET", "/oauth2/callback/google", nil) + + err := h.oauth.End(w, r) + if e := err.Error(); !strings.Contains(e, `oauth2 endpoint hit without session state`) { + t.Error("it should have errored:", e) + } + + h.session.ClientValues[authboss.SessionOAuth2State] = "state" + r, err = h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=x", nil)) + if err != nil { + t.Fatal(err) + } + if err := h.oauth.End(w, r); err != errOAuthStateValidation { + t.Error("error was wrong:", err) + } +} + +func TestEndErrors(t *testing.T) { + t.Parallel() + + h := testSetup() + + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec) + + h.session.ClientValues[authboss.SessionOAuth2State] = "state" + r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state&error=badtimes&error_reason=reason", nil)) + if err != nil { + t.Fatal(err) + } + + if err := h.oauth.End(w, r); err != nil { + t.Error(err) + } + + opts := h.redirector.Options + if opts.Code != http.StatusTemporaryRedirect { + t.Error("code was wrong:", opts.Code) + } + if opts.RedirectPath != "/auth/oauth2/not/ok" { + t.Error("path was wrong:", opts.RedirectPath) + } +} + +func TestEndHandling(t *testing.T) { + t.Parallel() + + t.Run("AfterOAuth2Fail", func(t *testing.T) { + h := testSetup() + + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec) + + h.session.ClientValues[authboss.SessionOAuth2State] = "state" + r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state&error=badtimes&error_reason=reason", nil)) + if err != nil { + t.Fatal(err) + } + + called := false + h.ab.Events.After(authboss.EventOAuth2Fail, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) { + called = true + return true, nil + }) + + if err := h.oauth.End(w, r); err != nil { + t.Error(err) + } + + if !called { + t.Error("it should have been called") + } + if h.redirector.Options.Code != 0 { + t.Error("it should not have tried to redirect") + } + }) + t.Run("BeforeOAuth2", func(t *testing.T) { + h := testSetup() + + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec) + + h.session.ClientValues[authboss.SessionOAuth2State] = "state" + r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state", nil)) + if err != nil { + t.Fatal(err) + } + + called := false + h.ab.Events.Before(authboss.EventOAuth2, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) { + called = true + return true, nil + }) + + if err := h.oauth.End(w, r); err != nil { + t.Error(err) + } + + w.WriteHeader(http.StatusOK) // Flush headers + + if !called { + t.Error("it should have been called") + } + if h.redirector.Options.Code != 0 { + t.Error("it should not have tried to redirect") + } + if len(h.session.ClientValues[authboss.SessionKey]) != 0 { + t.Error("should have not logged the user in") + } + }) + + t.Run("AfterOAuth2", func(t *testing.T) { + h := testSetup() + + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec) + + h.session.ClientValues[authboss.SessionOAuth2State] = "state" + r, err := h.ab.LoadClientState(w, httptest.NewRequest("GET", "/oauth2/callback/google?state=state", nil)) + if err != nil { + t.Fatal(err) + } + + called := false + h.ab.Events.After(authboss.EventOAuth2, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) { + called = true + return true, nil + }) + + if err := h.oauth.End(w, r); err != nil { + t.Error(err) + } + + w.WriteHeader(http.StatusOK) // Flush headers + + if !called { + t.Error("it should have been called") + } + if h.redirector.Options.Code != 0 { + t.Error("it should not have tried to redirect") + } + if s := h.session.ClientValues[authboss.SessionKey]; s != "oauth2;;google;;id" { + t.Error("session id should have been set:", s) + } + }) +} + +/* func TestInitialize(t *testing.T) { t.Parallel() @@ -281,45 +617,4 @@ func TestOAuthFailure(t *testing.T) { t.Error("It should record the failure.") } } - -func TestLogout(t *testing.T) { - t.Parallel() - - ab := authboss.New() - oauth := OAuth2{ab} - ab.AuthLogoutOKPath = "/dashboard" - - r, _ := http.NewRequest("GET", "/oauth2/google?", nil) - w := httptest.NewRecorder() - - ctx := ab.NewContext() - session := mocks.NewMockClientStorer(authboss.SessionKey, "asdf", authboss.SessionLastAction, "1234") - cookies := mocks.NewMockClientStorer(authboss.CookieRemember, "qwert") - ctx.SessionStorer = session - ctx.CookieStorer = cookies - - if err := oauth.logout(ctx, w, r); err != nil { - t.Error(err) - } - - if val, ok := session.Get(authboss.SessionKey); ok { - t.Error("Unexpected session key:", val) - } - - if val, ok := session.Get(authboss.SessionLastAction); ok { - t.Error("Unexpected last action:", val) - } - - if val, ok := cookies.Get(authboss.CookieRemember); ok { - t.Error("Unexpected rm cookie:", val) - } - - if http.StatusFound != w.Code { - t.Errorf("Expected status code %d, got %d", http.StatusFound, w.Code) - } - - location := w.Header().Get("Location") - if location != ab.AuthLogoutOKPath { - t.Error("Redirect wrong:", location) - } -} +*/ diff --git a/oauth2/providers.go b/oauth2/providers.go index ed32e22..383c3a5 100644 --- a/oauth2/providers.go +++ b/oauth2/providers.go @@ -3,12 +3,20 @@ package oauth2 import ( "context" "encoding/json" + "io/ioutil" "net/http" - "github.com/volatiletech/authboss" + "github.com/pkg/errors" "golang.org/x/oauth2" ) +// Constants for returning in the FindUserDetails call +const ( + OAuth2UID = "uid" + OAuth2Email = "email" + OAuth2Name = "name" +) + const ( googleInfoEndpoint = `https://www.googleapis.com/userinfo/v2/me` facebookInfoEndpoint = `https://graph.facebook.com/me?fields=name,email` @@ -22,8 +30,8 @@ type googleMeResponse struct { // testing var clientGet = (*http.Client).Get -// Google is a callback appropriate for use with Google's OAuth2 configuration. -func Google(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) { +// GoogleUserDetails can be used as a FindUserDetails function for an authboss.OAuth2Provider +func GoogleUserDetails(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) { client := cfg.Client(ctx, token) resp, err := clientGet(client, googleInfoEndpoint) if err != nil { @@ -31,15 +39,19 @@ func Google(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[st } defer resp.Body.Close() - dec := json.NewDecoder(resp.Body) - var jsonResp googleMeResponse - if err = dec.Decode(&jsonResp); err != nil { + byt, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read body from google oauth2 endpoint") + } + + var response googleMeResponse + if err = json.Unmarshal(byt, &response); err != nil { return nil, err } return map[string]string{ - authboss.StoreOAuth2UID: jsonResp.ID, - authboss.StoreEmail: jsonResp.Email, + OAuth2UID: response.ID, + OAuth2Email: response.Email, }, nil } @@ -49,8 +61,8 @@ type facebookMeResponse struct { Name string `json:"name"` } -// Facebook is a callback appropriate for use with Facebook's OAuth2 configuration. -func Facebook(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) { +// FacebookUserDetails can be used as a FindUserDetails function for an authboss.OAuth2Provider +func FacebookUserDetails(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[string]string, error) { client := cfg.Client(ctx, token) resp, err := clientGet(client, facebookInfoEndpoint) if err != nil { @@ -58,15 +70,19 @@ func Facebook(ctx context.Context, cfg oauth2.Config, token *oauth2.Token) (map[ } defer resp.Body.Close() - dec := json.NewDecoder(resp.Body) - var jsonResp facebookMeResponse - if err = dec.Decode(&jsonResp); err != nil { - return nil, err + byt, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read body from facebook oauth2 endpoint") + } + + var response facebookMeResponse + if err = json.Unmarshal(byt, &response); err != nil { + return nil, errors.Wrap(err, "failed to parse json from facebook oauth2 endpoint") } return map[string]string{ - "name": jsonResp.Name, - authboss.StoreOAuth2UID: jsonResp.ID, - authboss.StoreEmail: jsonResp.Email, + OAuth2UID: response.ID, + OAuth2Email: response.Email, + OAuth2Name: response.Name, }, nil } diff --git a/oauth2/providers_test.go b/oauth2/providers_test.go index 8f59a5c..ee62edb 100644 --- a/oauth2/providers_test.go +++ b/oauth2/providers_test.go @@ -8,21 +8,21 @@ import ( "testing" "time" - "github.com/volatiletech/authboss" "golang.org/x/oauth2" ) -func TestGoogle(t *testing.T) { - saveClientGet := clientGet - defer func() { - clientGet = saveClientGet - }() - +func init() { + // This has an extra parameter that the Google client wouldn't normally get, but it'll safely be + // ignored. clientGet = func(_ *http.Client, url string) (*http.Response, error) { return &http.Response{ - Body: ioutil.NopCloser(strings.NewReader(`{"id":"id", "email":"email"}`)), + Body: ioutil.NopCloser(strings.NewReader(`{"id":"id", "email":"email", "name": "name"}`)), }, nil } +} + +func TestGoogle(t *testing.T) { + t.Parallel() cfg := *testProviders["google"].OAuth2Config tok := &oauth2.Token{ @@ -32,30 +32,21 @@ func TestGoogle(t *testing.T) { Expiry: time.Now().Add(60 * time.Minute), } - user, err := Google(context.TODO(), cfg, tok) + details, err := GoogleUserDetails(context.Background(), cfg, tok) if err != nil { t.Error(err) } - if uid, ok := user[authboss.StoreOAuth2UID]; !ok || uid != "id" { + if uid, ok := details[OAuth2UID]; !ok || uid != "id" { t.Error("UID wrong:", uid) } - if email, ok := user[authboss.StoreEmail]; !ok || email != "email" { + if email, ok := details[OAuth2Email]; !ok || email != "email" { t.Error("Email wrong:", email) } } func TestFacebook(t *testing.T) { - saveClientGet := clientGet - defer func() { - clientGet = saveClientGet - }() - - clientGet = func(_ *http.Client, url string) (*http.Response, error) { - return &http.Response{ - Body: ioutil.NopCloser(strings.NewReader(`{"id":"id", "email":"email", "name":"name"}`)), - }, nil - } + t.Parallel() cfg := *testProviders["facebook"].OAuth2Config tok := &oauth2.Token{ @@ -65,18 +56,18 @@ func TestFacebook(t *testing.T) { Expiry: time.Now().Add(60 * time.Minute), } - user, err := Facebook(context.TODO(), cfg, tok) + details, err := FacebookUserDetails(context.Background(), cfg, tok) if err != nil { t.Error(err) } - if uid, ok := user[authboss.StoreOAuth2UID]; !ok || uid != "id" { + if uid, ok := details[OAuth2UID]; !ok || uid != "id" { t.Error("UID wrong:", uid) } - if email, ok := user[authboss.StoreEmail]; !ok || email != "email" { + if email, ok := details[OAuth2Email]; !ok || email != "email" { t.Error("Email wrong:", email) } - if name, ok := user["name"]; !ok || name != "name" { + if name, ok := details[OAuth2Name]; !ok || name != "name" { t.Error("Name wrong:", name) } } diff --git a/remember/remember.go b/remember/remember.go index 56c5bab..fe44a09 100644 --- a/remember/remember.go +++ b/remember/remember.go @@ -35,8 +35,7 @@ func (r *Remember) Init(ab *authboss.Authboss) error { r.Authboss = ab r.Events.After(authboss.EventAuth, r.RememberAfterAuth) - //TODO(aarondl): Rectify this once oauth2 is done - // r.Events.After(authboss.EventOAuth, r.RememberAfterAuth) + r.Events.After(authboss.EventOAuth, r.RememberAfterAuth) r.Events.After(authboss.EventPasswordReset, r.AfterPasswordReset) return nil diff --git a/storage.go b/storage.go index 70a4eaa..45624af 100644 --- a/storage.go +++ b/storage.go @@ -13,22 +13,6 @@ import ( "github.com/pkg/errors" ) -// Data store constants for attribute names. -const ( - StoreEmail = "email" - StoreUsername = "username" - StorePassword = "password" -) - -// Data store constants for OAuth2 attribute names. -const ( - StoreOAuth2UID = "oauth2_uid" - StoreOAuth2Provider = "oauth2_provider" - StoreOAuth2Token = "oauth2_token" - StoreOAuth2Refresh = "oauth2_refresh" - StoreOAuth2Expiry = "oauth2_expiry" -) - var ( // ErrUserFound should be returned from Create (see ConfirmUser) when the primaryID // of the record is found. @@ -52,7 +36,7 @@ type ServerStorer interface { } // CreatingServerStorer is used for creating new users -// like when Registration is being done. +// like when Registration or OAuth2 is being done. type CreatingServerStorer interface { ServerStorer @@ -64,6 +48,34 @@ type CreatingServerStorer interface { Create(ctx context.Context, user User) error } +// OAuth2ServerStorer has the ability to create users from data from the provider. +type OAuth2ServerStorer interface { + ServerStorer + + // NewFromOAuth2 should return an OAuth2User from a set + // of details returned from OAuth2Provider.FindUserDetails + // A more in-depth explanation is that once we've got an access token + // for the service in question (say a service that rhymes with book) + // the FindUserDetails function does an http request to a known endpoint that + // provides details about the user, those details are captured in a generic + // way as map[string]string and passed into this function to be turned + // into a real user. + // + // It's possible that the user exists in the database already, and so + // an attempt should be made to look that user up using the details. + // Any details that have changed should be updated. Do not save the user + // since that will be done by a later call to OAuth2ServerStorer.SaveOAuth2() + NewFromOAuth2(ctx context.Context, provider string, details map[string]string) (OAuth2User, error) + + // SaveOAuth2 has different semantics from the typical ServerStorer.Save, in this case + // we want to insert a user if they do not exist. The difference must be made clear because + // in the non-oauth2 case, we know exactly when we want to Create vs Update. However + // since we're simply trying to persist a user that may have been in our database, but if not + // should already be (since you can think of the operation as a caching of what's on the oauth2 provider's + // servers). + SaveOAuth2(ctx context.Context, user OAuth2User) error +} + // ConfirmingServerStorer can find a user by a confirm token type ConfirmingServerStorer interface { ServerStorer @@ -84,6 +96,8 @@ type RecoveringServerStorer interface { // RememberingServerStorer allows users to be remembered across sessions type RememberingServerStorer interface { + ServerStorer + // AddRememberToken to a user AddRememberToken(pid, token string) error // DelRememberTokens removes all tokens for the given pid @@ -132,3 +146,13 @@ func EnsureCanRemember(storer ServerStorer) RememberingServerStorer { return s } + +// EnsureCanOAuth2 makes sure the server storer supports oauth2 creation and lookup +func EnsureCanOAuth2(storer ServerStorer) OAuth2ServerStorer { + s, ok := storer.(OAuth2ServerStorer) + if !ok { + panic("could not upgrade ServerStorer to OAuth2ServerStorer, check your struct") + } + + return s +} diff --git a/stringers.go b/stringers.go index 4356fc7..08b4df7 100644 --- a/stringers.go +++ b/stringers.go @@ -4,9 +4,9 @@ package authboss import "strconv" -const _Event_name = "EventRegisterEventAuthEventOAuthEventAuthFailEventOAuthFailEventRecoverStartEventRecoverEndEventGetUserEventGetUserSessionEventPasswordReset" +const _Event_name = "EventRegisterEventAuthEventOAuth2EventAuthFailEventOAuth2FailEventRecoverStartEventRecoverEndEventGetUserEventGetUserSessionEventPasswordReset" -var _Event_index = [...]uint8{0, 13, 22, 32, 45, 59, 76, 91, 103, 122, 140} +var _Event_index = [...]uint8{0, 13, 22, 33, 46, 61, 78, 93, 105, 124, 142} func (i Event) String() string { if i < 0 || i >= Event(len(_Event_index)-1) { diff --git a/user.go b/user.go index dc73061..948ddc0 100644 --- a/user.go +++ b/user.go @@ -2,6 +2,7 @@ package authboss import ( "fmt" + "strings" "time" ) @@ -81,6 +82,8 @@ type ArbitraryUser interface { } // OAuth2User allows reading and writing values relating to OAuth2 +// Also see MakeOAuthPID/ParseOAuthPID for helpers to fullfill the User +// part of the interface. type OAuth2User interface { User @@ -88,17 +91,17 @@ type OAuth2User interface { // oauth2 user. IsOAuth2User() bool - GetUID() (uid string) - GetProvider() (provider string) - GetToken() (token string) - GetRefreshToken() (refreshToken string) - GetExpiry() (expiry time.Duration) + GetOAuth2UID() (uid string) + GetOAuth2Provider() (provider string) + GetOAuth2AccessToken() (token string) + GetOAuth2RefreshToken() (refreshToken string) + GetOAuth2Expiry() (expiry time.Time) - PutUID(uid string) - PutProvider(provider string) - PutToken(token string) - PutRefreshToken(refreshToken string) - PutExpiry(expiry time.Duration) + PutOAuth2UID(uid string) + PutOAuth2Provider(provider string) + PutOAuth2AccessToken(token string) + PutOAuth2RefreshToken(refreshToken string) + PutOAuth2Expiry(expiry time.Time) } // MustBeAuthable forces an upgrade to an AuthableUser or panic. @@ -132,3 +135,33 @@ func MustBeRecoverable(u User) RecoverableUser { } panic(fmt.Sprintf("could not upgrade user to a recoverable user, given type: %T", u)) } + +// MustBeOAuthable forces an upgrade to an OAuth2User or panic. +func MustBeOAuthable(u User) OAuth2User { + if ou, ok := u.(OAuth2User); ok { + return ou + } + panic(fmt.Sprintf("could not upgrade user to an oauthable user, given type: %T", u)) +} + +// MakeOAuth2PID is used to create a pid for users that don't have +// an e-mail address or username in the normal system. This allows +// all the modules to continue to working as intended without having +// a true primary id. As well as not having to divide the regular and oauth +// stuff all down the middle. +func MakeOAuth2PID(provider, uid string) string { + return fmt.Sprintf("oauth2;;%s;;%s", provider, uid) +} + +// ParseOAuth2PID returns the uid and provider for a given OAuth2 pid +func ParseOAuth2PID(pid string) (provider, uid string) { + splits := strings.Split(pid, ";;") + if len(splits) != 3 { + panic(fmt.Sprintf("failed to parse oauth2 pid, too many segments: %s", pid)) + } + if splits[0] != "oauth2" { + panic(fmt.Sprintf("invalid oauth2 pid, did not start with oauth2: %s", pid)) + } + + return splits[1], splits[2] +} diff --git a/user_test.go b/user_test.go new file mode 100644 index 0000000..e689bf1 --- /dev/null +++ b/user_test.go @@ -0,0 +1,23 @@ +package authboss + +import "testing" + +func TestOAuth2PIDs(t *testing.T) { + t.Parallel() + + provider := "provider" + uid := "uid" + pid := MakeOAuth2PID(provider, uid) + + if pid != "oauth2;;provider;;uid" { + t.Error("pid was wrong:", pid) + } + + gotProvider, gotUID := ParseOAuth2PID(pid) + if gotUID != uid { + t.Error("uid was wrong:", gotUID) + } + if gotProvider != provider { + t.Error("provider was wrong:", gotProvider) + } +}