mirror of
https://github.com/volatiletech/authboss.git
synced 2025-07-15 01:24:33 +02:00
Introduce new type of client storage
- This addresses the problem of having to update multiple times during one request. It's hard to have a nice interface especially with JWT because you always end up having to decode the request, encode new response, write header, then a second write to it comes, and where do you grab the value from? Often you don't have access to the response as a "read" structure. So we store it as events instead, and play those events against the original data right before the response is written to set the headers.
This commit is contained in:
@ -1,13 +1,8 @@
|
|||||||
package authboss
|
package authboss
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAuthBossInit(t *testing.T) {
|
func TestAuthBossInit(t *testing.T) {
|
||||||
@ -22,74 +17,6 @@ func TestAuthBossInit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthBossCurrentUser(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ab := New()
|
|
||||||
ab.LogWriter = ioutil.Discard
|
|
||||||
ab.StoreLoader = mockStoreLoader{"joe": mockUser{Email: "john@john.com", Password: "lies"}}
|
|
||||||
ab.ViewLoader = mockRenderLoader{}
|
|
||||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{SessionKey: "joe"})
|
|
||||||
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
|
||||||
|
|
||||||
if err := ab.Init(); err != nil {
|
|
||||||
t.Error("Unexpected error:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("GET", "localhost", nil)
|
|
||||||
|
|
||||||
userStruct := ab.CurrentUserP(rec, req)
|
|
||||||
us := userStruct.(mockStoredUser)
|
|
||||||
|
|
||||||
if us.Email != "john@john.com" || us.Password != "lies" {
|
|
||||||
t.Error("Wrong user found!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthBossCurrentUserCallbacks(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ab := New()
|
|
||||||
ab.LogWriter = ioutil.Discard
|
|
||||||
ab.StoreLoader = mockStoreLoader{"joe": mockUser{Email: "john@john.com", Password: "lies"}}
|
|
||||||
ab.ViewLoader = mockRenderLoader{}
|
|
||||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{SessionKey: "joe"})
|
|
||||||
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
|
||||||
|
|
||||||
if err := ab.Init(); err != nil {
|
|
||||||
t.Error("Unexpected error:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req, _ := http.NewRequest("GET", "localhost", nil)
|
|
||||||
|
|
||||||
afterGetUser := errors.New("afterGetUser")
|
|
||||||
beforeGetUser := errors.New("beforeGetUser")
|
|
||||||
beforeGetUserSession := errors.New("beforeGetUserSession")
|
|
||||||
|
|
||||||
ab.Callbacks.After(EventGetUser, func(context.Context) error {
|
|
||||||
return afterGetUser
|
|
||||||
})
|
|
||||||
if _, err := ab.CurrentUser(rec, req); err != afterGetUser {
|
|
||||||
t.Error("Want:", afterGetUser, "Got:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ab.Callbacks.Before(EventGetUser, func(context.Context) (Interrupt, error) {
|
|
||||||
return InterruptNone, beforeGetUser
|
|
||||||
})
|
|
||||||
if _, err := ab.CurrentUser(rec, req); err != beforeGetUser {
|
|
||||||
t.Error("Want:", beforeGetUser, "Got:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ab.Callbacks.Before(EventGetUserSession, func(context.Context) (Interrupt, error) {
|
|
||||||
return InterruptNone, beforeGetUserSession
|
|
||||||
})
|
|
||||||
if _, err := ab.CurrentUser(rec, req); err != beforeGetUserSession {
|
|
||||||
t.Error("Want:", beforeGetUserSession, "Got:", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthbossUpdatePassword(t *testing.T) {
|
func TestAuthbossUpdatePassword(t *testing.T) {
|
||||||
t.Skip("TODO(aarondl): Implement")
|
t.Skip("TODO(aarondl): Implement")
|
||||||
/*
|
/*
|
||||||
|
254
client_state.go
Normal file
254
client_state.go
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
package authboss
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SessionKey is the primarily used key by authboss.
|
||||||
|
SessionKey = "uid"
|
||||||
|
// SessionHalfAuthKey is used for sessions that have been authenticated by
|
||||||
|
// the remember module. This serves as a way to force full authentication
|
||||||
|
// by denying half-authed users acccess to sensitive areas.
|
||||||
|
SessionHalfAuthKey = "halfauth"
|
||||||
|
// SessionLastAction is the session key to retrieve the last action of a user.
|
||||||
|
SessionLastAction = "last_action"
|
||||||
|
// SessionOAuth2State is the xsrf protection key for oauth.
|
||||||
|
SessionOAuth2State = "oauth2_state"
|
||||||
|
// SessionOAuth2Params is the additional settings for oauth like redirection/remember.
|
||||||
|
SessionOAuth2Params = "oauth2_params"
|
||||||
|
|
||||||
|
// CookieRemember is used for cookies and form input names.
|
||||||
|
CookieRemember = "rm"
|
||||||
|
|
||||||
|
// FlashSuccessKey is used for storing sucess flash messages on the session
|
||||||
|
FlashSuccessKey = "flash_success"
|
||||||
|
// FlashErrorKey is used for storing sucess flash messages on the session
|
||||||
|
FlashErrorKey = "flash_error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientStateEventKind is an enum.
|
||||||
|
type ClientStateEventKind int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ClientStateEventPut ClientStateEventKind = iota
|
||||||
|
ClientStateEventDel
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientStateEvent are the different events that can be recorded during
|
||||||
|
type ClientStateEvent struct {
|
||||||
|
Kind ClientStateEventKind
|
||||||
|
Key string
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientStateReadWriter is used to create a cookie storer from an http request.
|
||||||
|
// Keep in mind security considerations for your implementation, Secure,
|
||||||
|
// HTTP-Only, etc flags.
|
||||||
|
//
|
||||||
|
// There's two major uses for this. To create session storage, and remember me
|
||||||
|
// cookies.
|
||||||
|
type ClientStateReadWriter interface {
|
||||||
|
ReadState(http.ResponseWriter, *http.Request) (ClientState, error)
|
||||||
|
WriteState(http.ResponseWriter, ClientState, []ClientStateEvent) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientState represents the client's current state and can answer queries
|
||||||
|
// about it.
|
||||||
|
type ClientState interface {
|
||||||
|
Get(key string) (string, bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientStateResponseWriter is used to write out the client state at the last
|
||||||
|
// moment before the response code is written.
|
||||||
|
type ClientStateResponseWriter struct {
|
||||||
|
ab *Authboss
|
||||||
|
http.ResponseWriter
|
||||||
|
|
||||||
|
hasWritten bool
|
||||||
|
ctx context.Context
|
||||||
|
sessionStateEvents []ClientStateEvent
|
||||||
|
cookieStateEvents []ClientStateEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authboss) NewResponse(w http.ResponseWriter, r *http.Request) http.ResponseWriter {
|
||||||
|
return &ClientStateResponseWriter{
|
||||||
|
ab: a,
|
||||||
|
ResponseWriter: w,
|
||||||
|
ctx: r.Context(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authboss) LoadClientState(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
|
||||||
|
if a.SessionStateStorer != nil {
|
||||||
|
state, err := a.SessionStateStorer.ReadState(w, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if state == nil {
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), ctxKeySessionState, state)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
}
|
||||||
|
if a.CookieStateStorer != nil {
|
||||||
|
state, err := a.CookieStateStorer.ReadState(w, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if state == nil {
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
ctx := context.WithValue(r.Context(), ctxKeyCookieState, state)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeader writes the header, but in order to handle errors from the
|
||||||
|
// underlying ClientStateReadWriter, it has to panic.
|
||||||
|
func (c *ClientStateResponseWriter) WriteHeader(code int) {
|
||||||
|
if !c.hasWritten {
|
||||||
|
if err := c.putClientState(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Header retrieves the underlying headers
|
||||||
|
func (c ClientStateResponseWriter) Header() http.Header {
|
||||||
|
return c.ResponseWriter.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write ensures that the
|
||||||
|
func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
if !c.hasWritten {
|
||||||
|
if err := c.putClientState(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientStateResponseWriter) putClientState() error {
|
||||||
|
if c.hasWritten {
|
||||||
|
panic("should not call putClientState twice")
|
||||||
|
}
|
||||||
|
c.hasWritten = true
|
||||||
|
|
||||||
|
sessionStateIntf := c.ctx.Value(ctxKeySessionState)
|
||||||
|
cookieStateIntf := c.ctx.Value(ctxKeyCookieState)
|
||||||
|
var session, cookie ClientState
|
||||||
|
if sessionStateIntf != nil {
|
||||||
|
session = sessionStateIntf.(ClientState)
|
||||||
|
}
|
||||||
|
if cookieStateIntf != nil {
|
||||||
|
cookie = cookieStateIntf.(ClientState)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ab.SessionStateStorer != nil {
|
||||||
|
err := c.ab.SessionStateStorer.WriteState(c, session, c.sessionStateEvents)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.ab.CookieStateStorer != nil {
|
||||||
|
err := c.ab.CookieStateStorer.WriteState(c, cookie, c.cookieStateEvents)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutSession puts a value into the session
|
||||||
|
func PutSession(w http.ResponseWriter, key, val string) {
|
||||||
|
putState(w, ctxKeySessionState, key, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DelSession deletes a key-value from the session.
|
||||||
|
func DelSession(w http.ResponseWriter, key string) {
|
||||||
|
delState(w, ctxKeySessionState, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSession fetches a value from the session
|
||||||
|
func GetSession(r *http.Request, key string) (string, bool) {
|
||||||
|
return getState(r, ctxKeySessionState, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutCookie puts a value into the session
|
||||||
|
func PutCookie(w http.ResponseWriter, key, val string) {
|
||||||
|
putState(w, ctxKeyCookieState, key, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DelCookie deletes a key-value from the session.
|
||||||
|
func DelCookie(w http.ResponseWriter, key string) {
|
||||||
|
delState(w, ctxKeyCookieState, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCookie fetches a value from the session
|
||||||
|
func GetCookie(r *http.Request, key string) (string, bool) {
|
||||||
|
return getState(r, ctxKeyCookieState, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func putState(w http.ResponseWriter, ctxKey contextKey, key, val string) {
|
||||||
|
setState(w, ctxKey, ClientStateEventPut, key, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func delState(w http.ResponseWriter, ctxKey contextKey, key string) {
|
||||||
|
setState(w, ctxKey, ClientStateEventDel, key, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setState(w http.ResponseWriter, ctxKey contextKey, op ClientStateEventKind, key, val string) {
|
||||||
|
csrw := w.(*ClientStateResponseWriter)
|
||||||
|
ev := ClientStateEvent{
|
||||||
|
Kind: op,
|
||||||
|
Key: key,
|
||||||
|
}
|
||||||
|
|
||||||
|
if op == ClientStateEventPut {
|
||||||
|
ev.Value = val
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ctxKey {
|
||||||
|
case ctxKeySessionState:
|
||||||
|
csrw.sessionStateEvents = append(csrw.sessionStateEvents, ev)
|
||||||
|
case ctxKeyCookieState:
|
||||||
|
csrw.cookieStateEvents = append(csrw.cookieStateEvents, ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getState(r *http.Request, ctxKey contextKey, key string) (string, bool) {
|
||||||
|
val := r.Context().Value(ctxKey)
|
||||||
|
if val == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
state := val.(ClientState)
|
||||||
|
return state.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlashSuccess returns FlashSuccessKey from the session and removes it.
|
||||||
|
func FlashSuccess(w http.ResponseWriter, r *http.Request) string {
|
||||||
|
str, ok := GetSession(r, FlashSuccessKey)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
DelSession(w, FlashSuccessKey)
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlashError returns FlashError from the session and removes it.
|
||||||
|
func FlashError(w http.ResponseWriter, r *http.Request) string {
|
||||||
|
str, ok := GetSession(r, FlashErrorKey)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
DelSession(w, FlashErrorKey)
|
||||||
|
return str
|
||||||
|
}
|
197
client_state_test.go
Normal file
197
client_state_test.go
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
package authboss
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStateGet(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ab := New()
|
||||||
|
ab.SessionStateStorer = newMockClientStateRW("one", "two")
|
||||||
|
ab.CookieStateStorer = newMockClientStateRW("three", "four")
|
||||||
|
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
r, err = ab.LoadClientState(w, r)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, _ := GetSession(r, "one"); got != "two" {
|
||||||
|
t.Error("session value was wrong:", got)
|
||||||
|
}
|
||||||
|
if got, _ := GetCookie(r, "three"); got != "four" {
|
||||||
|
t.Error("cookie value was wrong:", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ab := New()
|
||||||
|
ab.SessionStateStorer = newMockClientStateRW("one", "two")
|
||||||
|
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||||
|
csrw := w.(*ClientStateResponseWriter)
|
||||||
|
|
||||||
|
w.WriteHeader(200)
|
||||||
|
// Check this doesn't panic
|
||||||
|
w.WriteHeader(200)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if recover() == nil {
|
||||||
|
t.Error("expected a panic")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
csrw.putClientState()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateResponseWriterLastSecondWriteWithPrevious(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ab := New()
|
||||||
|
ab.SessionStateStorer = newMockClientStateRW("one", "two")
|
||||||
|
ab.CookieStateStorer = newMockClientStateRW("three", "four")
|
||||||
|
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
var w http.ResponseWriter = httptest.NewRecorder()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
r, err = ab.LoadClientState(w, r)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
w = ab.NewResponse(w, r)
|
||||||
|
|
||||||
|
w.WriteHeader(200)
|
||||||
|
|
||||||
|
// This is an odd test, since the mock will always overwrite the previous
|
||||||
|
// write with the cookie values. Keeping it anyway for code coverage
|
||||||
|
got := strings.TrimSpace(w.Header().Get("test_session"))
|
||||||
|
if got != `{"three":"four"}` {
|
||||||
|
t.Error("got:", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateResponseWriterLastSecondWriteHeader(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ab := New()
|
||||||
|
ab.SessionStateStorer = newMockClientStateRW()
|
||||||
|
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||||
|
|
||||||
|
PutSession(w, "one", "two")
|
||||||
|
|
||||||
|
w.WriteHeader(200)
|
||||||
|
got := strings.TrimSpace(w.Header().Get("test_session"))
|
||||||
|
if got != `{"one":"two"}` {
|
||||||
|
t.Error("got:", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateResponseWriterLastSecondWriteWrite(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ab := New()
|
||||||
|
ab.SessionStateStorer = newMockClientStateRW()
|
||||||
|
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||||
|
|
||||||
|
PutSession(w, "one", "two")
|
||||||
|
|
||||||
|
io.WriteString(w, "Hello world!")
|
||||||
|
|
||||||
|
got := strings.TrimSpace(w.Header().Get("test_session"))
|
||||||
|
if got != `{"one":"two"}` {
|
||||||
|
t.Error("got:", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateResponseWriterEvents(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ab := New()
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||||
|
|
||||||
|
csrw := w.(*ClientStateResponseWriter)
|
||||||
|
|
||||||
|
PutSession(w, "one", "two")
|
||||||
|
DelSession(w, "one")
|
||||||
|
DelCookie(w, "one")
|
||||||
|
PutCookie(w, "two", "one")
|
||||||
|
|
||||||
|
want := ClientStateEvent{Kind: ClientStateEventPut, Key: "one", Value: "two"}
|
||||||
|
if got := csrw.sessionStateEvents[0]; got != want {
|
||||||
|
t.Error("event was wrong", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
want = ClientStateEvent{Kind: ClientStateEventDel, Key: "one"}
|
||||||
|
if got := csrw.sessionStateEvents[1]; got != want {
|
||||||
|
t.Error("event was wrong", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
want = ClientStateEvent{Kind: ClientStateEventDel, Key: "one"}
|
||||||
|
if got := csrw.cookieStateEvents[0]; got != want {
|
||||||
|
t.Error("event was wrong", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
want = ClientStateEvent{Kind: ClientStateEventPut, Key: "two", Value: "one"}
|
||||||
|
if got := csrw.cookieStateEvents[1]; got != want {
|
||||||
|
t.Error("event was wrong", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlashClearer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ab := New()
|
||||||
|
ab.SessionStateStorer = newMockClientStateRW(FlashSuccessKey, "a", FlashErrorKey, "b")
|
||||||
|
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||||
|
csrw := w.(*ClientStateResponseWriter)
|
||||||
|
|
||||||
|
if msg := FlashSuccess(w, r); msg != "" {
|
||||||
|
t.Error("Unexpected flash success:", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg := FlashError(w, r); msg != "" {
|
||||||
|
t.Error("Unexpected flash error:", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
r, err = ab.LoadClientState(w, r)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg := FlashSuccess(w, r); msg != "a" {
|
||||||
|
t.Error("Unexpected flash success:", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg := FlashError(w, r); msg != "b" {
|
||||||
|
t.Error("Unexpected flash error:", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := ClientStateEvent{Kind: ClientStateEventDel, Key: FlashSuccessKey}
|
||||||
|
if got := csrw.sessionStateEvents[0]; got != want {
|
||||||
|
t.Error("event was wrong", got)
|
||||||
|
}
|
||||||
|
want = ClientStateEvent{Kind: ClientStateEventDel, Key: FlashErrorKey}
|
||||||
|
if got := csrw.sessionStateEvents[1]; got != want {
|
||||||
|
t.Error("event was wrong", got)
|
||||||
|
}
|
||||||
|
}
|
@ -1,86 +0,0 @@
|
|||||||
package authboss
|
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
const (
|
|
||||||
// SessionKey is the primarily used key by authboss.
|
|
||||||
SessionKey = "uid"
|
|
||||||
// SessionHalfAuthKey is used for sessions that have been authenticated by
|
|
||||||
// the remember module. This serves as a way to force full authentication
|
|
||||||
// by denying half-authed users acccess to sensitive areas.
|
|
||||||
SessionHalfAuthKey = "halfauth"
|
|
||||||
// SessionLastAction is the session key to retrieve the last action of a user.
|
|
||||||
SessionLastAction = "last_action"
|
|
||||||
// SessionOAuth2State is the xsrf protection key for oauth.
|
|
||||||
SessionOAuth2State = "oauth2_state"
|
|
||||||
// SessionOAuth2Params is the additional settings for oauth like redirection/remember.
|
|
||||||
SessionOAuth2Params = "oauth2_params"
|
|
||||||
|
|
||||||
// CookieRemember is used for cookies and form input names.
|
|
||||||
CookieRemember = "rm"
|
|
||||||
|
|
||||||
// FlashSuccessKey is used for storing sucess flash messages on the session
|
|
||||||
FlashSuccessKey = "flash_success"
|
|
||||||
// FlashErrorKey is used for storing sucess flash messages on the session
|
|
||||||
FlashErrorKey = "flash_error"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ClientStoreMaker is used to create a cookie storer from an http request.
|
|
||||||
// Keep in mind security considerations for your implementation, Secure,
|
|
||||||
// HTTP-Only, etc flags.
|
|
||||||
//
|
|
||||||
// There's two major uses for this. To create session storage, and remember me
|
|
||||||
// cookies.
|
|
||||||
type ClientStoreMaker interface {
|
|
||||||
Make(http.ResponseWriter, *http.Request) ClientStorer
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientStorer should be able to store values on the clients machine. Cookie and
|
|
||||||
// Session storers are built with this interface.
|
|
||||||
type ClientStorer interface {
|
|
||||||
Put(key, value string)
|
|
||||||
Get(key string) (string, bool)
|
|
||||||
Del(key string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientStorerErr is a wrapper to return error values from failed Gets.
|
|
||||||
type ClientStorerErr interface {
|
|
||||||
ClientStorer
|
|
||||||
GetErr(key string) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type clientStoreWrapper struct {
|
|
||||||
ClientStorer
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetErr returns a value or an error.
|
|
||||||
func (c clientStoreWrapper) GetErr(key string) (string, error) {
|
|
||||||
str, ok := c.Get(key)
|
|
||||||
if !ok {
|
|
||||||
return str, ClientDataErr{key}
|
|
||||||
}
|
|
||||||
|
|
||||||
return str, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FlashSuccess returns FlashSuccessKey from the session and removes it.
|
|
||||||
func (a *Authboss) FlashSuccess(w http.ResponseWriter, r *http.Request) string {
|
|
||||||
storer := a.SessionStoreMaker.Make(w, r)
|
|
||||||
msg, ok := storer.Get(FlashSuccessKey)
|
|
||||||
if ok {
|
|
||||||
storer.Del(FlashSuccessKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
// FlashError returns FlashError from the session and removes it.
|
|
||||||
func (a *Authboss) FlashError(w http.ResponseWriter, r *http.Request) string {
|
|
||||||
storer := a.SessionStoreMaker.Make(w, r)
|
|
||||||
msg, ok := storer.Get(FlashErrorKey)
|
|
||||||
if ok {
|
|
||||||
storer.Del(FlashErrorKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
@ -1,52 +0,0 @@
|
|||||||
package authboss
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
type testClientStorerErr string
|
|
||||||
|
|
||||||
func (t testClientStorerErr) Put(key, value string) {}
|
|
||||||
func (t testClientStorerErr) Get(key string) (string, bool) {
|
|
||||||
return string(t), key == string(t)
|
|
||||||
}
|
|
||||||
func (t testClientStorerErr) Del(key string) {}
|
|
||||||
|
|
||||||
func TestClientStorerErr(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var cs testClientStorerErr
|
|
||||||
|
|
||||||
csw := clientStoreWrapper{&cs}
|
|
||||||
if _, err := csw.GetErr("hello"); err == nil {
|
|
||||||
t.Error("Expected an error")
|
|
||||||
}
|
|
||||||
|
|
||||||
cs = "hello"
|
|
||||||
if str, err := csw.GetErr("hello"); err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
} else if str != "hello" {
|
|
||||||
t.Error("Wrong value:", str)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFlashClearer(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
session := mockClientStore{FlashSuccessKey: "success", FlashErrorKey: "error"}
|
|
||||||
ab := New()
|
|
||||||
ab.SessionStoreMaker = newMockClientStoreMaker(session)
|
|
||||||
|
|
||||||
if msg := ab.FlashSuccess(nil, nil); msg != "success" {
|
|
||||||
t.Error("Unexpected flash success:", msg)
|
|
||||||
}
|
|
||||||
if msg, ok := session.Get(FlashSuccessKey); ok {
|
|
||||||
t.Error("Unexpected success flash:", msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg := ab.FlashError(nil, nil); msg != "error" {
|
|
||||||
t.Error("Unexpected flash error:", msg)
|
|
||||||
}
|
|
||||||
if msg, ok := session.Get(FlashErrorKey); ok {
|
|
||||||
t.Error("Unexpected error flash:", msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
18
config.go
18
config.go
@ -97,14 +97,16 @@ type Config struct {
|
|||||||
// Storer is the interface through which Authboss accesses the web apps database.
|
// Storer is the interface through which Authboss accesses the web apps database.
|
||||||
StoreLoader StoreLoader
|
StoreLoader StoreLoader
|
||||||
|
|
||||||
// CookieStoreMaker must be defined to provide an interface capapable of storing cookies
|
// CookieStateStorer must be defined to provide an interface capapable of
|
||||||
// for the given response, and reading them from the request.
|
// storing cookies for the given response, and reading them from the request.
|
||||||
CookieStoreMaker ClientStoreMaker
|
CookieStateStorer ClientStateReadWriter
|
||||||
// SessionStoreMaker must be defined to provide an interface capable of storing session-only
|
// SessionStateStorer must be defined to provide an interface capable of
|
||||||
// values for the given response, and reading them from the request.
|
// storing session-only values for the given response, and reading them
|
||||||
SessionStoreMaker ClientStoreMaker
|
// from the request.
|
||||||
// LogWriter is written to when errors occur, as well as on startup to show which modules are loaded
|
SessionStateStorer ClientStateReadWriter
|
||||||
// and which routes they registered. By default writes to io.Discard.
|
// LogWriter is written to when errors occur, as well as on startup to show
|
||||||
|
// which modules are loaded and which routes they registered. By default
|
||||||
|
// writes to io.Discard.
|
||||||
LogWriter io.Writer
|
LogWriter io.Writer
|
||||||
// Mailer is the mailer being used to send e-mails out. Authboss defines two loggers for use
|
// Mailer is the mailer being used to send e-mails out. Authboss defines two loggers for use
|
||||||
// LogMailer and SMTPMailer, the default is a LogMailer to io.Discard.
|
// LogMailer and SMTPMailer, the default is a LogMailer to io.Discard.
|
||||||
|
@ -10,6 +10,9 @@ type contextKey string
|
|||||||
const (
|
const (
|
||||||
ctxKeyPID contextKey = "pid"
|
ctxKeyPID contextKey = "pid"
|
||||||
ctxKeyUser contextKey = "user"
|
ctxKeyUser contextKey = "user"
|
||||||
|
|
||||||
|
ctxKeySessionState contextKey = "session"
|
||||||
|
ctxKeyCookieState contextKey = "cookie"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c contextKey) String() string {
|
func (c contextKey) String() string {
|
||||||
@ -27,8 +30,7 @@ func (a *Authboss) CurrentUserID(w http.ResponseWriter, r *http.Request) (string
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
session := a.SessionStoreMaker.Make(w, r)
|
pid, _ := GetSession(r, SessionKey)
|
||||||
pid, _ := session.Get(SessionKey)
|
|
||||||
return pid, nil
|
return pid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,6 @@
|
|||||||
package authboss
|
package authboss
|
||||||
|
|
||||||
import (
|
/* TODO(aarondl): Re-enable
|
||||||
"context"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCurrentUserID(t *testing.T) {
|
func TestCurrentUserID(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
@ -233,3 +229,4 @@ func TestLoadCurrentUserP(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
_ = ab.LoadCurrentUserP(nil, &req)
|
_ = ab.LoadCurrentUserP(nil, &req)
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
22
expire.go
22
expire.go
@ -9,10 +9,12 @@ var nowTime = time.Now
|
|||||||
|
|
||||||
// TimeToExpiry returns zero if the user session is expired else the time until expiry.
|
// TimeToExpiry returns zero if the user session is expired else the time until expiry.
|
||||||
func (a *Authboss) TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration {
|
func (a *Authboss) TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration {
|
||||||
return a.timeToExpiry(a.SessionStoreMaker.Make(w, r))
|
//TODO(aarondl): Rewrite this so it makes sense with new ClientStorer idioms
|
||||||
|
//return a.timeToExpiry(state.(ClientState))
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration {
|
func (a *Authboss) timeToExpiry(session ClientState) time.Duration {
|
||||||
dateStr, ok := session.Get(SessionLastAction)
|
dateStr, ok := session.Get(SessionLastAction)
|
||||||
if !ok {
|
if !ok {
|
||||||
return a.ExpireAfter
|
return a.ExpireAfter
|
||||||
@ -33,12 +35,13 @@ func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration {
|
|||||||
|
|
||||||
// RefreshExpiry updates the last action for the user, so he doesn't become expired.
|
// RefreshExpiry updates the last action for the user, so he doesn't become expired.
|
||||||
func (a *Authboss) RefreshExpiry(w http.ResponseWriter, r *http.Request) {
|
func (a *Authboss) RefreshExpiry(w http.ResponseWriter, r *http.Request) {
|
||||||
session := a.SessionStoreMaker.Make(w, r)
|
//TODO(aarondl): Fix
|
||||||
a.refreshExpiry(session)
|
//a.refreshExpiry(session)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authboss) refreshExpiry(session ClientStorer) {
|
func (a *Authboss) refreshExpiry(session ClientState) {
|
||||||
session.Put(SessionLastAction, nowTime().UTC().Format(time.RFC3339))
|
//TODO(aarondl): Fix
|
||||||
|
PutSession(nil, SessionLastAction, nowTime().UTC().Format(time.RFC3339))
|
||||||
}
|
}
|
||||||
|
|
||||||
type expireMiddleware struct {
|
type expireMiddleware struct {
|
||||||
@ -57,8 +60,9 @@ func (a *Authboss) ExpireMiddleware(next http.Handler) http.Handler {
|
|||||||
|
|
||||||
// ServeHTTP removes the session if it's passed the expire time.
|
// ServeHTTP removes the session if it's passed the expire time.
|
||||||
func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
session := m.ab.SessionStoreMaker.Make(w, r)
|
//TODO(aarondl): Fix
|
||||||
if _, ok := session.Get(SessionKey); ok {
|
/*
|
||||||
|
if _, ok := GetSession(r, SessionKey); ok {
|
||||||
ttl := m.ab.timeToExpiry(session)
|
ttl := m.ab.timeToExpiry(session)
|
||||||
if ttl == 0 {
|
if ttl == 0 {
|
||||||
session.Del(SessionKey)
|
session.Del(SessionKey)
|
||||||
@ -68,5 +72,5 @@ func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.next.ServeHTTP(w, r)
|
m.next.ServeHTTP(w, r)*/
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,6 @@
|
|||||||
package authboss
|
package authboss
|
||||||
|
|
||||||
import (
|
/* TODO(aarondl): Re-enable
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// These tests use the global variable nowTime so cannot be parallelized
|
// These tests use the global variable nowTime so cannot be parallelized
|
||||||
|
|
||||||
@ -80,3 +75,4 @@ func TestDudeIsNotExpired(t *testing.T) {
|
|||||||
t.Error("Expected session key:", key)
|
t.Error("Expected session key:", key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package authboss
|
package authboss
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -78,35 +79,59 @@ func (m mockStoredUser) GetPassword(ctx context.Context) (password string, err e
|
|||||||
return m.Password, nil
|
return m.Password, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockClientStoreMaker struct {
|
type mockClientStateReadWriter struct {
|
||||||
store mockClientStore
|
state mockClientState
|
||||||
}
|
|
||||||
type mockClientStore map[string]string
|
|
||||||
|
|
||||||
func newMockClientStoreMaker(store mockClientStore) mockClientStoreMaker {
|
|
||||||
return mockClientStoreMaker{
|
|
||||||
store: store,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func (m mockClientStoreMaker) Make(w http.ResponseWriter, r *http.Request) ClientStorer {
|
|
||||||
return m.store
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m mockClientStore) Get(key string) (string, bool) {
|
type mockClientState map[string]string
|
||||||
v, ok := m[key]
|
|
||||||
return v, ok
|
|
||||||
}
|
|
||||||
func (m mockClientStore) GetErr(key string) (string, error) {
|
|
||||||
v, ok := m[key]
|
|
||||||
if !ok {
|
|
||||||
return v, ClientDataErr{key}
|
|
||||||
}
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
func (m mockClientStore) Put(key, val string) { m[key] = val }
|
|
||||||
func (m mockClientStore) Del(key string) { delete(m, key) }
|
|
||||||
|
|
||||||
func mockRequest(postKeyValues ...string) *http.Request {
|
func newMockClientStateRW(keyValue ...string) mockClientStateReadWriter {
|
||||||
|
state := mockClientState{}
|
||||||
|
for i := 0; i < len(keyValue); i += 2 {
|
||||||
|
key, value := keyValue[i], keyValue[i+1]
|
||||||
|
state[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
return mockClientStateReadWriter{state}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockClientStateReadWriter) ReadState(w http.ResponseWriter, r *http.Request) (ClientState, error) {
|
||||||
|
return m.state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockClientStateReadWriter) WriteState(w http.ResponseWriter, cs ClientState, evs []ClientStateEvent) error {
|
||||||
|
var state mockClientState
|
||||||
|
|
||||||
|
if cs != nil {
|
||||||
|
state = cs.(mockClientState)
|
||||||
|
} else {
|
||||||
|
state = mockClientState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ev := range evs {
|
||||||
|
switch ev.Kind {
|
||||||
|
case ClientStateEventPut:
|
||||||
|
state[ev.Key] = ev.Value
|
||||||
|
case ClientStateEventDel:
|
||||||
|
delete(state, ev.Key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := json.Marshal(state)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("test_session", string(b))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockClientState) Get(key string) (string, bool) {
|
||||||
|
val, ok := m[key]
|
||||||
|
return val, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockRequest(postKeyValues ...string) *http.Request {
|
||||||
urlValues := make(url.Values)
|
urlValues := make(url.Values)
|
||||||
for i := 0; i < len(postKeyValues); i += 2 {
|
for i := 0; i < len(postKeyValues); i += 2 {
|
||||||
urlValues.Set(postKeyValues[i], postKeyValues[i+1])
|
urlValues.Set(postKeyValues[i], postKeyValues[i+1])
|
||||||
@ -121,6 +146,27 @@ func mockRequest(postKeyValues ...string) *http.Request {
|
|||||||
return req
|
return req
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newMockAPIRequest(postKeyValues ...string) *http.Request {
|
||||||
|
kv := map[string]string{}
|
||||||
|
for i := 0; i < len(postKeyValues); i += 2 {
|
||||||
|
key, value := postKeyValues[i], postKeyValues[i+1]
|
||||||
|
kv[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := json.Marshal(kv)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", "http://localhost", bytes.NewReader(b))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
type mockValidator struct {
|
type mockValidator struct {
|
||||||
FieldName string
|
FieldName string
|
||||||
Errs ErrorList
|
Errs ErrorList
|
||||||
|
68
response.go
68
response.go
@ -7,30 +7,8 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RedirectOptions packages up all the pieces a module needs to write out a
|
// Respond to an HTTP request. Renders templates, flash messages, does XSRF
|
||||||
// response.
|
// and writes the headers out.
|
||||||
type RedirectOptions struct {
|
|
||||||
// Success & Failure are used to set Flash messages / JSON messages
|
|
||||||
// if set. They should be mutually exclusive.
|
|
||||||
Success string
|
|
||||||
Failure string
|
|
||||||
|
|
||||||
// Code is used when it's an API request instead of 200.
|
|
||||||
Code int
|
|
||||||
|
|
||||||
// When a request should redirect a user somewhere on completion, these
|
|
||||||
// should be set. RedirectURL tells it where to go. And optionally set
|
|
||||||
// FollowRedirParam to override the RedirectURL if the form parameter defined
|
|
||||||
// by FormValueRedirect is passed in the request.
|
|
||||||
//
|
|
||||||
// Redirecting works differently whether it's an API request or not.
|
|
||||||
// If it's an API request, then it will leave the URL in a "redirect"
|
|
||||||
// parameter.
|
|
||||||
RedirectPath string
|
|
||||||
FollowRedirParam bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Respond to an HTTP request.
|
|
||||||
func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, templateName string, data HTMLData) error {
|
func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, templateName string, data HTMLData) error {
|
||||||
data.MergeKV(
|
data.MergeKV(
|
||||||
"xsrfName", template.HTML(a.XSRFName),
|
"xsrfName", template.HTML(a.XSRFName),
|
||||||
@ -41,14 +19,13 @@ func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, tem
|
|||||||
data.Merge(a.LayoutDataMaker(w, r))
|
data.Merge(a.LayoutDataMaker(w, r))
|
||||||
}
|
}
|
||||||
|
|
||||||
session := a.SessionStoreMaker.Make(w, r)
|
flashSuccess := FlashSuccess(w, r)
|
||||||
if flash, ok := session.Get(FlashSuccessKey); ok {
|
flashError := FlashError(w, r)
|
||||||
session.Del(FlashSuccessKey)
|
if len(flashSuccess) != 0 {
|
||||||
data.MergeKV(FlashSuccessKey, flash)
|
data.MergeKV(FlashSuccessKey, flashSuccess)
|
||||||
}
|
}
|
||||||
if flash, ok := session.Get(FlashErrorKey); ok {
|
if len(flashError) != 0 {
|
||||||
session.Del(FlashErrorKey)
|
data.MergeKV(FlashErrorKey, flashError)
|
||||||
data.MergeKV(FlashErrorKey, flash)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rendered, mime, err := a.renderer.Render(r.Context(), templateName, data)
|
rendered, mime, err := a.renderer.Render(r.Context(), templateName, data)
|
||||||
@ -93,6 +70,29 @@ func (a *Authboss) Email(w http.ResponseWriter, r *http.Request, email Email, ro
|
|||||||
return a.Mailer.Send(ctx, email)
|
return a.Mailer.Send(ctx, email)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RedirectOptions packages up all the pieces a module needs to write out a
|
||||||
|
// response.
|
||||||
|
type RedirectOptions struct {
|
||||||
|
// Success & Failure are used to set Flash messages / JSON messages
|
||||||
|
// if set. They should be mutually exclusive.
|
||||||
|
Success string
|
||||||
|
Failure string
|
||||||
|
|
||||||
|
// Code is used when it's an API request instead of 200.
|
||||||
|
Code int
|
||||||
|
|
||||||
|
// When a request should redirect a user somewhere on completion, these
|
||||||
|
// should be set. RedirectURL tells it where to go. And optionally set
|
||||||
|
// FollowRedirParam to override the RedirectURL if the form parameter defined
|
||||||
|
// by FormValueRedirect is passed in the request.
|
||||||
|
//
|
||||||
|
// Redirecting works differently whether it's an API request or not.
|
||||||
|
// If it's an API request, then it will leave the URL in a "redirect"
|
||||||
|
// parameter.
|
||||||
|
RedirectPath string
|
||||||
|
FollowRedirParam bool
|
||||||
|
}
|
||||||
|
|
||||||
// Redirect the client elsewhere. If it's an API request it will simply render
|
// Redirect the client elsewhere. If it's an API request it will simply render
|
||||||
// a JSON response with information that should help a client to decide what
|
// a JSON response with information that should help a client to decide what
|
||||||
// to do.
|
// to do.
|
||||||
@ -155,12 +155,10 @@ func (a *Authboss) redirectNonAPI(w http.ResponseWriter, r *http.Request, ro Red
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(ro.Success) != 0 {
|
if len(ro.Success) != 0 {
|
||||||
session := a.SessionStoreMaker.Make(w, r)
|
PutSession(w, FlashSuccessKey, ro.Success)
|
||||||
session.Put(FlashSuccessKey, ro.Success)
|
|
||||||
}
|
}
|
||||||
if len(ro.Failure) != 0 {
|
if len(ro.Failure) != 0 {
|
||||||
session := a.SessionStoreMaker.Make(w, r)
|
PutSession(w, FlashErrorKey, ro.Failure)
|
||||||
session.Put(FlashErrorKey, ro.Failure)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
http.Redirect(w, r, path, http.StatusFound)
|
http.Redirect(w, r, path, http.StatusFound)
|
||||||
|
@ -31,8 +31,8 @@ func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
|
|||||||
ab.ViewLoader = mockRenderLoader{}
|
ab.ViewLoader = mockRenderLoader{}
|
||||||
ab.Init(testRouterModName)
|
ab.Init(testRouterModName)
|
||||||
ab.MountPath = "/prefix"
|
ab.MountPath = "/prefix"
|
||||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
//ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||||
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
//ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||||
|
|
||||||
logger.Reset() // Clear out the module load messages
|
logger.Reset() // Clear out the module load messages
|
||||||
|
|
||||||
|
@ -61,6 +61,10 @@ type Storer interface {
|
|||||||
Load(ctx context.Context) error
|
Load(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(aarondl): Document & move to Register module
|
||||||
|
// ArbitraryStorer allows arbitrary data from the web form through. You should
|
||||||
|
// definitely only pull the keys you want from the map, since this is unfiltered
|
||||||
|
// input from a web request and is an attack vector.
|
||||||
type ArbitraryStorer interface {
|
type ArbitraryStorer interface {
|
||||||
Storer
|
Storer
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ func TestErrorList_Map(t *testing.T) {
|
|||||||
func TestValidate(t *testing.T) {
|
func TestValidate(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
req := mockRequest(StoreUsername, "john", StoreEmail, "john@john.com")
|
req := newMockRequest(StoreUsername, "john", StoreEmail, "john@john.com")
|
||||||
|
|
||||||
errList := Validate(req, []Validator{
|
errList := Validate(req, []Validator{
|
||||||
mockValidator{
|
mockValidator{
|
||||||
@ -96,19 +96,19 @@ func TestValidate(t *testing.T) {
|
|||||||
func TestValidate_Confirm(t *testing.T) {
|
func TestValidate_Confirm(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
req := mockRequest(StoreUsername, "john", "confirmUsername", "johnny")
|
req := newMockRequest(StoreUsername, "john", "confirmUsername", "johnny")
|
||||||
errs := Validate(req, nil, StoreUsername, "confirmUsername").Map()
|
errs := Validate(req, nil, StoreUsername, "confirmUsername").Map()
|
||||||
if errs["confirmUsername"][0] != "Does not match username" {
|
if errs["confirmUsername"][0] != "Does not match username" {
|
||||||
t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0])
|
t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0])
|
||||||
}
|
}
|
||||||
|
|
||||||
req = mockRequest(StoreUsername, "john", "confirmUsername", "john")
|
req = newMockRequest(StoreUsername, "john", "confirmUsername", "john")
|
||||||
errs = Validate(req, nil, StoreUsername, "confirmUsername").Map()
|
errs = Validate(req, nil, StoreUsername, "confirmUsername").Map()
|
||||||
if len(errs) != 0 {
|
if len(errs) != 0 {
|
||||||
t.Error("Expected no errors:", errs)
|
t.Error("Expected no errors:", errs)
|
||||||
}
|
}
|
||||||
|
|
||||||
req = mockRequest(StoreUsername, "john", "confirmUsername", "john")
|
req = newMockRequest(StoreUsername, "john", "confirmUsername", "john")
|
||||||
errs = Validate(req, nil, StoreUsername).Map()
|
errs = Validate(req, nil, StoreUsername).Map()
|
||||||
if len(errs) != 0 {
|
if len(errs) != 0 {
|
||||||
t.Error("Expected no errors:", errs)
|
t.Error("Expected no errors:", errs)
|
||||||
|
Reference in New Issue
Block a user