1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-01-24 05:17:10 +02:00
authboss/oauth2/oauth2.go
2018-09-15 15:39:26 -07:00

303 lines
9.0 KiB
Go

// Package oauth2 allows users to be created and authenticated
// via oauth2 services like facebook, google etc. Currently
// only the web server flow is supported.
//
// The general flow looks like this:
// 1. User goes to Start handler and has his session packed with goodies
// then redirects to the OAuth service.
// 2. OAuth service returns to OAuthCallback which extracts state and
// parameters and generally checks that everything is ok. It uses the
// token received to get an access token from the oauth2 library
// 3. Calls the OAuth2Provider.FindUserDetails which should return the user's
// details in a generic form.
// 4. Passes the user details into the OAuth2ServerStorer.NewFromOAuth2 in
// order to create a user object we can work with.
// 5. Saves the user in the database, logs them in, redirects.
//
// In order to do this there are a number of parts:
// 1. The configuration of a provider
// (handled by authboss.Config.Modules.OAuth2Providers).
// 2. The flow of redirection of client, parameter passing etc
// (handled by this package)
// 3. The HTTP call to the service once a token has been retrieved to
// get user details (handled by OAuth2Provider.FindUserDetails)
// 4. The creation of a user from the user details returned from the
// FindUserDetails (authboss.OAuth2ServerStorer)
//
// Of these parts, the responsibility of the authboss library consumer
// is on 1, 3, and 4. Configuration of providers that should be used is totally
// up to the consumer. The FindUserDetails function is typically up to the
// user, but we have some basic ones included in this package too.
// The creation of users from the FindUserDetail's map[string]string return
// is handled as part of the implementation of the OAuth2ServerStorer.
package oauth2
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"path"
"path/filepath"
"sort"
"strings"
"github.com/pkg/errors"
"golang.org/x/oauth2"
"github.com/volatiletech/authboss"
)
// FormValue constants
const (
FormValueOAuth2State = "state"
FormValueOAuth2Redir = "redir"
)
var (
errOAuthStateValidation = errors.New("could not validate oauth2 state param")
)
// OAuth2 module
type OAuth2 struct {
*authboss.Authboss
}
func init() {
authboss.RegisterModule("oauth2", &OAuth2{})
}
// Init module
func (o *OAuth2) Init(ab *authboss.Authboss) error {
o.Authboss = ab
// Do annoying sorting on keys so we can have predictible
// route registration (both for consistency inside the router but
// also for tests -_-)
var keys []string
for k := range o.Authboss.Config.Modules.OAuth2Providers {
keys = append(keys, k)
}
sort.Strings(keys)
for _, provider := range keys {
cfg := o.Authboss.Config.Modules.OAuth2Providers[provider]
provider = strings.ToLower(provider)
init := fmt.Sprintf("/oauth2/%s", provider)
callback := fmt.Sprintf("/oauth2/callback/%s", provider)
o.Authboss.Config.Core.Router.Get(init, o.Authboss.Core.ErrorHandler.Wrap(o.Start))
o.Authboss.Config.Core.Router.Get(callback, o.Authboss.Core.ErrorHandler.Wrap(o.End))
if mount := o.Authboss.Config.Paths.Mount; len(mount) > 0 {
callback = path.Join(mount, callback)
}
cfg.OAuth2Config.RedirectURL = o.Authboss.Config.Paths.RootURL + callback
}
return nil
}
// Start the oauth2 process
func (o *OAuth2) Start(w http.ResponseWriter, r *http.Request) error {
logger := o.Authboss.RequestLogger(r)
provider := strings.ToLower(filepath.Base(r.URL.Path))
logger.Infof("started oauth2 flow for provider: %s", provider)
cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider]
if !ok {
return errors.Errorf("oauth2 provider %q not found", provider)
}
// Create nonce
nonce := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return errors.Wrap(err, "failed to create nonce")
}
state := base64.URLEncoding.EncodeToString(nonce)
authboss.PutSession(w, authboss.SessionOAuth2State, state)
// This clearly ignores the fact that query parameters can have multiple
// values but I guess we're ignoring that
passAlongs := make(map[string]string)
for k, vals := range r.URL.Query() {
for _, val := range vals {
passAlongs[k] = val
}
}
if len(passAlongs) > 0 {
byt, err := json.Marshal(passAlongs)
if err != nil {
return err
}
authboss.PutSession(w, authboss.SessionOAuth2Params, string(byt))
} else {
authboss.DelSession(w, authboss.SessionOAuth2Params)
}
url := cfg.OAuth2Config.AuthCodeURL(state)
extraParams := cfg.AdditionalParams.Encode()
if len(extraParams) > 0 {
url = fmt.Sprintf("%s&%s", url, extraParams)
}
ro := authboss.RedirectOptions{
Code: http.StatusTemporaryRedirect,
RedirectPath: url,
}
return o.Authboss.Core.Redirector.Redirect(w, r, ro)
}
// for testing, mocked out at the beginning
var exchanger = (*oauth2.Config).Exchange
// End the oauth2 process, this is the handler for the oauth2 callback
// that the third party will redirect to.
func (o *OAuth2) End(w http.ResponseWriter, r *http.Request) error {
logger := o.Authboss.RequestLogger(r)
provider := strings.ToLower(filepath.Base(r.URL.Path))
logger.Infof("finishing oauth2 flow for provider: %s", provider)
// This shouldn't happen because the router should 404 first, but just in case
cfg, ok := o.Authboss.Config.Modules.OAuth2Providers[provider]
if !ok {
return errors.Errorf("oauth2 provider %q not found", provider)
}
wantState, ok := authboss.GetSession(r, authboss.SessionOAuth2State)
if !ok {
return errors.New("oauth2 endpoint hit without session state")
}
// Verify we got the same state in the session as was passed to us in the
// query parameter.
state := r.FormValue(FormValueOAuth2State)
if state != wantState {
return errOAuthStateValidation
}
rawParams, ok := authboss.GetSession(r, authboss.SessionOAuth2Params)
var params map[string]string
if ok {
if err := json.Unmarshal([]byte(rawParams), &params); err != nil {
return errors.Wrap(err, "failed to decode oauth2 params")
}
}
authboss.DelSession(w, authboss.SessionOAuth2State)
authboss.DelSession(w, authboss.SessionOAuth2Params)
hasErr := r.FormValue("error")
if len(hasErr) > 0 {
reason := r.FormValue("error_reason")
logger.Infof("oauth2 login failed: %s, reason: %s", hasErr, reason)
handled, err := o.Authboss.Events.FireAfter(authboss.EventOAuth2Fail, w, r)
if err != nil {
return err
} else if handled {
return nil
}
ro := authboss.RedirectOptions{
Code: http.StatusTemporaryRedirect,
RedirectPath: o.Authboss.Config.Paths.OAuth2LoginNotOK,
Failure: fmt.Sprintf("%s login cancelled or failed", strings.Title(provider)),
}
return o.Authboss.Core.Redirector.Redirect(w, r, ro)
}
// Get the code which we can use to make an access token
code := r.FormValue("code")
token, err := exchanger(cfg.OAuth2Config, r.Context(), code)
if err != nil {
return errors.Wrap(err, "could not validate oauth2 code")
}
details, err := cfg.FindUserDetails(r.Context(), *cfg.OAuth2Config, token)
if err != nil {
return err
}
storer := authboss.EnsureCanOAuth2(o.Authboss.Config.Storage.Server)
user, err := storer.NewFromOAuth2(r.Context(), provider, details)
if err != nil {
return errors.Wrap(err, "failed to create oauth2 user from values")
}
user.PutOAuth2Provider(provider)
user.PutOAuth2AccessToken(token.AccessToken)
user.PutOAuth2Expiry(token.Expiry)
if len(token.RefreshToken) != 0 {
user.PutOAuth2RefreshToken(token.RefreshToken)
}
if err := storer.SaveOAuth2(r.Context(), user); err != nil {
return err
}
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user))
handled, err := o.Authboss.Events.FireBefore(authboss.EventOAuth2, w, r)
if err != nil {
return err
} else if handled {
return nil
}
// Fully log user in
authboss.PutSession(w, authboss.SessionKey, authboss.MakeOAuth2PID(provider, user.GetOAuth2UID()))
authboss.DelSession(w, authboss.SessionHalfAuthKey)
// Create a query string from all the pieces we've received
// as passthru from the original request.
redirect := o.Authboss.Config.Paths.OAuth2LoginOK
query := make(url.Values)
for k, v := range params {
switch k {
case authboss.CookieRemember:
if v == "true" {
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyValues, RMTrue{}))
}
case FormValueOAuth2Redir:
redirect = v
default:
query.Set(k, v)
}
}
handled, err = o.Authboss.Events.FireAfter(authboss.EventOAuth2, w, r)
if err != nil {
return err
} else if handled {
return nil
}
if len(query) > 0 {
redirect = fmt.Sprintf("%s?%s", redirect, query.Encode())
}
ro := authboss.RedirectOptions{
Code: http.StatusTemporaryRedirect,
RedirectPath: redirect,
Success: fmt.Sprintf("Logged in successfully with %s.", strings.Title(provider)),
}
return o.Authboss.Config.Core.Redirector.Redirect(w, r, ro)
}
// RMTrue is a dummy struct implementing authboss.RememberValuer
// in order to tell the remember me module to remember them.
type RMTrue struct{}
// GetShouldRemember always returns true
func (RMTrue) GetShouldRemember() bool { return true }