1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-02-01 13:17:43 +02:00

Separate the request from context.

This commit is contained in:
Aaron L 2015-08-02 11:51:35 -07:00
parent 2eff32e3c8
commit 8a87d0de63
10 changed files with 118 additions and 262 deletions

View File

@ -64,10 +64,7 @@ func (a *Authboss) Init(modulesToLoad ...string) error {
// CurrentUser retrieves the current user from the session and the database.
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
ctx, err := a.ContextFromRequest(r)
if err != nil {
return nil, err
}
ctx := a.NewContext()
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
@ -168,10 +165,7 @@ func (a *Authboss) UpdatePassword(w http.ResponseWriter, r *http.Request,
return nil
}
ctx, err := a.ContextFromRequest(r)
if err != nil {
return err
}
ctx := a.NewContext()
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
return a.Callbacks.FireAfter(EventPasswordReset, ctx)

View File

@ -2,11 +2,7 @@ package authboss
import (
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
)
// FormValue constants
@ -24,9 +20,6 @@ type Context struct {
SessionStorer ClientStorerErr
CookieStorer ClientStorerErr
User Attributes
postFormValues map[string][]string
formValues map[string][]string
}
// NewContext is exported for testing modules.
@ -36,75 +29,6 @@ func (a *Authboss) NewContext() *Context {
}
}
// ContextFromRequest creates a context from an http request.
func (a *Authboss) ContextFromRequest(r *http.Request) (*Context, error) {
if err := r.ParseForm(); err != nil {
return nil, err
}
c := a.NewContext()
c.formValues = map[string][]string(r.Form)
c.postFormValues = map[string][]string(r.PostForm)
return c, nil
}
// FormValue gets a form value from a context created with a request.
func (c *Context) FormValue(key string) ([]string, bool) {
val, ok := c.formValues[key]
return val, ok
}
// PostFormValue gets a form value from a context created with a request.
func (c *Context) PostFormValue(key string) ([]string, bool) {
val, ok := c.postFormValues[key]
return val, ok
}
// FirstFormValue gets the first form value from a context created with a request.
func (c *Context) FirstFormValue(key string) (string, bool) {
val, ok := c.formValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", false
}
return val[0], ok
}
// FirstPostFormValue gets the first form value from a context created with a request.
func (c *Context) FirstPostFormValue(key string) (string, bool) {
val, ok := c.postFormValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", false
}
return val[0], ok
}
// FirstFormValueErr gets the first form value from a context created with a request
// and additionally returns an error not a bool if it's not found.
func (c *Context) FirstFormValueErr(key string) (string, error) {
val, ok := c.formValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", ClientDataErr{key}
}
return val[0], nil
}
// FirstPostFormValueErr gets the first form value from a context created with a request.
func (c *Context) FirstPostFormValueErr(key string) (string, error) {
val, ok := c.postFormValues[key]
if !ok || len(val) == 0 || len(val[0]) == 0 {
return "", ClientDataErr{key}
}
return val[0], nil
}
// LoadUser loads the user Attributes if they haven't already been loaded.
func (c *Context) LoadUser(key string) error {
if c.User != nil {
@ -155,34 +79,3 @@ func (c *Context) SaveUser() error {
return c.Storer.Put(key, c.User)
}
// Attributes converts the post form values into an attributes map.
func (c *Context) Attributes() (Attributes, error) {
attr := make(Attributes)
for name, values := range c.postFormValues {
if len(values) == 0 {
continue
}
val := values[0]
switch {
case strings.HasSuffix(name, "_int"):
integer, err := strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err)
}
attr[strings.TrimRight(name, "_int")] = integer
case strings.HasSuffix(name, "_date"):
date, err := time.Parse(time.RFC3339, val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err)
}
attr[strings.TrimRight(name, "_date")] = date.UTC()
default:
attr[name] = val
}
}
return attr, nil
}

View File

@ -1,76 +1,6 @@
package authboss
import (
"bytes"
"net/http"
"testing"
"time"
)
func TestContext_Request(t *testing.T) {
t.Parallel()
ab := New()
req, err := http.NewRequest("POST", "http://localhost?query=string", bytes.NewBufferString("post=form"))
if err != nil {
t.Error("Unexpected Error:", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := ab.ContextFromRequest(req)
if err != nil {
t.Error("Unexpected Error:", err)
}
if query, ok := ctx.FormValue("query"); !ok || query[0] != "string" {
t.Error("Form value not getting recorded correctly.")
}
if post, ok := ctx.PostFormValue("post"); !ok || post[0] != "form" {
t.Error("Postform value not getting recorded correctly.")
}
if query, ok := ctx.FirstFormValue("query"); !ok || query != "string" {
t.Error("Form value not getting recorded correctly.")
}
if post, ok := ctx.FirstPostFormValue("post"); !ok || post != "form" {
t.Error("Postform value not getting recorded correctly.")
}
if _, err := ctx.FirstFormValueErr("query"); err != nil {
t.Error(err)
}
if _, err := ctx.FirstPostFormValueErr("post"); err != nil {
t.Error(err)
}
if query, ok := ctx.FormValue("query1"); ok {
t.Error("Expected query1 not to be found:", query)
}
if post, ok := ctx.PostFormValue("post1"); ok {
t.Error("Expected post1 not to be found:", post)
}
if query, ok := ctx.FirstFormValue("query1"); ok {
t.Error("Expected query1 not to be found:", query)
}
if post, ok := ctx.FirstPostFormValue("post1"); ok {
t.Error("Expected post1 not to be found:", post)
}
if query, err := ctx.FirstFormValueErr("query1"); err == nil {
t.Error("Expected query1 not to be found:", query)
}
if post, err := ctx.FirstPostFormValueErr("post1"); err == nil {
t.Error("Expected post1 not to be found:", post)
}
}
import "testing"
func TestContext_SaveUser(t *testing.T) {
t.Parallel()
@ -177,36 +107,3 @@ func TestContext_LoadSessionUser(t *testing.T) {
}
}
}
func TestContext_Attributes(t *testing.T) {
t.Parallel()
now := time.Now().UTC()
ab := New()
ctx := ab.NewContext()
ctx.postFormValues = map[string][]string{
"a": []string{"a", "1"},
"b_int": []string{"5", "hello"},
"wildcard": nil,
"c_date": []string{now.Format(time.RFC3339)},
}
attr, err := ctx.Attributes()
if err != nil {
t.Error(err)
}
if got := attr["a"].(string); got != "a" {
t.Error("a's value is wrong:", got)
}
if got := attr["b"].(int); got != 5 {
t.Error("b's value is wrong:", got)
}
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
t.Error("c's value is wrong:", now, got)
}
if _, ok := attr["wildcard"]; ok {
t.Error("We don't need totally empty fields.")
}
}

View File

@ -2,10 +2,11 @@
package mocks
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"gopkg.in/authboss.v0"
@ -278,28 +279,28 @@ func (m *MockClientStorer) Put(key, val string) { m.Values[key] = val }
// Del a key/value pair
func (m *MockClientStorer) Del(key string) { delete(m.Values, key) }
// MockRequestContext returns a new context as if it came from POST request.
func MockRequestContext(ab *authboss.Authboss, postKeyValues ...string) *authboss.Context {
keyValues := &bytes.Buffer{}
for i := 0; i < len(postKeyValues); i += 2 {
if i != 0 {
keyValues.WriteByte('&')
// MockRequest returns a new mock request with optional key-value body (form-post)
func MockRequest(method string, postKeyValues ...string) *http.Request {
var body io.Reader
if len(postKeyValues) > 0 {
urlValues := make(url.Values)
for i := 0; i < len(postKeyValues); i += 2 {
urlValues.Set(postKeyValues[i], postKeyValues[i+1])
}
fmt.Fprintf(keyValues, "%s=%s", postKeyValues[i], postKeyValues[i+1])
body = strings.NewReader(urlValues.Encode())
}
req, err := http.NewRequest("POST", "http://localhost", keyValues)
req, err := http.NewRequest(method, "http://localhost", body)
if err != nil {
panic(err.Error())
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := ab.ContextFromRequest(req)
if err != nil {
panic(err)
if len(postKeyValues) > 0 {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
return ctx
return req
}
// MockMailer helps simplify mailer testing by storing the last sent email

View File

@ -56,7 +56,7 @@ func (m mockClientStore) GetErr(key string) (string, error) {
func (m mockClientStore) Put(key, val string) { m[key] = val }
func (m mockClientStore) Del(key string) { delete(m, key) }
func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context {
func mockRequestContext(ab *Authboss, postKeyValues ...string) (*Context, *http.Request) {
keyValues := &bytes.Buffer{}
for i := 0; i < len(postKeyValues); i += 2 {
if i != 0 {
@ -71,12 +71,7 @@ func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context {
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := ab.ContextFromRequest(req)
if err != nil {
panic(err)
}
return ctx
return ab.NewContext(), req
}
type mockValidator struct {

View File

@ -47,11 +47,7 @@ type contextRoute struct {
func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Instantiate the context
ctx, err := c.Authboss.ContextFromRequest(r)
if err != nil {
fmt.Fprintf(c.LogWriter, "route: Malformed request, could not create context: %v", err)
return
}
ctx := c.Authboss.NewContext()
ctx.CookieStorer = clientStoreWrapper{c.CookieStoreMaker(w, r)}
ctx.SessionStorer = clientStoreWrapper{c.SessionStoreMaker(w, r)}
@ -61,7 +57,7 @@ func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// Call the handler
err = c.fn(ctx, w, r)
err := c.fn(ctx, w, r)
if err == nil {
return
}
@ -121,7 +117,7 @@ func redirectIfLoggedIn(ctx *Context, w http.ResponseWriter, r *http.Request) (h
io.WriteString(w, "500 An error has occurred")
return true
} else if cu != nil {
if redir, ok := ctx.FirstFormValue(FormValueRedirect); ok && len(redir) > 0 {
if redir := r.FormValue(FormValueRedirect); len(redir) > 0 {
http.Redirect(w, r, redir, http.StatusFound)
} else {
http.Redirect(w, r, ctx.AuthLoginOKPath, http.StatusFound)

View File

@ -6,7 +6,10 @@ import (
"database/sql/driver"
"errors"
"fmt"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"unicode"
)
@ -109,6 +112,45 @@ func (a AttributeMeta) Names() []string {
// Attributes is just a key-value mapping of data.
type Attributes map[string]interface{}
// Attributes converts the post form values into an attributes map.
func AttributesFromRequest(r *http.Request) (Attributes, error) {
attr := make(Attributes)
if err := r.ParseForm(); err != nil {
return nil, err
}
for name, values := range r.Form {
if len(values) == 0 {
continue
}
val := values[0]
if len(val) == 0 {
continue
}
switch {
case strings.HasSuffix(name, "_int"):
integer, err := strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err)
}
attr[strings.TrimRight(name, "_int")] = integer
case strings.HasSuffix(name, "_date"):
date, err := time.Parse(time.RFC3339, val)
if err != nil {
return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err)
}
attr[strings.TrimRight(name, "_date")] = date.UTC()
default:
attr[name] = val
}
}
return attr, nil
}
// Names returns the names of all the attributes.
func (a Attributes) Names() []string {
names := make([]string, len(a))

View File

@ -4,6 +4,8 @@ import (
"bytes"
"database/sql"
"database/sql/driver"
"net/http"
"net/url"
"strings"
"testing"
"time"
@ -26,6 +28,41 @@ func (nt NullTime) Value() (driver.Value, error) {
return nt.Time, nil
}
func TestAttributes_FromRequest(t *testing.T) {
t.Parallel()
now := time.Now().UTC()
vals := make(url.Values)
vals.Set("a", "a")
vals.Set("b_int", "5")
vals.Set("wildcard", "")
vals.Set("c_date", now.Format(time.RFC3339))
req, err := http.NewRequest("POST", "/", strings.NewReader(vals.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if err != nil {
t.Error(err)
}
attr, err := AttributesFromRequest(req)
if err != nil {
t.Error(err)
}
if got := attr["a"].(string); got != "a" {
t.Error("a's value is wrong:", got)
}
if got := attr["b"].(int); got != 5 {
t.Error("b's value is wrong:", got)
}
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
t.Error("c's value is wrong:", now, got)
}
if _, ok := attr["wildcard"]; ok {
t.Error("We don't need totally empty fields.")
}
}
func TestAttributes_Names(t *testing.T) {
t.Parallel()

View File

@ -3,6 +3,7 @@ package authboss
import (
"bytes"
"fmt"
"net/http"
)
const (
@ -64,26 +65,26 @@ func (f FieldError) Error() string {
}
// Validate validates a request using the given ruleset.
func (ctx *Context) Validate(ruleset []Validator, confirmFields ...string) ErrorList {
func (ctx *Context) Validate(r *http.Request, ruleset []Validator, confirmFields ...string) ErrorList {
errList := make(ErrorList, 0)
for _, validator := range ruleset {
field := validator.Field()
val, _ := ctx.FirstFormValue(field)
val := r.FormValue(field)
if errs := validator.Errors(val); len(errs) > 0 {
errList = append(errList, errs...)
}
}
for i := 0; i < len(confirmFields)-1; i += 2 {
main, ok := ctx.FirstPostFormValue(confirmFields[i])
if !ok {
main := r.FormValue(confirmFields[i])
if len(main) == 0 {
continue
}
confirm, ok := ctx.FirstPostFormValue(confirmFields[i+1])
if !ok || main != confirm {
confirm := r.FormValue(confirmFields[i+1])
if len(confirm) == 0 || main != confirm {
errList = append(errList, FieldError{confirmFields[i+1], fmt.Errorf("Does not match %s", confirmFields[i])})
}
}

View File

@ -65,9 +65,9 @@ func TestValidate(t *testing.T) {
t.Parallel()
ab := New()
ctx := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com")
ctx, req := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com")
errList := ctx.Validate([]Validator{
errList := ctx.Validate(req, []Validator{
mockValidator{
FieldName: StoreUsername,
Errs: ErrorList{FieldError{StoreUsername, errors.New("must be longer than 4")}},
@ -97,20 +97,20 @@ func TestValidate_Confirm(t *testing.T) {
t.Parallel()
ab := New()
ctx := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny")
errs := ctx.Validate(nil, StoreUsername, "confirmUsername").Map()
ctx, req := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny")
errs := ctx.Validate(req, nil, StoreUsername, "confirmUsername").Map()
if errs["confirmUsername"][0] != "Does not match username" {
t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0])
}
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
errs = ctx.Validate(nil, StoreUsername, "confirmUsername").Map()
ctx, req = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
errs = ctx.Validate(req, nil, StoreUsername, "confirmUsername").Map()
if len(errs) != 0 {
t.Error("Expected no errors:", errs)
}
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
errs = ctx.Validate(nil, StoreUsername).Map()
ctx, req = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
errs = ctx.Validate(req, nil, StoreUsername).Map()
if len(errs) != 0 {
t.Error("Expected no errors:", errs)
}