You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-15 00:15:00 +02:00
SessionState refactoring; improve token renewal and cookie refresh
* New SessionState to consolidate email, access token and refresh token * split ServeHttp into individual methods * log on session renewal * log on access token refresh * refactor cookie encription/decription and session state serialization
This commit is contained in:
363
oauthproxy.go
363
oauthproxy.go
@ -1,8 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -16,6 +14,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bitly/oauth2_proxy/cookie"
|
||||
"github.com/bitly/oauth2_proxy/providers"
|
||||
)
|
||||
|
||||
@ -44,7 +43,7 @@ type OauthProxy struct {
|
||||
serveMux http.Handler
|
||||
PassBasicAuth bool
|
||||
PassAccessToken bool
|
||||
AesCipher cipher.Block
|
||||
CookieCipher *cookie.Cipher
|
||||
skipAuthRegex []string
|
||||
compiledRegex []*regexp.Regexp
|
||||
templates *template.Template
|
||||
@ -116,10 +115,10 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
|
||||
|
||||
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain, refresh)
|
||||
|
||||
var aes_cipher cipher.Block
|
||||
var cipher *cookie.Cipher
|
||||
if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
|
||||
var err error
|
||||
aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret))
|
||||
cipher, err = cookie.NewCipher(opts.CookieSecret)
|
||||
if err != nil {
|
||||
log.Fatal("error creating AES cipher with "+
|
||||
"cookie-secret ", opts.CookieSecret, ": ", err)
|
||||
@ -150,7 +149,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
|
||||
compiledRegex: opts.CompiledRegex,
|
||||
PassBasicAuth: opts.PassBasicAuth,
|
||||
PassAccessToken: opts.PassAccessToken,
|
||||
AesCipher: aes_cipher,
|
||||
CookieCipher: cipher,
|
||||
templates: loadTemplates(opts.CustomTemplatesDir),
|
||||
}
|
||||
}
|
||||
@ -177,22 +176,20 @@ func (p *OauthProxy) displayCustomLoginForm() bool {
|
||||
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
|
||||
}
|
||||
|
||||
func (p *OauthProxy) redeemCode(host, code string) (string, string, error) {
|
||||
func (p *OauthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) {
|
||||
if code == "" {
|
||||
return "", "", errors.New("missing code")
|
||||
return nil, errors.New("missing code")
|
||||
}
|
||||
redirectUri := p.GetRedirectURI(host)
|
||||
body, access_token, err := p.provider.Redeem(redirectUri, code)
|
||||
s, err = p.provider.Redeem(redirectUri, code)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
return
|
||||
}
|
||||
|
||||
email, err := p.provider.GetEmailAddress(body, access_token)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
if s.Email == "" {
|
||||
s.Email, err = p.provider.GetEmailAddress(s)
|
||||
}
|
||||
|
||||
return access_token, email, nil
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
@ -208,9 +205,8 @@ func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time
|
||||
}
|
||||
|
||||
if value != "" {
|
||||
value = signedCookieValue(p.CookieSeed, p.CookieName, value, now)
|
||||
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: p.CookieName,
|
||||
Value: value,
|
||||
@ -230,35 +226,34 @@ func (p *OauthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val st
|
||||
http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (email, user, access_token string, ok bool) {
|
||||
var value string
|
||||
var timestamp time.Time
|
||||
cookie, err := req.Cookie(p.CookieName)
|
||||
if err == nil {
|
||||
value, timestamp, ok = validateCookie(cookie, p.CookieSeed, p.CookieExpire)
|
||||
if ok {
|
||||
email, user, access_token, err = parseCookieValue(value, p.AesCipher)
|
||||
}
|
||||
}
|
||||
func (p *OauthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
|
||||
var age time.Duration
|
||||
c, err := req.Cookie(p.CookieName)
|
||||
if err != nil {
|
||||
log.Printf(err.Error())
|
||||
ok = false
|
||||
} else if ok && p.CookieRefresh != time.Duration(0) {
|
||||
refresh := timestamp.Add(p.CookieRefresh)
|
||||
if refresh.Before(time.Now()) {
|
||||
log.Printf("refreshing %s old session for %s (refresh after %s)", time.Now().Sub(timestamp), email, p.CookieRefresh)
|
||||
ok = p.Validator(email)
|
||||
log.Printf("re-validating %s valid:%v", email, ok)
|
||||
if ok {
|
||||
ok = p.provider.ValidateToken(access_token)
|
||||
log.Printf("re-validating access token. valid:%v", ok)
|
||||
}
|
||||
if ok {
|
||||
p.SetCookie(rw, req, value)
|
||||
}
|
||||
}
|
||||
// always http.ErrNoCookie
|
||||
return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName)
|
||||
}
|
||||
return
|
||||
val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire)
|
||||
if !ok {
|
||||
return nil, age, errors.New("Cookie Signature not valid")
|
||||
}
|
||||
|
||||
session, err := p.provider.SessionFromCookie(val, p.CookieCipher)
|
||||
if err != nil {
|
||||
return nil, age, err
|
||||
}
|
||||
|
||||
age = time.Now().Truncate(time.Second).Sub(timestamp)
|
||||
return session, age, nil
|
||||
}
|
||||
|
||||
func (p *OauthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error {
|
||||
value, err := p.provider.CookieForSession(s, p.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.SetCookie(rw, req, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *OauthProxy) RobotsTxt(rw http.ResponseWriter) {
|
||||
@ -344,156 +339,226 @@ func (p *OauthProxy) GetRedirect(req *http.Request) (string, error) {
|
||||
return redirect, err
|
||||
}
|
||||
|
||||
func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// check if this is a redirect back at the end of oauth
|
||||
remoteAddr := req.RemoteAddr
|
||||
if req.Header.Get("X-Real-IP") != "" {
|
||||
remoteAddr += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP"))
|
||||
}
|
||||
|
||||
var ok bool
|
||||
var user string
|
||||
var email string
|
||||
var access_token string
|
||||
|
||||
if req.URL.Path == p.RobotsPath {
|
||||
p.RobotsTxt(rw)
|
||||
return
|
||||
}
|
||||
|
||||
if req.URL.Path == p.PingPath {
|
||||
p.PingPage(rw)
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OauthProxy) IsWhitelistedPath(path string) (ok bool) {
|
||||
for _, u := range p.compiledRegex {
|
||||
match := u.MatchString(req.URL.Path)
|
||||
if match {
|
||||
p.serveMux.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if req.URL.Path == p.SignInPath {
|
||||
redirect, err := p.GetRedirect(req)
|
||||
if err != nil {
|
||||
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user, ok = p.ManualSignIn(rw, req)
|
||||
ok = u.MatchString(path)
|
||||
if ok {
|
||||
p.SetCookie(rw, req, user)
|
||||
http.Redirect(rw, req, redirect, 302)
|
||||
} else {
|
||||
p.SignInPage(rw, req, 200)
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getRemoteAddr(req *http.Request) (s string) {
|
||||
s = req.RemoteAddr
|
||||
if req.Header.Get("X-Real-IP") != "" {
|
||||
s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP"))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
switch path := req.URL.Path; {
|
||||
case path == p.RobotsPath:
|
||||
p.RobotsTxt(rw)
|
||||
case path == p.PingPath:
|
||||
p.PingPage(rw)
|
||||
case p.IsWhitelistedPath(path):
|
||||
p.serveMux.ServeHTTP(rw, req)
|
||||
case path == p.SignInPath:
|
||||
p.SignIn(rw, req)
|
||||
case path == p.OauthStartPath:
|
||||
p.OauthStart(rw, req)
|
||||
case path == p.OauthCallbackPath:
|
||||
p.OauthCallback(rw, req)
|
||||
default:
|
||||
p.Proxy(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OauthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||
redirect, err := p.GetRedirect(req)
|
||||
if err != nil {
|
||||
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
if req.URL.Path == p.OauthStartPath {
|
||||
redirect, err := p.GetRedirect(req)
|
||||
if err != nil {
|
||||
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
redirectURI := p.GetRedirectURI(req.Host)
|
||||
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302)
|
||||
|
||||
user, ok := p.ManualSignIn(rw, req)
|
||||
if ok {
|
||||
session := &providers.SessionState{User: user}
|
||||
p.SaveSession(rw, req, session)
|
||||
http.Redirect(rw, req, redirect, 302)
|
||||
} else {
|
||||
p.SignInPage(rw, req, 200)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OauthProxy) OauthStart(rw http.ResponseWriter, req *http.Request) {
|
||||
redirect, err := p.GetRedirect(req)
|
||||
if err != nil {
|
||||
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
if req.URL.Path == p.OauthCallbackPath {
|
||||
// finish the oauth cycle
|
||||
err := req.ParseForm()
|
||||
redirectURI := p.GetRedirectURI(req.Host)
|
||||
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302)
|
||||
}
|
||||
|
||||
func (p *OauthProxy) OauthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
remoteAddr := getRemoteAddr(req)
|
||||
|
||||
// finish the oauth cycle
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
errorString := req.Form.Get("error")
|
||||
if errorString != "" {
|
||||
p.ErrorPage(rw, 403, "Permission Denied", errorString)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := p.redeemCode(req.Host, req.Form.Get("code"))
|
||||
if err != nil {
|
||||
log.Printf("%s error redeeming code %s", remoteAddr, err)
|
||||
p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
|
||||
return
|
||||
}
|
||||
|
||||
redirect := req.Form.Get("state")
|
||||
if redirect == "" {
|
||||
redirect = "/"
|
||||
}
|
||||
|
||||
// set cookie, or deny
|
||||
if p.Validator(session.Email) {
|
||||
log.Printf("%s authentication complete %s", remoteAddr, session)
|
||||
err := p.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
errorString := req.Form.Get("error")
|
||||
if errorString != "" {
|
||||
p.ErrorPage(rw, 403, "Permission Denied", errorString)
|
||||
log.Printf("%s %s", remoteAddr, err)
|
||||
p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
|
||||
return
|
||||
}
|
||||
http.Redirect(rw, req, redirect, 302)
|
||||
} else {
|
||||
log.Printf("%s Permission Denied: %q is unauthorized", remoteAddr, session.Email)
|
||||
p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account")
|
||||
}
|
||||
}
|
||||
|
||||
access_token, email, err = p.redeemCode(req.Host, req.Form.Get("code"))
|
||||
func (p *OauthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
||||
var saveSession, clearSession, revalidated bool
|
||||
remoteAddr := getRemoteAddr(req)
|
||||
|
||||
session, sessionAge, err := p.LoadCookiedSession(req)
|
||||
if err != nil {
|
||||
log.Printf("%s %s", remoteAddr, err)
|
||||
}
|
||||
if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
|
||||
log.Printf("%s refreshing %s old session cookie for %s (refresh after %s)", remoteAddr, sessionAge, session, p.CookieRefresh)
|
||||
saveSession = true
|
||||
}
|
||||
|
||||
if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil {
|
||||
log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session)
|
||||
clearSession = true
|
||||
session = nil
|
||||
} else if ok {
|
||||
saveSession = true
|
||||
revalidated = true
|
||||
}
|
||||
|
||||
if session != nil && session.IsExpired() {
|
||||
log.Printf("%s removing session. token expired %s", remoteAddr, session)
|
||||
session = nil
|
||||
saveSession = false
|
||||
clearSession = true
|
||||
}
|
||||
|
||||
if saveSession && !revalidated && session.AccessToken != "" {
|
||||
if !p.provider.ValidateSessionState(session) {
|
||||
log.Printf("%s removing session. error validating %s", remoteAddr, session)
|
||||
saveSession = false
|
||||
session = nil
|
||||
clearSession = true
|
||||
}
|
||||
}
|
||||
|
||||
if saveSession && session.Email != "" && !p.Validator(session.Email) {
|
||||
log.Printf("%s Permission Denied: removing session %s", remoteAddr, session)
|
||||
session = nil
|
||||
saveSession = false
|
||||
clearSession = true
|
||||
}
|
||||
|
||||
if saveSession {
|
||||
err := p.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
log.Printf("%s error redeeming code %s", remoteAddr, err)
|
||||
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
redirect := req.Form.Get("state")
|
||||
if redirect == "" {
|
||||
redirect = "/"
|
||||
}
|
||||
|
||||
// set cookie, or deny
|
||||
if p.Validator(email) {
|
||||
log.Printf("%s authenticating %s completed", remoteAddr, email)
|
||||
value, err := buildCookieValue(
|
||||
email, p.AesCipher, access_token)
|
||||
if err != nil {
|
||||
log.Printf("%s", err)
|
||||
}
|
||||
p.SetCookie(rw, req, value)
|
||||
http.Redirect(rw, req, redirect, 302)
|
||||
return
|
||||
} else {
|
||||
log.Printf("validating: %s is unauthorized")
|
||||
p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account")
|
||||
log.Printf("%s %s", remoteAddr, err)
|
||||
p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
email, user, access_token, ok = p.ProcessCookie(rw, req)
|
||||
if clearSession {
|
||||
p.ClearCookie(rw, req)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
user, ok = p.CheckBasicAuth(req)
|
||||
if session == nil {
|
||||
session, err = p.CheckBasicAuth(req)
|
||||
if err != nil {
|
||||
log.Printf("%s %s", remoteAddr, err)
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if session == nil {
|
||||
p.SignInPage(rw, req, 403)
|
||||
return
|
||||
}
|
||||
|
||||
// At this point, the user is authenticated. proxy normally
|
||||
if p.PassBasicAuth {
|
||||
req.SetBasicAuth(user, "")
|
||||
req.Header["X-Forwarded-User"] = []string{user}
|
||||
req.Header["X-Forwarded-Email"] = []string{email}
|
||||
req.SetBasicAuth(session.User, "")
|
||||
req.Header["X-Forwarded-User"] = []string{session.User}
|
||||
if session.Email != "" {
|
||||
req.Header["X-Forwarded-Email"] = []string{session.Email}
|
||||
}
|
||||
}
|
||||
if p.PassAccessToken {
|
||||
req.Header["X-Forwarded-Access-Token"] = []string{access_token}
|
||||
if p.PassAccessToken && session.AccessToken != "" {
|
||||
req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken}
|
||||
}
|
||||
if email == "" {
|
||||
rw.Header().Set("GAP-Auth", user)
|
||||
if session.Email == "" {
|
||||
rw.Header().Set("GAP-Auth", session.User)
|
||||
} else {
|
||||
rw.Header().Set("GAP-Auth", email)
|
||||
rw.Header().Set("GAP-Auth", session.Email)
|
||||
}
|
||||
|
||||
p.serveMux.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func (p *OauthProxy) CheckBasicAuth(req *http.Request) (string, bool) {
|
||||
func (p *OauthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
|
||||
if p.HtpasswdFile == nil {
|
||||
return "", false
|
||||
return nil, nil
|
||||
}
|
||||
s := strings.SplitN(req.Header.Get("Authorization"), " ", 2)
|
||||
auth := req.Header.Get("Authorization")
|
||||
if auth == "" {
|
||||
return nil, nil
|
||||
}
|
||||
s := strings.SplitN(auth, " ", 2)
|
||||
if len(s) != 2 || s[0] != "Basic" {
|
||||
return "", false
|
||||
return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization"))
|
||||
}
|
||||
b, err := base64.StdEncoding.DecodeString(s[1])
|
||||
if err != nil {
|
||||
return "", false
|
||||
return nil, err
|
||||
}
|
||||
pair := strings.SplitN(string(b), ":", 2)
|
||||
if len(pair) != 2 {
|
||||
return "", false
|
||||
return nil, fmt.Errorf("invalid format %s", b)
|
||||
}
|
||||
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
|
||||
log.Printf("authenticated %q via basic auth", pair[0])
|
||||
return pair[0], true
|
||||
return &providers.SessionState{User: pair[0]}, nil
|
||||
}
|
||||
return "", false
|
||||
return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0])
|
||||
}
|
||||
|
Reference in New Issue
Block a user