mirror of
https://github.com/volatiletech/authboss.git
synced 2025-02-01 13:17:43 +02:00
Separate the request from context.
This commit is contained in:
parent
2eff32e3c8
commit
8a87d0de63
10
authboss.go
10
authboss.go
@ -64,10 +64,7 @@ func (a *Authboss) Init(modulesToLoad ...string) error {
|
||||
|
||||
// CurrentUser retrieves the current user from the session and the database.
|
||||
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
|
||||
ctx, err := a.ContextFromRequest(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx := a.NewContext()
|
||||
|
||||
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
|
||||
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
|
||||
@ -168,10 +165,7 @@ func (a *Authboss) UpdatePassword(w http.ResponseWriter, r *http.Request,
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, err := a.ContextFromRequest(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := a.NewContext()
|
||||
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
|
||||
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
|
||||
return a.Callbacks.FireAfter(EventPasswordReset, ctx)
|
||||
|
107
context.go
107
context.go
@ -2,11 +2,7 @@ package authboss
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FormValue constants
|
||||
@ -24,9 +20,6 @@ type Context struct {
|
||||
SessionStorer ClientStorerErr
|
||||
CookieStorer ClientStorerErr
|
||||
User Attributes
|
||||
|
||||
postFormValues map[string][]string
|
||||
formValues map[string][]string
|
||||
}
|
||||
|
||||
// NewContext is exported for testing modules.
|
||||
@ -36,75 +29,6 @@ func (a *Authboss) NewContext() *Context {
|
||||
}
|
||||
}
|
||||
|
||||
// ContextFromRequest creates a context from an http request.
|
||||
func (a *Authboss) ContextFromRequest(r *http.Request) (*Context, error) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := a.NewContext()
|
||||
c.formValues = map[string][]string(r.Form)
|
||||
c.postFormValues = map[string][]string(r.PostForm)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// FormValue gets a form value from a context created with a request.
|
||||
func (c *Context) FormValue(key string) ([]string, bool) {
|
||||
val, ok := c.formValues[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// PostFormValue gets a form value from a context created with a request.
|
||||
func (c *Context) PostFormValue(key string) ([]string, bool) {
|
||||
val, ok := c.postFormValues[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// FirstFormValue gets the first form value from a context created with a request.
|
||||
func (c *Context) FirstFormValue(key string) (string, bool) {
|
||||
val, ok := c.formValues[key]
|
||||
|
||||
if !ok || len(val) == 0 || len(val[0]) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return val[0], ok
|
||||
}
|
||||
|
||||
// FirstPostFormValue gets the first form value from a context created with a request.
|
||||
func (c *Context) FirstPostFormValue(key string) (string, bool) {
|
||||
val, ok := c.postFormValues[key]
|
||||
|
||||
if !ok || len(val) == 0 || len(val[0]) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return val[0], ok
|
||||
}
|
||||
|
||||
// FirstFormValueErr gets the first form value from a context created with a request
|
||||
// and additionally returns an error not a bool if it's not found.
|
||||
func (c *Context) FirstFormValueErr(key string) (string, error) {
|
||||
val, ok := c.formValues[key]
|
||||
|
||||
if !ok || len(val) == 0 || len(val[0]) == 0 {
|
||||
return "", ClientDataErr{key}
|
||||
}
|
||||
|
||||
return val[0], nil
|
||||
}
|
||||
|
||||
// FirstPostFormValueErr gets the first form value from a context created with a request.
|
||||
func (c *Context) FirstPostFormValueErr(key string) (string, error) {
|
||||
val, ok := c.postFormValues[key]
|
||||
|
||||
if !ok || len(val) == 0 || len(val[0]) == 0 {
|
||||
return "", ClientDataErr{key}
|
||||
}
|
||||
|
||||
return val[0], nil
|
||||
}
|
||||
|
||||
// LoadUser loads the user Attributes if they haven't already been loaded.
|
||||
func (c *Context) LoadUser(key string) error {
|
||||
if c.User != nil {
|
||||
@ -155,34 +79,3 @@ func (c *Context) SaveUser() error {
|
||||
|
||||
return c.Storer.Put(key, c.User)
|
||||
}
|
||||
|
||||
// Attributes converts the post form values into an attributes map.
|
||||
func (c *Context) Attributes() (Attributes, error) {
|
||||
attr := make(Attributes)
|
||||
|
||||
for name, values := range c.postFormValues {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
val := values[0]
|
||||
switch {
|
||||
case strings.HasSuffix(name, "_int"):
|
||||
integer, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err)
|
||||
}
|
||||
attr[strings.TrimRight(name, "_int")] = integer
|
||||
case strings.HasSuffix(name, "_date"):
|
||||
date, err := time.Parse(time.RFC3339, val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err)
|
||||
}
|
||||
attr[strings.TrimRight(name, "_date")] = date.UTC()
|
||||
default:
|
||||
attr[name] = val
|
||||
}
|
||||
}
|
||||
|
||||
return attr, nil
|
||||
}
|
||||
|
105
context_test.go
105
context_test.go
@ -1,76 +1,6 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestContext_Request(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
|
||||
req, err := http.NewRequest("POST", "http://localhost?query=string", bytes.NewBufferString("post=form"))
|
||||
if err != nil {
|
||||
t.Error("Unexpected Error:", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
ctx, err := ab.ContextFromRequest(req)
|
||||
if err != nil {
|
||||
t.Error("Unexpected Error:", err)
|
||||
}
|
||||
|
||||
if query, ok := ctx.FormValue("query"); !ok || query[0] != "string" {
|
||||
t.Error("Form value not getting recorded correctly.")
|
||||
}
|
||||
|
||||
if post, ok := ctx.PostFormValue("post"); !ok || post[0] != "form" {
|
||||
t.Error("Postform value not getting recorded correctly.")
|
||||
}
|
||||
|
||||
if query, ok := ctx.FirstFormValue("query"); !ok || query != "string" {
|
||||
t.Error("Form value not getting recorded correctly.")
|
||||
}
|
||||
|
||||
if post, ok := ctx.FirstPostFormValue("post"); !ok || post != "form" {
|
||||
t.Error("Postform value not getting recorded correctly.")
|
||||
}
|
||||
|
||||
if _, err := ctx.FirstFormValueErr("query"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, err := ctx.FirstPostFormValueErr("post"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if query, ok := ctx.FormValue("query1"); ok {
|
||||
t.Error("Expected query1 not to be found:", query)
|
||||
}
|
||||
|
||||
if post, ok := ctx.PostFormValue("post1"); ok {
|
||||
t.Error("Expected post1 not to be found:", post)
|
||||
}
|
||||
|
||||
if query, ok := ctx.FirstFormValue("query1"); ok {
|
||||
t.Error("Expected query1 not to be found:", query)
|
||||
}
|
||||
|
||||
if post, ok := ctx.FirstPostFormValue("post1"); ok {
|
||||
t.Error("Expected post1 not to be found:", post)
|
||||
}
|
||||
|
||||
if query, err := ctx.FirstFormValueErr("query1"); err == nil {
|
||||
t.Error("Expected query1 not to be found:", query)
|
||||
}
|
||||
|
||||
if post, err := ctx.FirstPostFormValueErr("post1"); err == nil {
|
||||
t.Error("Expected post1 not to be found:", post)
|
||||
}
|
||||
}
|
||||
import "testing"
|
||||
|
||||
func TestContext_SaveUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
@ -177,36 +107,3 @@ func TestContext_LoadSessionUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext_Attributes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
ab := New()
|
||||
ctx := ab.NewContext()
|
||||
ctx.postFormValues = map[string][]string{
|
||||
"a": []string{"a", "1"},
|
||||
"b_int": []string{"5", "hello"},
|
||||
"wildcard": nil,
|
||||
"c_date": []string{now.Format(time.RFC3339)},
|
||||
}
|
||||
|
||||
attr, err := ctx.Attributes()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got := attr["a"].(string); got != "a" {
|
||||
t.Error("a's value is wrong:", got)
|
||||
}
|
||||
if got := attr["b"].(int); got != 5 {
|
||||
t.Error("b's value is wrong:", got)
|
||||
}
|
||||
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
|
||||
t.Error("c's value is wrong:", now, got)
|
||||
}
|
||||
if _, ok := attr["wildcard"]; ok {
|
||||
t.Error("We don't need totally empty fields.")
|
||||
}
|
||||
}
|
||||
|
@ -2,10 +2,11 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/authboss.v0"
|
||||
@ -278,28 +279,28 @@ func (m *MockClientStorer) Put(key, val string) { m.Values[key] = val }
|
||||
// Del a key/value pair
|
||||
func (m *MockClientStorer) Del(key string) { delete(m.Values, key) }
|
||||
|
||||
// MockRequestContext returns a new context as if it came from POST request.
|
||||
func MockRequestContext(ab *authboss.Authboss, postKeyValues ...string) *authboss.Context {
|
||||
keyValues := &bytes.Buffer{}
|
||||
for i := 0; i < len(postKeyValues); i += 2 {
|
||||
if i != 0 {
|
||||
keyValues.WriteByte('&')
|
||||
// MockRequest returns a new mock request with optional key-value body (form-post)
|
||||
func MockRequest(method string, postKeyValues ...string) *http.Request {
|
||||
var body io.Reader
|
||||
|
||||
if len(postKeyValues) > 0 {
|
||||
urlValues := make(url.Values)
|
||||
for i := 0; i < len(postKeyValues); i += 2 {
|
||||
urlValues.Set(postKeyValues[i], postKeyValues[i+1])
|
||||
}
|
||||
fmt.Fprintf(keyValues, "%s=%s", postKeyValues[i], postKeyValues[i+1])
|
||||
body = strings.NewReader(urlValues.Encode())
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", "http://localhost", keyValues)
|
||||
req, err := http.NewRequest(method, "http://localhost", body)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
ctx, err := ab.ContextFromRequest(req)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
if len(postKeyValues) > 0 {
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
|
||||
return ctx
|
||||
return req
|
||||
}
|
||||
|
||||
// MockMailer helps simplify mailer testing by storing the last sent email
|
||||
|
@ -56,7 +56,7 @@ func (m mockClientStore) GetErr(key string) (string, error) {
|
||||
func (m mockClientStore) Put(key, val string) { m[key] = val }
|
||||
func (m mockClientStore) Del(key string) { delete(m, key) }
|
||||
|
||||
func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context {
|
||||
func mockRequestContext(ab *Authboss, postKeyValues ...string) (*Context, *http.Request) {
|
||||
keyValues := &bytes.Buffer{}
|
||||
for i := 0; i < len(postKeyValues); i += 2 {
|
||||
if i != 0 {
|
||||
@ -71,12 +71,7 @@ func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context {
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
ctx, err := ab.ContextFromRequest(req)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return ctx
|
||||
return ab.NewContext(), req
|
||||
}
|
||||
|
||||
type mockValidator struct {
|
||||
|
10
router.go
10
router.go
@ -47,11 +47,7 @@ type contextRoute struct {
|
||||
|
||||
func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Instantiate the context
|
||||
ctx, err := c.Authboss.ContextFromRequest(r)
|
||||
if err != nil {
|
||||
fmt.Fprintf(c.LogWriter, "route: Malformed request, could not create context: %v", err)
|
||||
return
|
||||
}
|
||||
ctx := c.Authboss.NewContext()
|
||||
ctx.CookieStorer = clientStoreWrapper{c.CookieStoreMaker(w, r)}
|
||||
ctx.SessionStorer = clientStoreWrapper{c.SessionStoreMaker(w, r)}
|
||||
|
||||
@ -61,7 +57,7 @@ func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Call the handler
|
||||
err = c.fn(ctx, w, r)
|
||||
err := c.fn(ctx, w, r)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
@ -121,7 +117,7 @@ func redirectIfLoggedIn(ctx *Context, w http.ResponseWriter, r *http.Request) (h
|
||||
io.WriteString(w, "500 An error has occurred")
|
||||
return true
|
||||
} else if cu != nil {
|
||||
if redir, ok := ctx.FirstFormValue(FormValueRedirect); ok && len(redir) > 0 {
|
||||
if redir := r.FormValue(FormValueRedirect); len(redir) > 0 {
|
||||
http.Redirect(w, r, redir, http.StatusFound)
|
||||
} else {
|
||||
http.Redirect(w, r, ctx.AuthLoginOKPath, http.StatusFound)
|
||||
|
42
storer.go
42
storer.go
@ -6,7 +6,10 @@ import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
@ -109,6 +112,45 @@ func (a AttributeMeta) Names() []string {
|
||||
// Attributes is just a key-value mapping of data.
|
||||
type Attributes map[string]interface{}
|
||||
|
||||
// Attributes converts the post form values into an attributes map.
|
||||
func AttributesFromRequest(r *http.Request) (Attributes, error) {
|
||||
attr := make(Attributes)
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for name, values := range r.Form {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
val := values[0]
|
||||
if len(val) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(name, "_int"):
|
||||
integer, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err)
|
||||
}
|
||||
attr[strings.TrimRight(name, "_int")] = integer
|
||||
case strings.HasSuffix(name, "_date"):
|
||||
date, err := time.Parse(time.RFC3339, val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err)
|
||||
}
|
||||
attr[strings.TrimRight(name, "_date")] = date.UTC()
|
||||
default:
|
||||
attr[name] = val
|
||||
}
|
||||
}
|
||||
|
||||
return attr, nil
|
||||
}
|
||||
|
||||
// Names returns the names of all the attributes.
|
||||
func (a Attributes) Names() []string {
|
||||
names := make([]string, len(a))
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -26,6 +28,41 @@ func (nt NullTime) Value() (driver.Value, error) {
|
||||
return nt.Time, nil
|
||||
}
|
||||
|
||||
func TestAttributes_FromRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
vals := make(url.Values)
|
||||
vals.Set("a", "a")
|
||||
vals.Set("b_int", "5")
|
||||
vals.Set("wildcard", "")
|
||||
vals.Set("c_date", now.Format(time.RFC3339))
|
||||
req, err := http.NewRequest("POST", "/", strings.NewReader(vals.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
attr, err := AttributesFromRequest(req)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got := attr["a"].(string); got != "a" {
|
||||
t.Error("a's value is wrong:", got)
|
||||
}
|
||||
if got := attr["b"].(int); got != 5 {
|
||||
t.Error("b's value is wrong:", got)
|
||||
}
|
||||
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
|
||||
t.Error("c's value is wrong:", now, got)
|
||||
}
|
||||
if _, ok := attr["wildcard"]; ok {
|
||||
t.Error("We don't need totally empty fields.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttributes_Names(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -3,6 +3,7 @@ package authboss
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -64,26 +65,26 @@ func (f FieldError) Error() string {
|
||||
}
|
||||
|
||||
// Validate validates a request using the given ruleset.
|
||||
func (ctx *Context) Validate(ruleset []Validator, confirmFields ...string) ErrorList {
|
||||
func (ctx *Context) Validate(r *http.Request, ruleset []Validator, confirmFields ...string) ErrorList {
|
||||
errList := make(ErrorList, 0)
|
||||
|
||||
for _, validator := range ruleset {
|
||||
field := validator.Field()
|
||||
|
||||
val, _ := ctx.FirstFormValue(field)
|
||||
val := r.FormValue(field)
|
||||
if errs := validator.Errors(val); len(errs) > 0 {
|
||||
errList = append(errList, errs...)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(confirmFields)-1; i += 2 {
|
||||
main, ok := ctx.FirstPostFormValue(confirmFields[i])
|
||||
if !ok {
|
||||
main := r.FormValue(confirmFields[i])
|
||||
if len(main) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
confirm, ok := ctx.FirstPostFormValue(confirmFields[i+1])
|
||||
if !ok || main != confirm {
|
||||
confirm := r.FormValue(confirmFields[i+1])
|
||||
if len(confirm) == 0 || main != confirm {
|
||||
errList = append(errList, FieldError{confirmFields[i+1], fmt.Errorf("Does not match %s", confirmFields[i])})
|
||||
}
|
||||
}
|
||||
|
@ -65,9 +65,9 @@ func TestValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ctx := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com")
|
||||
ctx, req := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com")
|
||||
|
||||
errList := ctx.Validate([]Validator{
|
||||
errList := ctx.Validate(req, []Validator{
|
||||
mockValidator{
|
||||
FieldName: StoreUsername,
|
||||
Errs: ErrorList{FieldError{StoreUsername, errors.New("must be longer than 4")}},
|
||||
@ -97,20 +97,20 @@ func TestValidate_Confirm(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ctx := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny")
|
||||
errs := ctx.Validate(nil, StoreUsername, "confirmUsername").Map()
|
||||
ctx, req := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny")
|
||||
errs := ctx.Validate(req, nil, StoreUsername, "confirmUsername").Map()
|
||||
if errs["confirmUsername"][0] != "Does not match username" {
|
||||
t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0])
|
||||
}
|
||||
|
||||
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
|
||||
errs = ctx.Validate(nil, StoreUsername, "confirmUsername").Map()
|
||||
ctx, req = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
|
||||
errs = ctx.Validate(req, nil, StoreUsername, "confirmUsername").Map()
|
||||
if len(errs) != 0 {
|
||||
t.Error("Expected no errors:", errs)
|
||||
}
|
||||
|
||||
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
|
||||
errs = ctx.Validate(nil, StoreUsername).Map()
|
||||
ctx, req = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
|
||||
errs = ctx.Validate(req, nil, StoreUsername).Map()
|
||||
if len(errs) != 0 {
|
||||
t.Error("Expected no errors:", errs)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user