1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-02-13 13:58:38 +02:00

Separate the request from context.

This commit is contained in:
Aaron L 2015-08-02 11:51:35 -07:00
parent 2eff32e3c8
commit 8a87d0de63
10 changed files with 118 additions and 262 deletions

View File

@ -64,10 +64,7 @@ func (a *Authboss) Init(modulesToLoad ...string) error {
// CurrentUser retrieves the current user from the session and the database. // CurrentUser retrieves the current user from the session and the database.
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) { func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
ctx, err := a.ContextFromRequest(r) ctx := a.NewContext()
if err != nil {
return nil, err
}
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)} ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)} ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
@ -168,10 +165,7 @@ func (a *Authboss) UpdatePassword(w http.ResponseWriter, r *http.Request,
return nil return nil
} }
ctx, err := a.ContextFromRequest(r) ctx := a.NewContext()
if err != nil {
return err
}
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)} ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)} ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
return a.Callbacks.FireAfter(EventPasswordReset, ctx) return a.Callbacks.FireAfter(EventPasswordReset, ctx)

View File

@ -2,11 +2,7 @@ package authboss
import ( import (
"errors" "errors"
"fmt"
"net/http"
"strconv"
"strings" "strings"
"time"
) )
// FormValue constants // FormValue constants
@ -24,9 +20,6 @@ type Context struct {
SessionStorer ClientStorerErr SessionStorer ClientStorerErr
CookieStorer ClientStorerErr CookieStorer ClientStorerErr
User Attributes User Attributes
postFormValues map[string][]string
formValues map[string][]string
} }
// NewContext is exported for testing modules. // NewContext is exported for testing modules.
@ -36,75 +29,6 @@ func (a *Authboss) NewContext() *Context {
} }
} }
// ContextFromRequest creates a context from an http request.
func (a *Authboss) ContextFromRequest(r *http.Request) (*Context, error) {
if err := r.ParseForm(); err != nil {
return nil, err
}
c := a.NewContext()
c.formValues = map[string][]string(r.Form)
c.postFormValues = map[string][]string(r.PostForm)
return c, nil
}
// FormValue gets a form value from a context created with a request.
func (c *Context) FormValue(key string) ([]string, bool) {
val, ok := c.formValues[key]
return val, ok
}
// PostFormValue gets a form value from a context created with a request.
func (c *Context) PostFormValue(key string) ([]string, bool) {
val, ok := c.postFormValues[key]
return val, ok
}
// FirstFormValue gets the first form value from a context created with a request.
func (c *Context) FirstFormValue(key string) (string, bool) {
val, ok := c.formValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", false
}
return val[0], ok
}
// FirstPostFormValue gets the first form value from a context created with a request.
func (c *Context) FirstPostFormValue(key string) (string, bool) {
val, ok := c.postFormValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", false
}
return val[0], ok
}
// FirstFormValueErr gets the first form value from a context created with a request
// and additionally returns an error not a bool if it's not found.
func (c *Context) FirstFormValueErr(key string) (string, error) {
val, ok := c.formValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", ClientDataErr{key}
}
return val[0], nil
}
// FirstPostFormValueErr gets the first form value from a context created with a request.
func (c *Context) FirstPostFormValueErr(key string) (string, error) {
val, ok := c.postFormValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", ClientDataErr{key}
}
return val[0], nil
}
// LoadUser loads the user Attributes if they haven't already been loaded. // LoadUser loads the user Attributes if they haven't already been loaded.
func (c *Context) LoadUser(key string) error { func (c *Context) LoadUser(key string) error {
if c.User != nil { if c.User != nil {
@ -155,34 +79,3 @@ func (c *Context) SaveUser() error {
return c.Storer.Put(key, c.User) return c.Storer.Put(key, c.User)
} }
// Attributes converts the post form values into an attributes map.
func (c *Context) Attributes() (Attributes, error) {
attr := make(Attributes)
for name, values := range c.postFormValues {
if len(values) == 0 {
continue
}
val := values[0]
switch {
case strings.HasSuffix(name, "_int"):
integer, err := strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err)
}
attr[strings.TrimRight(name, "_int")] = integer
case strings.HasSuffix(name, "_date"):
date, err := time.Parse(time.RFC3339, val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err)
}
attr[strings.TrimRight(name, "_date")] = date.UTC()
default:
attr[name] = val
}
}
return attr, nil
}

View File

@ -1,76 +1,6 @@
package authboss package authboss
import ( import "testing"
"bytes"
"net/http"
"testing"
"time"
)
func TestContext_Request(t *testing.T) {
t.Parallel()
ab := New()
req, err := http.NewRequest("POST", "http://localhost?query=string", bytes.NewBufferString("post=form"))
if err != nil {
t.Error("Unexpected Error:", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := ab.ContextFromRequest(req)
if err != nil {
t.Error("Unexpected Error:", err)
}
if query, ok := ctx.FormValue("query"); !ok || query[0] != "string" {
t.Error("Form value not getting recorded correctly.")
}
if post, ok := ctx.PostFormValue("post"); !ok || post[0] != "form" {
t.Error("Postform value not getting recorded correctly.")
}
if query, ok := ctx.FirstFormValue("query"); !ok || query != "string" {
t.Error("Form value not getting recorded correctly.")
}
if post, ok := ctx.FirstPostFormValue("post"); !ok || post != "form" {
t.Error("Postform value not getting recorded correctly.")
}
if _, err := ctx.FirstFormValueErr("query"); err != nil {
t.Error(err)
}
if _, err := ctx.FirstPostFormValueErr("post"); err != nil {
t.Error(err)
}
if query, ok := ctx.FormValue("query1"); ok {
t.Error("Expected query1 not to be found:", query)
}
if post, ok := ctx.PostFormValue("post1"); ok {
t.Error("Expected post1 not to be found:", post)
}
if query, ok := ctx.FirstFormValue("query1"); ok {
t.Error("Expected query1 not to be found:", query)
}
if post, ok := ctx.FirstPostFormValue("post1"); ok {
t.Error("Expected post1 not to be found:", post)
}
if query, err := ctx.FirstFormValueErr("query1"); err == nil {
t.Error("Expected query1 not to be found:", query)
}
if post, err := ctx.FirstPostFormValueErr("post1"); err == nil {
t.Error("Expected post1 not to be found:", post)
}
}
func TestContext_SaveUser(t *testing.T) { func TestContext_SaveUser(t *testing.T) {
t.Parallel() t.Parallel()
@ -177,36 +107,3 @@ func TestContext_LoadSessionUser(t *testing.T) {
} }
} }
} }
func TestContext_Attributes(t *testing.T) {
t.Parallel()
now := time.Now().UTC()
ab := New()
ctx := ab.NewContext()
ctx.postFormValues = map[string][]string{
"a": []string{"a", "1"},
"b_int": []string{"5", "hello"},
"wildcard": nil,
"c_date": []string{now.Format(time.RFC3339)},
}
attr, err := ctx.Attributes()
if err != nil {
t.Error(err)
}
if got := attr["a"].(string); got != "a" {
t.Error("a's value is wrong:", got)
}
if got := attr["b"].(int); got != 5 {
t.Error("b's value is wrong:", got)
}
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
t.Error("c's value is wrong:", now, got)
}
if _, ok := attr["wildcard"]; ok {
t.Error("We don't need totally empty fields.")
}
}

View File

@ -2,10 +2,11 @@
package mocks package mocks
import ( import (
"bytes"
"errors" "errors"
"fmt" "io"
"net/http" "net/http"
"net/url"
"strings"
"time" "time"
"gopkg.in/authboss.v0" "gopkg.in/authboss.v0"
@ -278,28 +279,28 @@ func (m *MockClientStorer) Put(key, val string) { m.Values[key] = val }
// Del a key/value pair // Del a key/value pair
func (m *MockClientStorer) Del(key string) { delete(m.Values, key) } func (m *MockClientStorer) Del(key string) { delete(m.Values, key) }
// MockRequestContext returns a new context as if it came from POST request. // MockRequest returns a new mock request with optional key-value body (form-post)
func MockRequestContext(ab *authboss.Authboss, postKeyValues ...string) *authboss.Context { func MockRequest(method string, postKeyValues ...string) *http.Request {
keyValues := &bytes.Buffer{} var body io.Reader
if len(postKeyValues) > 0 {
urlValues := make(url.Values)
for i := 0; i < len(postKeyValues); i += 2 { for i := 0; i < len(postKeyValues); i += 2 {
if i != 0 { urlValues.Set(postKeyValues[i], postKeyValues[i+1])
keyValues.WriteByte('&')
} }
fmt.Fprintf(keyValues, "%s=%s", postKeyValues[i], postKeyValues[i+1]) body = strings.NewReader(urlValues.Encode())
} }
req, err := http.NewRequest("POST", "http://localhost", keyValues) req, err := http.NewRequest(method, "http://localhost", body)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := ab.ContextFromRequest(req) if len(postKeyValues) > 0 {
if err != nil { req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
panic(err)
} }
return ctx return req
} }
// MockMailer helps simplify mailer testing by storing the last sent email // MockMailer helps simplify mailer testing by storing the last sent email

View File

@ -56,7 +56,7 @@ func (m mockClientStore) GetErr(key string) (string, error) {
func (m mockClientStore) Put(key, val string) { m[key] = val } func (m mockClientStore) Put(key, val string) { m[key] = val }
func (m mockClientStore) Del(key string) { delete(m, key) } func (m mockClientStore) Del(key string) { delete(m, key) }
func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context { func mockRequestContext(ab *Authboss, postKeyValues ...string) (*Context, *http.Request) {
keyValues := &bytes.Buffer{} keyValues := &bytes.Buffer{}
for i := 0; i < len(postKeyValues); i += 2 { for i := 0; i < len(postKeyValues); i += 2 {
if i != 0 { if i != 0 {
@ -71,12 +71,7 @@ func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context {
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := ab.ContextFromRequest(req) return ab.NewContext(), req
if err != nil {
panic(err)
}
return ctx
} }
type mockValidator struct { type mockValidator struct {

View File

@ -47,11 +47,7 @@ type contextRoute struct {
func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Instantiate the context // Instantiate the context
ctx, err := c.Authboss.ContextFromRequest(r) ctx := c.Authboss.NewContext()
if err != nil {
fmt.Fprintf(c.LogWriter, "route: Malformed request, could not create context: %v", err)
return
}
ctx.CookieStorer = clientStoreWrapper{c.CookieStoreMaker(w, r)} ctx.CookieStorer = clientStoreWrapper{c.CookieStoreMaker(w, r)}
ctx.SessionStorer = clientStoreWrapper{c.SessionStoreMaker(w, r)} ctx.SessionStorer = clientStoreWrapper{c.SessionStoreMaker(w, r)}
@ -61,7 +57,7 @@ func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// Call the handler // Call the handler
err = c.fn(ctx, w, r) err := c.fn(ctx, w, r)
if err == nil { if err == nil {
return return
} }
@ -121,7 +117,7 @@ func redirectIfLoggedIn(ctx *Context, w http.ResponseWriter, r *http.Request) (h
io.WriteString(w, "500 An error has occurred") io.WriteString(w, "500 An error has occurred")
return true return true
} else if cu != nil { } else if cu != nil {
if redir, ok := ctx.FirstFormValue(FormValueRedirect); ok && len(redir) > 0 { if redir := r.FormValue(FormValueRedirect); len(redir) > 0 {
http.Redirect(w, r, redir, http.StatusFound) http.Redirect(w, r, redir, http.StatusFound)
} else { } else {
http.Redirect(w, r, ctx.AuthLoginOKPath, http.StatusFound) http.Redirect(w, r, ctx.AuthLoginOKPath, http.StatusFound)

View File

@ -6,7 +6,10 @@ import (
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"net/http"
"reflect" "reflect"
"strconv"
"strings"
"time" "time"
"unicode" "unicode"
) )
@ -109,6 +112,45 @@ func (a AttributeMeta) Names() []string {
// Attributes is just a key-value mapping of data. // Attributes is just a key-value mapping of data.
type Attributes map[string]interface{} type Attributes map[string]interface{}
// Attributes converts the post form values into an attributes map.
func AttributesFromRequest(r *http.Request) (Attributes, error) {
attr := make(Attributes)
if err := r.ParseForm(); err != nil {
return nil, err
}
for name, values := range r.Form {
if len(values) == 0 {
continue
}
val := values[0]
if len(val) == 0 {
continue
}
switch {
case strings.HasSuffix(name, "_int"):
integer, err := strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err)
}
attr[strings.TrimRight(name, "_int")] = integer
case strings.HasSuffix(name, "_date"):
date, err := time.Parse(time.RFC3339, val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err)
}
attr[strings.TrimRight(name, "_date")] = date.UTC()
default:
attr[name] = val
}
}
return attr, nil
}
// Names returns the names of all the attributes. // Names returns the names of all the attributes.
func (a Attributes) Names() []string { func (a Attributes) Names() []string {
names := make([]string, len(a)) names := make([]string, len(a))

View File

@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"net/http"
"net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -26,6 +28,41 @@ func (nt NullTime) Value() (driver.Value, error) {
return nt.Time, nil return nt.Time, nil
} }
func TestAttributes_FromRequest(t *testing.T) {
t.Parallel()
now := time.Now().UTC()
vals := make(url.Values)
vals.Set("a", "a")
vals.Set("b_int", "5")
vals.Set("wildcard", "")
vals.Set("c_date", now.Format(time.RFC3339))
req, err := http.NewRequest("POST", "/", strings.NewReader(vals.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if err != nil {
t.Error(err)
}
attr, err := AttributesFromRequest(req)
if err != nil {
t.Error(err)
}
if got := attr["a"].(string); got != "a" {
t.Error("a's value is wrong:", got)
}
if got := attr["b"].(int); got != 5 {
t.Error("b's value is wrong:", got)
}
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
t.Error("c's value is wrong:", now, got)
}
if _, ok := attr["wildcard"]; ok {
t.Error("We don't need totally empty fields.")
}
}
func TestAttributes_Names(t *testing.T) { func TestAttributes_Names(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -3,6 +3,7 @@ package authboss
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net/http"
) )
const ( const (
@ -64,26 +65,26 @@ func (f FieldError) Error() string {
} }
// Validate validates a request using the given ruleset. // Validate validates a request using the given ruleset.
func (ctx *Context) Validate(ruleset []Validator, confirmFields ...string) ErrorList { func (ctx *Context) Validate(r *http.Request, ruleset []Validator, confirmFields ...string) ErrorList {
errList := make(ErrorList, 0) errList := make(ErrorList, 0)
for _, validator := range ruleset { for _, validator := range ruleset {
field := validator.Field() field := validator.Field()
val, _ := ctx.FirstFormValue(field) val := r.FormValue(field)
if errs := validator.Errors(val); len(errs) > 0 { if errs := validator.Errors(val); len(errs) > 0 {
errList = append(errList, errs...) errList = append(errList, errs...)
} }
} }
for i := 0; i < len(confirmFields)-1; i += 2 { for i := 0; i < len(confirmFields)-1; i += 2 {
main, ok := ctx.FirstPostFormValue(confirmFields[i]) main := r.FormValue(confirmFields[i])
if !ok { if len(main) == 0 {
continue continue
} }
confirm, ok := ctx.FirstPostFormValue(confirmFields[i+1]) confirm := r.FormValue(confirmFields[i+1])
if !ok || main != confirm { if len(confirm) == 0 || main != confirm {
errList = append(errList, FieldError{confirmFields[i+1], fmt.Errorf("Does not match %s", confirmFields[i])}) errList = append(errList, FieldError{confirmFields[i+1], fmt.Errorf("Does not match %s", confirmFields[i])})
} }
} }

View File

@ -65,9 +65,9 @@ func TestValidate(t *testing.T) {
t.Parallel() t.Parallel()
ab := New() ab := New()
ctx := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com") ctx, req := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com")
errList := ctx.Validate([]Validator{ errList := ctx.Validate(req, []Validator{
mockValidator{ mockValidator{
FieldName: StoreUsername, FieldName: StoreUsername,
Errs: ErrorList{FieldError{StoreUsername, errors.New("must be longer than 4")}}, Errs: ErrorList{FieldError{StoreUsername, errors.New("must be longer than 4")}},
@ -97,20 +97,20 @@ func TestValidate_Confirm(t *testing.T) {
t.Parallel() t.Parallel()
ab := New() ab := New()
ctx := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny") ctx, req := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny")
errs := ctx.Validate(nil, StoreUsername, "confirmUsername").Map() errs := ctx.Validate(req, nil, StoreUsername, "confirmUsername").Map()
if errs["confirmUsername"][0] != "Does not match username" { if errs["confirmUsername"][0] != "Does not match username" {
t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0]) t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0])
} }
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john") ctx, req = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
errs = ctx.Validate(nil, StoreUsername, "confirmUsername").Map() errs = ctx.Validate(req, nil, StoreUsername, "confirmUsername").Map()
if len(errs) != 0 { if len(errs) != 0 {
t.Error("Expected no errors:", errs) t.Error("Expected no errors:", errs)
} }
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john") ctx, req = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
errs = ctx.Validate(nil, StoreUsername).Map() errs = ctx.Validate(req, nil, StoreUsername).Map()
if len(errs) != 0 { if len(errs) != 0 {
t.Error("Expected no errors:", errs) t.Error("Expected no errors:", errs)
} }