mirror of
https://github.com/volatiletech/authboss.git
synced 2025-02-09 13:47:09 +02:00
Merge branch 'context-request-separation'
This commit is contained in:
commit
c4eb529fd9
@ -80,8 +80,8 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
|
||||
)
|
||||
return a.templates.Render(ctx, w, r, tplLogin, data)
|
||||
case methodPOST:
|
||||
key, _ := ctx.FirstPostFormValue(a.PrimaryID)
|
||||
password, _ := ctx.FirstPostFormValue("password")
|
||||
key := r.FormValue(a.PrimaryID)
|
||||
password := r.FormValue("password")
|
||||
|
||||
errData := authboss.NewHTMLData(
|
||||
"error", fmt.Sprintf("invalid %s and/or password", a.PrimaryID),
|
||||
@ -120,6 +120,7 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
|
||||
|
||||
ctx.SessionStorer.Put(authboss.SessionKey, key)
|
||||
ctx.SessionStorer.Del(authboss.SessionHalfAuthKey)
|
||||
ctx.Values = map[string]string{authboss.CookieRemember: r.FormValue(authboss.CookieRemember)}
|
||||
|
||||
if err := a.Callbacks.FireAfter(authboss.EventAuth, ctx); err != nil {
|
||||
return err
|
||||
|
@ -36,13 +36,9 @@ func testSetup() (a *Auth, s *mocks.MockStorer) {
|
||||
}
|
||||
|
||||
func testRequest(ab *authboss.Authboss, method string, postFormValues ...string) (*authboss.Context, *httptest.ResponseRecorder, *http.Request, authboss.ClientStorerErr) {
|
||||
r, err := http.NewRequest(method, "", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sessionStorer := mocks.NewMockClientStorer()
|
||||
ctx := mocks.MockRequestContext(ab, postFormValues...)
|
||||
ctx := ab.NewContext()
|
||||
r := mocks.MockRequest(method, postFormValues...)
|
||||
ctx.SessionStorer = sessionStorer
|
||||
|
||||
return ctx, httptest.NewRecorder(), r, sessionStorer
|
||||
@ -243,6 +239,9 @@ func TestAuth_loginHandlerFunc_POST(t *testing.T) {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
if _, ok := ctx.Values[authboss.CookieRemember]; !ok {
|
||||
t.Error("Authboss cookie remember should be set for the callback")
|
||||
}
|
||||
if !cb.HasBeenCalled {
|
||||
t.Error("Expected after callback to have been called")
|
||||
}
|
||||
|
10
authboss.go
10
authboss.go
@ -64,10 +64,7 @@ func (a *Authboss) Init(modulesToLoad ...string) error {
|
||||
|
||||
// CurrentUser retrieves the current user from the session and the database.
|
||||
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
|
||||
ctx, err := a.ContextFromRequest(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx := a.NewContext()
|
||||
|
||||
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
|
||||
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
|
||||
@ -168,10 +165,7 @@ func (a *Authboss) UpdatePassword(w http.ResponseWriter, r *http.Request,
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, err := a.ContextFromRequest(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := a.NewContext()
|
||||
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
|
||||
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
|
||||
return a.Callbacks.FireAfter(EventPasswordReset, ctx)
|
||||
|
@ -154,9 +154,9 @@ func (c *Confirm) confirmEmail(to, token string) {
|
||||
}
|
||||
|
||||
func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
||||
token, err := ctx.FirstFormValueErr(FormValueConfirm)
|
||||
if err != nil {
|
||||
return err
|
||||
token := r.FormValue(FormValueConfirm)
|
||||
if len(token) == 0 {
|
||||
return authboss.ClientDataErr{FormValueConfirm}
|
||||
}
|
||||
|
||||
toHash, err := base64.URLEncoding.DecodeString(token)
|
||||
|
@ -163,7 +163,7 @@ func TestConfirm_ConfirmHandlerErrors(t *testing.T) {
|
||||
for i, test := range tests {
|
||||
r, _ := http.NewRequest("GET", test.URL, nil)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := c.ContextFromRequest(r)
|
||||
ctx := c.NewContext()
|
||||
|
||||
err := c.confirmHandler(ctx, w, r)
|
||||
if err == nil {
|
||||
@ -207,7 +207,7 @@ func TestConfirm_Confirm(t *testing.T) {
|
||||
// Make a request with session and context support.
|
||||
r, _ := http.NewRequest("GET", "http://localhost?cnf="+base64.URLEncoding.EncodeToString(token), nil)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ = c.ContextFromRequest(r)
|
||||
ctx = c.NewContext()
|
||||
ctx.CookieStorer = mocks.NewMockClientStorer()
|
||||
session := mocks.NewMockClientStorer()
|
||||
ctx.User = user
|
||||
|
108
context.go
108
context.go
@ -2,11 +2,7 @@ package authboss
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FormValue constants
|
||||
@ -25,8 +21,8 @@ type Context struct {
|
||||
CookieStorer ClientStorerErr
|
||||
User Attributes
|
||||
|
||||
postFormValues map[string][]string
|
||||
formValues map[string][]string
|
||||
// Values is a free-form key-value store to pass data to callbacks
|
||||
Values map[string]string
|
||||
}
|
||||
|
||||
// NewContext is exported for testing modules.
|
||||
@ -36,75 +32,6 @@ func (a *Authboss) NewContext() *Context {
|
||||
}
|
||||
}
|
||||
|
||||
// ContextFromRequest creates a context from an http request.
|
||||
func (a *Authboss) ContextFromRequest(r *http.Request) (*Context, error) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := a.NewContext()
|
||||
c.formValues = map[string][]string(r.Form)
|
||||
c.postFormValues = map[string][]string(r.PostForm)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// FormValue gets a form value from a context created with a request.
|
||||
func (c *Context) FormValue(key string) ([]string, bool) {
|
||||
val, ok := c.formValues[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// PostFormValue gets a form value from a context created with a request.
|
||||
func (c *Context) PostFormValue(key string) ([]string, bool) {
|
||||
val, ok := c.postFormValues[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// FirstFormValue gets the first form value from a context created with a request.
|
||||
func (c *Context) FirstFormValue(key string) (string, bool) {
|
||||
val, ok := c.formValues[key]
|
||||
|
||||
if !ok || len(val) == 0 || len(val[0]) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return val[0], ok
|
||||
}
|
||||
|
||||
// FirstPostFormValue gets the first form value from a context created with a request.
|
||||
func (c *Context) FirstPostFormValue(key string) (string, bool) {
|
||||
val, ok := c.postFormValues[key]
|
||||
|
||||
if !ok || len(val) == 0 || len(val[0]) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return val[0], ok
|
||||
}
|
||||
|
||||
// FirstFormValueErr gets the first form value from a context created with a request
|
||||
// and additionally returns an error not a bool if it's not found.
|
||||
func (c *Context) FirstFormValueErr(key string) (string, error) {
|
||||
val, ok := c.formValues[key]
|
||||
|
||||
if !ok || len(val) == 0 || len(val[0]) == 0 {
|
||||
return "", ClientDataErr{key}
|
||||
}
|
||||
|
||||
return val[0], nil
|
||||
}
|
||||
|
||||
// FirstPostFormValueErr gets the first form value from a context created with a request.
|
||||
func (c *Context) FirstPostFormValueErr(key string) (string, error) {
|
||||
val, ok := c.postFormValues[key]
|
||||
|
||||
if !ok || len(val) == 0 || len(val[0]) == 0 {
|
||||
return "", ClientDataErr{key}
|
||||
}
|
||||
|
||||
return val[0], nil
|
||||
}
|
||||
|
||||
// LoadUser loads the user Attributes if they haven't already been loaded.
|
||||
func (c *Context) LoadUser(key string) error {
|
||||
if c.User != nil {
|
||||
@ -155,34 +82,3 @@ func (c *Context) SaveUser() error {
|
||||
|
||||
return c.Storer.Put(key, c.User)
|
||||
}
|
||||
|
||||
// Attributes converts the post form values into an attributes map.
|
||||
func (c *Context) Attributes() (Attributes, error) {
|
||||
attr := make(Attributes)
|
||||
|
||||
for name, values := range c.postFormValues {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
val := values[0]
|
||||
switch {
|
||||
case strings.HasSuffix(name, "_int"):
|
||||
integer, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err)
|
||||
}
|
||||
attr[strings.TrimRight(name, "_int")] = integer
|
||||
case strings.HasSuffix(name, "_date"):
|
||||
date, err := time.Parse(time.RFC3339, val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err)
|
||||
}
|
||||
attr[strings.TrimRight(name, "_date")] = date.UTC()
|
||||
default:
|
||||
attr[name] = val
|
||||
}
|
||||
}
|
||||
|
||||
return attr, nil
|
||||
}
|
||||
|
105
context_test.go
105
context_test.go
@ -1,76 +1,6 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestContext_Request(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
|
||||
req, err := http.NewRequest("POST", "http://localhost?query=string", bytes.NewBufferString("post=form"))
|
||||
if err != nil {
|
||||
t.Error("Unexpected Error:", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
ctx, err := ab.ContextFromRequest(req)
|
||||
if err != nil {
|
||||
t.Error("Unexpected Error:", err)
|
||||
}
|
||||
|
||||
if query, ok := ctx.FormValue("query"); !ok || query[0] != "string" {
|
||||
t.Error("Form value not getting recorded correctly.")
|
||||
}
|
||||
|
||||
if post, ok := ctx.PostFormValue("post"); !ok || post[0] != "form" {
|
||||
t.Error("Postform value not getting recorded correctly.")
|
||||
}
|
||||
|
||||
if query, ok := ctx.FirstFormValue("query"); !ok || query != "string" {
|
||||
t.Error("Form value not getting recorded correctly.")
|
||||
}
|
||||
|
||||
if post, ok := ctx.FirstPostFormValue("post"); !ok || post != "form" {
|
||||
t.Error("Postform value not getting recorded correctly.")
|
||||
}
|
||||
|
||||
if _, err := ctx.FirstFormValueErr("query"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, err := ctx.FirstPostFormValueErr("post"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if query, ok := ctx.FormValue("query1"); ok {
|
||||
t.Error("Expected query1 not to be found:", query)
|
||||
}
|
||||
|
||||
if post, ok := ctx.PostFormValue("post1"); ok {
|
||||
t.Error("Expected post1 not to be found:", post)
|
||||
}
|
||||
|
||||
if query, ok := ctx.FirstFormValue("query1"); ok {
|
||||
t.Error("Expected query1 not to be found:", query)
|
||||
}
|
||||
|
||||
if post, ok := ctx.FirstPostFormValue("post1"); ok {
|
||||
t.Error("Expected post1 not to be found:", post)
|
||||
}
|
||||
|
||||
if query, err := ctx.FirstFormValueErr("query1"); err == nil {
|
||||
t.Error("Expected query1 not to be found:", query)
|
||||
}
|
||||
|
||||
if post, err := ctx.FirstPostFormValueErr("post1"); err == nil {
|
||||
t.Error("Expected post1 not to be found:", post)
|
||||
}
|
||||
}
|
||||
import "testing"
|
||||
|
||||
func TestContext_SaveUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
@ -177,36 +107,3 @@ func TestContext_LoadSessionUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext_Attributes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
ab := New()
|
||||
ctx := ab.NewContext()
|
||||
ctx.postFormValues = map[string][]string{
|
||||
"a": []string{"a", "1"},
|
||||
"b_int": []string{"5", "hello"},
|
||||
"wildcard": nil,
|
||||
"c_date": []string{now.Format(time.RFC3339)},
|
||||
}
|
||||
|
||||
attr, err := ctx.Attributes()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got := attr["a"].(string); got != "a" {
|
||||
t.Error("a's value is wrong:", got)
|
||||
}
|
||||
if got := attr["b"].(int); got != 5 {
|
||||
t.Error("b's value is wrong:", got)
|
||||
}
|
||||
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
|
||||
t.Error("c's value is wrong:", now, got)
|
||||
}
|
||||
if _, ok := attr["wildcard"]; ok {
|
||||
t.Error("We don't need totally empty fields.")
|
||||
}
|
||||
}
|
||||
|
@ -2,10 +2,11 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/authboss.v0"
|
||||
@ -278,28 +279,34 @@ func (m *MockClientStorer) Put(key, val string) { m.Values[key] = val }
|
||||
// Del a key/value pair
|
||||
func (m *MockClientStorer) Del(key string) { delete(m.Values, key) }
|
||||
|
||||
// MockRequestContext returns a new context as if it came from POST request.
|
||||
func MockRequestContext(ab *authboss.Authboss, postKeyValues ...string) *authboss.Context {
|
||||
keyValues := &bytes.Buffer{}
|
||||
for i := 0; i < len(postKeyValues); i += 2 {
|
||||
if i != 0 {
|
||||
keyValues.WriteByte('&')
|
||||
// MockRequest returns a new mock request with optional key-value body (form-post)
|
||||
func MockRequest(method string, postKeyValues ...string) *http.Request {
|
||||
var body io.Reader
|
||||
location := "http://localhost"
|
||||
|
||||
if len(postKeyValues) > 0 {
|
||||
urlValues := make(url.Values)
|
||||
for i := 0; i < len(postKeyValues); i += 2 {
|
||||
urlValues.Set(postKeyValues[i], postKeyValues[i+1])
|
||||
}
|
||||
|
||||
if method == "POST" || method == "PUT" {
|
||||
body = strings.NewReader(urlValues.Encode())
|
||||
} else {
|
||||
location += "?" + urlValues.Encode()
|
||||
}
|
||||
fmt.Fprintf(keyValues, "%s=%s", postKeyValues[i], postKeyValues[i+1])
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", "http://localhost", keyValues)
|
||||
req, err := http.NewRequest(method, location, body)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
ctx, err := ab.ContextFromRequest(req)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
if len(postKeyValues) > 0 {
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
|
||||
return ctx
|
||||
return req
|
||||
}
|
||||
|
||||
// MockMailer helps simplify mailer testing by storing the last sent email
|
||||
|
@ -68,7 +68,7 @@ func TestTemplates_Render(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "http://localhost", nil)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := ab.ContextFromRequest(r)
|
||||
ctx := ab.NewContext()
|
||||
ctx.SessionStorer = cookies
|
||||
|
||||
tpls := Templates{
|
||||
@ -128,7 +128,7 @@ func TestRedirect(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "http://localhost", nil)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := ab.ContextFromRequest(r)
|
||||
ctx := ab.NewContext()
|
||||
ctx.SessionStorer = cookies
|
||||
|
||||
Redirect(ctx, w, r, "/", "success", "failure", false)
|
||||
@ -157,7 +157,7 @@ func TestRedirect_Override(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("GET", "http://localhost?redir=foo/bar", nil)
|
||||
w := httptest.NewRecorder()
|
||||
ctx, _ := ab.ContextFromRequest(r)
|
||||
ctx := ab.NewContext()
|
||||
ctx.SessionStorer = cookies
|
||||
|
||||
Redirect(ctx, w, r, "/shouldNotGo", "success", "failure", true)
|
||||
|
@ -1,9 +1,9 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type mockUser struct {
|
||||
@ -56,27 +56,19 @@ func (m mockClientStore) GetErr(key string) (string, error) {
|
||||
func (m mockClientStore) Put(key, val string) { m[key] = val }
|
||||
func (m mockClientStore) Del(key string) { delete(m, key) }
|
||||
|
||||
func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context {
|
||||
keyValues := &bytes.Buffer{}
|
||||
func mockRequest(postKeyValues ...string) *http.Request {
|
||||
urlValues := make(url.Values)
|
||||
for i := 0; i < len(postKeyValues); i += 2 {
|
||||
if i != 0 {
|
||||
keyValues.WriteByte('&')
|
||||
}
|
||||
fmt.Fprintf(keyValues, "%s=%s", postKeyValues[i], postKeyValues[i+1])
|
||||
urlValues.Set(postKeyValues[i], postKeyValues[i+1])
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", "http://localhost", keyValues)
|
||||
req, err := http.NewRequest("POST", "http://localhost", strings.NewReader(urlValues.Encode()))
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
ctx, err := ab.ContextFromRequest(req)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return ctx
|
||||
return req
|
||||
}
|
||||
|
||||
type mockValidator struct {
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
@ -22,6 +23,10 @@ const (
|
||||
StoreRecoverTokenExpiry = "recover_token_expiry"
|
||||
)
|
||||
|
||||
const (
|
||||
formValueToken = "token"
|
||||
)
|
||||
|
||||
const (
|
||||
methodGET = "GET"
|
||||
methodPOST = "POST"
|
||||
@ -129,8 +134,8 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite
|
||||
|
||||
return rec.templates.Render(ctx, w, r, tplRecover, data)
|
||||
case methodPOST:
|
||||
primaryID, _ := ctx.FirstPostFormValue(rec.PrimaryID)
|
||||
confirmPrimaryID, _ := ctx.FirstPostFormValue(fmt.Sprintf("confirm_%s", rec.PrimaryID))
|
||||
primaryID := r.FormValue(rec.PrimaryID)
|
||||
confirmPrimaryID := r.FormValue(fmt.Sprintf("confirm_%s", rec.PrimaryID))
|
||||
|
||||
errData := authboss.NewHTMLData(
|
||||
"primaryID", rec.PrimaryID,
|
||||
@ -139,7 +144,7 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite
|
||||
)
|
||||
|
||||
policies := authboss.FilterValidators(rec.Policies, rec.PrimaryID)
|
||||
if validationErrs := ctx.Validate(policies, rec.PrimaryID, authboss.ConfirmPrefix+rec.PrimaryID).Map(); len(validationErrs) > 0 {
|
||||
if validationErrs := authboss.Validate(r, policies, rec.PrimaryID, authboss.ConfirmPrefix+rec.PrimaryID).Map(); len(validationErrs) > 0 {
|
||||
errData.MergeKV("errs", validationErrs)
|
||||
return rec.templates.Render(ctx, w, r, tplRecover, errData)
|
||||
}
|
||||
@ -195,7 +200,8 @@ var goRecoverEmail = func(r *Recover, to, encodedToken string) {
|
||||
|
||||
func (r *Recover) sendRecoverEmail(to, encodedToken string) {
|
||||
p := path.Join(r.MountPath, "recover/complete")
|
||||
url := fmt.Sprintf("%s%s?token=%s", r.RootURL, p, encodedToken)
|
||||
query := url.Values{formValueToken: []string{encodedToken}}
|
||||
url := fmt.Sprintf("%s%s?%s", r.RootURL, p, query.Encode())
|
||||
|
||||
email := authboss.Email{
|
||||
To: []string{to},
|
||||
@ -211,35 +217,35 @@ func (r *Recover) sendRecoverEmail(to, encodedToken string) {
|
||||
func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, req *http.Request) (err error) {
|
||||
switch req.Method {
|
||||
case methodGET:
|
||||
_, err = verifyToken(ctx)
|
||||
_, err = verifyToken(ctx, req)
|
||||
if err == errRecoveryTokenExpired {
|
||||
return authboss.ErrAndRedirect{err, "/recover", "", recoverTokenExpiredFlash}
|
||||
} else if err != nil {
|
||||
return authboss.ErrAndRedirect{err, "/", "", ""}
|
||||
}
|
||||
|
||||
token, _ := ctx.FirstFormValue("token")
|
||||
data := authboss.NewHTMLData("token", token)
|
||||
token := req.FormValue(formValueToken)
|
||||
data := authboss.NewHTMLData(formValueToken, token)
|
||||
return r.templates.Render(ctx, w, req, tplRecoverComplete, data)
|
||||
case methodPOST:
|
||||
token, err := ctx.FirstFormValueErr("token")
|
||||
if err != nil {
|
||||
return err
|
||||
token := req.FormValue(formValueToken)
|
||||
if len(token) == 0 {
|
||||
return authboss.ClientDataErr{formValueToken}
|
||||
}
|
||||
|
||||
password, _ := ctx.FirstPostFormValue("password")
|
||||
password := req.FormValue(authboss.StorePassword)
|
||||
//confirmPassword, _ := ctx.FirstPostFormValue("confirmPassword")
|
||||
|
||||
policies := authboss.FilterValidators(r.Policies, "password")
|
||||
if validationErrs := ctx.Validate(policies, authboss.StorePassword, authboss.ConfirmPrefix+authboss.StorePassword).Map(); len(validationErrs) > 0 {
|
||||
policies := authboss.FilterValidators(r.Policies, authboss.StorePassword)
|
||||
if validationErrs := authboss.Validate(req, policies, authboss.StorePassword, authboss.ConfirmPrefix+authboss.StorePassword).Map(); len(validationErrs) > 0 {
|
||||
data := authboss.NewHTMLData(
|
||||
"token", token,
|
||||
formValueToken, token,
|
||||
"errs", validationErrs,
|
||||
)
|
||||
return r.templates.Render(ctx, w, req, tplRecoverComplete, data)
|
||||
}
|
||||
|
||||
if ctx.User, err = verifyToken(ctx); err != nil {
|
||||
if ctx.User, err = verifyToken(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -276,10 +282,10 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit
|
||||
}
|
||||
|
||||
// verifyToken expects a base64.URLEncoded token.
|
||||
func verifyToken(ctx *authboss.Context) (attrs authboss.Attributes, err error) {
|
||||
token, err := ctx.FirstFormValueErr("token")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func verifyToken(ctx *authboss.Context, r *http.Request) (attrs authboss.Attributes, err error) {
|
||||
token := r.FormValue(formValueToken)
|
||||
if len(token) == 0 {
|
||||
return nil, authboss.ClientDataErr{token}
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(token)
|
||||
|
@ -63,13 +63,9 @@ func testSetup() (r *Recover, s *mocks.MockStorer, l *bytes.Buffer) {
|
||||
}
|
||||
|
||||
func testRequest(ab *authboss.Authboss, method string, postFormValues ...string) (*authboss.Context, *httptest.ResponseRecorder, *http.Request, authboss.ClientStorerErr) {
|
||||
r, err := http.NewRequest(method, "", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sessionStorer := mocks.NewMockClientStorer()
|
||||
ctx := mocks.MockRequestContext(ab, postFormValues...)
|
||||
ctx := ab.NewContext()
|
||||
r := mocks.MockRequest(method, postFormValues...)
|
||||
ctx.SessionStorer = sessionStorer
|
||||
|
||||
return ctx, httptest.NewRecorder(), r, sessionStorer
|
||||
@ -300,7 +296,7 @@ func TestRecover_sendRecoverEmail(t *testing.T) {
|
||||
t.Error("Unexpected subject:", mailer.Last.Subject)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/recover/complete?token=abc=", r.RootURL)
|
||||
url := fmt.Sprintf("%s/recover/complete?token=abc%%3D", r.RootURL)
|
||||
if !strings.Contains(mailer.Last.HTMLBody, url) {
|
||||
t.Error("Expected HTMLBody to contain url:", url)
|
||||
}
|
||||
@ -319,7 +315,7 @@ func TestRecover_completeHandlerFunc_GET_VerifyFails(t *testing.T) {
|
||||
err := rec.completeHandlerFunc(ctx, w, r)
|
||||
rerr, ok := err.(authboss.ErrAndRedirect)
|
||||
if !ok {
|
||||
t.Error("Expected ErrAndRedirect")
|
||||
t.Error("Expected ErrAndRedirect:", err)
|
||||
}
|
||||
if rerr.Location != "/" {
|
||||
t.Error("Unexpected location:", rerr.Location)
|
||||
@ -382,7 +378,7 @@ func TestRecover_completeHandlerFunc_POST_TokenMissing(t *testing.T) {
|
||||
ctx, w, r, _ := testRequest(rec.Authboss, "POST")
|
||||
|
||||
err := rec.completeHandlerFunc(ctx, w, r)
|
||||
if err.Error() != "Failed to retrieve client attribute: token" {
|
||||
if err == nil || err.Error() != "Failed to retrieve client attribute: token" {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
@ -477,9 +473,9 @@ func Test_verifyToken_MissingToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testSetup()
|
||||
r := mocks.MockRequest("GET")
|
||||
|
||||
ctx := &authboss.Context{}
|
||||
if _, err := verifyToken(ctx); err == nil {
|
||||
if _, err := verifyToken(nil, r); err == nil {
|
||||
t.Error("Expected error about missing token")
|
||||
}
|
||||
}
|
||||
@ -492,8 +488,9 @@ func Test_verifyToken_InvalidToken(t *testing.T) {
|
||||
StoreRecoverToken: testStdBase64Token,
|
||||
}
|
||||
|
||||
ctx := mocks.MockRequestContext(rec.Authboss, "token", "asdf")
|
||||
if _, err := verifyToken(ctx); err != authboss.ErrUserNotFound {
|
||||
ctx := rec.Authboss.NewContext()
|
||||
req, _ := http.NewRequest("GET", "/?token=asdf", nil)
|
||||
if _, err := verifyToken(ctx, req); err != authboss.ErrUserNotFound {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
}
|
||||
@ -507,8 +504,9 @@ func Test_verifyToken_ExpiredToken(t *testing.T) {
|
||||
StoreRecoverTokenExpiry: time.Now().Add(time.Duration(-24) * time.Hour),
|
||||
}
|
||||
|
||||
ctx := mocks.MockRequestContext(rec.Authboss, "token", testURLBase64Token)
|
||||
if _, err := verifyToken(ctx); err != errRecoveryTokenExpired {
|
||||
ctx := rec.Authboss.NewContext()
|
||||
req, _ := http.NewRequest("GET", "/?token="+testURLBase64Token, nil)
|
||||
if _, err := verifyToken(ctx, req); err != errRecoveryTokenExpired {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
}
|
||||
@ -522,8 +520,9 @@ func Test_verifyToken(t *testing.T) {
|
||||
StoreRecoverTokenExpiry: time.Now().Add(time.Duration(24) * time.Hour),
|
||||
}
|
||||
|
||||
ctx := mocks.MockRequestContext(rec.Authboss, "token", testURLBase64Token)
|
||||
attrs, err := verifyToken(ctx)
|
||||
ctx := rec.Authboss.NewContext()
|
||||
req, _ := http.NewRequest("GET", "/?token="+testURLBase64Token, nil)
|
||||
attrs, err := verifyToken(ctx, req)
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
@ -82,10 +82,10 @@ func (reg *Register) registerHandler(ctx *authboss.Context, w http.ResponseWrite
|
||||
}
|
||||
|
||||
func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
||||
key, _ := ctx.FirstPostFormValue(reg.PrimaryID)
|
||||
password, _ := ctx.FirstPostFormValue(authboss.StorePassword)
|
||||
key := r.FormValue(reg.PrimaryID)
|
||||
password := r.FormValue(authboss.StorePassword)
|
||||
|
||||
validationErrs := ctx.Validate(reg.Policies, reg.ConfirmFields...)
|
||||
validationErrs := authboss.Validate(r, reg.Policies, reg.ConfirmFields...)
|
||||
|
||||
if user, err := ctx.Storer.Get(key); err != nil && err != authboss.ErrUserNotFound {
|
||||
return err
|
||||
@ -101,13 +101,13 @@ func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseW
|
||||
}
|
||||
|
||||
for _, f := range reg.PreserveFields {
|
||||
data[f], _ = ctx.FirstFormValue(f)
|
||||
data[f] = r.FormValue(f)
|
||||
}
|
||||
|
||||
return reg.templates.Render(ctx, w, r, tplRegister, data)
|
||||
}
|
||||
|
||||
attr, err := ctx.Attributes() // Attributes from overriden forms
|
||||
attr, err := authboss.AttributesFromRequest(r) // Attributes from overriden forms
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -129,7 +129,7 @@ func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseW
|
||||
}
|
||||
|
||||
for _, f := range reg.PreserveFields {
|
||||
data[f], _ = ctx.FirstFormValue(f)
|
||||
data[f] = r.FormValue(f)
|
||||
}
|
||||
|
||||
return reg.templates.Render(ctx, w, r, tplRegister, data)
|
||||
|
@ -58,7 +58,7 @@ func TestRegisterGet(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r, _ := http.NewRequest("GET", "/register", nil)
|
||||
ctx, _ := reg.ContextFromRequest(r)
|
||||
ctx := reg.NewContext()
|
||||
ctx.SessionStorer = mocks.NewMockClientStorer()
|
||||
|
||||
if err := reg.registerHandler(ctx, w, r); err != nil {
|
||||
@ -93,7 +93,7 @@ func TestRegisterPostValidationErrs(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("POST", "/register", bytes.NewBufferString(vals.Encode()))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
ctx, _ := reg.ContextFromRequest(r)
|
||||
ctx := reg.NewContext()
|
||||
ctx.SessionStorer = mocks.NewMockClientStorer()
|
||||
|
||||
if err := reg.registerHandler(ctx, w, r); err != nil {
|
||||
@ -131,7 +131,7 @@ func TestRegisterPostSuccess(t *testing.T) {
|
||||
|
||||
r, _ := http.NewRequest("POST", "/register", bytes.NewBufferString(vals.Encode()))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
ctx, _ := reg.ContextFromRequest(r)
|
||||
ctx := reg.NewContext()
|
||||
ctx.SessionStorer = mocks.NewMockClientStorer()
|
||||
|
||||
if err := reg.registerHandler(ctx, w, r); err != nil {
|
||||
|
@ -83,7 +83,7 @@ func (r *Remember) Storage() authboss.StorageOptions {
|
||||
|
||||
// afterAuth is called after authentication is successful.
|
||||
func (r *Remember) afterAuth(ctx *authboss.Context) error {
|
||||
if val, ok := ctx.FirstPostFormValue(authboss.CookieRemember); !ok || val != "true" {
|
||||
if val := ctx.Values[authboss.CookieRemember]; val != "true" {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,6 @@ package remember
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
@ -49,15 +48,13 @@ func TestAfterAuth(t *testing.T) {
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
ctx, err := r.ContextFromRequest(req)
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
ctx := r.NewContext()
|
||||
ctx.SessionStorer = session
|
||||
ctx.CookieStorer = cookies
|
||||
ctx.User = authboss.Attributes{r.PrimaryID: "test@email.com"}
|
||||
|
||||
ctx.Values = map[string]string{authboss.CookieRemember: "true"}
|
||||
|
||||
if err := r.afterAuth(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -77,17 +74,7 @@ func TestAfterOAuth(t *testing.T) {
|
||||
cookies := mocks.NewMockClientStorer()
|
||||
session := mocks.NewMockClientStorer(authboss.SessionOAuth2Params, `{"rm":"true"}`)
|
||||
|
||||
uri := fmt.Sprintf("%s?state=%s", "localhost/oauthed", "xsrf")
|
||||
req, err := http.NewRequest("GET", uri, nil)
|
||||
if err != nil {
|
||||
t.Error("Unexpected Error:", err)
|
||||
}
|
||||
|
||||
ctx, err := r.ContextFromRequest(req)
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
ctx := r.NewContext()
|
||||
ctx.SessionStorer = session
|
||||
ctx.CookieStorer = cookies
|
||||
ctx.User = authboss.Attributes{
|
||||
|
10
router.go
10
router.go
@ -47,11 +47,7 @@ type contextRoute struct {
|
||||
|
||||
func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Instantiate the context
|
||||
ctx, err := c.Authboss.ContextFromRequest(r)
|
||||
if err != nil {
|
||||
fmt.Fprintf(c.LogWriter, "route: Malformed request, could not create context: %v", err)
|
||||
return
|
||||
}
|
||||
ctx := c.Authboss.NewContext()
|
||||
ctx.CookieStorer = clientStoreWrapper{c.CookieStoreMaker(w, r)}
|
||||
ctx.SessionStorer = clientStoreWrapper{c.SessionStoreMaker(w, r)}
|
||||
|
||||
@ -61,7 +57,7 @@ func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Call the handler
|
||||
err = c.fn(ctx, w, r)
|
||||
err := c.fn(ctx, w, r)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
@ -121,7 +117,7 @@ func redirectIfLoggedIn(ctx *Context, w http.ResponseWriter, r *http.Request) (h
|
||||
io.WriteString(w, "500 An error has occurred")
|
||||
return true
|
||||
} else if cu != nil {
|
||||
if redir, ok := ctx.FirstFormValue(FormValueRedirect); ok && len(redir) > 0 {
|
||||
if redir := r.FormValue(FormValueRedirect); len(redir) > 0 {
|
||||
http.Redirect(w, r, redir, http.StatusFound)
|
||||
} else {
|
||||
http.Redirect(w, r, ctx.AuthLoginOKPath, http.StatusFound)
|
||||
|
42
storer.go
42
storer.go
@ -6,7 +6,10 @@ import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
@ -109,6 +112,45 @@ func (a AttributeMeta) Names() []string {
|
||||
// Attributes is just a key-value mapping of data.
|
||||
type Attributes map[string]interface{}
|
||||
|
||||
// Attributes converts the post form values into an attributes map.
|
||||
func AttributesFromRequest(r *http.Request) (Attributes, error) {
|
||||
attr := make(Attributes)
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for name, values := range r.Form {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
val := values[0]
|
||||
if len(val) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(name, "_int"):
|
||||
integer, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err)
|
||||
}
|
||||
attr[strings.TrimRight(name, "_int")] = integer
|
||||
case strings.HasSuffix(name, "_date"):
|
||||
date, err := time.Parse(time.RFC3339, val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err)
|
||||
}
|
||||
attr[strings.TrimRight(name, "_date")] = date.UTC()
|
||||
default:
|
||||
attr[name] = val
|
||||
}
|
||||
}
|
||||
|
||||
return attr, nil
|
||||
}
|
||||
|
||||
// Names returns the names of all the attributes.
|
||||
func (a Attributes) Names() []string {
|
||||
names := make([]string, len(a))
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -26,6 +28,41 @@ func (nt NullTime) Value() (driver.Value, error) {
|
||||
return nt.Time, nil
|
||||
}
|
||||
|
||||
func TestAttributes_FromRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
vals := make(url.Values)
|
||||
vals.Set("a", "a")
|
||||
vals.Set("b_int", "5")
|
||||
vals.Set("wildcard", "")
|
||||
vals.Set("c_date", now.Format(time.RFC3339))
|
||||
req, err := http.NewRequest("POST", "/", strings.NewReader(vals.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
attr, err := AttributesFromRequest(req)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got := attr["a"].(string); got != "a" {
|
||||
t.Error("a's value is wrong:", got)
|
||||
}
|
||||
if got := attr["b"].(int); got != 5 {
|
||||
t.Error("b's value is wrong:", got)
|
||||
}
|
||||
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
|
||||
t.Error("c's value is wrong:", now, got)
|
||||
}
|
||||
if _, ok := attr["wildcard"]; ok {
|
||||
t.Error("We don't need totally empty fields.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttributes_Names(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -3,6 +3,7 @@ package authboss
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -64,26 +65,26 @@ func (f FieldError) Error() string {
|
||||
}
|
||||
|
||||
// Validate validates a request using the given ruleset.
|
||||
func (ctx *Context) Validate(ruleset []Validator, confirmFields ...string) ErrorList {
|
||||
func Validate(r *http.Request, ruleset []Validator, confirmFields ...string) ErrorList {
|
||||
errList := make(ErrorList, 0)
|
||||
|
||||
for _, validator := range ruleset {
|
||||
field := validator.Field()
|
||||
|
||||
val, _ := ctx.FirstFormValue(field)
|
||||
val := r.FormValue(field)
|
||||
if errs := validator.Errors(val); len(errs) > 0 {
|
||||
errList = append(errList, errs...)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(confirmFields)-1; i += 2 {
|
||||
main, ok := ctx.FirstPostFormValue(confirmFields[i])
|
||||
if !ok {
|
||||
main := r.FormValue(confirmFields[i])
|
||||
if len(main) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
confirm, ok := ctx.FirstPostFormValue(confirmFields[i+1])
|
||||
if !ok || main != confirm {
|
||||
confirm := r.FormValue(confirmFields[i+1])
|
||||
if len(confirm) == 0 || main != confirm {
|
||||
errList = append(errList, FieldError{confirmFields[i+1], fmt.Errorf("Does not match %s", confirmFields[i])})
|
||||
}
|
||||
}
|
||||
|
@ -64,10 +64,9 @@ func TestErrorList_Map(t *testing.T) {
|
||||
func TestValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ctx := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com")
|
||||
req := mockRequest(StoreUsername, "john", StoreEmail, "john@john.com")
|
||||
|
||||
errList := ctx.Validate([]Validator{
|
||||
errList := Validate(req, []Validator{
|
||||
mockValidator{
|
||||
FieldName: StoreUsername,
|
||||
Errs: ErrorList{FieldError{StoreUsername, errors.New("must be longer than 4")}},
|
||||
@ -96,21 +95,20 @@ func TestValidate(t *testing.T) {
|
||||
func TestValidate_Confirm(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ctx := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny")
|
||||
errs := ctx.Validate(nil, StoreUsername, "confirmUsername").Map()
|
||||
req := mockRequest(StoreUsername, "john", "confirmUsername", "johnny")
|
||||
errs := Validate(req, nil, StoreUsername, "confirmUsername").Map()
|
||||
if errs["confirmUsername"][0] != "Does not match username" {
|
||||
t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0])
|
||||
}
|
||||
|
||||
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
|
||||
errs = ctx.Validate(nil, StoreUsername, "confirmUsername").Map()
|
||||
req = mockRequest(StoreUsername, "john", "confirmUsername", "john")
|
||||
errs = Validate(req, nil, StoreUsername, "confirmUsername").Map()
|
||||
if len(errs) != 0 {
|
||||
t.Error("Expected no errors:", errs)
|
||||
}
|
||||
|
||||
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
|
||||
errs = ctx.Validate(nil, StoreUsername).Map()
|
||||
req = mockRequest(StoreUsername, "john", "confirmUsername", "john")
|
||||
errs = Validate(req, nil, StoreUsername).Map()
|
||||
if len(errs) != 0 {
|
||||
t.Error("Expected no errors:", errs)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user