mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2024-11-24 08:52:25 +02:00
Lint for non-comment linter errors
This commit is contained in:
parent
990873eb42
commit
8ee802d4e5
@ -3,10 +3,6 @@
|
||||
# for detailed Gopkg.toml documentation.
|
||||
#
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/18F/hmacauth"
|
||||
version = "~1.0.1"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/BurntSushi/toml"
|
||||
version = "~0.3.0"
|
||||
|
@ -32,7 +32,7 @@ func Request(req *http.Request) (*simplejson.Json, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func RequestJson(req *http.Request, v interface{}) error {
|
||||
func RequestJSON(req *http.Request, v interface{}) error {
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("%s %s %s", req.Method, req.URL, err)
|
||||
|
@ -1,20 +1,21 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/bitly/go-simplejson"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testBackend(response_code int, payload string) *httptest.Server {
|
||||
func testBackend(responseCode int, payload string) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(response_code)
|
||||
w.WriteHeader(responseCode)
|
||||
w.Write([]byte(payload))
|
||||
}))
|
||||
}
|
||||
|
@ -24,10 +24,11 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
|
||||
const secret_b64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk="
|
||||
const secretBase64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk="
|
||||
const token = "my access token"
|
||||
|
||||
secret, err := base64.URLEncoding.DecodeString(secret_b64)
|
||||
secret, err := base64.URLEncoding.DecodeString(secretBase64)
|
||||
assert.Equal(t, nil, err)
|
||||
c, err := NewCipher([]byte(secret))
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Nonce generates a random 16 byte string to be used as a nonce
|
||||
func Nonce() (nonce string, err error) {
|
||||
b := make([]byte, 16)
|
||||
_, err = rand.Read(b)
|
||||
|
10
htpasswd.go
10
htpasswd.go
@ -28,12 +28,12 @@ func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
|
||||
}
|
||||
|
||||
func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) {
|
||||
csv_reader := csv.NewReader(file)
|
||||
csv_reader.Comma = ':'
|
||||
csv_reader.Comment = '#'
|
||||
csv_reader.TrimLeadingSpace = true
|
||||
csvReader := csv.NewReader(file)
|
||||
csvReader.Comma = ':'
|
||||
csvReader.Comment = '#'
|
||||
csvReader.TrimLeadingSpace = true
|
||||
|
||||
records, err := csv_reader.ReadAll()
|
||||
records, err := csvReader.ReadAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ func TestSHA(t *testing.T) {
|
||||
|
||||
func TestBcrypt(t *testing.T) {
|
||||
hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1)
|
||||
assert.Equal(t, err, nil)
|
||||
hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2)
|
||||
assert.Equal(t, err, nil)
|
||||
|
||||
|
10
http.go
10
http.go
@ -23,12 +23,12 @@ func (s *Server) ListenAndServe() {
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP() {
|
||||
httpAddress := s.Opts.HttpAddress
|
||||
HTTPAddress := s.Opts.HTTPAddress
|
||||
scheme := ""
|
||||
|
||||
i := strings.Index(httpAddress, "://")
|
||||
i := strings.Index(HTTPAddress, "://")
|
||||
if i > -1 {
|
||||
scheme = httpAddress[0:i]
|
||||
scheme = HTTPAddress[0:i]
|
||||
}
|
||||
|
||||
var networkType string
|
||||
@ -39,7 +39,7 @@ func (s *Server) ServeHTTP() {
|
||||
networkType = scheme
|
||||
}
|
||||
|
||||
slice := strings.SplitN(httpAddress, "//", 2)
|
||||
slice := strings.SplitN(HTTPAddress, "//", 2)
|
||||
listenAddr := slice[len(slice)-1]
|
||||
|
||||
listener, err := net.Listen(networkType, listenAddr)
|
||||
@ -58,7 +58,7 @@ func (s *Server) ServeHTTP() {
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTPS() {
|
||||
addr := s.Opts.HttpsAddress
|
||||
addr := s.Opts.HTTPSAddress
|
||||
config := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS12,
|
||||
|
@ -14,14 +14,19 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
"github.com/mbland/hmacauth"
|
||||
)
|
||||
|
||||
const SignatureHeader = "GAP-Signature"
|
||||
const (
|
||||
SignatureHeader = "GAP-Signature"
|
||||
|
||||
var SignatureHeaders []string = []string{
|
||||
httpScheme = "http"
|
||||
httpsScheme = "https"
|
||||
)
|
||||
|
||||
var SignatureHeaders = []string{
|
||||
"Content-Length",
|
||||
"Content-Md5",
|
||||
"Content-Type",
|
||||
@ -40,7 +45,7 @@ type OAuthProxy struct {
|
||||
CSRFCookieName string
|
||||
CookieDomain string
|
||||
CookieSecure bool
|
||||
CookieHttpOnly bool
|
||||
CookieHTTPOnly bool
|
||||
CookieExpire time.Duration
|
||||
CookieRefresh time.Duration
|
||||
Validator func(string) bool
|
||||
@ -125,7 +130,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
||||
for _, u := range opts.proxyURLs {
|
||||
path := u.Path
|
||||
switch u.Scheme {
|
||||
case "http", "https":
|
||||
case httpScheme, httpsScheme:
|
||||
u.Path = ""
|
||||
log.Printf("mapping path %q => upstream %q", path, u)
|
||||
proxy := NewReverseProxy(u)
|
||||
@ -160,7 +165,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
||||
refresh = fmt.Sprintf("after %s", opts.CookieRefresh)
|
||||
}
|
||||
|
||||
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh)
|
||||
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, refresh)
|
||||
|
||||
var cipher *cookie.Cipher
|
||||
if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
|
||||
@ -177,7 +182,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
||||
CookieSeed: opts.CookieSecret,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieHttpOnly: opts.CookieHttpOnly,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
CookieRefresh: opts.CookieRefresh,
|
||||
Validator: validator,
|
||||
@ -218,9 +223,9 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
|
||||
u = *p.redirectURL
|
||||
if u.Scheme == "" {
|
||||
if p.CookieSecure {
|
||||
u.Scheme = "https"
|
||||
u.Scheme = httpsScheme
|
||||
} else {
|
||||
u.Scheme = "http"
|
||||
u.Scheme = httpScheme
|
||||
}
|
||||
}
|
||||
u.Host = host
|
||||
@ -285,7 +290,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
|
||||
Value: value,
|
||||
Path: "/",
|
||||
Domain: p.CookieDomain,
|
||||
HttpOnly: p.CookieHttpOnly,
|
||||
HttpOnly: p.CookieHTTPOnly,
|
||||
Secure: p.CookieSecure,
|
||||
Expires: now.Add(expiration),
|
||||
}
|
||||
@ -374,12 +379,12 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
|
||||
p.ClearSessionCookie(rw, req)
|
||||
rw.WriteHeader(code)
|
||||
|
||||
redirect_url := req.URL.RequestURI()
|
||||
redirecURL := req.URL.RequestURI()
|
||||
if req.Header.Get("X-Auth-Request-Redirect") != "" {
|
||||
redirect_url = req.Header.Get("X-Auth-Request-Redirect")
|
||||
redirecURL = req.Header.Get("X-Auth-Request-Redirect")
|
||||
}
|
||||
if redirect_url == p.SignInPath {
|
||||
redirect_url = "/"
|
||||
if redirecURL == p.SignInPath {
|
||||
redirecURL = "/"
|
||||
}
|
||||
|
||||
t := struct {
|
||||
@ -394,7 +399,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
|
||||
ProviderName: p.provider.Data().ProviderName,
|
||||
SignInMessage: p.SignInMessage,
|
||||
CustomLogin: p.displayCustomLoginForm(),
|
||||
Redirect: redirect_url,
|
||||
Redirect: redirecURL,
|
||||
Version: VERSION,
|
||||
ProxyPrefix: p.ProxyPrefix,
|
||||
Footer: template.HTML(p.Footer),
|
||||
@ -653,7 +658,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
||||
}
|
||||
|
||||
if saveSession && session != nil {
|
||||
err := p.SaveSession(rw, req, session)
|
||||
err = p.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
log.Printf("%s %s", remoteAddr, err)
|
||||
return http.StatusInternalServerError
|
||||
|
@ -15,8 +15,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -98,28 +98,28 @@ type TestProvider struct {
|
||||
ValidToken bool
|
||||
}
|
||||
|
||||
func NewTestProvider(provider_url *url.URL, email_address string) *TestProvider {
|
||||
func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
|
||||
return &TestProvider{
|
||||
ProviderData: &providers.ProviderData{
|
||||
ProviderName: "Test Provider",
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: provider_url.Host,
|
||||
Host: providerURL.Host,
|
||||
Path: "/oauth/authorize",
|
||||
},
|
||||
RedeemURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: provider_url.Host,
|
||||
Host: providerURL.Host,
|
||||
Path: "/oauth/token",
|
||||
},
|
||||
ProfileURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: provider_url.Host,
|
||||
Host: providerURL.Host,
|
||||
Path: "/api/v1/profile",
|
||||
},
|
||||
Scope: "profile.email",
|
||||
},
|
||||
EmailAddress: email_address,
|
||||
EmailAddress: emailAddress,
|
||||
}
|
||||
}
|
||||
|
||||
@ -132,11 +132,10 @@ func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bo
|
||||
}
|
||||
|
||||
func TestBasicAuthPassword(t *testing.T) {
|
||||
provider_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%#v", r)
|
||||
url := r.URL
|
||||
payload := ""
|
||||
switch url.Path {
|
||||
var payload string
|
||||
switch r.URL.Path {
|
||||
case "/oauth/token":
|
||||
payload = `{"access_token": "my_auth_token"}`
|
||||
default:
|
||||
@ -149,7 +148,7 @@ func TestBasicAuthPassword(t *testing.T) {
|
||||
w.Write([]byte(payload))
|
||||
}))
|
||||
opts := NewOptions()
|
||||
opts.Upstreams = append(opts.Upstreams, provider_server.URL)
|
||||
opts.Upstreams = append(opts.Upstreams, providerServer.URL)
|
||||
// The CookieSecret must be 32 bytes in order to create the AES
|
||||
// cipher.
|
||||
opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
|
||||
@ -161,13 +160,13 @@ func TestBasicAuthPassword(t *testing.T) {
|
||||
opts.BasicAuthPassword = "This is a secure password"
|
||||
opts.Validate()
|
||||
|
||||
provider_url, _ := url.Parse(provider_server.URL)
|
||||
const email_address = "michael.bland@gsa.gov"
|
||||
const user_name = "michael.bland"
|
||||
providerURL, _ := url.Parse(providerServer.URL)
|
||||
const emailAddress = "michael.bland@gsa.gov"
|
||||
const username = "michael.bland"
|
||||
|
||||
opts.provider = NewTestProvider(provider_url, email_address)
|
||||
opts.provider = NewTestProvider(providerURL, emailAddress)
|
||||
proxy := NewOAuthProxy(opts, func(email string) bool {
|
||||
return email == email_address
|
||||
return email == emailAddress
|
||||
})
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@ -182,10 +181,10 @@ func TestBasicAuthPassword(t *testing.T) {
|
||||
|
||||
cookieName := proxy.CookieName
|
||||
var value string
|
||||
key_prefix := cookieName + "="
|
||||
keyPrefix := cookieName + "="
|
||||
|
||||
for _, field := range strings.Split(cookie, "; ") {
|
||||
value = strings.TrimPrefix(field, key_prefix)
|
||||
value = strings.TrimPrefix(field, keyPrefix)
|
||||
if value != field {
|
||||
break
|
||||
} else {
|
||||
@ -206,15 +205,15 @@ func TestBasicAuthPassword(t *testing.T) {
|
||||
rw = httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rw, req)
|
||||
|
||||
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword))
|
||||
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+opts.BasicAuthPassword))
|
||||
assert.Equal(t, expectedHeader, rw.Body.String())
|
||||
provider_server.Close()
|
||||
providerServer.Close()
|
||||
}
|
||||
|
||||
type PassAccessTokenTest struct {
|
||||
provider_server *httptest.Server
|
||||
proxy *OAuthProxy
|
||||
opts *Options
|
||||
providerServer *httptest.Server
|
||||
proxy *OAuthProxy
|
||||
opts *Options
|
||||
}
|
||||
|
||||
type PassAccessTokenTestOptions struct {
|
||||
@ -224,12 +223,11 @@ type PassAccessTokenTestOptions struct {
|
||||
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
|
||||
t := &PassAccessTokenTest{}
|
||||
|
||||
t.provider_server = httptest.NewServer(
|
||||
t.providerServer = httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("%#v", r)
|
||||
url := r.URL
|
||||
payload := ""
|
||||
switch url.Path {
|
||||
var payload string
|
||||
switch r.URL.Path {
|
||||
case "/oauth/token":
|
||||
payload = `{"access_token": "my_auth_token"}`
|
||||
default:
|
||||
@ -243,7 +241,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
|
||||
}))
|
||||
|
||||
t.opts = NewOptions()
|
||||
t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL)
|
||||
t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL)
|
||||
// The CookieSecret must be 32 bytes in order to create the AES
|
||||
// cipher.
|
||||
t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
|
||||
@ -253,21 +251,21 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
|
||||
t.opts.PassAccessToken = opts.PassAccessToken
|
||||
t.opts.Validate()
|
||||
|
||||
provider_url, _ := url.Parse(t.provider_server.URL)
|
||||
const email_address = "michael.bland@gsa.gov"
|
||||
providerURL, _ := url.Parse(t.providerServer.URL)
|
||||
const emailAddress = "michael.bland@gsa.gov"
|
||||
|
||||
t.opts.provider = NewTestProvider(provider_url, email_address)
|
||||
t.opts.provider = NewTestProvider(providerURL, emailAddress)
|
||||
t.proxy = NewOAuthProxy(t.opts, func(email string) bool {
|
||||
return email == email_address
|
||||
return email == emailAddress
|
||||
})
|
||||
return t
|
||||
}
|
||||
|
||||
func (pat_test *PassAccessTokenTest) Close() {
|
||||
pat_test.provider_server.Close()
|
||||
func (patTest *PassAccessTokenTest) Close() {
|
||||
patTest.providerServer.Close()
|
||||
}
|
||||
|
||||
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
|
||||
func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int,
|
||||
cookie string) {
|
||||
rw := httptest.NewRecorder()
|
||||
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
|
||||
@ -275,18 +273,18 @@ func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
|
||||
if err != nil {
|
||||
return 0, ""
|
||||
}
|
||||
req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
|
||||
pat_test.proxy.ServeHTTP(rw, req)
|
||||
req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
|
||||
patTest.proxy.ServeHTTP(rw, req)
|
||||
return rw.Code, rw.HeaderMap["Set-Cookie"][1]
|
||||
}
|
||||
|
||||
func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
|
||||
cookieName := pat_test.proxy.CookieName
|
||||
func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int, accessToken string) {
|
||||
cookieName := patTest.proxy.CookieName
|
||||
var value string
|
||||
key_prefix := cookieName + "="
|
||||
keyPrefix := cookieName + "="
|
||||
|
||||
for _, field := range strings.Split(cookie, "; ") {
|
||||
value = strings.TrimPrefix(field, key_prefix)
|
||||
value = strings.TrimPrefix(field, keyPrefix)
|
||||
if value != field {
|
||||
break
|
||||
} else {
|
||||
@ -310,18 +308,18 @@ func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code i
|
||||
})
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
pat_test.proxy.ServeHTTP(rw, req)
|
||||
patTest.proxy.ServeHTTP(rw, req)
|
||||
return rw.Code, rw.Body.String()
|
||||
}
|
||||
|
||||
func TestForwardAccessTokenUpstream(t *testing.T) {
|
||||
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
||||
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
||||
PassAccessToken: true,
|
||||
})
|
||||
defer pat_test.Close()
|
||||
defer patTest.Close()
|
||||
|
||||
// A successful validation will redirect and set the auth cookie.
|
||||
code, cookie := pat_test.getCallbackEndpoint()
|
||||
code, cookie := patTest.getCallbackEndpoint()
|
||||
if code != 302 {
|
||||
t.Fatalf("expected 302; got %d", code)
|
||||
}
|
||||
@ -330,7 +328,7 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
|
||||
// Now we make a regular request; the access_token from the cookie is
|
||||
// forwarded as the "X-Forwarded-Access-Token" header. The token is
|
||||
// read by the test provider server and written in the response body.
|
||||
code, payload := pat_test.getRootEndpoint(cookie)
|
||||
code, payload := patTest.getRootEndpoint(cookie)
|
||||
if code != 200 {
|
||||
t.Fatalf("expected 200; got %d", code)
|
||||
}
|
||||
@ -338,13 +336,13 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
||||
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
||||
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
||||
PassAccessToken: false,
|
||||
})
|
||||
defer pat_test.Close()
|
||||
defer patTest.Close()
|
||||
|
||||
// A successful validation will redirect and set the auth cookie.
|
||||
code, cookie := pat_test.getCallbackEndpoint()
|
||||
code, cookie := patTest.getCallbackEndpoint()
|
||||
if code != 302 {
|
||||
t.Fatalf("expected 302; got %d", code)
|
||||
}
|
||||
@ -352,7 +350,7 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
||||
|
||||
// Now we make a regular request, but the access token header should
|
||||
// not be present.
|
||||
code, payload := pat_test.getRootEndpoint(cookie)
|
||||
code, payload := patTest.getRootEndpoint(cookie)
|
||||
if code != 200 {
|
||||
t.Fatalf("expected 200; got %d", code)
|
||||
}
|
||||
@ -360,49 +358,49 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
type SignInPageTest struct {
|
||||
opts *Options
|
||||
proxy *OAuthProxy
|
||||
sign_in_regexp *regexp.Regexp
|
||||
sign_in_provider_regexp *regexp.Regexp
|
||||
opts *Options
|
||||
proxy *OAuthProxy
|
||||
signInRegexp *regexp.Regexp
|
||||
signInProviderRegexp *regexp.Regexp
|
||||
}
|
||||
|
||||
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
|
||||
const signInSkipProvider = `>Found<`
|
||||
|
||||
func NewSignInPageTest(skipProvider bool) *SignInPageTest {
|
||||
var sip_test SignInPageTest
|
||||
var sipTest SignInPageTest
|
||||
|
||||
sip_test.opts = NewOptions()
|
||||
sip_test.opts.CookieSecret = "foobar"
|
||||
sip_test.opts.ClientID = "bazquux"
|
||||
sip_test.opts.ClientSecret = "xyzzyplugh"
|
||||
sip_test.opts.SkipProviderButton = skipProvider
|
||||
sip_test.opts.Validate()
|
||||
sipTest.opts = NewOptions()
|
||||
sipTest.opts.CookieSecret = "foobar"
|
||||
sipTest.opts.ClientID = "bazquux"
|
||||
sipTest.opts.ClientSecret = "xyzzyplugh"
|
||||
sipTest.opts.SkipProviderButton = skipProvider
|
||||
sipTest.opts.Validate()
|
||||
|
||||
sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool {
|
||||
sipTest.proxy = NewOAuthProxy(sipTest.opts, func(email string) bool {
|
||||
return true
|
||||
})
|
||||
sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern)
|
||||
sip_test.sign_in_provider_regexp = regexp.MustCompile(signInSkipProvider)
|
||||
sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern)
|
||||
sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider)
|
||||
|
||||
return &sip_test
|
||||
return &sipTest
|
||||
}
|
||||
|
||||
func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
|
||||
func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
|
||||
rw := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
|
||||
sip_test.proxy.ServeHTTP(rw, req)
|
||||
sipTest.proxy.ServeHTTP(rw, req)
|
||||
return rw.Code, rw.Body.String()
|
||||
}
|
||||
|
||||
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
|
||||
sip_test := NewSignInPageTest(false)
|
||||
sipTest := NewSignInPageTest(false)
|
||||
const endpoint = "/some/random/endpoint"
|
||||
|
||||
code, body := sip_test.GetEndpoint(endpoint)
|
||||
code, body := sipTest.GetEndpoint(endpoint)
|
||||
assert.Equal(t, 403, code)
|
||||
|
||||
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
|
||||
match := sipTest.signInRegexp.FindStringSubmatch(body)
|
||||
if match == nil {
|
||||
t.Fatal("Did not find pattern in body: " +
|
||||
signInRedirectPattern + "\nBody:\n" + body)
|
||||
@ -414,11 +412,11 @@ func TestSignInPageIncludesTargetRedirect(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
|
||||
sip_test := NewSignInPageTest(false)
|
||||
code, body := sip_test.GetEndpoint("/oauth2/sign_in")
|
||||
sipTest := NewSignInPageTest(false)
|
||||
code, body := sipTest.GetEndpoint("/oauth2/sign_in")
|
||||
assert.Equal(t, 200, code)
|
||||
|
||||
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
|
||||
match := sipTest.signInRegexp.FindStringSubmatch(body)
|
||||
if match == nil {
|
||||
t.Fatal("Did not find pattern in body: " +
|
||||
signInRedirectPattern + "\nBody:\n" + body)
|
||||
@ -429,13 +427,13 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSignInPageSkipProvider(t *testing.T) {
|
||||
sip_test := NewSignInPageTest(true)
|
||||
sipTest := NewSignInPageTest(true)
|
||||
const endpoint = "/some/random/endpoint"
|
||||
|
||||
code, body := sip_test.GetEndpoint(endpoint)
|
||||
code, body := sipTest.GetEndpoint(endpoint)
|
||||
assert.Equal(t, 302, code)
|
||||
|
||||
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body)
|
||||
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
|
||||
if match == nil {
|
||||
t.Fatal("Did not find pattern in body: " +
|
||||
signInSkipProvider + "\nBody:\n" + body)
|
||||
@ -443,13 +441,13 @@ func TestSignInPageSkipProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSignInPageSkipProviderDirect(t *testing.T) {
|
||||
sip_test := NewSignInPageTest(true)
|
||||
sipTest := NewSignInPageTest(true)
|
||||
const endpoint = "/sign_in"
|
||||
|
||||
code, body := sip_test.GetEndpoint(endpoint)
|
||||
code, body := sipTest.GetEndpoint(endpoint)
|
||||
assert.Equal(t, 302, code)
|
||||
|
||||
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body)
|
||||
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
|
||||
if match == nil {
|
||||
t.Fatal("Did not find pattern in body: " +
|
||||
signInSkipProvider + "\nBody:\n" + body)
|
||||
@ -457,50 +455,50 @@ func TestSignInPageSkipProviderDirect(t *testing.T) {
|
||||
}
|
||||
|
||||
type ProcessCookieTest struct {
|
||||
opts *Options
|
||||
proxy *OAuthProxy
|
||||
rw *httptest.ResponseRecorder
|
||||
req *http.Request
|
||||
provider TestProvider
|
||||
response_code int
|
||||
validate_user bool
|
||||
opts *Options
|
||||
proxy *OAuthProxy
|
||||
rw *httptest.ResponseRecorder
|
||||
req *http.Request
|
||||
provider TestProvider
|
||||
responseCode int
|
||||
validateUser bool
|
||||
}
|
||||
|
||||
type ProcessCookieTestOpts struct {
|
||||
provider_validate_cookie_response bool
|
||||
providerValidateCookieResponse bool
|
||||
}
|
||||
|
||||
func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
|
||||
var pc_test ProcessCookieTest
|
||||
var pcTest ProcessCookieTest
|
||||
|
||||
pc_test.opts = NewOptions()
|
||||
pc_test.opts.ClientID = "bazquux"
|
||||
pc_test.opts.ClientSecret = "xyzzyplugh"
|
||||
pc_test.opts.CookieSecret = "0123456789abcdefabcd"
|
||||
pcTest.opts = NewOptions()
|
||||
pcTest.opts.ClientID = "bazquux"
|
||||
pcTest.opts.ClientSecret = "xyzzyplugh"
|
||||
pcTest.opts.CookieSecret = "0123456789abcdefabcd"
|
||||
// First, set the CookieRefresh option so proxy.AesCipher is created,
|
||||
// needed to encrypt the access_token.
|
||||
pc_test.opts.CookieRefresh = time.Hour
|
||||
pc_test.opts.Validate()
|
||||
pcTest.opts.CookieRefresh = time.Hour
|
||||
pcTest.opts.Validate()
|
||||
|
||||
pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool {
|
||||
return pc_test.validate_user
|
||||
pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
||||
return pcTest.validateUser
|
||||
})
|
||||
pc_test.proxy.provider = &TestProvider{
|
||||
ValidToken: opts.provider_validate_cookie_response,
|
||||
pcTest.proxy.provider = &TestProvider{
|
||||
ValidToken: opts.providerValidateCookieResponse,
|
||||
}
|
||||
|
||||
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
||||
// access_token validation.
|
||||
pc_test.proxy.CookieRefresh = time.Duration(0)
|
||||
pc_test.rw = httptest.NewRecorder()
|
||||
pc_test.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
|
||||
pc_test.validate_user = true
|
||||
return &pc_test
|
||||
pcTest.proxy.CookieRefresh = time.Duration(0)
|
||||
pcTest.rw = httptest.NewRecorder()
|
||||
pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
|
||||
pcTest.validateUser = true
|
||||
return &pcTest
|
||||
}
|
||||
|
||||
func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
|
||||
return NewProcessCookieTest(ProcessCookieTestOpts{
|
||||
provider_validate_cookie_response: true,
|
||||
providerValidateCookieResponse: true,
|
||||
})
|
||||
}
|
||||
|
||||
@ -522,12 +520,12 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.
|
||||
}
|
||||
|
||||
func TestLoadCookiedSession(t *testing.T) {
|
||||
pc_test := NewProcessCookieTestWithDefaults()
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
|
||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pc_test.SaveSession(startSession, time.Now())
|
||||
pcTest.SaveSession(startSession, time.Now())
|
||||
|
||||
session, _, err := pc_test.LoadCookiedSession()
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, startSession.Email, session.Email)
|
||||
assert.Equal(t, "michael.bland", session.User)
|
||||
@ -535,9 +533,9 @@ func TestLoadCookiedSession(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProcessCookieNoCookieError(t *testing.T) {
|
||||
pc_test := NewProcessCookieTestWithDefaults()
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
|
||||
session, _, err := pc_test.LoadCookiedSession()
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
|
||||
if session != nil {
|
||||
t.Errorf("expected nil session. got %#v", session)
|
||||
@ -545,14 +543,14 @@ func TestProcessCookieNoCookieError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProcessCookieRefreshNotSet(t *testing.T) {
|
||||
pc_test := NewProcessCookieTestWithDefaults()
|
||||
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
|
||||
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
||||
|
||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pc_test.SaveSession(startSession, reference)
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
|
||||
session, age, err := pc_test.LoadCookiedSession()
|
||||
session, age, err := pcTest.LoadCookiedSession()
|
||||
assert.Equal(t, nil, err)
|
||||
if age < time.Duration(-2)*time.Hour {
|
||||
t.Errorf("cookie too young %v", age)
|
||||
@ -561,13 +559,13 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
||||
pc_test := NewProcessCookieTestWithDefaults()
|
||||
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pc_test.SaveSession(startSession, reference)
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
|
||||
session, _, err := pc_test.LoadCookiedSession()
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
assert.NotEqual(t, nil, err)
|
||||
if session != nil {
|
||||
t.Errorf("expected nil session %#v", session)
|
||||
@ -575,14 +573,14 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
||||
pc_test := NewProcessCookieTestWithDefaults()
|
||||
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pc_test.SaveSession(startSession, reference)
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
|
||||
pc_test.proxy.CookieRefresh = time.Hour
|
||||
session, _, err := pc_test.LoadCookiedSession()
|
||||
pcTest.proxy.CookieRefresh = time.Hour
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
assert.NotEqual(t, nil, err)
|
||||
if session != nil {
|
||||
t.Errorf("expected nil session %#v", session)
|
||||
@ -590,10 +588,10 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
||||
}
|
||||
|
||||
func NewAuthOnlyEndpointTest() *ProcessCookieTest {
|
||||
pc_test := NewProcessCookieTestWithDefaults()
|
||||
pc_test.req, _ = http.NewRequest("GET",
|
||||
pc_test.opts.ProxyPrefix+"/auth", nil)
|
||||
return pc_test
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
pcTest.req, _ = http.NewRequest("GET",
|
||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
||||
return pcTest
|
||||
}
|
||||
|
||||
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
||||
@ -636,7 +634,7 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
||||
startSession := &providers.SessionState{
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
test.SaveSession(startSession, time.Now())
|
||||
test.validate_user = false
|
||||
test.validateUser = false
|
||||
|
||||
test.proxy.ServeHTTP(test.rw, test.req)
|
||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||
@ -645,33 +643,33 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
||||
var pc_test ProcessCookieTest
|
||||
var pcTest ProcessCookieTest
|
||||
|
||||
pc_test.opts = NewOptions()
|
||||
pc_test.opts.SetXAuthRequest = true
|
||||
pc_test.opts.Validate()
|
||||
pcTest.opts = NewOptions()
|
||||
pcTest.opts.SetXAuthRequest = true
|
||||
pcTest.opts.Validate()
|
||||
|
||||
pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool {
|
||||
return pc_test.validate_user
|
||||
pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
||||
return pcTest.validateUser
|
||||
})
|
||||
pc_test.proxy.provider = &TestProvider{
|
||||
pcTest.proxy.provider = &TestProvider{
|
||||
ValidToken: true,
|
||||
}
|
||||
|
||||
pc_test.validate_user = true
|
||||
pcTest.validateUser = true
|
||||
|
||||
pc_test.rw = httptest.NewRecorder()
|
||||
pc_test.req, _ = http.NewRequest("GET",
|
||||
pc_test.opts.ProxyPrefix+"/auth", nil)
|
||||
pcTest.rw = httptest.NewRecorder()
|
||||
pcTest.req, _ = http.NewRequest("GET",
|
||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
||||
|
||||
startSession := &providers.SessionState{
|
||||
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
|
||||
pc_test.SaveSession(startSession, time.Now())
|
||||
pcTest.SaveSession(startSession, time.Now())
|
||||
|
||||
pc_test.proxy.ServeHTTP(pc_test.rw, pc_test.req)
|
||||
assert.Equal(t, http.StatusAccepted, pc_test.rw.Code)
|
||||
assert.Equal(t, "oauth_user", pc_test.rw.HeaderMap["X-Auth-Request-User"][0])
|
||||
assert.Equal(t, "oauth_user@example.com", pc_test.rw.HeaderMap["X-Auth-Request-Email"][0])
|
||||
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
||||
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
||||
assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0])
|
||||
assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0])
|
||||
}
|
||||
|
||||
func TestAuthSkippedForPreflightRequests(t *testing.T) {
|
||||
@ -689,8 +687,8 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
|
||||
opts.SkipAuthPreflight = true
|
||||
opts.Validate()
|
||||
|
||||
upstream_url, _ := url.Parse(upstream.URL)
|
||||
opts.provider = NewTestProvider(upstream_url, "")
|
||||
upstreamURL, _ := url.Parse(upstream.URL)
|
||||
opts.provider = NewTestProvider(upstreamURL, "")
|
||||
|
||||
proxy := NewOAuthProxy(opts, func(string) bool { return false })
|
||||
rw := httptest.NewRecorder()
|
||||
@ -723,7 +721,7 @@ func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Req
|
||||
type SignatureTest struct {
|
||||
opts *Options
|
||||
upstream *httptest.Server
|
||||
upstream_host string
|
||||
upstreamHost string
|
||||
provider *httptest.Server
|
||||
header http.Header
|
||||
rw *httptest.ResponseRecorder
|
||||
@ -740,20 +738,20 @@ func NewSignatureTest() *SignatureTest {
|
||||
authenticator := &SignatureAuthenticator{}
|
||||
upstream := httptest.NewServer(
|
||||
http.HandlerFunc(authenticator.Authenticate))
|
||||
upstream_url, _ := url.Parse(upstream.URL)
|
||||
upstreamURL, _ := url.Parse(upstream.URL)
|
||||
opts.Upstreams = append(opts.Upstreams, upstream.URL)
|
||||
|
||||
providerHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"access_token": "my_auth_token"}`))
|
||||
}
|
||||
provider := httptest.NewServer(http.HandlerFunc(providerHandler))
|
||||
provider_url, _ := url.Parse(provider.URL)
|
||||
opts.provider = NewTestProvider(provider_url, "mbland@acm.org")
|
||||
providerURL, _ := url.Parse(provider.URL)
|
||||
opts.provider = NewTestProvider(providerURL, "mbland@acm.org")
|
||||
|
||||
return &SignatureTest{
|
||||
opts,
|
||||
upstream,
|
||||
upstream_url.Host,
|
||||
upstreamURL.Host,
|
||||
provider,
|
||||
make(http.Header),
|
||||
httptest.NewRecorder(),
|
||||
|
36
options.go
36
options.go
@ -13,16 +13,17 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
)
|
||||
|
||||
// Configuration Options that can be set by Command Line Flag, or Config File
|
||||
// Options holds Configuration Options that can be set by Command Line Flag,
|
||||
// or Config File
|
||||
type Options struct {
|
||||
ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"`
|
||||
HttpAddress string `flag:"http-address" cfg:"http_address"`
|
||||
HttpsAddress string `flag:"https-address" cfg:"https_address"`
|
||||
HTTPAddress string `flag:"http-address" cfg:"http_address"`
|
||||
HTTPSAddress string `flag:"https-address" cfg:"https_address"`
|
||||
RedirectURL string `flag:"redirect-url" cfg:"redirect_url"`
|
||||
ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"`
|
||||
ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"`
|
||||
@ -48,7 +49,7 @@ type Options struct {
|
||||
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"`
|
||||
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"`
|
||||
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"`
|
||||
CookieHttpOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"`
|
||||
CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"`
|
||||
|
||||
Upstreams []string `flag:"upstream" cfg:"upstreams"`
|
||||
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"`
|
||||
@ -96,12 +97,12 @@ type SignatureData struct {
|
||||
func NewOptions() *Options {
|
||||
return &Options{
|
||||
ProxyPrefix: "/oauth2",
|
||||
HttpAddress: "127.0.0.1:4180",
|
||||
HttpsAddress: ":443",
|
||||
HTTPAddress: "127.0.0.1:4180",
|
||||
HTTPSAddress: ":443",
|
||||
DisplayHtpasswdForm: true,
|
||||
CookieName: "_oauth2_proxy",
|
||||
CookieSecure: true,
|
||||
CookieHttpOnly: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieRefresh: time.Duration(0),
|
||||
SetXAuthRequest: false,
|
||||
@ -116,11 +117,11 @@ func NewOptions() *Options {
|
||||
}
|
||||
}
|
||||
|
||||
func parseURL(to_parse string, urltype string, msgs []string) (*url.URL, []string) {
|
||||
parsed, err := url.Parse(to_parse)
|
||||
func parseURL(toParse string, urltype string, msgs []string) (*url.URL, []string) {
|
||||
parsed, err := url.Parse(toParse)
|
||||
if err != nil {
|
||||
return nil, append(msgs, fmt.Sprintf(
|
||||
"error parsing %s-url=%q %s", urltype, to_parse, err))
|
||||
"error parsing %s-url=%q %s", urltype, toParse, err))
|
||||
}
|
||||
return parsed, msgs
|
||||
}
|
||||
@ -190,17 +191,17 @@ func (o *Options) Validate() error {
|
||||
msgs = parseProviderInfo(o, msgs)
|
||||
|
||||
if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) {
|
||||
valid_cookie_secret_size := false
|
||||
validCookieSecretSize := false
|
||||
for _, i := range []int{16, 24, 32} {
|
||||
if len(secretBytes(o.CookieSecret)) == i {
|
||||
valid_cookie_secret_size = true
|
||||
validCookieSecretSize = true
|
||||
}
|
||||
}
|
||||
var decoded bool
|
||||
if string(secretBytes(o.CookieSecret)) != o.CookieSecret {
|
||||
decoded = true
|
||||
}
|
||||
if valid_cookie_secret_size == false {
|
||||
if validCookieSecretSize == false {
|
||||
var suffix string
|
||||
if decoded {
|
||||
suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret)
|
||||
@ -294,12 +295,13 @@ func parseSignatureKey(o *Options, msgs []string) []string {
|
||||
}
|
||||
|
||||
algorithm, secretKey := components[0], components[1]
|
||||
if hash, err := hmacauth.DigestNameToCryptoHash(algorithm); err != nil {
|
||||
var hash crypto.Hash
|
||||
var err error
|
||||
if hash, err = hmacauth.DigestNameToCryptoHash(algorithm); err != nil {
|
||||
return append(msgs, "unsupported signature hash algorithm: "+
|
||||
o.SignatureKey)
|
||||
} else {
|
||||
o.signatureData = &SignatureData{hash, secretKey}
|
||||
}
|
||||
o.signatureData = &SignatureData{hash, secretKey}
|
||||
return msgs
|
||||
}
|
||||
|
||||
|
@ -88,9 +88,9 @@ func TestProxyURLs(t *testing.T) {
|
||||
o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081")
|
||||
assert.Equal(t, nil, o.Validate())
|
||||
expected := []*url.URL{
|
||||
&url.URL{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"},
|
||||
{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"},
|
||||
// note the '/' was added
|
||||
&url.URL{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"},
|
||||
{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"},
|
||||
}
|
||||
assert.Equal(t, expected, o.proxyURLs)
|
||||
}
|
||||
|
@ -3,11 +3,12 @@ package providers
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bitly/go-simplejson"
|
||||
"github.com/pusher/oauth2_proxy/api"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
"github.com/pusher/oauth2_proxy/api"
|
||||
)
|
||||
|
||||
type AzureProvider struct {
|
||||
@ -60,9 +61,9 @@ func (p *AzureProvider) Configure(tenant string) {
|
||||
}
|
||||
}
|
||||
|
||||
func getAzureHeader(access_token string) http.Header {
|
||||
func getAzureHeader(accessToken string) http.Header {
|
||||
header := make(http.Header)
|
||||
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
|
||||
header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||
return header
|
||||
}
|
||||
|
||||
|
@ -110,8 +110,7 @@ func testAzureBackend(payload string) *httptest.Server {
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
url := r.URL
|
||||
if url.Path != path || url.RawQuery != query {
|
||||
if r.URL.Path != path || r.URL.RawQuery != query {
|
||||
w.WriteHeader(404)
|
||||
} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" {
|
||||
w.WriteHeader(403)
|
||||
|
@ -43,11 +43,11 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider {
|
||||
return &FacebookProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
func getFacebookHeader(access_token string) http.Header {
|
||||
func getFacebookHeader(accessToken string) http.Header {
|
||||
header := make(http.Header)
|
||||
header.Set("Accept", "application/json")
|
||||
header.Set("x-li-format", "json")
|
||||
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
|
||||
header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||
return header
|
||||
}
|
||||
|
||||
@ -65,7 +65,7 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
Email string
|
||||
}
|
||||
var r result
|
||||
err = api.RequestJson(req, &r)
|
||||
err = api.RequestJSON(req, &r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -106,7 +106,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
|
||||
}
|
||||
|
||||
orgs = append(orgs, op...)
|
||||
pn += 1
|
||||
pn++
|
||||
}
|
||||
|
||||
var presentOrgs []string
|
||||
@ -186,7 +186,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
|
||||
log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams)
|
||||
} else {
|
||||
var allOrgs []string
|
||||
for org, _ := range presentOrgs {
|
||||
for org := range presentOrgs {
|
||||
allOrgs = append(allOrgs, org)
|
||||
}
|
||||
log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs)
|
||||
|
@ -29,19 +29,18 @@ func testGitHubProvider(hostname string) *GitHubProvider {
|
||||
|
||||
func testGitHubBackend(payload []string) *httptest.Server {
|
||||
pathToQueryMap := map[string][]string{
|
||||
"/user": []string{""},
|
||||
"/user/emails": []string{""},
|
||||
"/user/orgs": []string{"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"},
|
||||
"/user": {""},
|
||||
"/user/emails": {""},
|
||||
"/user/orgs": {"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"},
|
||||
}
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
url := r.URL
|
||||
query, ok := pathToQueryMap[url.Path]
|
||||
query, ok := pathToQueryMap[r.URL.Path]
|
||||
validQuery := false
|
||||
index := 0
|
||||
for i, q := range query {
|
||||
if q == url.RawQuery {
|
||||
if q == r.URL.RawQuery {
|
||||
validQuery = true
|
||||
index = i
|
||||
}
|
||||
|
@ -33,8 +33,7 @@ func testGitLabBackend(payload string) *httptest.Server {
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
url := r.URL
|
||||
if url.Path != path || url.RawQuery != query {
|
||||
if r.URL.Path != path || r.URL.RawQuery != query {
|
||||
w.WriteHeader(404)
|
||||
} else {
|
||||
w.WriteHeader(200)
|
||||
@ -87,8 +86,8 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) {
|
||||
b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}")
|
||||
defer b.Close()
|
||||
|
||||
b_url, _ := url.Parse(b.URL)
|
||||
p := testGitLabProvider(b_url.Host)
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitLabProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
@ -102,8 +101,8 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
b := testGitLabBackend("unused payload")
|
||||
defer b.Close()
|
||||
|
||||
b_url, _ := url.Parse(b.URL)
|
||||
p := testGitLabProvider(b_url.Host)
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitLabProvider(bURL.Host)
|
||||
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
@ -118,8 +117,8 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
b := testGitLabBackend("{\"foo\": \"bar\"}")
|
||||
defer b.Close()
|
||||
|
||||
b_url, _ := url.Parse(b.URL)
|
||||
p := testGitLabProvider(b_url.Host)
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitLabProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
|
@ -62,7 +62,7 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
|
||||
}
|
||||
}
|
||||
|
||||
func emailFromIdToken(idToken string) (string, error) {
|
||||
func emailFromIDToken(idToken string) (string, error) {
|
||||
|
||||
// id_token is a base64 encode ID token payload
|
||||
// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
|
||||
@ -129,14 +129,14 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IdToken string `json:"id_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var email string
|
||||
email, err = emailFromIdToken(jsonResponse.IdToken)
|
||||
email, err = emailFromIDToken(jsonResponse.IDToken)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ type redeemResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IdToken string `json:"id_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
func TestGoogleProviderGetEmailAddress(t *testing.T) {
|
||||
@ -90,7 +90,7 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) {
|
||||
AccessToken: "a1234",
|
||||
ExpiresIn: 10,
|
||||
RefreshToken: "refresh12345",
|
||||
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)),
|
||||
IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)),
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
var server *httptest.Server
|
||||
@ -127,7 +127,7 @@ func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) {
|
||||
p := newGoogleProvider()
|
||||
body, err := json.Marshal(redeemResponse{
|
||||
AccessToken: "a1234",
|
||||
IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`,
|
||||
IDToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`,
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
var server *httptest.Server
|
||||
@ -146,7 +146,7 @@ func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) {
|
||||
|
||||
body, err := json.Marshal(redeemResponse{
|
||||
AccessToken: "a1234",
|
||||
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)),
|
||||
IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)),
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
var server *httptest.Server
|
||||
@ -165,7 +165,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
|
||||
p := newGoogleProvider()
|
||||
body, err := json.Marshal(redeemResponse{
|
||||
AccessToken: "a1234",
|
||||
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)),
|
||||
IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)),
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
var server *httptest.Server
|
||||
|
@ -46,13 +46,13 @@ func stripParam(param, endpoint string) string {
|
||||
}
|
||||
|
||||
// validateToken returns true if token is valid
|
||||
func validateToken(p Provider, access_token string, header http.Header) bool {
|
||||
if access_token == "" || p.Data().ValidateURL == nil {
|
||||
func validateToken(p Provider, accessToken string, header http.Header) bool {
|
||||
if accessToken == "" || p.Data().ValidateURL == nil {
|
||||
return false
|
||||
}
|
||||
endpoint := p.Data().ValidateURL.String()
|
||||
if len(header) == 0 {
|
||||
params := url.Values{"access_token": {access_token}}
|
||||
params := url.Values{"access_token": {accessToken}}
|
||||
endpoint = endpoint + "?" + params.Encode()
|
||||
}
|
||||
resp, err := api.RequestUnparsedResponse(endpoint, header)
|
||||
@ -72,8 +72,3 @@ func validateToken(p Provider, access_token string, header http.Header) bool {
|
||||
log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body)
|
||||
return false
|
||||
}
|
||||
|
||||
func updateURL(url *url.URL, hostname string) {
|
||||
url.Scheme = "http"
|
||||
url.Host = hostname
|
||||
}
|
||||
|
@ -10,6 +10,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func updateURL(url *url.URL, hostname string) {
|
||||
url.Scheme = "http"
|
||||
url.Host = hostname
|
||||
}
|
||||
|
||||
type ValidateSessionStateTestProvider struct {
|
||||
*ProviderData
|
||||
}
|
||||
@ -25,28 +30,28 @@ func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState
|
||||
}
|
||||
|
||||
type ValidateSessionStateTest struct {
|
||||
backend *httptest.Server
|
||||
response_code int
|
||||
provider *ValidateSessionStateTestProvider
|
||||
header http.Header
|
||||
backend *httptest.Server
|
||||
responseCode int
|
||||
provider *ValidateSessionStateTestProvider
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func NewValidateSessionStateTest() *ValidateSessionStateTest {
|
||||
var vt_test ValidateSessionStateTest
|
||||
var vtTest ValidateSessionStateTest
|
||||
|
||||
vt_test.backend = httptest.NewServer(
|
||||
vtTest.backend = httptest.NewServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/oauth/tokeninfo" {
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte("unknown URL"))
|
||||
}
|
||||
token_param := r.FormValue("access_token")
|
||||
if token_param == "" {
|
||||
tokenParam := r.FormValue("access_token")
|
||||
if tokenParam == "" {
|
||||
missing := false
|
||||
received_headers := r.Header
|
||||
for k, _ := range vt_test.header {
|
||||
received := received_headers.Get(k)
|
||||
expected := vt_test.header.Get(k)
|
||||
receivedHeaders := r.Header
|
||||
for k := range vtTest.header {
|
||||
received := receivedHeaders.Get(k)
|
||||
expected := vtTest.header.Get(k)
|
||||
if received == "" || received != expected {
|
||||
missing = true
|
||||
}
|
||||
@ -56,68 +61,68 @@ func NewValidateSessionStateTest() *ValidateSessionStateTest {
|
||||
w.Write([]byte("no token param and missing or incorrect headers"))
|
||||
}
|
||||
}
|
||||
w.WriteHeader(vt_test.response_code)
|
||||
w.WriteHeader(vtTest.responseCode)
|
||||
w.Write([]byte("only code matters; contents disregarded"))
|
||||
|
||||
}))
|
||||
backend_url, _ := url.Parse(vt_test.backend.URL)
|
||||
vt_test.provider = &ValidateSessionStateTestProvider{
|
||||
backendURL, _ := url.Parse(vtTest.backend.URL)
|
||||
vtTest.provider = &ValidateSessionStateTestProvider{
|
||||
ProviderData: &ProviderData{
|
||||
ValidateURL: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: backend_url.Host,
|
||||
Host: backendURL.Host,
|
||||
Path: "/oauth/tokeninfo",
|
||||
},
|
||||
},
|
||||
}
|
||||
vt_test.response_code = 200
|
||||
return &vt_test
|
||||
vtTest.responseCode = 200
|
||||
return &vtTest
|
||||
}
|
||||
|
||||
func (vt_test *ValidateSessionStateTest) Close() {
|
||||
vt_test.backend.Close()
|
||||
func (vtTest *ValidateSessionStateTest) Close() {
|
||||
vtTest.backend.Close()
|
||||
}
|
||||
|
||||
func TestValidateSessionStateValidToken(t *testing.T) {
|
||||
vt_test := NewValidateSessionStateTest()
|
||||
defer vt_test.Close()
|
||||
assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil))
|
||||
vtTest := NewValidateSessionStateTest()
|
||||
defer vtTest.Close()
|
||||
assert.Equal(t, true, validateToken(vtTest.provider, "foobar", nil))
|
||||
}
|
||||
|
||||
func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) {
|
||||
vt_test := NewValidateSessionStateTest()
|
||||
defer vt_test.Close()
|
||||
vt_test.header = make(http.Header)
|
||||
vt_test.header.Set("Authorization", "Bearer foobar")
|
||||
vtTest := NewValidateSessionStateTest()
|
||||
defer vtTest.Close()
|
||||
vtTest.header = make(http.Header)
|
||||
vtTest.header.Set("Authorization", "Bearer foobar")
|
||||
assert.Equal(t, true,
|
||||
validateToken(vt_test.provider, "foobar", vt_test.header))
|
||||
validateToken(vtTest.provider, "foobar", vtTest.header))
|
||||
}
|
||||
|
||||
func TestValidateSessionStateEmptyToken(t *testing.T) {
|
||||
vt_test := NewValidateSessionStateTest()
|
||||
defer vt_test.Close()
|
||||
assert.Equal(t, false, validateToken(vt_test.provider, "", nil))
|
||||
vtTest := NewValidateSessionStateTest()
|
||||
defer vtTest.Close()
|
||||
assert.Equal(t, false, validateToken(vtTest.provider, "", nil))
|
||||
}
|
||||
|
||||
func TestValidateSessionStateEmptyValidateURL(t *testing.T) {
|
||||
vt_test := NewValidateSessionStateTest()
|
||||
defer vt_test.Close()
|
||||
vt_test.provider.Data().ValidateURL = nil
|
||||
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
|
||||
vtTest := NewValidateSessionStateTest()
|
||||
defer vtTest.Close()
|
||||
vtTest.provider.Data().ValidateURL = nil
|
||||
assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
|
||||
}
|
||||
|
||||
func TestValidateSessionStateRequestNetworkFailure(t *testing.T) {
|
||||
vt_test := NewValidateSessionStateTest()
|
||||
vtTest := NewValidateSessionStateTest()
|
||||
// Close immediately to simulate a network failure
|
||||
vt_test.Close()
|
||||
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
|
||||
vtTest.Close()
|
||||
assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
|
||||
}
|
||||
|
||||
func TestValidateSessionStateExpiredToken(t *testing.T) {
|
||||
vt_test := NewValidateSessionStateTest()
|
||||
defer vt_test.Close()
|
||||
vt_test.response_code = 401
|
||||
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
|
||||
vtTest := NewValidateSessionStateTest()
|
||||
defer vtTest.Close()
|
||||
vtTest.responseCode = 401
|
||||
assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
|
||||
}
|
||||
|
||||
func TestStripTokenNotPresent(t *testing.T) {
|
||||
|
@ -39,11 +39,11 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
|
||||
return &LinkedInProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
func getLinkedInHeader(access_token string) http.Header {
|
||||
func getLinkedInHeader(accessToken string) http.Header {
|
||||
header := make(http.Header)
|
||||
header.Set("Accept", "application/json")
|
||||
header.Set("x-li-format", "json")
|
||||
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
|
||||
header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||
return header
|
||||
}
|
||||
|
||||
|
@ -31,8 +31,7 @@ func testLinkedInBackend(payload string) *httptest.Server {
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
url := r.URL
|
||||
if url.Path != path {
|
||||
if r.URL.Path != path {
|
||||
w.WriteHeader(404)
|
||||
} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" {
|
||||
w.WriteHeader(403)
|
||||
@ -95,8 +94,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
|
||||
b := testLinkedInBackend(`"user@linkedin.com"`)
|
||||
defer b.Close()
|
||||
|
||||
b_url, _ := url.Parse(b.URL)
|
||||
p := testLinkedInProvider(b_url.Host)
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testLinkedInProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
@ -108,8 +107,8 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
b := testLinkedInBackend("unused payload")
|
||||
defer b.Close()
|
||||
|
||||
b_url, _ := url.Parse(b.URL)
|
||||
p := testLinkedInProvider(b_url.Host)
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testLinkedInProvider(bURL.Host)
|
||||
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
@ -124,8 +123,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
b := testLinkedInBackend("{\"foo\": \"bar\"}")
|
||||
defer b.Close()
|
||||
|
||||
b_url, _ := url.Parse(b.URL)
|
||||
p := testLinkedInProvider(b_url.Host)
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testLinkedInProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
|
@ -121,7 +121,8 @@ func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
|
||||
return validateToken(p, s.AccessToken, nil)
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded
|
||||
// RefreshSessionIfNeeded should refresh the user's session if required and
|
||||
// do nothing if a refresh is not required
|
||||
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
10
validator.go
10
validator.go
@ -42,11 +42,11 @@ func (um *UserMap) LoadAuthenticatedEmailsFile() {
|
||||
log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err)
|
||||
}
|
||||
defer r.Close()
|
||||
csv_reader := csv.NewReader(r)
|
||||
csv_reader.Comma = ','
|
||||
csv_reader.Comment = '#'
|
||||
csv_reader.TrimLeadingSpace = true
|
||||
records, err := csv_reader.ReadAll()
|
||||
csvReader := csv.NewReader(r)
|
||||
csvReader.Comma = ','
|
||||
csvReader.Comment = '#'
|
||||
csvReader.TrimLeadingSpace = true
|
||||
records, err := csvReader.ReadAll()
|
||||
if err != nil {
|
||||
log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err)
|
||||
return
|
||||
|
@ -8,15 +8,15 @@ import (
|
||||
)
|
||||
|
||||
type ValidatorTest struct {
|
||||
auth_email_file *os.File
|
||||
done chan bool
|
||||
update_seen bool
|
||||
authEmailFile *os.File
|
||||
done chan bool
|
||||
updateSeen bool
|
||||
}
|
||||
|
||||
func NewValidatorTest(t *testing.T) *ValidatorTest {
|
||||
vt := &ValidatorTest{}
|
||||
var err error
|
||||
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
|
||||
vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
|
||||
if err != nil {
|
||||
t.Fatal("failed to create temp file: " + err.Error())
|
||||
}
|
||||
@ -26,27 +26,27 @@ func NewValidatorTest(t *testing.T) *ValidatorTest {
|
||||
|
||||
func (vt *ValidatorTest) TearDown() {
|
||||
vt.done <- true
|
||||
os.Remove(vt.auth_email_file.Name())
|
||||
os.Remove(vt.authEmailFile.Name())
|
||||
}
|
||||
|
||||
func (vt *ValidatorTest) NewValidator(domains []string,
|
||||
updated chan<- bool) func(string) bool {
|
||||
return newValidatorImpl(domains, vt.auth_email_file.Name(),
|
||||
return newValidatorImpl(domains, vt.authEmailFile.Name(),
|
||||
vt.done, func() {
|
||||
if vt.update_seen == false {
|
||||
if vt.updateSeen == false {
|
||||
updated <- true
|
||||
vt.update_seen = true
|
||||
vt.updateSeen = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// This will close vt.auth_email_file.
|
||||
// This will close vt.authEmailFile.
|
||||
func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) {
|
||||
defer vt.auth_email_file.Close()
|
||||
vt.auth_email_file.WriteString(strings.Join(emails, "\n"))
|
||||
if err := vt.auth_email_file.Close(); err != nil {
|
||||
defer vt.authEmailFile.Close()
|
||||
vt.authEmailFile.WriteString(strings.Join(emails, "\n"))
|
||||
if err := vt.authEmailFile.Close(); err != nil {
|
||||
t.Fatal("failed to close temp file " +
|
||||
vt.auth_email_file.Name() + ": " + err.Error())
|
||||
vt.authEmailFile.Name() + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -12,18 +12,18 @@ import (
|
||||
|
||||
func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver(
|
||||
t *testing.T, emails []string) {
|
||||
orig_file := vt.auth_email_file
|
||||
origFile := vt.authEmailFile
|
||||
var err error
|
||||
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
|
||||
vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
|
||||
if err != nil {
|
||||
t.Fatal("failed to create temp file for copy: " + err.Error())
|
||||
}
|
||||
vt.WriteEmails(t, emails)
|
||||
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name())
|
||||
err = os.Rename(vt.authEmailFile.Name(), origFile.Name())
|
||||
if err != nil {
|
||||
t.Fatal("failed to copy over temp file: " + err.Error())
|
||||
}
|
||||
vt.auth_email_file = orig_file
|
||||
vt.authEmailFile = origFile
|
||||
}
|
||||
|
||||
func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) {
|
||||
|
@ -10,8 +10,8 @@ import (
|
||||
|
||||
func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) {
|
||||
var err error
|
||||
vt.auth_email_file, err = os.OpenFile(
|
||||
vt.auth_email_file.Name(), os.O_WRONLY|os.O_CREATE, 0600)
|
||||
vt.authEmailFile, err = os.OpenFile(
|
||||
vt.authEmailFile.Name(), os.O_WRONLY|os.O_CREATE, 0600)
|
||||
if err != nil {
|
||||
t.Fatal("failed to re-open temp file for updates")
|
||||
}
|
||||
@ -20,24 +20,24 @@ func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) {
|
||||
|
||||
func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace(
|
||||
t *testing.T, emails []string) {
|
||||
orig_file := vt.auth_email_file
|
||||
origFile := vt.authEmailFile
|
||||
var err error
|
||||
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
|
||||
vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
|
||||
if err != nil {
|
||||
t.Fatal("failed to create temp file for rename and replace: " +
|
||||
err.Error())
|
||||
}
|
||||
vt.WriteEmails(t, emails)
|
||||
|
||||
moved_name := orig_file.Name() + "-moved"
|
||||
err = os.Rename(orig_file.Name(), moved_name)
|
||||
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name())
|
||||
movedName := origFile.Name() + "-moved"
|
||||
err = os.Rename(origFile.Name(), movedName)
|
||||
err = os.Rename(vt.authEmailFile.Name(), origFile.Name())
|
||||
if err != nil {
|
||||
t.Fatal("failed to rename and replace temp file: " +
|
||||
err.Error())
|
||||
}
|
||||
vt.auth_email_file = orig_file
|
||||
os.Remove(moved_name)
|
||||
vt.authEmailFile = origFile
|
||||
os.Remove(movedName)
|
||||
}
|
||||
|
||||
func TestValidatorOverwriteEmailListDirectly(t *testing.T) {
|
||||
|
@ -13,11 +13,11 @@ import (
|
||||
|
||||
func WaitForReplacement(filename string, op fsnotify.Op,
|
||||
watcher *fsnotify.Watcher) {
|
||||
const sleep_interval = 50 * time.Millisecond
|
||||
const sleepInterval = 50 * time.Millisecond
|
||||
|
||||
// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
|
||||
if op&fsnotify.Chmod != 0 {
|
||||
time.Sleep(sleep_interval)
|
||||
time.Sleep(sleepInterval)
|
||||
}
|
||||
for {
|
||||
if _, err := os.Stat(filename); err == nil {
|
||||
@ -26,7 +26,7 @@ func WaitForReplacement(filename string, op fsnotify.Op,
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(sleep_interval)
|
||||
time.Sleep(sleepInterval)
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,7 +56,7 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) {
|
||||
}
|
||||
log.Printf("reloading after event: %s", event)
|
||||
action()
|
||||
case err := <-watcher.Errors:
|
||||
case err = <-watcher.Errors:
|
||||
log.Printf("error watching %s: %s", filename, err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user