1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2024-11-24 08:52:25 +02:00

Lint for non-comment linter errors

This commit is contained in:
Joel Speed 2018-11-29 14:26:41 +00:00
parent 990873eb42
commit 8ee802d4e5
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
30 changed files with 337 additions and 334 deletions

View File

@ -3,10 +3,6 @@
# for detailed Gopkg.toml documentation. # for detailed Gopkg.toml documentation.
# #
[[constraint]]
name = "github.com/18F/hmacauth"
version = "~1.0.1"
[[constraint]] [[constraint]]
name = "github.com/BurntSushi/toml" name = "github.com/BurntSushi/toml"
version = "~0.3.0" version = "~0.3.0"

View File

@ -32,7 +32,7 @@ func Request(req *http.Request) (*simplejson.Json, error) {
return data, nil return data, nil
} }
func RequestJson(req *http.Request, v interface{}) error { func RequestJSON(req *http.Request, v interface{}) error {
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
log.Printf("%s %s %s", req.Method, req.URL, err) log.Printf("%s %s %s", req.Method, req.URL, err)

View File

@ -1,20 +1,21 @@
package api package api
import ( import (
"github.com/bitly/go-simplejson"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"github.com/bitly/go-simplejson"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func testBackend(response_code int, payload string) *httptest.Server { func testBackend(responseCode int, payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc( return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(response_code) w.WriteHeader(responseCode)
w.Write([]byte(payload)) w.Write([]byte(payload))
})) }))
} }

View File

@ -24,10 +24,11 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
} }
func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
const secret_b64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk=" const secretBase64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk="
const token = "my access token" const token = "my access token"
secret, err := base64.URLEncoding.DecodeString(secret_b64) secret, err := base64.URLEncoding.DecodeString(secretBase64)
assert.Equal(t, nil, err)
c, err := NewCipher([]byte(secret)) c, err := NewCipher([]byte(secret))
assert.Equal(t, nil, err) assert.Equal(t, nil, err)

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
) )
// Nonce generates a random 16 byte string to be used as a nonce
func Nonce() (nonce string, err error) { func Nonce() (nonce string, err error) {
b := make([]byte, 16) b := make([]byte, 16)
_, err = rand.Read(b) _, err = rand.Read(b)

View File

@ -28,12 +28,12 @@ func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
} }
func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) {
csv_reader := csv.NewReader(file) csvReader := csv.NewReader(file)
csv_reader.Comma = ':' csvReader.Comma = ':'
csv_reader.Comment = '#' csvReader.Comment = '#'
csv_reader.TrimLeadingSpace = true csvReader.TrimLeadingSpace = true
records, err := csv_reader.ReadAll() records, err := csvReader.ReadAll()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -20,6 +20,7 @@ func TestSHA(t *testing.T) {
func TestBcrypt(t *testing.T) { func TestBcrypt(t *testing.T) {
hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1) hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1)
assert.Equal(t, err, nil)
hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2) hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)

10
http.go
View File

@ -23,12 +23,12 @@ func (s *Server) ListenAndServe() {
} }
func (s *Server) ServeHTTP() { func (s *Server) ServeHTTP() {
httpAddress := s.Opts.HttpAddress HTTPAddress := s.Opts.HTTPAddress
scheme := "" scheme := ""
i := strings.Index(httpAddress, "://") i := strings.Index(HTTPAddress, "://")
if i > -1 { if i > -1 {
scheme = httpAddress[0:i] scheme = HTTPAddress[0:i]
} }
var networkType string var networkType string
@ -39,7 +39,7 @@ func (s *Server) ServeHTTP() {
networkType = scheme networkType = scheme
} }
slice := strings.SplitN(httpAddress, "//", 2) slice := strings.SplitN(HTTPAddress, "//", 2)
listenAddr := slice[len(slice)-1] listenAddr := slice[len(slice)-1]
listener, err := net.Listen(networkType, listenAddr) listener, err := net.Listen(networkType, listenAddr)
@ -58,7 +58,7 @@ func (s *Server) ServeHTTP() {
} }
func (s *Server) ServeHTTPS() { func (s *Server) ServeHTTPS() {
addr := s.Opts.HttpsAddress addr := s.Opts.HTTPSAddress
config := &tls.Config{ config := &tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS12, MaxVersion: tls.VersionTLS12,

View File

@ -14,14 +14,19 @@ import (
"strings" "strings"
"time" "time"
"github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/providers" "github.com/pusher/oauth2_proxy/providers"
"github.com/mbland/hmacauth"
) )
const SignatureHeader = "GAP-Signature" const (
SignatureHeader = "GAP-Signature"
var SignatureHeaders []string = []string{ httpScheme = "http"
httpsScheme = "https"
)
var SignatureHeaders = []string{
"Content-Length", "Content-Length",
"Content-Md5", "Content-Md5",
"Content-Type", "Content-Type",
@ -40,7 +45,7 @@ type OAuthProxy struct {
CSRFCookieName string CSRFCookieName string
CookieDomain string CookieDomain string
CookieSecure bool CookieSecure bool
CookieHttpOnly bool CookieHTTPOnly bool
CookieExpire time.Duration CookieExpire time.Duration
CookieRefresh time.Duration CookieRefresh time.Duration
Validator func(string) bool Validator func(string) bool
@ -125,7 +130,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
for _, u := range opts.proxyURLs { for _, u := range opts.proxyURLs {
path := u.Path path := u.Path
switch u.Scheme { switch u.Scheme {
case "http", "https": case httpScheme, httpsScheme:
u.Path = "" u.Path = ""
log.Printf("mapping path %q => upstream %q", path, u) log.Printf("mapping path %q => upstream %q", path, u)
proxy := NewReverseProxy(u) proxy := NewReverseProxy(u)
@ -160,7 +165,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
refresh = fmt.Sprintf("after %s", opts.CookieRefresh) refresh = fmt.Sprintf("after %s", opts.CookieRefresh)
} }
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh) log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, refresh)
var cipher *cookie.Cipher var cipher *cookie.Cipher
if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
@ -177,7 +182,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
CookieSeed: opts.CookieSecret, CookieSeed: opts.CookieSecret,
CookieDomain: opts.CookieDomain, CookieDomain: opts.CookieDomain,
CookieSecure: opts.CookieSecure, CookieSecure: opts.CookieSecure,
CookieHttpOnly: opts.CookieHttpOnly, CookieHTTPOnly: opts.CookieHTTPOnly,
CookieExpire: opts.CookieExpire, CookieExpire: opts.CookieExpire,
CookieRefresh: opts.CookieRefresh, CookieRefresh: opts.CookieRefresh,
Validator: validator, Validator: validator,
@ -218,9 +223,9 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
u = *p.redirectURL u = *p.redirectURL
if u.Scheme == "" { if u.Scheme == "" {
if p.CookieSecure { if p.CookieSecure {
u.Scheme = "https" u.Scheme = httpsScheme
} else { } else {
u.Scheme = "http" u.Scheme = httpScheme
} }
} }
u.Host = host u.Host = host
@ -285,7 +290,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
Value: value, Value: value,
Path: "/", Path: "/",
Domain: p.CookieDomain, Domain: p.CookieDomain,
HttpOnly: p.CookieHttpOnly, HttpOnly: p.CookieHTTPOnly,
Secure: p.CookieSecure, Secure: p.CookieSecure,
Expires: now.Add(expiration), Expires: now.Add(expiration),
} }
@ -374,12 +379,12 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
p.ClearSessionCookie(rw, req) p.ClearSessionCookie(rw, req)
rw.WriteHeader(code) rw.WriteHeader(code)
redirect_url := req.URL.RequestURI() redirecURL := req.URL.RequestURI()
if req.Header.Get("X-Auth-Request-Redirect") != "" { if req.Header.Get("X-Auth-Request-Redirect") != "" {
redirect_url = req.Header.Get("X-Auth-Request-Redirect") redirecURL = req.Header.Get("X-Auth-Request-Redirect")
} }
if redirect_url == p.SignInPath { if redirecURL == p.SignInPath {
redirect_url = "/" redirecURL = "/"
} }
t := struct { t := struct {
@ -394,7 +399,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
ProviderName: p.provider.Data().ProviderName, ProviderName: p.provider.Data().ProviderName,
SignInMessage: p.SignInMessage, SignInMessage: p.SignInMessage,
CustomLogin: p.displayCustomLoginForm(), CustomLogin: p.displayCustomLoginForm(),
Redirect: redirect_url, Redirect: redirecURL,
Version: VERSION, Version: VERSION,
ProxyPrefix: p.ProxyPrefix, ProxyPrefix: p.ProxyPrefix,
Footer: template.HTML(p.Footer), Footer: template.HTML(p.Footer),
@ -653,7 +658,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
} }
if saveSession && session != nil { if saveSession && session != nil {
err := p.SaveSession(rw, req, session) err = p.SaveSession(rw, req, session)
if err != nil { if err != nil {
log.Printf("%s %s", remoteAddr, err) log.Printf("%s %s", remoteAddr, err)
return http.StatusInternalServerError return http.StatusInternalServerError

View File

@ -15,8 +15,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/pusher/oauth2_proxy/providers"
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/providers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -98,28 +98,28 @@ type TestProvider struct {
ValidToken bool ValidToken bool
} }
func NewTestProvider(provider_url *url.URL, email_address string) *TestProvider { func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
return &TestProvider{ return &TestProvider{
ProviderData: &providers.ProviderData{ ProviderData: &providers.ProviderData{
ProviderName: "Test Provider", ProviderName: "Test Provider",
LoginURL: &url.URL{ LoginURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: provider_url.Host, Host: providerURL.Host,
Path: "/oauth/authorize", Path: "/oauth/authorize",
}, },
RedeemURL: &url.URL{ RedeemURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: provider_url.Host, Host: providerURL.Host,
Path: "/oauth/token", Path: "/oauth/token",
}, },
ProfileURL: &url.URL{ ProfileURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: provider_url.Host, Host: providerURL.Host,
Path: "/api/v1/profile", Path: "/api/v1/profile",
}, },
Scope: "profile.email", Scope: "profile.email",
}, },
EmailAddress: email_address, EmailAddress: emailAddress,
} }
} }
@ -132,11 +132,10 @@ func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bo
} }
func TestBasicAuthPassword(t *testing.T) { func TestBasicAuthPassword(t *testing.T) {
provider_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%#v", r) log.Printf("%#v", r)
url := r.URL var payload string
payload := "" switch r.URL.Path {
switch url.Path {
case "/oauth/token": case "/oauth/token":
payload = `{"access_token": "my_auth_token"}` payload = `{"access_token": "my_auth_token"}`
default: default:
@ -149,7 +148,7 @@ func TestBasicAuthPassword(t *testing.T) {
w.Write([]byte(payload)) w.Write([]byte(payload))
})) }))
opts := NewOptions() opts := NewOptions()
opts.Upstreams = append(opts.Upstreams, provider_server.URL) opts.Upstreams = append(opts.Upstreams, providerServer.URL)
// The CookieSecret must be 32 bytes in order to create the AES // The CookieSecret must be 32 bytes in order to create the AES
// cipher. // cipher.
opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
@ -161,13 +160,13 @@ func TestBasicAuthPassword(t *testing.T) {
opts.BasicAuthPassword = "This is a secure password" opts.BasicAuthPassword = "This is a secure password"
opts.Validate() opts.Validate()
provider_url, _ := url.Parse(provider_server.URL) providerURL, _ := url.Parse(providerServer.URL)
const email_address = "michael.bland@gsa.gov" const emailAddress = "michael.bland@gsa.gov"
const user_name = "michael.bland" const username = "michael.bland"
opts.provider = NewTestProvider(provider_url, email_address) opts.provider = NewTestProvider(providerURL, emailAddress)
proxy := NewOAuthProxy(opts, func(email string) bool { proxy := NewOAuthProxy(opts, func(email string) bool {
return email == email_address return email == emailAddress
}) })
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -182,10 +181,10 @@ func TestBasicAuthPassword(t *testing.T) {
cookieName := proxy.CookieName cookieName := proxy.CookieName
var value string var value string
key_prefix := cookieName + "=" keyPrefix := cookieName + "="
for _, field := range strings.Split(cookie, "; ") { for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, key_prefix) value = strings.TrimPrefix(field, keyPrefix)
if value != field { if value != field {
break break
} else { } else {
@ -206,15 +205,15 @@ func TestBasicAuthPassword(t *testing.T) {
rw = httptest.NewRecorder() rw = httptest.NewRecorder()
proxy.ServeHTTP(rw, req) proxy.ServeHTTP(rw, req)
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword)) expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+opts.BasicAuthPassword))
assert.Equal(t, expectedHeader, rw.Body.String()) assert.Equal(t, expectedHeader, rw.Body.String())
provider_server.Close() providerServer.Close()
} }
type PassAccessTokenTest struct { type PassAccessTokenTest struct {
provider_server *httptest.Server providerServer *httptest.Server
proxy *OAuthProxy proxy *OAuthProxy
opts *Options opts *Options
} }
type PassAccessTokenTestOptions struct { type PassAccessTokenTestOptions struct {
@ -224,12 +223,11 @@ type PassAccessTokenTestOptions struct {
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
t := &PassAccessTokenTest{} t := &PassAccessTokenTest{}
t.provider_server = httptest.NewServer( t.providerServer = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%#v", r) log.Printf("%#v", r)
url := r.URL var payload string
payload := "" switch r.URL.Path {
switch url.Path {
case "/oauth/token": case "/oauth/token":
payload = `{"access_token": "my_auth_token"}` payload = `{"access_token": "my_auth_token"}`
default: default:
@ -243,7 +241,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
})) }))
t.opts = NewOptions() t.opts = NewOptions()
t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL) t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL)
// The CookieSecret must be 32 bytes in order to create the AES // The CookieSecret must be 32 bytes in order to create the AES
// cipher. // cipher.
t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
@ -253,21 +251,21 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
t.opts.PassAccessToken = opts.PassAccessToken t.opts.PassAccessToken = opts.PassAccessToken
t.opts.Validate() t.opts.Validate()
provider_url, _ := url.Parse(t.provider_server.URL) providerURL, _ := url.Parse(t.providerServer.URL)
const email_address = "michael.bland@gsa.gov" const emailAddress = "michael.bland@gsa.gov"
t.opts.provider = NewTestProvider(provider_url, email_address) t.opts.provider = NewTestProvider(providerURL, emailAddress)
t.proxy = NewOAuthProxy(t.opts, func(email string) bool { t.proxy = NewOAuthProxy(t.opts, func(email string) bool {
return email == email_address return email == emailAddress
}) })
return t return t
} }
func (pat_test *PassAccessTokenTest) Close() { func (patTest *PassAccessTokenTest) Close() {
pat_test.provider_server.Close() patTest.providerServer.Close()
} }
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int,
cookie string) { cookie string) {
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
@ -275,18 +273,18 @@ func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
if err != nil { if err != nil {
return 0, "" return 0, ""
} }
req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now())) req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
pat_test.proxy.ServeHTTP(rw, req) patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.HeaderMap["Set-Cookie"][1] return rw.Code, rw.HeaderMap["Set-Cookie"][1]
} }
func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) { func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int, accessToken string) {
cookieName := pat_test.proxy.CookieName cookieName := patTest.proxy.CookieName
var value string var value string
key_prefix := cookieName + "=" keyPrefix := cookieName + "="
for _, field := range strings.Split(cookie, "; ") { for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, key_prefix) value = strings.TrimPrefix(field, keyPrefix)
if value != field { if value != field {
break break
} else { } else {
@ -310,18 +308,18 @@ func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code i
}) })
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
pat_test.proxy.ServeHTTP(rw, req) patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String() return rw.Code, rw.Body.String()
} }
func TestForwardAccessTokenUpstream(t *testing.T) { func TestForwardAccessTokenUpstream(t *testing.T) {
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true, PassAccessToken: true,
}) })
defer pat_test.Close() defer patTest.Close()
// A successful validation will redirect and set the auth cookie. // A successful validation will redirect and set the auth cookie.
code, cookie := pat_test.getCallbackEndpoint() code, cookie := patTest.getCallbackEndpoint()
if code != 302 { if code != 302 {
t.Fatalf("expected 302; got %d", code) t.Fatalf("expected 302; got %d", code)
} }
@ -330,7 +328,7 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
// Now we make a regular request; the access_token from the cookie is // Now we make a regular request; the access_token from the cookie is
// forwarded as the "X-Forwarded-Access-Token" header. The token is // forwarded as the "X-Forwarded-Access-Token" header. The token is
// read by the test provider server and written in the response body. // read by the test provider server and written in the response body.
code, payload := pat_test.getRootEndpoint(cookie) code, payload := patTest.getRootEndpoint(cookie)
if code != 200 { if code != 200 {
t.Fatalf("expected 200; got %d", code) t.Fatalf("expected 200; got %d", code)
} }
@ -338,13 +336,13 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
} }
func TestDoNotForwardAccessTokenUpstream(t *testing.T) { func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: false, PassAccessToken: false,
}) })
defer pat_test.Close() defer patTest.Close()
// A successful validation will redirect and set the auth cookie. // A successful validation will redirect and set the auth cookie.
code, cookie := pat_test.getCallbackEndpoint() code, cookie := patTest.getCallbackEndpoint()
if code != 302 { if code != 302 {
t.Fatalf("expected 302; got %d", code) t.Fatalf("expected 302; got %d", code)
} }
@ -352,7 +350,7 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
// Now we make a regular request, but the access token header should // Now we make a regular request, but the access token header should
// not be present. // not be present.
code, payload := pat_test.getRootEndpoint(cookie) code, payload := patTest.getRootEndpoint(cookie)
if code != 200 { if code != 200 {
t.Fatalf("expected 200; got %d", code) t.Fatalf("expected 200; got %d", code)
} }
@ -360,49 +358,49 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
} }
type SignInPageTest struct { type SignInPageTest struct {
opts *Options opts *Options
proxy *OAuthProxy proxy *OAuthProxy
sign_in_regexp *regexp.Regexp signInRegexp *regexp.Regexp
sign_in_provider_regexp *regexp.Regexp signInProviderRegexp *regexp.Regexp
} }
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">` const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
const signInSkipProvider = `>Found<` const signInSkipProvider = `>Found<`
func NewSignInPageTest(skipProvider bool) *SignInPageTest { func NewSignInPageTest(skipProvider bool) *SignInPageTest {
var sip_test SignInPageTest var sipTest SignInPageTest
sip_test.opts = NewOptions() sipTest.opts = NewOptions()
sip_test.opts.CookieSecret = "foobar" sipTest.opts.CookieSecret = "foobar"
sip_test.opts.ClientID = "bazquux" sipTest.opts.ClientID = "bazquux"
sip_test.opts.ClientSecret = "xyzzyplugh" sipTest.opts.ClientSecret = "xyzzyplugh"
sip_test.opts.SkipProviderButton = skipProvider sipTest.opts.SkipProviderButton = skipProvider
sip_test.opts.Validate() sipTest.opts.Validate()
sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool { sipTest.proxy = NewOAuthProxy(sipTest.opts, func(email string) bool {
return true return true
}) })
sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern) sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern)
sip_test.sign_in_provider_regexp = regexp.MustCompile(signInSkipProvider) sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider)
return &sip_test return &sipTest
} }
func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) { func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", endpoint, strings.NewReader("")) req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
sip_test.proxy.ServeHTTP(rw, req) sipTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String() return rw.Code, rw.Body.String()
} }
func TestSignInPageIncludesTargetRedirect(t *testing.T) { func TestSignInPageIncludesTargetRedirect(t *testing.T) {
sip_test := NewSignInPageTest(false) sipTest := NewSignInPageTest(false)
const endpoint = "/some/random/endpoint" const endpoint = "/some/random/endpoint"
code, body := sip_test.GetEndpoint(endpoint) code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 403, code) assert.Equal(t, 403, code)
match := sip_test.sign_in_regexp.FindStringSubmatch(body) match := sipTest.signInRegexp.FindStringSubmatch(body)
if match == nil { if match == nil {
t.Fatal("Did not find pattern in body: " + t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body) signInRedirectPattern + "\nBody:\n" + body)
@ -414,11 +412,11 @@ func TestSignInPageIncludesTargetRedirect(t *testing.T) {
} }
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
sip_test := NewSignInPageTest(false) sipTest := NewSignInPageTest(false)
code, body := sip_test.GetEndpoint("/oauth2/sign_in") code, body := sipTest.GetEndpoint("/oauth2/sign_in")
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
match := sip_test.sign_in_regexp.FindStringSubmatch(body) match := sipTest.signInRegexp.FindStringSubmatch(body)
if match == nil { if match == nil {
t.Fatal("Did not find pattern in body: " + t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body) signInRedirectPattern + "\nBody:\n" + body)
@ -429,13 +427,13 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
} }
func TestSignInPageSkipProvider(t *testing.T) { func TestSignInPageSkipProvider(t *testing.T) {
sip_test := NewSignInPageTest(true) sipTest := NewSignInPageTest(true)
const endpoint = "/some/random/endpoint" const endpoint = "/some/random/endpoint"
code, body := sip_test.GetEndpoint(endpoint) code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 302, code) assert.Equal(t, 302, code)
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
if match == nil { if match == nil {
t.Fatal("Did not find pattern in body: " + t.Fatal("Did not find pattern in body: " +
signInSkipProvider + "\nBody:\n" + body) signInSkipProvider + "\nBody:\n" + body)
@ -443,13 +441,13 @@ func TestSignInPageSkipProvider(t *testing.T) {
} }
func TestSignInPageSkipProviderDirect(t *testing.T) { func TestSignInPageSkipProviderDirect(t *testing.T) {
sip_test := NewSignInPageTest(true) sipTest := NewSignInPageTest(true)
const endpoint = "/sign_in" const endpoint = "/sign_in"
code, body := sip_test.GetEndpoint(endpoint) code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 302, code) assert.Equal(t, 302, code)
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
if match == nil { if match == nil {
t.Fatal("Did not find pattern in body: " + t.Fatal("Did not find pattern in body: " +
signInSkipProvider + "\nBody:\n" + body) signInSkipProvider + "\nBody:\n" + body)
@ -457,50 +455,50 @@ func TestSignInPageSkipProviderDirect(t *testing.T) {
} }
type ProcessCookieTest struct { type ProcessCookieTest struct {
opts *Options opts *Options
proxy *OAuthProxy proxy *OAuthProxy
rw *httptest.ResponseRecorder rw *httptest.ResponseRecorder
req *http.Request req *http.Request
provider TestProvider provider TestProvider
response_code int responseCode int
validate_user bool validateUser bool
} }
type ProcessCookieTestOpts struct { type ProcessCookieTestOpts struct {
provider_validate_cookie_response bool providerValidateCookieResponse bool
} }
func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
var pc_test ProcessCookieTest var pcTest ProcessCookieTest
pc_test.opts = NewOptions() pcTest.opts = NewOptions()
pc_test.opts.ClientID = "bazquux" pcTest.opts.ClientID = "bazquux"
pc_test.opts.ClientSecret = "xyzzyplugh" pcTest.opts.ClientSecret = "xyzzyplugh"
pc_test.opts.CookieSecret = "0123456789abcdefabcd" pcTest.opts.CookieSecret = "0123456789abcdefabcd"
// First, set the CookieRefresh option so proxy.AesCipher is created, // First, set the CookieRefresh option so proxy.AesCipher is created,
// needed to encrypt the access_token. // needed to encrypt the access_token.
pc_test.opts.CookieRefresh = time.Hour pcTest.opts.CookieRefresh = time.Hour
pc_test.opts.Validate() pcTest.opts.Validate()
pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool { pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pc_test.validate_user return pcTest.validateUser
}) })
pc_test.proxy.provider = &TestProvider{ pcTest.proxy.provider = &TestProvider{
ValidToken: opts.provider_validate_cookie_response, ValidToken: opts.providerValidateCookieResponse,
} }
// Now, zero-out proxy.CookieRefresh for the cases that don't involve // Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation. // access_token validation.
pc_test.proxy.CookieRefresh = time.Duration(0) pcTest.proxy.CookieRefresh = time.Duration(0)
pc_test.rw = httptest.NewRecorder() pcTest.rw = httptest.NewRecorder()
pc_test.req, _ = http.NewRequest("GET", "/", strings.NewReader("")) pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
pc_test.validate_user = true pcTest.validateUser = true
return &pc_test return &pcTest
} }
func NewProcessCookieTestWithDefaults() *ProcessCookieTest { func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
return NewProcessCookieTest(ProcessCookieTestOpts{ return NewProcessCookieTest(ProcessCookieTestOpts{
provider_validate_cookie_response: true, providerValidateCookieResponse: true,
}) })
} }
@ -522,12 +520,12 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.
} }
func TestLoadCookiedSession(t *testing.T) { func TestLoadCookiedSession(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, time.Now()) pcTest.SaveSession(startSession, time.Now())
session, _, err := pc_test.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, startSession.Email, session.Email) assert.Equal(t, startSession.Email, session.Email)
assert.Equal(t, "michael.bland", session.User) assert.Equal(t, "michael.bland", session.User)
@ -535,9 +533,9 @@ func TestLoadCookiedSession(t *testing.T) {
} }
func TestProcessCookieNoCookieError(t *testing.T) { func TestProcessCookieNoCookieError(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
session, _, err := pc_test.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
if session != nil { if session != nil {
t.Errorf("expected nil session. got %#v", session) t.Errorf("expected nil session. got %#v", session)
@ -545,14 +543,14 @@ func TestProcessCookieNoCookieError(t *testing.T) {
} }
func TestProcessCookieRefreshNotSet(t *testing.T) { func TestProcessCookieRefreshNotSet(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-2) * time.Hour) reference := time.Now().Add(time.Duration(-2) * time.Hour)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
session, age, err := pc_test.LoadCookiedSession() session, age, err := pcTest.LoadCookiedSession()
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
if age < time.Duration(-2)*time.Hour { if age < time.Duration(-2)*time.Hour {
t.Errorf("cookie too young %v", age) t.Errorf("cookie too young %v", age)
@ -561,13 +559,13 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
} }
func TestProcessCookieFailIfCookieExpired(t *testing.T) { func TestProcessCookieFailIfCookieExpired(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
session, _, err := pc_test.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil { if session != nil {
t.Errorf("expected nil session %#v", session) t.Errorf("expected nil session %#v", session)
@ -575,14 +573,14 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
} }
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
pc_test.proxy.CookieRefresh = time.Hour pcTest.proxy.CookieRefresh = time.Hour
session, _, err := pc_test.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil { if session != nil {
t.Errorf("expected nil session %#v", session) t.Errorf("expected nil session %#v", session)
@ -590,10 +588,10 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
} }
func NewAuthOnlyEndpointTest() *ProcessCookieTest { func NewAuthOnlyEndpointTest() *ProcessCookieTest {
pc_test := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pc_test.req, _ = http.NewRequest("GET", pcTest.req, _ = http.NewRequest("GET",
pc_test.opts.ProxyPrefix+"/auth", nil) pcTest.opts.ProxyPrefix+"/auth", nil)
return pc_test return pcTest
} }
func TestAuthOnlyEndpointAccepted(t *testing.T) { func TestAuthOnlyEndpointAccepted(t *testing.T) {
@ -636,7 +634,7 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
startSession := &providers.SessionState{ startSession := &providers.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, time.Now()) test.SaveSession(startSession, time.Now())
test.validate_user = false test.validateUser = false
test.proxy.ServeHTTP(test.rw, test.req) test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code) assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
@ -645,33 +643,33 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
} }
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
var pc_test ProcessCookieTest var pcTest ProcessCookieTest
pc_test.opts = NewOptions() pcTest.opts = NewOptions()
pc_test.opts.SetXAuthRequest = true pcTest.opts.SetXAuthRequest = true
pc_test.opts.Validate() pcTest.opts.Validate()
pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool { pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pc_test.validate_user return pcTest.validateUser
}) })
pc_test.proxy.provider = &TestProvider{ pcTest.proxy.provider = &TestProvider{
ValidToken: true, ValidToken: true,
} }
pc_test.validate_user = true pcTest.validateUser = true
pc_test.rw = httptest.NewRecorder() pcTest.rw = httptest.NewRecorder()
pc_test.req, _ = http.NewRequest("GET", pcTest.req, _ = http.NewRequest("GET",
pc_test.opts.ProxyPrefix+"/auth", nil) pcTest.opts.ProxyPrefix+"/auth", nil)
startSession := &providers.SessionState{ startSession := &providers.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
pc_test.SaveSession(startSession, time.Now()) pcTest.SaveSession(startSession, time.Now())
pc_test.proxy.ServeHTTP(pc_test.rw, pc_test.req) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
assert.Equal(t, http.StatusAccepted, pc_test.rw.Code) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
assert.Equal(t, "oauth_user", pc_test.rw.HeaderMap["X-Auth-Request-User"][0]) assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0])
assert.Equal(t, "oauth_user@example.com", pc_test.rw.HeaderMap["X-Auth-Request-Email"][0]) assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0])
} }
func TestAuthSkippedForPreflightRequests(t *testing.T) { func TestAuthSkippedForPreflightRequests(t *testing.T) {
@ -689,8 +687,8 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
opts.SkipAuthPreflight = true opts.SkipAuthPreflight = true
opts.Validate() opts.Validate()
upstream_url, _ := url.Parse(upstream.URL) upstreamURL, _ := url.Parse(upstream.URL)
opts.provider = NewTestProvider(upstream_url, "") opts.provider = NewTestProvider(upstreamURL, "")
proxy := NewOAuthProxy(opts, func(string) bool { return false }) proxy := NewOAuthProxy(opts, func(string) bool { return false })
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -723,7 +721,7 @@ func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Req
type SignatureTest struct { type SignatureTest struct {
opts *Options opts *Options
upstream *httptest.Server upstream *httptest.Server
upstream_host string upstreamHost string
provider *httptest.Server provider *httptest.Server
header http.Header header http.Header
rw *httptest.ResponseRecorder rw *httptest.ResponseRecorder
@ -740,20 +738,20 @@ func NewSignatureTest() *SignatureTest {
authenticator := &SignatureAuthenticator{} authenticator := &SignatureAuthenticator{}
upstream := httptest.NewServer( upstream := httptest.NewServer(
http.HandlerFunc(authenticator.Authenticate)) http.HandlerFunc(authenticator.Authenticate))
upstream_url, _ := url.Parse(upstream.URL) upstreamURL, _ := url.Parse(upstream.URL)
opts.Upstreams = append(opts.Upstreams, upstream.URL) opts.Upstreams = append(opts.Upstreams, upstream.URL)
providerHandler := func(w http.ResponseWriter, r *http.Request) { providerHandler := func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"access_token": "my_auth_token"}`)) w.Write([]byte(`{"access_token": "my_auth_token"}`))
} }
provider := httptest.NewServer(http.HandlerFunc(providerHandler)) provider := httptest.NewServer(http.HandlerFunc(providerHandler))
provider_url, _ := url.Parse(provider.URL) providerURL, _ := url.Parse(provider.URL)
opts.provider = NewTestProvider(provider_url, "mbland@acm.org") opts.provider = NewTestProvider(providerURL, "mbland@acm.org")
return &SignatureTest{ return &SignatureTest{
opts, opts,
upstream, upstream,
upstream_url.Host, upstreamURL.Host,
provider, provider,
make(http.Header), make(http.Header),
httptest.NewRecorder(), httptest.NewRecorder(),

View File

@ -13,16 +13,17 @@ import (
"strings" "strings"
"time" "time"
"github.com/pusher/oauth2_proxy/providers"
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/providers"
) )
// Configuration Options that can be set by Command Line Flag, or Config File // Options holds Configuration Options that can be set by Command Line Flag,
// or Config File
type Options struct { type Options struct {
ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"` ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"`
HttpAddress string `flag:"http-address" cfg:"http_address"` HTTPAddress string `flag:"http-address" cfg:"http_address"`
HttpsAddress string `flag:"https-address" cfg:"https_address"` HTTPSAddress string `flag:"https-address" cfg:"https_address"`
RedirectURL string `flag:"redirect-url" cfg:"redirect_url"` RedirectURL string `flag:"redirect-url" cfg:"redirect_url"`
ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"`
ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"`
@ -48,7 +49,7 @@ type Options struct {
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"`
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"`
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"` CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"`
CookieHttpOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"`
Upstreams []string `flag:"upstream" cfg:"upstreams"` Upstreams []string `flag:"upstream" cfg:"upstreams"`
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"`
@ -96,12 +97,12 @@ type SignatureData struct {
func NewOptions() *Options { func NewOptions() *Options {
return &Options{ return &Options{
ProxyPrefix: "/oauth2", ProxyPrefix: "/oauth2",
HttpAddress: "127.0.0.1:4180", HTTPAddress: "127.0.0.1:4180",
HttpsAddress: ":443", HTTPSAddress: ":443",
DisplayHtpasswdForm: true, DisplayHtpasswdForm: true,
CookieName: "_oauth2_proxy", CookieName: "_oauth2_proxy",
CookieSecure: true, CookieSecure: true,
CookieHttpOnly: true, CookieHTTPOnly: true,
CookieExpire: time.Duration(168) * time.Hour, CookieExpire: time.Duration(168) * time.Hour,
CookieRefresh: time.Duration(0), CookieRefresh: time.Duration(0),
SetXAuthRequest: false, SetXAuthRequest: false,
@ -116,11 +117,11 @@ func NewOptions() *Options {
} }
} }
func parseURL(to_parse string, urltype string, msgs []string) (*url.URL, []string) { func parseURL(toParse string, urltype string, msgs []string) (*url.URL, []string) {
parsed, err := url.Parse(to_parse) parsed, err := url.Parse(toParse)
if err != nil { if err != nil {
return nil, append(msgs, fmt.Sprintf( return nil, append(msgs, fmt.Sprintf(
"error parsing %s-url=%q %s", urltype, to_parse, err)) "error parsing %s-url=%q %s", urltype, toParse, err))
} }
return parsed, msgs return parsed, msgs
} }
@ -190,17 +191,17 @@ func (o *Options) Validate() error {
msgs = parseProviderInfo(o, msgs) msgs = parseProviderInfo(o, msgs)
if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) {
valid_cookie_secret_size := false validCookieSecretSize := false
for _, i := range []int{16, 24, 32} { for _, i := range []int{16, 24, 32} {
if len(secretBytes(o.CookieSecret)) == i { if len(secretBytes(o.CookieSecret)) == i {
valid_cookie_secret_size = true validCookieSecretSize = true
} }
} }
var decoded bool var decoded bool
if string(secretBytes(o.CookieSecret)) != o.CookieSecret { if string(secretBytes(o.CookieSecret)) != o.CookieSecret {
decoded = true decoded = true
} }
if valid_cookie_secret_size == false { if validCookieSecretSize == false {
var suffix string var suffix string
if decoded { if decoded {
suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret) suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret)
@ -294,12 +295,13 @@ func parseSignatureKey(o *Options, msgs []string) []string {
} }
algorithm, secretKey := components[0], components[1] algorithm, secretKey := components[0], components[1]
if hash, err := hmacauth.DigestNameToCryptoHash(algorithm); err != nil { var hash crypto.Hash
var err error
if hash, err = hmacauth.DigestNameToCryptoHash(algorithm); err != nil {
return append(msgs, "unsupported signature hash algorithm: "+ return append(msgs, "unsupported signature hash algorithm: "+
o.SignatureKey) o.SignatureKey)
} else {
o.signatureData = &SignatureData{hash, secretKey}
} }
o.signatureData = &SignatureData{hash, secretKey}
return msgs return msgs
} }

View File

@ -88,9 +88,9 @@ func TestProxyURLs(t *testing.T) {
o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081") o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081")
assert.Equal(t, nil, o.Validate()) assert.Equal(t, nil, o.Validate())
expected := []*url.URL{ expected := []*url.URL{
&url.URL{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"}, {Scheme: "http", Host: "127.0.0.1:8080", Path: "/"},
// note the '/' was added // note the '/' was added
&url.URL{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, {Scheme: "http", Host: "127.0.0.1:8081", Path: "/"},
} }
assert.Equal(t, expected, o.proxyURLs) assert.Equal(t, expected, o.proxyURLs)
} }

View File

@ -3,11 +3,12 @@ package providers
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/bitly/go-simplejson"
"github.com/pusher/oauth2_proxy/api"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
"github.com/bitly/go-simplejson"
"github.com/pusher/oauth2_proxy/api"
) )
type AzureProvider struct { type AzureProvider struct {
@ -60,9 +61,9 @@ func (p *AzureProvider) Configure(tenant string) {
} }
} }
func getAzureHeader(access_token string) http.Header { func getAzureHeader(accessToken string) http.Header {
header := make(http.Header) header := make(http.Header)
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
return header return header
} }

View File

@ -110,8 +110,7 @@ func testAzureBackend(payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc( return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
url := r.URL if r.URL.Path != path || r.URL.RawQuery != query {
if url.Path != path || url.RawQuery != query {
w.WriteHeader(404) w.WriteHeader(404)
} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" {
w.WriteHeader(403) w.WriteHeader(403)

View File

@ -43,11 +43,11 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider {
return &FacebookProvider{ProviderData: p} return &FacebookProvider{ProviderData: p}
} }
func getFacebookHeader(access_token string) http.Header { func getFacebookHeader(accessToken string) http.Header {
header := make(http.Header) header := make(http.Header)
header.Set("Accept", "application/json") header.Set("Accept", "application/json")
header.Set("x-li-format", "json") header.Set("x-li-format", "json")
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
return header return header
} }
@ -65,7 +65,7 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
Email string Email string
} }
var r result var r result
err = api.RequestJson(req, &r) err = api.RequestJSON(req, &r)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -106,7 +106,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) {
} }
orgs = append(orgs, op...) orgs = append(orgs, op...)
pn += 1 pn++
} }
var presentOrgs []string var presentOrgs []string
@ -186,7 +186,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams) log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams)
} else { } else {
var allOrgs []string var allOrgs []string
for org, _ := range presentOrgs { for org := range presentOrgs {
allOrgs = append(allOrgs, org) allOrgs = append(allOrgs, org)
} }
log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs) log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs)

View File

@ -29,19 +29,18 @@ func testGitHubProvider(hostname string) *GitHubProvider {
func testGitHubBackend(payload []string) *httptest.Server { func testGitHubBackend(payload []string) *httptest.Server {
pathToQueryMap := map[string][]string{ pathToQueryMap := map[string][]string{
"/user": []string{""}, "/user": {""},
"/user/emails": []string{""}, "/user/emails": {""},
"/user/orgs": []string{"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"}, "/user/orgs": {"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"},
} }
return httptest.NewServer(http.HandlerFunc( return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
url := r.URL query, ok := pathToQueryMap[r.URL.Path]
query, ok := pathToQueryMap[url.Path]
validQuery := false validQuery := false
index := 0 index := 0
for i, q := range query { for i, q := range query {
if q == url.RawQuery { if q == r.URL.RawQuery {
validQuery = true validQuery = true
index = i index = i
} }

View File

@ -33,8 +33,7 @@ func testGitLabBackend(payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc( return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
url := r.URL if r.URL.Path != path || r.URL.RawQuery != query {
if url.Path != path || url.RawQuery != query {
w.WriteHeader(404) w.WriteHeader(404)
} else { } else {
w.WriteHeader(200) w.WriteHeader(200)
@ -87,8 +86,8 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) {
b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}") b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}")
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(b_url.Host) p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
@ -102,8 +101,8 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testGitLabBackend("unused payload") b := testGitLabBackend("unused payload")
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(b_url.Host) p := testGitLabProvider(bURL.Host)
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
@ -118,8 +117,8 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testGitLabBackend("{\"foo\": \"bar\"}") b := testGitLabBackend("{\"foo\": \"bar\"}")
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(b_url.Host) p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)

View File

@ -62,7 +62,7 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
} }
} }
func emailFromIdToken(idToken string) (string, error) { func emailFromIDToken(idToken string) (string, error) {
// id_token is a base64 encode ID token payload // id_token is a base64 encode ID token payload
// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo // https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
@ -129,14 +129,14 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
IdToken string `json:"id_token"` IDToken string `json:"id_token"`
} }
err = json.Unmarshal(body, &jsonResponse) err = json.Unmarshal(body, &jsonResponse)
if err != nil { if err != nil {
return return
} }
var email string var email string
email, err = emailFromIdToken(jsonResponse.IdToken) email, err = emailFromIDToken(jsonResponse.IDToken)
if err != nil { if err != nil {
return return
} }

View File

@ -81,7 +81,7 @@ type redeemResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
IdToken string `json:"id_token"` IDToken string `json:"id_token"`
} }
func TestGoogleProviderGetEmailAddress(t *testing.T) { func TestGoogleProviderGetEmailAddress(t *testing.T) {
@ -90,7 +90,7 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) {
AccessToken: "a1234", AccessToken: "a1234",
ExpiresIn: 10, ExpiresIn: 10,
RefreshToken: "refresh12345", RefreshToken: "refresh12345",
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)), IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)),
}) })
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
var server *httptest.Server var server *httptest.Server
@ -127,7 +127,7 @@ func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
body, err := json.Marshal(redeemResponse{ body, err := json.Marshal(redeemResponse{
AccessToken: "a1234", AccessToken: "a1234",
IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, IDToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`,
}) })
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
var server *httptest.Server var server *httptest.Server
@ -146,7 +146,7 @@ func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) {
body, err := json.Marshal(redeemResponse{ body, err := json.Marshal(redeemResponse{
AccessToken: "a1234", AccessToken: "a1234",
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)),
}) })
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
var server *httptest.Server var server *httptest.Server
@ -165,7 +165,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
body, err := json.Marshal(redeemResponse{ body, err := json.Marshal(redeemResponse{
AccessToken: "a1234", AccessToken: "a1234",
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)),
}) })
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
var server *httptest.Server var server *httptest.Server

View File

@ -46,13 +46,13 @@ func stripParam(param, endpoint string) string {
} }
// validateToken returns true if token is valid // validateToken returns true if token is valid
func validateToken(p Provider, access_token string, header http.Header) bool { func validateToken(p Provider, accessToken string, header http.Header) bool {
if access_token == "" || p.Data().ValidateURL == nil { if accessToken == "" || p.Data().ValidateURL == nil {
return false return false
} }
endpoint := p.Data().ValidateURL.String() endpoint := p.Data().ValidateURL.String()
if len(header) == 0 { if len(header) == 0 {
params := url.Values{"access_token": {access_token}} params := url.Values{"access_token": {accessToken}}
endpoint = endpoint + "?" + params.Encode() endpoint = endpoint + "?" + params.Encode()
} }
resp, err := api.RequestUnparsedResponse(endpoint, header) resp, err := api.RequestUnparsedResponse(endpoint, header)
@ -72,8 +72,3 @@ func validateToken(p Provider, access_token string, header http.Header) bool {
log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body)
return false return false
} }
func updateURL(url *url.URL, hostname string) {
url.Scheme = "http"
url.Host = hostname
}

View File

@ -10,6 +10,11 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func updateURL(url *url.URL, hostname string) {
url.Scheme = "http"
url.Host = hostname
}
type ValidateSessionStateTestProvider struct { type ValidateSessionStateTestProvider struct {
*ProviderData *ProviderData
} }
@ -25,28 +30,28 @@ func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState
} }
type ValidateSessionStateTest struct { type ValidateSessionStateTest struct {
backend *httptest.Server backend *httptest.Server
response_code int responseCode int
provider *ValidateSessionStateTestProvider provider *ValidateSessionStateTestProvider
header http.Header header http.Header
} }
func NewValidateSessionStateTest() *ValidateSessionStateTest { func NewValidateSessionStateTest() *ValidateSessionStateTest {
var vt_test ValidateSessionStateTest var vtTest ValidateSessionStateTest
vt_test.backend = httptest.NewServer( vtTest.backend = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/tokeninfo" { if r.URL.Path != "/oauth/tokeninfo" {
w.WriteHeader(500) w.WriteHeader(500)
w.Write([]byte("unknown URL")) w.Write([]byte("unknown URL"))
} }
token_param := r.FormValue("access_token") tokenParam := r.FormValue("access_token")
if token_param == "" { if tokenParam == "" {
missing := false missing := false
received_headers := r.Header receivedHeaders := r.Header
for k, _ := range vt_test.header { for k := range vtTest.header {
received := received_headers.Get(k) received := receivedHeaders.Get(k)
expected := vt_test.header.Get(k) expected := vtTest.header.Get(k)
if received == "" || received != expected { if received == "" || received != expected {
missing = true missing = true
} }
@ -56,68 +61,68 @@ func NewValidateSessionStateTest() *ValidateSessionStateTest {
w.Write([]byte("no token param and missing or incorrect headers")) w.Write([]byte("no token param and missing or incorrect headers"))
} }
} }
w.WriteHeader(vt_test.response_code) w.WriteHeader(vtTest.responseCode)
w.Write([]byte("only code matters; contents disregarded")) w.Write([]byte("only code matters; contents disregarded"))
})) }))
backend_url, _ := url.Parse(vt_test.backend.URL) backendURL, _ := url.Parse(vtTest.backend.URL)
vt_test.provider = &ValidateSessionStateTestProvider{ vtTest.provider = &ValidateSessionStateTestProvider{
ProviderData: &ProviderData{ ProviderData: &ProviderData{
ValidateURL: &url.URL{ ValidateURL: &url.URL{
Scheme: "http", Scheme: "http",
Host: backend_url.Host, Host: backendURL.Host,
Path: "/oauth/tokeninfo", Path: "/oauth/tokeninfo",
}, },
}, },
} }
vt_test.response_code = 200 vtTest.responseCode = 200
return &vt_test return &vtTest
} }
func (vt_test *ValidateSessionStateTest) Close() { func (vtTest *ValidateSessionStateTest) Close() {
vt_test.backend.Close() vtTest.backend.Close()
} }
func TestValidateSessionStateValidToken(t *testing.T) { func TestValidateSessionStateValidToken(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vt_test.Close() defer vtTest.Close()
assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, true, validateToken(vtTest.provider, "foobar", nil))
} }
func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vt_test.Close() defer vtTest.Close()
vt_test.header = make(http.Header) vtTest.header = make(http.Header)
vt_test.header.Set("Authorization", "Bearer foobar") vtTest.header.Set("Authorization", "Bearer foobar")
assert.Equal(t, true, assert.Equal(t, true,
validateToken(vt_test.provider, "foobar", vt_test.header)) validateToken(vtTest.provider, "foobar", vtTest.header))
} }
func TestValidateSessionStateEmptyToken(t *testing.T) { func TestValidateSessionStateEmptyToken(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vt_test.Close() defer vtTest.Close()
assert.Equal(t, false, validateToken(vt_test.provider, "", nil)) assert.Equal(t, false, validateToken(vtTest.provider, "", nil))
} }
func TestValidateSessionStateEmptyValidateURL(t *testing.T) { func TestValidateSessionStateEmptyValidateURL(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vt_test.Close() defer vtTest.Close()
vt_test.provider.Data().ValidateURL = nil vtTest.provider.Data().ValidateURL = nil
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
} }
func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { func TestValidateSessionStateRequestNetworkFailure(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
// Close immediately to simulate a network failure // Close immediately to simulate a network failure
vt_test.Close() vtTest.Close()
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
} }
func TestValidateSessionStateExpiredToken(t *testing.T) { func TestValidateSessionStateExpiredToken(t *testing.T) {
vt_test := NewValidateSessionStateTest() vtTest := NewValidateSessionStateTest()
defer vt_test.Close() defer vtTest.Close()
vt_test.response_code = 401 vtTest.responseCode = 401
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil))
} }
func TestStripTokenNotPresent(t *testing.T) { func TestStripTokenNotPresent(t *testing.T) {

View File

@ -39,11 +39,11 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
return &LinkedInProvider{ProviderData: p} return &LinkedInProvider{ProviderData: p}
} }
func getLinkedInHeader(access_token string) http.Header { func getLinkedInHeader(accessToken string) http.Header {
header := make(http.Header) header := make(http.Header)
header.Set("Accept", "application/json") header.Set("Accept", "application/json")
header.Set("x-li-format", "json") header.Set("x-li-format", "json")
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
return header return header
} }

View File

@ -31,8 +31,7 @@ func testLinkedInBackend(payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc( return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
url := r.URL if r.URL.Path != path {
if url.Path != path {
w.WriteHeader(404) w.WriteHeader(404)
} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" {
w.WriteHeader(403) w.WriteHeader(403)
@ -95,8 +94,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
b := testLinkedInBackend(`"user@linkedin.com"`) b := testLinkedInBackend(`"user@linkedin.com"`)
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host) p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
@ -108,8 +107,8 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testLinkedInBackend("unused payload") b := testLinkedInBackend("unused payload")
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host) p := testLinkedInProvider(bURL.Host)
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
@ -124,8 +123,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testLinkedInBackend("{\"foo\": \"bar\"}") b := testLinkedInBackend("{\"foo\": \"bar\"}")
defer b.Close() defer b.Close()
b_url, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host) p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)

View File

@ -121,7 +121,8 @@ func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
return validateToken(p, s.AccessToken, nil) return validateToken(p, s.AccessToken, nil)
} }
// RefreshSessionIfNeeded // RefreshSessionIfNeeded should refresh the user's session if required and
// do nothing if a refresh is not required
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
return false, nil return false, nil
} }

View File

@ -42,11 +42,11 @@ func (um *UserMap) LoadAuthenticatedEmailsFile() {
log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err)
} }
defer r.Close() defer r.Close()
csv_reader := csv.NewReader(r) csvReader := csv.NewReader(r)
csv_reader.Comma = ',' csvReader.Comma = ','
csv_reader.Comment = '#' csvReader.Comment = '#'
csv_reader.TrimLeadingSpace = true csvReader.TrimLeadingSpace = true
records, err := csv_reader.ReadAll() records, err := csvReader.ReadAll()
if err != nil { if err != nil {
log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err)
return return

View File

@ -8,15 +8,15 @@ import (
) )
type ValidatorTest struct { type ValidatorTest struct {
auth_email_file *os.File authEmailFile *os.File
done chan bool done chan bool
update_seen bool updateSeen bool
} }
func NewValidatorTest(t *testing.T) *ValidatorTest { func NewValidatorTest(t *testing.T) *ValidatorTest {
vt := &ValidatorTest{} vt := &ValidatorTest{}
var err error var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil { if err != nil {
t.Fatal("failed to create temp file: " + err.Error()) t.Fatal("failed to create temp file: " + err.Error())
} }
@ -26,27 +26,27 @@ func NewValidatorTest(t *testing.T) *ValidatorTest {
func (vt *ValidatorTest) TearDown() { func (vt *ValidatorTest) TearDown() {
vt.done <- true vt.done <- true
os.Remove(vt.auth_email_file.Name()) os.Remove(vt.authEmailFile.Name())
} }
func (vt *ValidatorTest) NewValidator(domains []string, func (vt *ValidatorTest) NewValidator(domains []string,
updated chan<- bool) func(string) bool { updated chan<- bool) func(string) bool {
return newValidatorImpl(domains, vt.auth_email_file.Name(), return newValidatorImpl(domains, vt.authEmailFile.Name(),
vt.done, func() { vt.done, func() {
if vt.update_seen == false { if vt.updateSeen == false {
updated <- true updated <- true
vt.update_seen = true vt.updateSeen = true
} }
}) })
} }
// This will close vt.auth_email_file. // This will close vt.authEmailFile.
func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) {
defer vt.auth_email_file.Close() defer vt.authEmailFile.Close()
vt.auth_email_file.WriteString(strings.Join(emails, "\n")) vt.authEmailFile.WriteString(strings.Join(emails, "\n"))
if err := vt.auth_email_file.Close(); err != nil { if err := vt.authEmailFile.Close(); err != nil {
t.Fatal("failed to close temp file " + t.Fatal("failed to close temp file " +
vt.auth_email_file.Name() + ": " + err.Error()) vt.authEmailFile.Name() + ": " + err.Error())
} }
} }

View File

@ -12,18 +12,18 @@ import (
func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver( func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver(
t *testing.T, emails []string) { t *testing.T, emails []string) {
orig_file := vt.auth_email_file origFile := vt.authEmailFile
var err error var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil { if err != nil {
t.Fatal("failed to create temp file for copy: " + err.Error()) t.Fatal("failed to create temp file for copy: " + err.Error())
} }
vt.WriteEmails(t, emails) vt.WriteEmails(t, emails)
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) err = os.Rename(vt.authEmailFile.Name(), origFile.Name())
if err != nil { if err != nil {
t.Fatal("failed to copy over temp file: " + err.Error()) t.Fatal("failed to copy over temp file: " + err.Error())
} }
vt.auth_email_file = orig_file vt.authEmailFile = origFile
} }
func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) { func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) {

View File

@ -10,8 +10,8 @@ import (
func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) { func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) {
var err error var err error
vt.auth_email_file, err = os.OpenFile( vt.authEmailFile, err = os.OpenFile(
vt.auth_email_file.Name(), os.O_WRONLY|os.O_CREATE, 0600) vt.authEmailFile.Name(), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil { if err != nil {
t.Fatal("failed to re-open temp file for updates") t.Fatal("failed to re-open temp file for updates")
} }
@ -20,24 +20,24 @@ func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) {
func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace( func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace(
t *testing.T, emails []string) { t *testing.T, emails []string) {
orig_file := vt.auth_email_file origFile := vt.authEmailFile
var err error var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil { if err != nil {
t.Fatal("failed to create temp file for rename and replace: " + t.Fatal("failed to create temp file for rename and replace: " +
err.Error()) err.Error())
} }
vt.WriteEmails(t, emails) vt.WriteEmails(t, emails)
moved_name := orig_file.Name() + "-moved" movedName := origFile.Name() + "-moved"
err = os.Rename(orig_file.Name(), moved_name) err = os.Rename(origFile.Name(), movedName)
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) err = os.Rename(vt.authEmailFile.Name(), origFile.Name())
if err != nil { if err != nil {
t.Fatal("failed to rename and replace temp file: " + t.Fatal("failed to rename and replace temp file: " +
err.Error()) err.Error())
} }
vt.auth_email_file = orig_file vt.authEmailFile = origFile
os.Remove(moved_name) os.Remove(movedName)
} }
func TestValidatorOverwriteEmailListDirectly(t *testing.T) { func TestValidatorOverwriteEmailListDirectly(t *testing.T) {

View File

@ -13,11 +13,11 @@ import (
func WaitForReplacement(filename string, op fsnotify.Op, func WaitForReplacement(filename string, op fsnotify.Op,
watcher *fsnotify.Watcher) { watcher *fsnotify.Watcher) {
const sleep_interval = 50 * time.Millisecond const sleepInterval = 50 * time.Millisecond
// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod. // Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
if op&fsnotify.Chmod != 0 { if op&fsnotify.Chmod != 0 {
time.Sleep(sleep_interval) time.Sleep(sleepInterval)
} }
for { for {
if _, err := os.Stat(filename); err == nil { if _, err := os.Stat(filename); err == nil {
@ -26,7 +26,7 @@ func WaitForReplacement(filename string, op fsnotify.Op,
return return
} }
} }
time.Sleep(sleep_interval) time.Sleep(sleepInterval)
} }
} }
@ -56,7 +56,7 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) {
} }
log.Printf("reloading after event: %s", event) log.Printf("reloading after event: %s", event)
action() action()
case err := <-watcher.Errors: case err = <-watcher.Errors:
log.Printf("error watching %s: %s", filename, err) log.Printf("error watching %s: %s", filename, err)
} }
} }