1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-01-18 04:58:57 +02:00
authboss/oauth2/oauth2.go
Aaron L 06edd2e615 Make OAuth2 implementation less shoddy.
- Add a new storer specifically for OAuth2 to enable clients to choose
  regular database storing OR Oauth2 but not have to have both.
- Stop storing OAuth2 credentials in a combined form inside username.
- Add new events to capture OAuth events just like auth.
- Have pass-through parameters for OAuth init urls, this allows us to
  pass additional behavior options (redirects and remember me) as well
  as other things that should be present on the page that is redirected
  to.
- Context.LoadUser is now OAuth aware.
- Remember's callbacks now include an OAuth check to see if a horribly
  packed state variable contains a flag to say that we want to be
  remembered.
- Change the OAuth2 Callback to use Attributes instead of that custom
  struct to allow people to append whatever attributes they want into
  the user that will be saved.
2015-03-13 16:29:58 -07:00

201 lines
4.8 KiB
Go

package oauth2
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
"golang.org/x/oauth2"
"gopkg.in/authboss.v0"
)
var (
errOAuthStateValidation = errors.New("Could not validate oauth2 state param")
)
type OAuth2 struct{}
func init() {
authboss.RegisterModule("oauth2", &OAuth2{})
}
func (o *OAuth2) Initialize() error {
if authboss.Cfg.OAuth2Storer == nil {
return errors.New("oauth2: need an OAuth2Storer")
}
return nil
}
func (o *OAuth2) Routes() authboss.RouteTable {
routes := make(authboss.RouteTable)
for prov, cfg := range authboss.Cfg.OAuth2Providers {
prov = strings.ToLower(prov)
init := fmt.Sprintf("/oauth2/%s", prov)
callback := fmt.Sprintf("/oauth2/callback/%s", prov)
if len(authboss.Cfg.MountPath) > 0 {
init = path.Join(authboss.Cfg.MountPath, init)
callback = path.Join(authboss.Cfg.MountPath, callback)
}
routes[init] = oauthInit
routes[callback] = oauthCallback
cfg.OAuth2Config.RedirectURL = authboss.Cfg.RootURL + callback
}
return routes
}
func (o *OAuth2) Storage() authboss.StorageOptions {
return authboss.StorageOptions{
authboss.StoreEmail: authboss.String,
authboss.StoreOAuth2UID: authboss.String,
authboss.StoreOAuth2Provider: authboss.String,
authboss.StoreOAuth2Token: authboss.String,
authboss.StoreOAuth2Refresh: authboss.String,
authboss.StoreOAuth2Expiry: authboss.DateTime,
}
}
func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
provider := strings.ToLower(filepath.Base(r.URL.Path))
cfg, ok := authboss.Cfg.OAuth2Providers[provider]
if !ok {
return fmt.Errorf("OAuth2 provider %q not found", provider)
}
random := make([]byte, 32)
_, err := rand.Read(random)
if err != nil {
return err
}
state := base64.URLEncoding.EncodeToString(random)
ctx.SessionStorer.Put(authboss.SessionOAuth2State, state)
var passAlongs []string
for k, vals := range r.URL.Query() {
for _, val := range vals {
passAlongs = append(passAlongs, fmt.Sprintf("%s=%s", k, val))
}
}
if len(passAlongs) > 0 {
state += ";" + strings.Join(passAlongs, ";")
}
url := cfg.OAuth2Config.AuthCodeURL(state)
extraParams := cfg.AdditionalParams.Encode()
if len(extraParams) > 0 {
url = fmt.Sprintf("%s&%s", url, extraParams)
}
http.Redirect(w, r, url, http.StatusFound)
return nil
}
// for testing
var exchanger = (*oauth2.Config).Exchange
func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
provider := strings.ToLower(filepath.Base(r.URL.Path))
hasErr := r.FormValue("error")
if len(hasErr) > 0 {
if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
return err
}
return authboss.ErrAndRedirect{
Err: errors.New(r.FormValue("error_reason")),
Location: authboss.Cfg.AuthLoginFailPath,
FlashError: fmt.Sprintf("%s login cancelled or failed.", strings.Title(provider)),
}
}
sessState, err := ctx.SessionStorer.GetErr(authboss.SessionOAuth2State)
if err != nil {
return err
}
cfg, ok := authboss.Cfg.OAuth2Providers[provider]
if !ok {
return fmt.Errorf("OAuth2 provider %q not found", provider)
}
// Ensure request is genuine
state := r.FormValue(authboss.FormValueOAuth2State)
splState := strings.Split(state, ";")
if len(splState) == 0 || splState[0] != sessState {
return errOAuthStateValidation
}
// Get the code
code := r.FormValue("code")
token, err := exchanger(cfg.OAuth2Config, oauth2.NoContext, code)
if err != nil {
return fmt.Errorf("Could not validate oauth2 code: %v", err)
}
user, err := cfg.Callback(*cfg.OAuth2Config, token)
if err != nil {
return err
}
// OAuth2UID is required.
uid, err := user.StringErr(authboss.StoreOAuth2UID)
if err != nil {
return err
}
user[authboss.StoreOAuth2UID] = uid
user[authboss.StoreOAuth2Provider] = provider
user[authboss.StoreOAuth2Expiry] = token.Expiry
user[authboss.StoreOAuth2Token] = token.AccessToken
if len(token.RefreshToken) != 0 {
user[authboss.StoreOAuth2Refresh] = token.RefreshToken
}
if err = authboss.Cfg.OAuth2Storer.PutOAuth(uid, provider, user); err != nil {
return err
}
// Log user in
ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider))
if err = authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuth, ctx); err != nil {
return nil
}
redirect := authboss.Cfg.AuthLoginOKPath
values := make(url.Values)
if len(splState) > 0 {
for _, arg := range splState[1:] {
spl := strings.Split(arg, "=")
switch spl[0] {
case authboss.CookieRemember:
case authboss.FormValueRedirect:
redirect = spl[1]
default:
values.Set(spl[0], spl[1])
}
}
}
if len(values) > 0 {
redirect = fmt.Sprintf("%s?%s", redirect, values.Encode())
}
http.Redirect(w, r, redirect, http.StatusFound)
return nil
}