mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2024-11-24 08:52:25 +02:00
testing
This commit is contained in:
parent
42359333b2
commit
42f539109e
15
htpasswd.go
15
htpasswd.go
@ -4,6 +4,7 @@ import (
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/csv"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
@ -15,26 +16,30 @@ type HtpasswdFile struct {
|
||||
Users map[string]string
|
||||
}
|
||||
|
||||
func NewHtpasswdFile(path string) *HtpasswdFile {
|
||||
func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
|
||||
log.Printf("using htpasswd file %s", path)
|
||||
r, err := os.Open(path)
|
||||
if err != nil {
|
||||
log.Fatalf("failed opening %v, %s", path, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
csv_reader := csv.NewReader(r)
|
||||
return NewHtpasswd(r)
|
||||
}
|
||||
|
||||
func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) {
|
||||
csv_reader := csv.NewReader(file)
|
||||
csv_reader.Comma = ':'
|
||||
csv_reader.Comment = '#'
|
||||
csv_reader.TrimLeadingSpace = true
|
||||
|
||||
records, err := csv_reader.ReadAll()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed reading file %s", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
h := &HtpasswdFile{Users: make(map[string]string)}
|
||||
for _, record := range records {
|
||||
h.Users[record[0]] = record[1]
|
||||
}
|
||||
return h
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (h *HtpasswdFile) Validate(user string, password string) bool {
|
||||
|
16
htpasswd_test.go
Normal file
16
htpasswd_test.go
Normal file
@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/bmizerany/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHtpasswd(t *testing.T) {
|
||||
file := bytes.NewBuffer([]byte("testuser:{SHA}PaVBVZkYqAjCQCu6UBL2xgsnZhw=\n"))
|
||||
h, err := NewHtpasswd(file)
|
||||
assert.Equal(t, err, nil)
|
||||
|
||||
valid := h.Validate("testuser", "asdf")
|
||||
assert.Equal(t, valid, true)
|
||||
}
|
7
main.go
7
main.go
@ -2,12 +2,12 @@ package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const VERSION = "0.0.1"
|
||||
@ -72,7 +72,10 @@ func main() {
|
||||
oauthproxy.SignInMessage = fmt.Sprintf("using a %s email address", *googleAppsDomain)
|
||||
}
|
||||
if *htpasswdFile != "" {
|
||||
oauthproxy.HtpasswdFile = NewHtpasswdFile(*htpasswdFile)
|
||||
oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(*htpasswdFile)
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: unable to open %s %s", *htpasswdFile, err.Error())
|
||||
}
|
||||
}
|
||||
listener, err := net.Listen("tcp", *httpAddr)
|
||||
if err != nil {
|
||||
|
@ -169,11 +169,11 @@ func (p *OauthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
|
||||
rw.WriteHeader(code)
|
||||
templates := getTemplates()
|
||||
t := struct {
|
||||
Title string
|
||||
Message string
|
||||
Title string
|
||||
Message string
|
||||
}{
|
||||
Title: fmt.Sprintf("%d %s", code, title),
|
||||
Message: message,
|
||||
Title: fmt.Sprintf("%d %s", code, title),
|
||||
Message: message,
|
||||
}
|
||||
templates.ExecuteTemplate(rw, "error.html", t)
|
||||
}
|
||||
@ -254,7 +254,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
cookie, err := req.Cookie(p.CookieKey)
|
||||
var ok bool
|
||||
var email string
|
||||
|
@ -18,7 +18,7 @@ func getTemplates() *template.Template {
|
||||
if err != nil {
|
||||
log.Fatalf("failed parsing template %s", err.Error())
|
||||
}
|
||||
|
||||
|
||||
t, err = t.Parse(`{{define "error.html"}}
|
||||
<html><head><title>{{.Title}}</title></head>
|
||||
<body>
|
||||
|
12
templates_test.go
Normal file
12
templates_test.go
Normal file
@ -0,0 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/bmizerany/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTemplatesCompile(t *testing.T) {
|
||||
templates := getTemplates()
|
||||
assert.NotEqual(t, templates, nil)
|
||||
|
||||
}
|
@ -1,10 +1,10 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"log"
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user