From 8a87d0de6399681e4bf638600198cba607ec4a2a Mon Sep 17 00:00:00 2001 From: Aaron L Date: Sun, 2 Aug 2015 11:51:35 -0700 Subject: [PATCH] Separate the request from context. --- authboss.go | 10 +--- context.go | 107 ---------------------------------------- context_test.go | 105 +-------------------------------------- internal/mocks/mocks.go | 31 ++++++------ mocks_test.go | 9 +--- router.go | 10 ++-- storer.go | 42 ++++++++++++++++ storer_test.go | 37 ++++++++++++++ validation.go | 13 ++--- validation_test.go | 16 +++--- 10 files changed, 118 insertions(+), 262 deletions(-) diff --git a/authboss.go b/authboss.go index ce2bac8..7c8f4b5 100644 --- a/authboss.go +++ b/authboss.go @@ -64,10 +64,7 @@ func (a *Authboss) Init(modulesToLoad ...string) error { // CurrentUser retrieves the current user from the session and the database. func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) { - ctx, err := a.ContextFromRequest(r) - if err != nil { - return nil, err - } + ctx := a.NewContext() ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)} ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)} @@ -168,10 +165,7 @@ func (a *Authboss) UpdatePassword(w http.ResponseWriter, r *http.Request, return nil } - ctx, err := a.ContextFromRequest(r) - if err != nil { - return err - } + ctx := a.NewContext() ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)} ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)} return a.Callbacks.FireAfter(EventPasswordReset, ctx) diff --git a/context.go b/context.go index ae1d1d4..53922db 100644 --- a/context.go +++ b/context.go @@ -2,11 +2,7 @@ package authboss import ( "errors" - "fmt" - "net/http" - "strconv" "strings" - "time" ) // FormValue constants @@ -24,9 +20,6 @@ type Context struct { SessionStorer ClientStorerErr CookieStorer ClientStorerErr User Attributes - - postFormValues map[string][]string - formValues map[string][]string } // NewContext is exported for testing modules. @@ -36,75 +29,6 @@ func (a *Authboss) NewContext() *Context { } } -// ContextFromRequest creates a context from an http request. -func (a *Authboss) ContextFromRequest(r *http.Request) (*Context, error) { - if err := r.ParseForm(); err != nil { - return nil, err - } - - c := a.NewContext() - c.formValues = map[string][]string(r.Form) - c.postFormValues = map[string][]string(r.PostForm) - return c, nil -} - -// FormValue gets a form value from a context created with a request. -func (c *Context) FormValue(key string) ([]string, bool) { - val, ok := c.formValues[key] - return val, ok -} - -// PostFormValue gets a form value from a context created with a request. -func (c *Context) PostFormValue(key string) ([]string, bool) { - val, ok := c.postFormValues[key] - return val, ok -} - -// FirstFormValue gets the first form value from a context created with a request. -func (c *Context) FirstFormValue(key string) (string, bool) { - val, ok := c.formValues[key] - - if !ok || len(val) == 0 || len(val[0]) == 0 { - return "", false - } - - return val[0], ok -} - -// FirstPostFormValue gets the first form value from a context created with a request. -func (c *Context) FirstPostFormValue(key string) (string, bool) { - val, ok := c.postFormValues[key] - - if !ok || len(val) == 0 || len(val[0]) == 0 { - return "", false - } - - return val[0], ok -} - -// FirstFormValueErr gets the first form value from a context created with a request -// and additionally returns an error not a bool if it's not found. -func (c *Context) FirstFormValueErr(key string) (string, error) { - val, ok := c.formValues[key] - - if !ok || len(val) == 0 || len(val[0]) == 0 { - return "", ClientDataErr{key} - } - - return val[0], nil -} - -// FirstPostFormValueErr gets the first form value from a context created with a request. -func (c *Context) FirstPostFormValueErr(key string) (string, error) { - val, ok := c.postFormValues[key] - - if !ok || len(val) == 0 || len(val[0]) == 0 { - return "", ClientDataErr{key} - } - - return val[0], nil -} - // LoadUser loads the user Attributes if they haven't already been loaded. func (c *Context) LoadUser(key string) error { if c.User != nil { @@ -155,34 +79,3 @@ func (c *Context) SaveUser() error { return c.Storer.Put(key, c.User) } - -// Attributes converts the post form values into an attributes map. -func (c *Context) Attributes() (Attributes, error) { - attr := make(Attributes) - - for name, values := range c.postFormValues { - if len(values) == 0 { - continue - } - - val := values[0] - switch { - case strings.HasSuffix(name, "_int"): - integer, err := strconv.Atoi(val) - if err != nil { - return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err) - } - attr[strings.TrimRight(name, "_int")] = integer - case strings.HasSuffix(name, "_date"): - date, err := time.Parse(time.RFC3339, val) - if err != nil { - return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err) - } - attr[strings.TrimRight(name, "_date")] = date.UTC() - default: - attr[name] = val - } - } - - return attr, nil -} diff --git a/context_test.go b/context_test.go index a4bec43..20436f5 100644 --- a/context_test.go +++ b/context_test.go @@ -1,76 +1,6 @@ package authboss -import ( - "bytes" - "net/http" - "testing" - "time" -) - -func TestContext_Request(t *testing.T) { - t.Parallel() - - ab := New() - - req, err := http.NewRequest("POST", "http://localhost?query=string", bytes.NewBufferString("post=form")) - if err != nil { - t.Error("Unexpected Error:", err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - ctx, err := ab.ContextFromRequest(req) - if err != nil { - t.Error("Unexpected Error:", err) - } - - if query, ok := ctx.FormValue("query"); !ok || query[0] != "string" { - t.Error("Form value not getting recorded correctly.") - } - - if post, ok := ctx.PostFormValue("post"); !ok || post[0] != "form" { - t.Error("Postform value not getting recorded correctly.") - } - - if query, ok := ctx.FirstFormValue("query"); !ok || query != "string" { - t.Error("Form value not getting recorded correctly.") - } - - if post, ok := ctx.FirstPostFormValue("post"); !ok || post != "form" { - t.Error("Postform value not getting recorded correctly.") - } - - if _, err := ctx.FirstFormValueErr("query"); err != nil { - t.Error(err) - } - - if _, err := ctx.FirstPostFormValueErr("post"); err != nil { - t.Error(err) - } - - if query, ok := ctx.FormValue("query1"); ok { - t.Error("Expected query1 not to be found:", query) - } - - if post, ok := ctx.PostFormValue("post1"); ok { - t.Error("Expected post1 not to be found:", post) - } - - if query, ok := ctx.FirstFormValue("query1"); ok { - t.Error("Expected query1 not to be found:", query) - } - - if post, ok := ctx.FirstPostFormValue("post1"); ok { - t.Error("Expected post1 not to be found:", post) - } - - if query, err := ctx.FirstFormValueErr("query1"); err == nil { - t.Error("Expected query1 not to be found:", query) - } - - if post, err := ctx.FirstPostFormValueErr("post1"); err == nil { - t.Error("Expected post1 not to be found:", post) - } -} +import "testing" func TestContext_SaveUser(t *testing.T) { t.Parallel() @@ -177,36 +107,3 @@ func TestContext_LoadSessionUser(t *testing.T) { } } } - -func TestContext_Attributes(t *testing.T) { - t.Parallel() - - now := time.Now().UTC() - - ab := New() - ctx := ab.NewContext() - ctx.postFormValues = map[string][]string{ - "a": []string{"a", "1"}, - "b_int": []string{"5", "hello"}, - "wildcard": nil, - "c_date": []string{now.Format(time.RFC3339)}, - } - - attr, err := ctx.Attributes() - if err != nil { - t.Error(err) - } - - if got := attr["a"].(string); got != "a" { - t.Error("a's value is wrong:", got) - } - if got := attr["b"].(int); got != 5 { - t.Error("b's value is wrong:", got) - } - if got := attr["c"].(time.Time); got.Unix() != now.Unix() { - t.Error("c's value is wrong:", now, got) - } - if _, ok := attr["wildcard"]; ok { - t.Error("We don't need totally empty fields.") - } -} diff --git a/internal/mocks/mocks.go b/internal/mocks/mocks.go index cf3d75b..f067384 100644 --- a/internal/mocks/mocks.go +++ b/internal/mocks/mocks.go @@ -2,10 +2,11 @@ package mocks import ( - "bytes" "errors" - "fmt" + "io" "net/http" + "net/url" + "strings" "time" "gopkg.in/authboss.v0" @@ -278,28 +279,28 @@ func (m *MockClientStorer) Put(key, val string) { m.Values[key] = val } // Del a key/value pair func (m *MockClientStorer) Del(key string) { delete(m.Values, key) } -// MockRequestContext returns a new context as if it came from POST request. -func MockRequestContext(ab *authboss.Authboss, postKeyValues ...string) *authboss.Context { - keyValues := &bytes.Buffer{} - for i := 0; i < len(postKeyValues); i += 2 { - if i != 0 { - keyValues.WriteByte('&') +// MockRequest returns a new mock request with optional key-value body (form-post) +func MockRequest(method string, postKeyValues ...string) *http.Request { + var body io.Reader + + if len(postKeyValues) > 0 { + urlValues := make(url.Values) + for i := 0; i < len(postKeyValues); i += 2 { + urlValues.Set(postKeyValues[i], postKeyValues[i+1]) } - fmt.Fprintf(keyValues, "%s=%s", postKeyValues[i], postKeyValues[i+1]) + body = strings.NewReader(urlValues.Encode()) } - req, err := http.NewRequest("POST", "http://localhost", keyValues) + req, err := http.NewRequest(method, "http://localhost", body) if err != nil { panic(err.Error()) } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - ctx, err := ab.ContextFromRequest(req) - if err != nil { - panic(err) + if len(postKeyValues) > 0 { + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") } - return ctx + return req } // MockMailer helps simplify mailer testing by storing the last sent email diff --git a/mocks_test.go b/mocks_test.go index ffce18c..78ded53 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -56,7 +56,7 @@ func (m mockClientStore) GetErr(key string) (string, error) { func (m mockClientStore) Put(key, val string) { m[key] = val } func (m mockClientStore) Del(key string) { delete(m, key) } -func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context { +func mockRequestContext(ab *Authboss, postKeyValues ...string) (*Context, *http.Request) { keyValues := &bytes.Buffer{} for i := 0; i < len(postKeyValues); i += 2 { if i != 0 { @@ -71,12 +71,7 @@ func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context { } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - ctx, err := ab.ContextFromRequest(req) - if err != nil { - panic(err) - } - - return ctx + return ab.NewContext(), req } type mockValidator struct { diff --git a/router.go b/router.go index 4679a79..9607f5f 100644 --- a/router.go +++ b/router.go @@ -47,11 +47,7 @@ type contextRoute struct { func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Instantiate the context - ctx, err := c.Authboss.ContextFromRequest(r) - if err != nil { - fmt.Fprintf(c.LogWriter, "route: Malformed request, could not create context: %v", err) - return - } + ctx := c.Authboss.NewContext() ctx.CookieStorer = clientStoreWrapper{c.CookieStoreMaker(w, r)} ctx.SessionStorer = clientStoreWrapper{c.SessionStoreMaker(w, r)} @@ -61,7 +57,7 @@ func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Call the handler - err = c.fn(ctx, w, r) + err := c.fn(ctx, w, r) if err == nil { return } @@ -121,7 +117,7 @@ func redirectIfLoggedIn(ctx *Context, w http.ResponseWriter, r *http.Request) (h io.WriteString(w, "500 An error has occurred") return true } else if cu != nil { - if redir, ok := ctx.FirstFormValue(FormValueRedirect); ok && len(redir) > 0 { + if redir := r.FormValue(FormValueRedirect); len(redir) > 0 { http.Redirect(w, r, redir, http.StatusFound) } else { http.Redirect(w, r, ctx.AuthLoginOKPath, http.StatusFound) diff --git a/storer.go b/storer.go index d5b8edc..c0c9824 100644 --- a/storer.go +++ b/storer.go @@ -6,7 +6,10 @@ import ( "database/sql/driver" "errors" "fmt" + "net/http" "reflect" + "strconv" + "strings" "time" "unicode" ) @@ -109,6 +112,45 @@ func (a AttributeMeta) Names() []string { // Attributes is just a key-value mapping of data. type Attributes map[string]interface{} +// Attributes converts the post form values into an attributes map. +func AttributesFromRequest(r *http.Request) (Attributes, error) { + attr := make(Attributes) + + if err := r.ParseForm(); err != nil { + return nil, err + } + + for name, values := range r.Form { + if len(values) == 0 { + continue + } + + val := values[0] + if len(val) == 0 { + continue + } + + switch { + case strings.HasSuffix(name, "_int"): + integer, err := strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err) + } + attr[strings.TrimRight(name, "_int")] = integer + case strings.HasSuffix(name, "_date"): + date, err := time.Parse(time.RFC3339, val) + if err != nil { + return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err) + } + attr[strings.TrimRight(name, "_date")] = date.UTC() + default: + attr[name] = val + } + } + + return attr, nil +} + // Names returns the names of all the attributes. func (a Attributes) Names() []string { names := make([]string, len(a)) diff --git a/storer_test.go b/storer_test.go index dee90e1..a33f93f 100644 --- a/storer_test.go +++ b/storer_test.go @@ -4,6 +4,8 @@ import ( "bytes" "database/sql" "database/sql/driver" + "net/http" + "net/url" "strings" "testing" "time" @@ -26,6 +28,41 @@ func (nt NullTime) Value() (driver.Value, error) { return nt.Time, nil } +func TestAttributes_FromRequest(t *testing.T) { + t.Parallel() + + now := time.Now().UTC() + + vals := make(url.Values) + vals.Set("a", "a") + vals.Set("b_int", "5") + vals.Set("wildcard", "") + vals.Set("c_date", now.Format(time.RFC3339)) + req, err := http.NewRequest("POST", "/", strings.NewReader(vals.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if err != nil { + t.Error(err) + } + + attr, err := AttributesFromRequest(req) + if err != nil { + t.Error(err) + } + + if got := attr["a"].(string); got != "a" { + t.Error("a's value is wrong:", got) + } + if got := attr["b"].(int); got != 5 { + t.Error("b's value is wrong:", got) + } + if got := attr["c"].(time.Time); got.Unix() != now.Unix() { + t.Error("c's value is wrong:", now, got) + } + if _, ok := attr["wildcard"]; ok { + t.Error("We don't need totally empty fields.") + } +} + func TestAttributes_Names(t *testing.T) { t.Parallel() diff --git a/validation.go b/validation.go index 7960b82..a66bf32 100644 --- a/validation.go +++ b/validation.go @@ -3,6 +3,7 @@ package authboss import ( "bytes" "fmt" + "net/http" ) const ( @@ -64,26 +65,26 @@ func (f FieldError) Error() string { } // Validate validates a request using the given ruleset. -func (ctx *Context) Validate(ruleset []Validator, confirmFields ...string) ErrorList { +func (ctx *Context) Validate(r *http.Request, ruleset []Validator, confirmFields ...string) ErrorList { errList := make(ErrorList, 0) for _, validator := range ruleset { field := validator.Field() - val, _ := ctx.FirstFormValue(field) + val := r.FormValue(field) if errs := validator.Errors(val); len(errs) > 0 { errList = append(errList, errs...) } } for i := 0; i < len(confirmFields)-1; i += 2 { - main, ok := ctx.FirstPostFormValue(confirmFields[i]) - if !ok { + main := r.FormValue(confirmFields[i]) + if len(main) == 0 { continue } - confirm, ok := ctx.FirstPostFormValue(confirmFields[i+1]) - if !ok || main != confirm { + confirm := r.FormValue(confirmFields[i+1]) + if len(confirm) == 0 || main != confirm { errList = append(errList, FieldError{confirmFields[i+1], fmt.Errorf("Does not match %s", confirmFields[i])}) } } diff --git a/validation_test.go b/validation_test.go index 1d65c12..0edfb9e 100644 --- a/validation_test.go +++ b/validation_test.go @@ -65,9 +65,9 @@ func TestValidate(t *testing.T) { t.Parallel() ab := New() - ctx := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com") + ctx, req := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com") - errList := ctx.Validate([]Validator{ + errList := ctx.Validate(req, []Validator{ mockValidator{ FieldName: StoreUsername, Errs: ErrorList{FieldError{StoreUsername, errors.New("must be longer than 4")}}, @@ -97,20 +97,20 @@ func TestValidate_Confirm(t *testing.T) { t.Parallel() ab := New() - ctx := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny") - errs := ctx.Validate(nil, StoreUsername, "confirmUsername").Map() + ctx, req := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny") + errs := ctx.Validate(req, nil, StoreUsername, "confirmUsername").Map() if errs["confirmUsername"][0] != "Does not match username" { t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0]) } - ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john") - errs = ctx.Validate(nil, StoreUsername, "confirmUsername").Map() + ctx, req = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john") + errs = ctx.Validate(req, nil, StoreUsername, "confirmUsername").Map() if len(errs) != 0 { t.Error("Expected no errors:", errs) } - ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john") - errs = ctx.Validate(nil, StoreUsername).Map() + ctx, req = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john") + errs = ctx.Validate(req, nil, StoreUsername).Map() if len(errs) != 0 { t.Error("Expected no errors:", errs) }