1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-02-09 13:47:09 +02:00

Merge branch 'context-request-separation'

This commit is contained in:
Aaron L 2015-08-30 06:41:19 -07:00
commit c4eb529fd9
21 changed files with 202 additions and 350 deletions

View File

@ -80,8 +80,8 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
)
return a.templates.Render(ctx, w, r, tplLogin, data)
case methodPOST:
key, _ := ctx.FirstPostFormValue(a.PrimaryID)
password, _ := ctx.FirstPostFormValue("password")
key := r.FormValue(a.PrimaryID)
password := r.FormValue("password")
errData := authboss.NewHTMLData(
"error", fmt.Sprintf("invalid %s and/or password", a.PrimaryID),
@ -120,6 +120,7 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
ctx.SessionStorer.Put(authboss.SessionKey, key)
ctx.SessionStorer.Del(authboss.SessionHalfAuthKey)
ctx.Values = map[string]string{authboss.CookieRemember: r.FormValue(authboss.CookieRemember)}
if err := a.Callbacks.FireAfter(authboss.EventAuth, ctx); err != nil {
return err

View File

@ -36,13 +36,9 @@ func testSetup() (a *Auth, s *mocks.MockStorer) {
}
func testRequest(ab *authboss.Authboss, method string, postFormValues ...string) (*authboss.Context, *httptest.ResponseRecorder, *http.Request, authboss.ClientStorerErr) {
r, err := http.NewRequest(method, "", nil)
if err != nil {
panic(err)
}
sessionStorer := mocks.NewMockClientStorer()
ctx := mocks.MockRequestContext(ab, postFormValues...)
ctx := ab.NewContext()
r := mocks.MockRequest(method, postFormValues...)
ctx.SessionStorer = sessionStorer
return ctx, httptest.NewRecorder(), r, sessionStorer
@ -243,6 +239,9 @@ func TestAuth_loginHandlerFunc_POST(t *testing.T) {
t.Error("Unexpected error:", err)
}
if _, ok := ctx.Values[authboss.CookieRemember]; !ok {
t.Error("Authboss cookie remember should be set for the callback")
}
if !cb.HasBeenCalled {
t.Error("Expected after callback to have been called")
}

View File

@ -64,10 +64,7 @@ func (a *Authboss) Init(modulesToLoad ...string) error {
// CurrentUser retrieves the current user from the session and the database.
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
ctx, err := a.ContextFromRequest(r)
if err != nil {
return nil, err
}
ctx := a.NewContext()
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
@ -168,10 +165,7 @@ func (a *Authboss) UpdatePassword(w http.ResponseWriter, r *http.Request,
return nil
}
ctx, err := a.ContextFromRequest(r)
if err != nil {
return err
}
ctx := a.NewContext()
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
return a.Callbacks.FireAfter(EventPasswordReset, ctx)

View File

@ -154,9 +154,9 @@ func (c *Confirm) confirmEmail(to, token string) {
}
func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
token, err := ctx.FirstFormValueErr(FormValueConfirm)
if err != nil {
return err
token := r.FormValue(FormValueConfirm)
if len(token) == 0 {
return authboss.ClientDataErr{FormValueConfirm}
}
toHash, err := base64.URLEncoding.DecodeString(token)

View File

@ -163,7 +163,7 @@ func TestConfirm_ConfirmHandlerErrors(t *testing.T) {
for i, test := range tests {
r, _ := http.NewRequest("GET", test.URL, nil)
w := httptest.NewRecorder()
ctx, _ := c.ContextFromRequest(r)
ctx := c.NewContext()
err := c.confirmHandler(ctx, w, r)
if err == nil {
@ -207,7 +207,7 @@ func TestConfirm_Confirm(t *testing.T) {
// Make a request with session and context support.
r, _ := http.NewRequest("GET", "http://localhost?cnf="+base64.URLEncoding.EncodeToString(token), nil)
w := httptest.NewRecorder()
ctx, _ = c.ContextFromRequest(r)
ctx = c.NewContext()
ctx.CookieStorer = mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer()
ctx.User = user

View File

@ -2,11 +2,7 @@ package authboss
import (
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
)
// FormValue constants
@ -25,8 +21,8 @@ type Context struct {
CookieStorer ClientStorerErr
User Attributes
postFormValues map[string][]string
formValues map[string][]string
// Values is a free-form key-value store to pass data to callbacks
Values map[string]string
}
// NewContext is exported for testing modules.
@ -36,75 +32,6 @@ func (a *Authboss) NewContext() *Context {
}
}
// ContextFromRequest creates a context from an http request.
func (a *Authboss) ContextFromRequest(r *http.Request) (*Context, error) {
if err := r.ParseForm(); err != nil {
return nil, err
}
c := a.NewContext()
c.formValues = map[string][]string(r.Form)
c.postFormValues = map[string][]string(r.PostForm)
return c, nil
}
// FormValue gets a form value from a context created with a request.
func (c *Context) FormValue(key string) ([]string, bool) {
val, ok := c.formValues[key]
return val, ok
}
// PostFormValue gets a form value from a context created with a request.
func (c *Context) PostFormValue(key string) ([]string, bool) {
val, ok := c.postFormValues[key]
return val, ok
}
// FirstFormValue gets the first form value from a context created with a request.
func (c *Context) FirstFormValue(key string) (string, bool) {
val, ok := c.formValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", false
}
return val[0], ok
}
// FirstPostFormValue gets the first form value from a context created with a request.
func (c *Context) FirstPostFormValue(key string) (string, bool) {
val, ok := c.postFormValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", false
}
return val[0], ok
}
// FirstFormValueErr gets the first form value from a context created with a request
// and additionally returns an error not a bool if it's not found.
func (c *Context) FirstFormValueErr(key string) (string, error) {
val, ok := c.formValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", ClientDataErr{key}
}
return val[0], nil
}
// FirstPostFormValueErr gets the first form value from a context created with a request.
func (c *Context) FirstPostFormValueErr(key string) (string, error) {
val, ok := c.postFormValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", ClientDataErr{key}
}
return val[0], nil
}
// LoadUser loads the user Attributes if they haven't already been loaded.
func (c *Context) LoadUser(key string) error {
if c.User != nil {
@ -155,34 +82,3 @@ func (c *Context) SaveUser() error {
return c.Storer.Put(key, c.User)
}
// Attributes converts the post form values into an attributes map.
func (c *Context) Attributes() (Attributes, error) {
attr := make(Attributes)
for name, values := range c.postFormValues {
if len(values) == 0 {
continue
}
val := values[0]
switch {
case strings.HasSuffix(name, "_int"):
integer, err := strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err)
}
attr[strings.TrimRight(name, "_int")] = integer
case strings.HasSuffix(name, "_date"):
date, err := time.Parse(time.RFC3339, val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err)
}
attr[strings.TrimRight(name, "_date")] = date.UTC()
default:
attr[name] = val
}
}
return attr, nil
}

View File

@ -1,76 +1,6 @@
package authboss
import (
"bytes"
"net/http"
"testing"
"time"
)
func TestContext_Request(t *testing.T) {
t.Parallel()
ab := New()
req, err := http.NewRequest("POST", "http://localhost?query=string", bytes.NewBufferString("post=form"))
if err != nil {
t.Error("Unexpected Error:", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := ab.ContextFromRequest(req)
if err != nil {
t.Error("Unexpected Error:", err)
}
if query, ok := ctx.FormValue("query"); !ok || query[0] != "string" {
t.Error("Form value not getting recorded correctly.")
}
if post, ok := ctx.PostFormValue("post"); !ok || post[0] != "form" {
t.Error("Postform value not getting recorded correctly.")
}
if query, ok := ctx.FirstFormValue("query"); !ok || query != "string" {
t.Error("Form value not getting recorded correctly.")
}
if post, ok := ctx.FirstPostFormValue("post"); !ok || post != "form" {
t.Error("Postform value not getting recorded correctly.")
}
if _, err := ctx.FirstFormValueErr("query"); err != nil {
t.Error(err)
}
if _, err := ctx.FirstPostFormValueErr("post"); err != nil {
t.Error(err)
}
if query, ok := ctx.FormValue("query1"); ok {
t.Error("Expected query1 not to be found:", query)
}
if post, ok := ctx.PostFormValue("post1"); ok {
t.Error("Expected post1 not to be found:", post)
}
if query, ok := ctx.FirstFormValue("query1"); ok {
t.Error("Expected query1 not to be found:", query)
}
if post, ok := ctx.FirstPostFormValue("post1"); ok {
t.Error("Expected post1 not to be found:", post)
}
if query, err := ctx.FirstFormValueErr("query1"); err == nil {
t.Error("Expected query1 not to be found:", query)
}
if post, err := ctx.FirstPostFormValueErr("post1"); err == nil {
t.Error("Expected post1 not to be found:", post)
}
}
import "testing"
func TestContext_SaveUser(t *testing.T) {
t.Parallel()
@ -177,36 +107,3 @@ func TestContext_LoadSessionUser(t *testing.T) {
}
}
}
func TestContext_Attributes(t *testing.T) {
t.Parallel()
now := time.Now().UTC()
ab := New()
ctx := ab.NewContext()
ctx.postFormValues = map[string][]string{
"a": []string{"a", "1"},
"b_int": []string{"5", "hello"},
"wildcard": nil,
"c_date": []string{now.Format(time.RFC3339)},
}
attr, err := ctx.Attributes()
if err != nil {
t.Error(err)
}
if got := attr["a"].(string); got != "a" {
t.Error("a's value is wrong:", got)
}
if got := attr["b"].(int); got != 5 {
t.Error("b's value is wrong:", got)
}
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
t.Error("c's value is wrong:", now, got)
}
if _, ok := attr["wildcard"]; ok {
t.Error("We don't need totally empty fields.")
}
}

View File

@ -2,10 +2,11 @@
package mocks
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"gopkg.in/authboss.v0"
@ -278,28 +279,34 @@ func (m *MockClientStorer) Put(key, val string) { m.Values[key] = val }
// Del a key/value pair
func (m *MockClientStorer) Del(key string) { delete(m.Values, key) }
// MockRequestContext returns a new context as if it came from POST request.
func MockRequestContext(ab *authboss.Authboss, postKeyValues ...string) *authboss.Context {
keyValues := &bytes.Buffer{}
for i := 0; i < len(postKeyValues); i += 2 {
if i != 0 {
keyValues.WriteByte('&')
// MockRequest returns a new mock request with optional key-value body (form-post)
func MockRequest(method string, postKeyValues ...string) *http.Request {
var body io.Reader
location := "http://localhost"
if len(postKeyValues) > 0 {
urlValues := make(url.Values)
for i := 0; i < len(postKeyValues); i += 2 {
urlValues.Set(postKeyValues[i], postKeyValues[i+1])
}
if method == "POST" || method == "PUT" {
body = strings.NewReader(urlValues.Encode())
} else {
location += "?" + urlValues.Encode()
}
fmt.Fprintf(keyValues, "%s=%s", postKeyValues[i], postKeyValues[i+1])
}
req, err := http.NewRequest("POST", "http://localhost", keyValues)
req, err := http.NewRequest(method, location, body)
if err != nil {
panic(err.Error())
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := ab.ContextFromRequest(req)
if err != nil {
panic(err)
if len(postKeyValues) > 0 {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
return ctx
return req
}
// MockMailer helps simplify mailer testing by storing the last sent email

View File

@ -68,7 +68,7 @@ func TestTemplates_Render(t *testing.T) {
r, _ := http.NewRequest("GET", "http://localhost", nil)
w := httptest.NewRecorder()
ctx, _ := ab.ContextFromRequest(r)
ctx := ab.NewContext()
ctx.SessionStorer = cookies
tpls := Templates{
@ -128,7 +128,7 @@ func TestRedirect(t *testing.T) {
r, _ := http.NewRequest("GET", "http://localhost", nil)
w := httptest.NewRecorder()
ctx, _ := ab.ContextFromRequest(r)
ctx := ab.NewContext()
ctx.SessionStorer = cookies
Redirect(ctx, w, r, "/", "success", "failure", false)
@ -157,7 +157,7 @@ func TestRedirect_Override(t *testing.T) {
r, _ := http.NewRequest("GET", "http://localhost?redir=foo/bar", nil)
w := httptest.NewRecorder()
ctx, _ := ab.ContextFromRequest(r)
ctx := ab.NewContext()
ctx.SessionStorer = cookies
Redirect(ctx, w, r, "/shouldNotGo", "success", "failure", true)

View File

@ -1,9 +1,9 @@
package authboss
import (
"bytes"
"fmt"
"net/http"
"net/url"
"strings"
)
type mockUser struct {
@ -56,27 +56,19 @@ func (m mockClientStore) GetErr(key string) (string, error) {
func (m mockClientStore) Put(key, val string) { m[key] = val }
func (m mockClientStore) Del(key string) { delete(m, key) }
func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context {
keyValues := &bytes.Buffer{}
func mockRequest(postKeyValues ...string) *http.Request {
urlValues := make(url.Values)
for i := 0; i < len(postKeyValues); i += 2 {
if i != 0 {
keyValues.WriteByte('&')
}
fmt.Fprintf(keyValues, "%s=%s", postKeyValues[i], postKeyValues[i+1])
urlValues.Set(postKeyValues[i], postKeyValues[i+1])
}
req, err := http.NewRequest("POST", "http://localhost", keyValues)
req, err := http.NewRequest("POST", "http://localhost", strings.NewReader(urlValues.Encode()))
if err != nil {
panic(err.Error())
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := ab.ContextFromRequest(req)
if err != nil {
panic(err)
}
return ctx
return req
}
type mockValidator struct {

View File

@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"path"
"time"
@ -22,6 +23,10 @@ const (
StoreRecoverTokenExpiry = "recover_token_expiry"
)
const (
formValueToken = "token"
)
const (
methodGET = "GET"
methodPOST = "POST"
@ -129,8 +134,8 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite
return rec.templates.Render(ctx, w, r, tplRecover, data)
case methodPOST:
primaryID, _ := ctx.FirstPostFormValue(rec.PrimaryID)
confirmPrimaryID, _ := ctx.FirstPostFormValue(fmt.Sprintf("confirm_%s", rec.PrimaryID))
primaryID := r.FormValue(rec.PrimaryID)
confirmPrimaryID := r.FormValue(fmt.Sprintf("confirm_%s", rec.PrimaryID))
errData := authboss.NewHTMLData(
"primaryID", rec.PrimaryID,
@ -139,7 +144,7 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite
)
policies := authboss.FilterValidators(rec.Policies, rec.PrimaryID)
if validationErrs := ctx.Validate(policies, rec.PrimaryID, authboss.ConfirmPrefix+rec.PrimaryID).Map(); len(validationErrs) > 0 {
if validationErrs := authboss.Validate(r, policies, rec.PrimaryID, authboss.ConfirmPrefix+rec.PrimaryID).Map(); len(validationErrs) > 0 {
errData.MergeKV("errs", validationErrs)
return rec.templates.Render(ctx, w, r, tplRecover, errData)
}
@ -195,7 +200,8 @@ var goRecoverEmail = func(r *Recover, to, encodedToken string) {
func (r *Recover) sendRecoverEmail(to, encodedToken string) {
p := path.Join(r.MountPath, "recover/complete")
url := fmt.Sprintf("%s%s?token=%s", r.RootURL, p, encodedToken)
query := url.Values{formValueToken: []string{encodedToken}}
url := fmt.Sprintf("%s%s?%s", r.RootURL, p, query.Encode())
email := authboss.Email{
To: []string{to},
@ -211,35 +217,35 @@ func (r *Recover) sendRecoverEmail(to, encodedToken string) {
func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, req *http.Request) (err error) {
switch req.Method {
case methodGET:
_, err = verifyToken(ctx)
_, err = verifyToken(ctx, req)
if err == errRecoveryTokenExpired {
return authboss.ErrAndRedirect{err, "/recover", "", recoverTokenExpiredFlash}
} else if err != nil {
return authboss.ErrAndRedirect{err, "/", "", ""}
}
token, _ := ctx.FirstFormValue("token")
data := authboss.NewHTMLData("token", token)
token := req.FormValue(formValueToken)
data := authboss.NewHTMLData(formValueToken, token)
return r.templates.Render(ctx, w, req, tplRecoverComplete, data)
case methodPOST:
token, err := ctx.FirstFormValueErr("token")
if err != nil {
return err
token := req.FormValue(formValueToken)
if len(token) == 0 {
return authboss.ClientDataErr{formValueToken}
}
password, _ := ctx.FirstPostFormValue("password")
password := req.FormValue(authboss.StorePassword)
//confirmPassword, _ := ctx.FirstPostFormValue("confirmPassword")
policies := authboss.FilterValidators(r.Policies, "password")
if validationErrs := ctx.Validate(policies, authboss.StorePassword, authboss.ConfirmPrefix+authboss.StorePassword).Map(); len(validationErrs) > 0 {
policies := authboss.FilterValidators(r.Policies, authboss.StorePassword)
if validationErrs := authboss.Validate(req, policies, authboss.StorePassword, authboss.ConfirmPrefix+authboss.StorePassword).Map(); len(validationErrs) > 0 {
data := authboss.NewHTMLData(
"token", token,
formValueToken, token,
"errs", validationErrs,
)
return r.templates.Render(ctx, w, req, tplRecoverComplete, data)
}
if ctx.User, err = verifyToken(ctx); err != nil {
if ctx.User, err = verifyToken(ctx, req); err != nil {
return err
}
@ -276,10 +282,10 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit
}
// verifyToken expects a base64.URLEncoded token.
func verifyToken(ctx *authboss.Context) (attrs authboss.Attributes, err error) {
token, err := ctx.FirstFormValueErr("token")
if err != nil {
return nil, err
func verifyToken(ctx *authboss.Context, r *http.Request) (attrs authboss.Attributes, err error) {
token := r.FormValue(formValueToken)
if len(token) == 0 {
return nil, authboss.ClientDataErr{token}
}
decoded, err := base64.URLEncoding.DecodeString(token)

View File

@ -63,13 +63,9 @@ func testSetup() (r *Recover, s *mocks.MockStorer, l *bytes.Buffer) {
}
func testRequest(ab *authboss.Authboss, method string, postFormValues ...string) (*authboss.Context, *httptest.ResponseRecorder, *http.Request, authboss.ClientStorerErr) {
r, err := http.NewRequest(method, "", nil)
if err != nil {
panic(err)
}
sessionStorer := mocks.NewMockClientStorer()
ctx := mocks.MockRequestContext(ab, postFormValues...)
ctx := ab.NewContext()
r := mocks.MockRequest(method, postFormValues...)
ctx.SessionStorer = sessionStorer
return ctx, httptest.NewRecorder(), r, sessionStorer
@ -300,7 +296,7 @@ func TestRecover_sendRecoverEmail(t *testing.T) {
t.Error("Unexpected subject:", mailer.Last.Subject)
}
url := fmt.Sprintf("%s/recover/complete?token=abc=", r.RootURL)
url := fmt.Sprintf("%s/recover/complete?token=abc%%3D", r.RootURL)
if !strings.Contains(mailer.Last.HTMLBody, url) {
t.Error("Expected HTMLBody to contain url:", url)
}
@ -319,7 +315,7 @@ func TestRecover_completeHandlerFunc_GET_VerifyFails(t *testing.T) {
err := rec.completeHandlerFunc(ctx, w, r)
rerr, ok := err.(authboss.ErrAndRedirect)
if !ok {
t.Error("Expected ErrAndRedirect")
t.Error("Expected ErrAndRedirect:", err)
}
if rerr.Location != "/" {
t.Error("Unexpected location:", rerr.Location)
@ -382,7 +378,7 @@ func TestRecover_completeHandlerFunc_POST_TokenMissing(t *testing.T) {
ctx, w, r, _ := testRequest(rec.Authboss, "POST")
err := rec.completeHandlerFunc(ctx, w, r)
if err.Error() != "Failed to retrieve client attribute: token" {
if err == nil || err.Error() != "Failed to retrieve client attribute: token" {
t.Error("Unexpected error:", err)
}
@ -477,9 +473,9 @@ func Test_verifyToken_MissingToken(t *testing.T) {
t.Parallel()
testSetup()
r := mocks.MockRequest("GET")
ctx := &authboss.Context{}
if _, err := verifyToken(ctx); err == nil {
if _, err := verifyToken(nil, r); err == nil {
t.Error("Expected error about missing token")
}
}
@ -492,8 +488,9 @@ func Test_verifyToken_InvalidToken(t *testing.T) {
StoreRecoverToken: testStdBase64Token,
}
ctx := mocks.MockRequestContext(rec.Authboss, "token", "asdf")
if _, err := verifyToken(ctx); err != authboss.ErrUserNotFound {
ctx := rec.Authboss.NewContext()
req, _ := http.NewRequest("GET", "/?token=asdf", nil)
if _, err := verifyToken(ctx, req); err != authboss.ErrUserNotFound {
t.Error("Unexpected error:", err)
}
}
@ -507,8 +504,9 @@ func Test_verifyToken_ExpiredToken(t *testing.T) {
StoreRecoverTokenExpiry: time.Now().Add(time.Duration(-24) * time.Hour),
}
ctx := mocks.MockRequestContext(rec.Authboss, "token", testURLBase64Token)
if _, err := verifyToken(ctx); err != errRecoveryTokenExpired {
ctx := rec.Authboss.NewContext()
req, _ := http.NewRequest("GET", "/?token="+testURLBase64Token, nil)
if _, err := verifyToken(ctx, req); err != errRecoveryTokenExpired {
t.Error("Unexpected error:", err)
}
}
@ -522,8 +520,9 @@ func Test_verifyToken(t *testing.T) {
StoreRecoverTokenExpiry: time.Now().Add(time.Duration(24) * time.Hour),
}
ctx := mocks.MockRequestContext(rec.Authboss, "token", testURLBase64Token)
attrs, err := verifyToken(ctx)
ctx := rec.Authboss.NewContext()
req, _ := http.NewRequest("GET", "/?token="+testURLBase64Token, nil)
attrs, err := verifyToken(ctx, req)
if err != nil {
t.Error("Unexpected error:", err)
}

View File

@ -82,10 +82,10 @@ func (reg *Register) registerHandler(ctx *authboss.Context, w http.ResponseWrite
}
func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
key, _ := ctx.FirstPostFormValue(reg.PrimaryID)
password, _ := ctx.FirstPostFormValue(authboss.StorePassword)
key := r.FormValue(reg.PrimaryID)
password := r.FormValue(authboss.StorePassword)
validationErrs := ctx.Validate(reg.Policies, reg.ConfirmFields...)
validationErrs := authboss.Validate(r, reg.Policies, reg.ConfirmFields...)
if user, err := ctx.Storer.Get(key); err != nil && err != authboss.ErrUserNotFound {
return err
@ -101,13 +101,13 @@ func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseW
}
for _, f := range reg.PreserveFields {
data[f], _ = ctx.FirstFormValue(f)
data[f] = r.FormValue(f)
}
return reg.templates.Render(ctx, w, r, tplRegister, data)
}
attr, err := ctx.Attributes() // Attributes from overriden forms
attr, err := authboss.AttributesFromRequest(r) // Attributes from overriden forms
if err != nil {
return err
}
@ -129,7 +129,7 @@ func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseW
}
for _, f := range reg.PreserveFields {
data[f], _ = ctx.FirstFormValue(f)
data[f] = r.FormValue(f)
}
return reg.templates.Render(ctx, w, r, tplRegister, data)

View File

@ -58,7 +58,7 @@ func TestRegisterGet(t *testing.T) {
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/register", nil)
ctx, _ := reg.ContextFromRequest(r)
ctx := reg.NewContext()
ctx.SessionStorer = mocks.NewMockClientStorer()
if err := reg.registerHandler(ctx, w, r); err != nil {
@ -93,7 +93,7 @@ func TestRegisterPostValidationErrs(t *testing.T) {
r, _ := http.NewRequest("POST", "/register", bytes.NewBufferString(vals.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, _ := reg.ContextFromRequest(r)
ctx := reg.NewContext()
ctx.SessionStorer = mocks.NewMockClientStorer()
if err := reg.registerHandler(ctx, w, r); err != nil {
@ -131,7 +131,7 @@ func TestRegisterPostSuccess(t *testing.T) {
r, _ := http.NewRequest("POST", "/register", bytes.NewBufferString(vals.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, _ := reg.ContextFromRequest(r)
ctx := reg.NewContext()
ctx.SessionStorer = mocks.NewMockClientStorer()
if err := reg.registerHandler(ctx, w, r); err != nil {

View File

@ -83,7 +83,7 @@ func (r *Remember) Storage() authboss.StorageOptions {
// afterAuth is called after authentication is successful.
func (r *Remember) afterAuth(ctx *authboss.Context) error {
if val, ok := ctx.FirstPostFormValue(authboss.CookieRemember); !ok || val != "true" {
if val := ctx.Values[authboss.CookieRemember]; val != "true" {
return nil
}

View File

@ -2,7 +2,6 @@ package remember
import (
"bytes"
"fmt"
"net/http"
"testing"
@ -49,15 +48,13 @@ func TestAfterAuth(t *testing.T) {
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := r.ContextFromRequest(req)
if err != nil {
t.Error("Unexpected error:", err)
}
ctx := r.NewContext()
ctx.SessionStorer = session
ctx.CookieStorer = cookies
ctx.User = authboss.Attributes{r.PrimaryID: "test@email.com"}
ctx.Values = map[string]string{authboss.CookieRemember: "true"}
if err := r.afterAuth(ctx); err != nil {
t.Error(err)
}
@ -77,17 +74,7 @@ func TestAfterOAuth(t *testing.T) {
cookies := mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer(authboss.SessionOAuth2Params, `{"rm":"true"}`)
uri := fmt.Sprintf("%s?state=%s", "localhost/oauthed", "xsrf")
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
t.Error("Unexpected Error:", err)
}
ctx, err := r.ContextFromRequest(req)
if err != nil {
t.Error("Unexpected error:", err)
}
ctx := r.NewContext()
ctx.SessionStorer = session
ctx.CookieStorer = cookies
ctx.User = authboss.Attributes{

View File

@ -47,11 +47,7 @@ type contextRoute struct {
func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Instantiate the context
ctx, err := c.Authboss.ContextFromRequest(r)
if err != nil {
fmt.Fprintf(c.LogWriter, "route: Malformed request, could not create context: %v", err)
return
}
ctx := c.Authboss.NewContext()
ctx.CookieStorer = clientStoreWrapper{c.CookieStoreMaker(w, r)}
ctx.SessionStorer = clientStoreWrapper{c.SessionStoreMaker(w, r)}
@ -61,7 +57,7 @@ func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// Call the handler
err = c.fn(ctx, w, r)
err := c.fn(ctx, w, r)
if err == nil {
return
}
@ -121,7 +117,7 @@ func redirectIfLoggedIn(ctx *Context, w http.ResponseWriter, r *http.Request) (h
io.WriteString(w, "500 An error has occurred")
return true
} else if cu != nil {
if redir, ok := ctx.FirstFormValue(FormValueRedirect); ok && len(redir) > 0 {
if redir := r.FormValue(FormValueRedirect); len(redir) > 0 {
http.Redirect(w, r, redir, http.StatusFound)
} else {
http.Redirect(w, r, ctx.AuthLoginOKPath, http.StatusFound)

View File

@ -6,7 +6,10 @@ import (
"database/sql/driver"
"errors"
"fmt"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"unicode"
)
@ -109,6 +112,45 @@ func (a AttributeMeta) Names() []string {
// Attributes is just a key-value mapping of data.
type Attributes map[string]interface{}
// Attributes converts the post form values into an attributes map.
func AttributesFromRequest(r *http.Request) (Attributes, error) {
attr := make(Attributes)
if err := r.ParseForm(); err != nil {
return nil, err
}
for name, values := range r.Form {
if len(values) == 0 {
continue
}
val := values[0]
if len(val) == 0 {
continue
}
switch {
case strings.HasSuffix(name, "_int"):
integer, err := strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err)
}
attr[strings.TrimRight(name, "_int")] = integer
case strings.HasSuffix(name, "_date"):
date, err := time.Parse(time.RFC3339, val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err)
}
attr[strings.TrimRight(name, "_date")] = date.UTC()
default:
attr[name] = val
}
}
return attr, nil
}
// Names returns the names of all the attributes.
func (a Attributes) Names() []string {
names := make([]string, len(a))

View File

@ -4,6 +4,8 @@ import (
"bytes"
"database/sql"
"database/sql/driver"
"net/http"
"net/url"
"strings"
"testing"
"time"
@ -26,6 +28,41 @@ func (nt NullTime) Value() (driver.Value, error) {
return nt.Time, nil
}
func TestAttributes_FromRequest(t *testing.T) {
t.Parallel()
now := time.Now().UTC()
vals := make(url.Values)
vals.Set("a", "a")
vals.Set("b_int", "5")
vals.Set("wildcard", "")
vals.Set("c_date", now.Format(time.RFC3339))
req, err := http.NewRequest("POST", "/", strings.NewReader(vals.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if err != nil {
t.Error(err)
}
attr, err := AttributesFromRequest(req)
if err != nil {
t.Error(err)
}
if got := attr["a"].(string); got != "a" {
t.Error("a's value is wrong:", got)
}
if got := attr["b"].(int); got != 5 {
t.Error("b's value is wrong:", got)
}
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
t.Error("c's value is wrong:", now, got)
}
if _, ok := attr["wildcard"]; ok {
t.Error("We don't need totally empty fields.")
}
}
func TestAttributes_Names(t *testing.T) {
t.Parallel()

View File

@ -3,6 +3,7 @@ package authboss
import (
"bytes"
"fmt"
"net/http"
)
const (
@ -64,26 +65,26 @@ func (f FieldError) Error() string {
}
// Validate validates a request using the given ruleset.
func (ctx *Context) Validate(ruleset []Validator, confirmFields ...string) ErrorList {
func Validate(r *http.Request, ruleset []Validator, confirmFields ...string) ErrorList {
errList := make(ErrorList, 0)
for _, validator := range ruleset {
field := validator.Field()
val, _ := ctx.FirstFormValue(field)
val := r.FormValue(field)
if errs := validator.Errors(val); len(errs) > 0 {
errList = append(errList, errs...)
}
}
for i := 0; i < len(confirmFields)-1; i += 2 {
main, ok := ctx.FirstPostFormValue(confirmFields[i])
if !ok {
main := r.FormValue(confirmFields[i])
if len(main) == 0 {
continue
}
confirm, ok := ctx.FirstPostFormValue(confirmFields[i+1])
if !ok || main != confirm {
confirm := r.FormValue(confirmFields[i+1])
if len(confirm) == 0 || main != confirm {
errList = append(errList, FieldError{confirmFields[i+1], fmt.Errorf("Does not match %s", confirmFields[i])})
}
}

View File

@ -64,10 +64,9 @@ func TestErrorList_Map(t *testing.T) {
func TestValidate(t *testing.T) {
t.Parallel()
ab := New()
ctx := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com")
req := mockRequest(StoreUsername, "john", StoreEmail, "john@john.com")
errList := ctx.Validate([]Validator{
errList := Validate(req, []Validator{
mockValidator{
FieldName: StoreUsername,
Errs: ErrorList{FieldError{StoreUsername, errors.New("must be longer than 4")}},
@ -96,21 +95,20 @@ func TestValidate(t *testing.T) {
func TestValidate_Confirm(t *testing.T) {
t.Parallel()
ab := New()
ctx := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny")
errs := ctx.Validate(nil, StoreUsername, "confirmUsername").Map()
req := mockRequest(StoreUsername, "john", "confirmUsername", "johnny")
errs := Validate(req, nil, StoreUsername, "confirmUsername").Map()
if errs["confirmUsername"][0] != "Does not match username" {
t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0])
}
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
errs = ctx.Validate(nil, StoreUsername, "confirmUsername").Map()
req = mockRequest(StoreUsername, "john", "confirmUsername", "john")
errs = Validate(req, nil, StoreUsername, "confirmUsername").Map()
if len(errs) != 0 {
t.Error("Expected no errors:", errs)
}
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
errs = ctx.Validate(nil, StoreUsername).Map()
req = mockRequest(StoreUsername, "john", "confirmUsername", "john")
errs = Validate(req, nil, StoreUsername).Map()
if len(errs) != 0 {
t.Error("Expected no errors:", errs)
}