mirror of
https://github.com/volatiletech/authboss.git
synced 2025-01-18 04:58:57 +02:00
06edd2e615
- Add a new storer specifically for OAuth2 to enable clients to choose regular database storing OR Oauth2 but not have to have both. - Stop storing OAuth2 credentials in a combined form inside username. - Add new events to capture OAuth events just like auth. - Have pass-through parameters for OAuth init urls, this allows us to pass additional behavior options (redirects and remember me) as well as other things that should be present on the page that is redirected to. - Context.LoadUser is now OAuth aware. - Remember's callbacks now include an OAuth check to see if a horribly packed state variable contains a flag to say that we want to be remembered. - Change the OAuth2 Callback to use Attributes instead of that custom struct to allow people to append whatever attributes they want into the user that will be saved.
201 lines
4.8 KiB
Go
201 lines
4.8 KiB
Go
package oauth2
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"golang.org/x/oauth2"
|
|
"gopkg.in/authboss.v0"
|
|
)
|
|
|
|
var (
|
|
errOAuthStateValidation = errors.New("Could not validate oauth2 state param")
|
|
)
|
|
|
|
type OAuth2 struct{}
|
|
|
|
func init() {
|
|
authboss.RegisterModule("oauth2", &OAuth2{})
|
|
}
|
|
|
|
func (o *OAuth2) Initialize() error {
|
|
if authboss.Cfg.OAuth2Storer == nil {
|
|
return errors.New("oauth2: need an OAuth2Storer")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (o *OAuth2) Routes() authboss.RouteTable {
|
|
routes := make(authboss.RouteTable)
|
|
|
|
for prov, cfg := range authboss.Cfg.OAuth2Providers {
|
|
prov = strings.ToLower(prov)
|
|
|
|
init := fmt.Sprintf("/oauth2/%s", prov)
|
|
callback := fmt.Sprintf("/oauth2/callback/%s", prov)
|
|
|
|
if len(authboss.Cfg.MountPath) > 0 {
|
|
init = path.Join(authboss.Cfg.MountPath, init)
|
|
callback = path.Join(authboss.Cfg.MountPath, callback)
|
|
}
|
|
|
|
routes[init] = oauthInit
|
|
routes[callback] = oauthCallback
|
|
|
|
cfg.OAuth2Config.RedirectURL = authboss.Cfg.RootURL + callback
|
|
}
|
|
|
|
return routes
|
|
}
|
|
|
|
func (o *OAuth2) Storage() authboss.StorageOptions {
|
|
return authboss.StorageOptions{
|
|
authboss.StoreEmail: authboss.String,
|
|
authboss.StoreOAuth2UID: authboss.String,
|
|
authboss.StoreOAuth2Provider: authboss.String,
|
|
authboss.StoreOAuth2Token: authboss.String,
|
|
authboss.StoreOAuth2Refresh: authboss.String,
|
|
authboss.StoreOAuth2Expiry: authboss.DateTime,
|
|
}
|
|
}
|
|
|
|
func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
|
provider := strings.ToLower(filepath.Base(r.URL.Path))
|
|
cfg, ok := authboss.Cfg.OAuth2Providers[provider]
|
|
if !ok {
|
|
return fmt.Errorf("OAuth2 provider %q not found", provider)
|
|
}
|
|
|
|
random := make([]byte, 32)
|
|
_, err := rand.Read(random)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
state := base64.URLEncoding.EncodeToString(random)
|
|
ctx.SessionStorer.Put(authboss.SessionOAuth2State, state)
|
|
|
|
var passAlongs []string
|
|
for k, vals := range r.URL.Query() {
|
|
for _, val := range vals {
|
|
passAlongs = append(passAlongs, fmt.Sprintf("%s=%s", k, val))
|
|
}
|
|
}
|
|
if len(passAlongs) > 0 {
|
|
state += ";" + strings.Join(passAlongs, ";")
|
|
}
|
|
|
|
url := cfg.OAuth2Config.AuthCodeURL(state)
|
|
|
|
extraParams := cfg.AdditionalParams.Encode()
|
|
if len(extraParams) > 0 {
|
|
url = fmt.Sprintf("%s&%s", url, extraParams)
|
|
}
|
|
|
|
http.Redirect(w, r, url, http.StatusFound)
|
|
return nil
|
|
}
|
|
|
|
// for testing
|
|
var exchanger = (*oauth2.Config).Exchange
|
|
|
|
func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
|
provider := strings.ToLower(filepath.Base(r.URL.Path))
|
|
|
|
hasErr := r.FormValue("error")
|
|
if len(hasErr) > 0 {
|
|
if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
return authboss.ErrAndRedirect{
|
|
Err: errors.New(r.FormValue("error_reason")),
|
|
Location: authboss.Cfg.AuthLoginFailPath,
|
|
FlashError: fmt.Sprintf("%s login cancelled or failed.", strings.Title(provider)),
|
|
}
|
|
}
|
|
|
|
sessState, err := ctx.SessionStorer.GetErr(authboss.SessionOAuth2State)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
cfg, ok := authboss.Cfg.OAuth2Providers[provider]
|
|
if !ok {
|
|
return fmt.Errorf("OAuth2 provider %q not found", provider)
|
|
}
|
|
|
|
// Ensure request is genuine
|
|
state := r.FormValue(authboss.FormValueOAuth2State)
|
|
splState := strings.Split(state, ";")
|
|
if len(splState) == 0 || splState[0] != sessState {
|
|
return errOAuthStateValidation
|
|
}
|
|
|
|
// Get the code
|
|
code := r.FormValue("code")
|
|
token, err := exchanger(cfg.OAuth2Config, oauth2.NoContext, code)
|
|
if err != nil {
|
|
return fmt.Errorf("Could not validate oauth2 code: %v", err)
|
|
}
|
|
|
|
user, err := cfg.Callback(*cfg.OAuth2Config, token)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// OAuth2UID is required.
|
|
uid, err := user.StringErr(authboss.StoreOAuth2UID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
user[authboss.StoreOAuth2UID] = uid
|
|
user[authboss.StoreOAuth2Provider] = provider
|
|
user[authboss.StoreOAuth2Expiry] = token.Expiry
|
|
user[authboss.StoreOAuth2Token] = token.AccessToken
|
|
if len(token.RefreshToken) != 0 {
|
|
user[authboss.StoreOAuth2Refresh] = token.RefreshToken
|
|
}
|
|
|
|
if err = authboss.Cfg.OAuth2Storer.PutOAuth(uid, provider, user); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Log user in
|
|
ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider))
|
|
|
|
if err = authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuth, ctx); err != nil {
|
|
return nil
|
|
}
|
|
|
|
redirect := authboss.Cfg.AuthLoginOKPath
|
|
values := make(url.Values)
|
|
if len(splState) > 0 {
|
|
for _, arg := range splState[1:] {
|
|
spl := strings.Split(arg, "=")
|
|
switch spl[0] {
|
|
case authboss.CookieRemember:
|
|
case authboss.FormValueRedirect:
|
|
redirect = spl[1]
|
|
default:
|
|
values.Set(spl[0], spl[1])
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(values) > 0 {
|
|
redirect = fmt.Sprintf("%s?%s", redirect, values.Encode())
|
|
}
|
|
|
|
http.Redirect(w, r, redirect, http.StatusFound)
|
|
return nil
|
|
}
|