mirror of
https://github.com/volatiletech/authboss.git
synced 2025-07-15 01:24:33 +02:00
Introduce new type of client storage
- This addresses the problem of having to update multiple times during one request. It's hard to have a nice interface especially with JWT because you always end up having to decode the request, encode new response, write header, then a second write to it comes, and where do you grab the value from? Often you don't have access to the response as a "read" structure. So we store it as events instead, and play those events against the original data right before the response is written to set the headers.
This commit is contained in:
@ -1,13 +1,8 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func TestAuthBossInit(t *testing.T) {
|
||||
@ -22,74 +17,6 @@ func TestAuthBossInit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthBossCurrentUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.LogWriter = ioutil.Discard
|
||||
ab.StoreLoader = mockStoreLoader{"joe": mockUser{Email: "john@john.com", Password: "lies"}}
|
||||
ab.ViewLoader = mockRenderLoader{}
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{SessionKey: "joe"})
|
||||
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
|
||||
if err := ab.Init(); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "localhost", nil)
|
||||
|
||||
userStruct := ab.CurrentUserP(rec, req)
|
||||
us := userStruct.(mockStoredUser)
|
||||
|
||||
if us.Email != "john@john.com" || us.Password != "lies" {
|
||||
t.Error("Wrong user found!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthBossCurrentUserCallbacks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.LogWriter = ioutil.Discard
|
||||
ab.StoreLoader = mockStoreLoader{"joe": mockUser{Email: "john@john.com", Password: "lies"}}
|
||||
ab.ViewLoader = mockRenderLoader{}
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{SessionKey: "joe"})
|
||||
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
|
||||
if err := ab.Init(); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "localhost", nil)
|
||||
|
||||
afterGetUser := errors.New("afterGetUser")
|
||||
beforeGetUser := errors.New("beforeGetUser")
|
||||
beforeGetUserSession := errors.New("beforeGetUserSession")
|
||||
|
||||
ab.Callbacks.After(EventGetUser, func(context.Context) error {
|
||||
return afterGetUser
|
||||
})
|
||||
if _, err := ab.CurrentUser(rec, req); err != afterGetUser {
|
||||
t.Error("Want:", afterGetUser, "Got:", err)
|
||||
}
|
||||
|
||||
ab.Callbacks.Before(EventGetUser, func(context.Context) (Interrupt, error) {
|
||||
return InterruptNone, beforeGetUser
|
||||
})
|
||||
if _, err := ab.CurrentUser(rec, req); err != beforeGetUser {
|
||||
t.Error("Want:", beforeGetUser, "Got:", err)
|
||||
}
|
||||
|
||||
ab.Callbacks.Before(EventGetUserSession, func(context.Context) (Interrupt, error) {
|
||||
return InterruptNone, beforeGetUserSession
|
||||
})
|
||||
if _, err := ab.CurrentUser(rec, req); err != beforeGetUserSession {
|
||||
t.Error("Want:", beforeGetUserSession, "Got:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthbossUpdatePassword(t *testing.T) {
|
||||
t.Skip("TODO(aarondl): Implement")
|
||||
/*
|
||||
|
254
client_state.go
Normal file
254
client_state.go
Normal file
@ -0,0 +1,254 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
// SessionKey is the primarily used key by authboss.
|
||||
SessionKey = "uid"
|
||||
// SessionHalfAuthKey is used for sessions that have been authenticated by
|
||||
// the remember module. This serves as a way to force full authentication
|
||||
// by denying half-authed users acccess to sensitive areas.
|
||||
SessionHalfAuthKey = "halfauth"
|
||||
// SessionLastAction is the session key to retrieve the last action of a user.
|
||||
SessionLastAction = "last_action"
|
||||
// SessionOAuth2State is the xsrf protection key for oauth.
|
||||
SessionOAuth2State = "oauth2_state"
|
||||
// SessionOAuth2Params is the additional settings for oauth like redirection/remember.
|
||||
SessionOAuth2Params = "oauth2_params"
|
||||
|
||||
// CookieRemember is used for cookies and form input names.
|
||||
CookieRemember = "rm"
|
||||
|
||||
// FlashSuccessKey is used for storing sucess flash messages on the session
|
||||
FlashSuccessKey = "flash_success"
|
||||
// FlashErrorKey is used for storing sucess flash messages on the session
|
||||
FlashErrorKey = "flash_error"
|
||||
)
|
||||
|
||||
// ClientStateEventKind is an enum.
|
||||
type ClientStateEventKind int
|
||||
|
||||
const (
|
||||
ClientStateEventPut ClientStateEventKind = iota
|
||||
ClientStateEventDel
|
||||
)
|
||||
|
||||
// ClientStateEvent are the different events that can be recorded during
|
||||
type ClientStateEvent struct {
|
||||
Kind ClientStateEventKind
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
// ClientStateReadWriter is used to create a cookie storer from an http request.
|
||||
// Keep in mind security considerations for your implementation, Secure,
|
||||
// HTTP-Only, etc flags.
|
||||
//
|
||||
// There's two major uses for this. To create session storage, and remember me
|
||||
// cookies.
|
||||
type ClientStateReadWriter interface {
|
||||
ReadState(http.ResponseWriter, *http.Request) (ClientState, error)
|
||||
WriteState(http.ResponseWriter, ClientState, []ClientStateEvent) error
|
||||
}
|
||||
|
||||
// ClientState represents the client's current state and can answer queries
|
||||
// about it.
|
||||
type ClientState interface {
|
||||
Get(key string) (string, bool)
|
||||
}
|
||||
|
||||
// clientStateResponseWriter is used to write out the client state at the last
|
||||
// moment before the response code is written.
|
||||
type ClientStateResponseWriter struct {
|
||||
ab *Authboss
|
||||
http.ResponseWriter
|
||||
|
||||
hasWritten bool
|
||||
ctx context.Context
|
||||
sessionStateEvents []ClientStateEvent
|
||||
cookieStateEvents []ClientStateEvent
|
||||
}
|
||||
|
||||
func (a *Authboss) NewResponse(w http.ResponseWriter, r *http.Request) http.ResponseWriter {
|
||||
return &ClientStateResponseWriter{
|
||||
ab: a,
|
||||
ResponseWriter: w,
|
||||
ctx: r.Context(),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Authboss) LoadClientState(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
|
||||
if a.SessionStateStorer != nil {
|
||||
state, err := a.SessionStateStorer.ReadState(w, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if state == nil {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), ctxKeySessionState, state)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
if a.CookieStateStorer != nil {
|
||||
state, err := a.CookieStateStorer.ReadState(w, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if state == nil {
|
||||
return r, nil
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), ctxKeyCookieState, state)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// WriteHeader writes the header, but in order to handle errors from the
|
||||
// underlying ClientStateReadWriter, it has to panic.
|
||||
func (c *ClientStateResponseWriter) WriteHeader(code int) {
|
||||
if !c.hasWritten {
|
||||
if err := c.putClientState(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
c.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Header retrieves the underlying headers
|
||||
func (c ClientStateResponseWriter) Header() http.Header {
|
||||
return c.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
// Write ensures that the
|
||||
func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
|
||||
if !c.hasWritten {
|
||||
if err := c.putClientState(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (c *ClientStateResponseWriter) putClientState() error {
|
||||
if c.hasWritten {
|
||||
panic("should not call putClientState twice")
|
||||
}
|
||||
c.hasWritten = true
|
||||
|
||||
sessionStateIntf := c.ctx.Value(ctxKeySessionState)
|
||||
cookieStateIntf := c.ctx.Value(ctxKeyCookieState)
|
||||
var session, cookie ClientState
|
||||
if sessionStateIntf != nil {
|
||||
session = sessionStateIntf.(ClientState)
|
||||
}
|
||||
if cookieStateIntf != nil {
|
||||
cookie = cookieStateIntf.(ClientState)
|
||||
}
|
||||
|
||||
if c.ab.SessionStateStorer != nil {
|
||||
err := c.ab.SessionStateStorer.WriteState(c, session, c.sessionStateEvents)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.ab.CookieStateStorer != nil {
|
||||
err := c.ab.CookieStateStorer.WriteState(c, cookie, c.cookieStateEvents)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PutSession puts a value into the session
|
||||
func PutSession(w http.ResponseWriter, key, val string) {
|
||||
putState(w, ctxKeySessionState, key, val)
|
||||
}
|
||||
|
||||
// DelSession deletes a key-value from the session.
|
||||
func DelSession(w http.ResponseWriter, key string) {
|
||||
delState(w, ctxKeySessionState, key)
|
||||
}
|
||||
|
||||
// GetSession fetches a value from the session
|
||||
func GetSession(r *http.Request, key string) (string, bool) {
|
||||
return getState(r, ctxKeySessionState, key)
|
||||
}
|
||||
|
||||
// PutCookie puts a value into the session
|
||||
func PutCookie(w http.ResponseWriter, key, val string) {
|
||||
putState(w, ctxKeyCookieState, key, val)
|
||||
}
|
||||
|
||||
// DelCookie deletes a key-value from the session.
|
||||
func DelCookie(w http.ResponseWriter, key string) {
|
||||
delState(w, ctxKeyCookieState, key)
|
||||
}
|
||||
|
||||
// GetCookie fetches a value from the session
|
||||
func GetCookie(r *http.Request, key string) (string, bool) {
|
||||
return getState(r, ctxKeyCookieState, key)
|
||||
}
|
||||
|
||||
func putState(w http.ResponseWriter, ctxKey contextKey, key, val string) {
|
||||
setState(w, ctxKey, ClientStateEventPut, key, val)
|
||||
}
|
||||
|
||||
func delState(w http.ResponseWriter, ctxKey contextKey, key string) {
|
||||
setState(w, ctxKey, ClientStateEventDel, key, "")
|
||||
}
|
||||
|
||||
func setState(w http.ResponseWriter, ctxKey contextKey, op ClientStateEventKind, key, val string) {
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
ev := ClientStateEvent{
|
||||
Kind: op,
|
||||
Key: key,
|
||||
}
|
||||
|
||||
if op == ClientStateEventPut {
|
||||
ev.Value = val
|
||||
}
|
||||
|
||||
switch ctxKey {
|
||||
case ctxKeySessionState:
|
||||
csrw.sessionStateEvents = append(csrw.sessionStateEvents, ev)
|
||||
case ctxKeyCookieState:
|
||||
csrw.cookieStateEvents = append(csrw.cookieStateEvents, ev)
|
||||
}
|
||||
}
|
||||
|
||||
func getState(r *http.Request, ctxKey contextKey, key string) (string, bool) {
|
||||
val := r.Context().Value(ctxKey)
|
||||
if val == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
state := val.(ClientState)
|
||||
return state.Get(key)
|
||||
}
|
||||
|
||||
// FlashSuccess returns FlashSuccessKey from the session and removes it.
|
||||
func FlashSuccess(w http.ResponseWriter, r *http.Request) string {
|
||||
str, ok := GetSession(r, FlashSuccessKey)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
DelSession(w, FlashSuccessKey)
|
||||
return str
|
||||
}
|
||||
|
||||
// FlashError returns FlashError from the session and removes it.
|
||||
func FlashError(w http.ResponseWriter, r *http.Request) string {
|
||||
str, ok := GetSession(r, FlashErrorKey)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
DelSession(w, FlashErrorKey)
|
||||
return str
|
||||
}
|
197
client_state_test.go
Normal file
197
client_state_test.go
Normal file
@ -0,0 +1,197 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStateGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStateStorer = newMockClientStateRW("one", "two")
|
||||
ab.CookieStateStorer = newMockClientStateRW("three", "four")
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
|
||||
var err error
|
||||
r, err = ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got, _ := GetSession(r, "one"); got != "two" {
|
||||
t.Error("session value was wrong:", got)
|
||||
}
|
||||
if got, _ := GetCookie(r, "three"); got != "four" {
|
||||
t.Error("cookie value was wrong:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStateStorer = newMockClientStateRW("one", "two")
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
|
||||
w.WriteHeader(200)
|
||||
// Check this doesn't panic
|
||||
w.WriteHeader(200)
|
||||
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Error("expected a panic")
|
||||
}
|
||||
}()
|
||||
|
||||
csrw.putClientState()
|
||||
}
|
||||
|
||||
func TestStateResponseWriterLastSecondWriteWithPrevious(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStateStorer = newMockClientStateRW("one", "two")
|
||||
ab.CookieStateStorer = newMockClientStateRW("three", "four")
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
var w http.ResponseWriter = httptest.NewRecorder()
|
||||
|
||||
var err error
|
||||
r, err = ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
w = ab.NewResponse(w, r)
|
||||
|
||||
w.WriteHeader(200)
|
||||
|
||||
// This is an odd test, since the mock will always overwrite the previous
|
||||
// write with the cookie values. Keeping it anyway for code coverage
|
||||
got := strings.TrimSpace(w.Header().Get("test_session"))
|
||||
if got != `{"three":"four"}` {
|
||||
t.Error("got:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateResponseWriterLastSecondWriteHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStateStorer = newMockClientStateRW()
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
|
||||
PutSession(w, "one", "two")
|
||||
|
||||
w.WriteHeader(200)
|
||||
got := strings.TrimSpace(w.Header().Get("test_session"))
|
||||
if got != `{"one":"two"}` {
|
||||
t.Error("got:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateResponseWriterLastSecondWriteWrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStateStorer = newMockClientStateRW()
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
|
||||
PutSession(w, "one", "two")
|
||||
|
||||
io.WriteString(w, "Hello world!")
|
||||
|
||||
got := strings.TrimSpace(w.Header().Get("test_session"))
|
||||
if got != `{"one":"two"}` {
|
||||
t.Error("got:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateResponseWriterEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
|
||||
PutSession(w, "one", "two")
|
||||
DelSession(w, "one")
|
||||
DelCookie(w, "one")
|
||||
PutCookie(w, "two", "one")
|
||||
|
||||
want := ClientStateEvent{Kind: ClientStateEventPut, Key: "one", Value: "two"}
|
||||
if got := csrw.sessionStateEvents[0]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
|
||||
want = ClientStateEvent{Kind: ClientStateEventDel, Key: "one"}
|
||||
if got := csrw.sessionStateEvents[1]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
|
||||
want = ClientStateEvent{Kind: ClientStateEventDel, Key: "one"}
|
||||
if got := csrw.cookieStateEvents[0]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
|
||||
want = ClientStateEvent{Kind: ClientStateEventPut, Key: "two", Value: "one"}
|
||||
if got := csrw.cookieStateEvents[1]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlashClearer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStateStorer = newMockClientStateRW(FlashSuccessKey, "a", FlashErrorKey, "b")
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
|
||||
if msg := FlashSuccess(w, r); msg != "" {
|
||||
t.Error("Unexpected flash success:", msg)
|
||||
}
|
||||
|
||||
if msg := FlashError(w, r); msg != "" {
|
||||
t.Error("Unexpected flash error:", msg)
|
||||
}
|
||||
|
||||
var err error
|
||||
r, err = ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if msg := FlashSuccess(w, r); msg != "a" {
|
||||
t.Error("Unexpected flash success:", msg)
|
||||
}
|
||||
|
||||
if msg := FlashError(w, r); msg != "b" {
|
||||
t.Error("Unexpected flash error:", msg)
|
||||
}
|
||||
|
||||
want := ClientStateEvent{Kind: ClientStateEventDel, Key: FlashSuccessKey}
|
||||
if got := csrw.sessionStateEvents[0]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
want = ClientStateEvent{Kind: ClientStateEventDel, Key: FlashErrorKey}
|
||||
if got := csrw.sessionStateEvents[1]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
}
|
@ -1,86 +0,0 @@
|
||||
package authboss
|
||||
|
||||
import "net/http"
|
||||
|
||||
const (
|
||||
// SessionKey is the primarily used key by authboss.
|
||||
SessionKey = "uid"
|
||||
// SessionHalfAuthKey is used for sessions that have been authenticated by
|
||||
// the remember module. This serves as a way to force full authentication
|
||||
// by denying half-authed users acccess to sensitive areas.
|
||||
SessionHalfAuthKey = "halfauth"
|
||||
// SessionLastAction is the session key to retrieve the last action of a user.
|
||||
SessionLastAction = "last_action"
|
||||
// SessionOAuth2State is the xsrf protection key for oauth.
|
||||
SessionOAuth2State = "oauth2_state"
|
||||
// SessionOAuth2Params is the additional settings for oauth like redirection/remember.
|
||||
SessionOAuth2Params = "oauth2_params"
|
||||
|
||||
// CookieRemember is used for cookies and form input names.
|
||||
CookieRemember = "rm"
|
||||
|
||||
// FlashSuccessKey is used for storing sucess flash messages on the session
|
||||
FlashSuccessKey = "flash_success"
|
||||
// FlashErrorKey is used for storing sucess flash messages on the session
|
||||
FlashErrorKey = "flash_error"
|
||||
)
|
||||
|
||||
// ClientStoreMaker is used to create a cookie storer from an http request.
|
||||
// Keep in mind security considerations for your implementation, Secure,
|
||||
// HTTP-Only, etc flags.
|
||||
//
|
||||
// There's two major uses for this. To create session storage, and remember me
|
||||
// cookies.
|
||||
type ClientStoreMaker interface {
|
||||
Make(http.ResponseWriter, *http.Request) ClientStorer
|
||||
}
|
||||
|
||||
// ClientStorer should be able to store values on the clients machine. Cookie and
|
||||
// Session storers are built with this interface.
|
||||
type ClientStorer interface {
|
||||
Put(key, value string)
|
||||
Get(key string) (string, bool)
|
||||
Del(key string)
|
||||
}
|
||||
|
||||
// ClientStorerErr is a wrapper to return error values from failed Gets.
|
||||
type ClientStorerErr interface {
|
||||
ClientStorer
|
||||
GetErr(key string) (string, error)
|
||||
}
|
||||
|
||||
type clientStoreWrapper struct {
|
||||
ClientStorer
|
||||
}
|
||||
|
||||
// GetErr returns a value or an error.
|
||||
func (c clientStoreWrapper) GetErr(key string) (string, error) {
|
||||
str, ok := c.Get(key)
|
||||
if !ok {
|
||||
return str, ClientDataErr{key}
|
||||
}
|
||||
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// FlashSuccess returns FlashSuccessKey from the session and removes it.
|
||||
func (a *Authboss) FlashSuccess(w http.ResponseWriter, r *http.Request) string {
|
||||
storer := a.SessionStoreMaker.Make(w, r)
|
||||
msg, ok := storer.Get(FlashSuccessKey)
|
||||
if ok {
|
||||
storer.Del(FlashSuccessKey)
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// FlashError returns FlashError from the session and removes it.
|
||||
func (a *Authboss) FlashError(w http.ResponseWriter, r *http.Request) string {
|
||||
storer := a.SessionStoreMaker.Make(w, r)
|
||||
msg, ok := storer.Get(FlashErrorKey)
|
||||
if ok {
|
||||
storer.Del(FlashErrorKey)
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
package authboss
|
||||
|
||||
import "testing"
|
||||
|
||||
type testClientStorerErr string
|
||||
|
||||
func (t testClientStorerErr) Put(key, value string) {}
|
||||
func (t testClientStorerErr) Get(key string) (string, bool) {
|
||||
return string(t), key == string(t)
|
||||
}
|
||||
func (t testClientStorerErr) Del(key string) {}
|
||||
|
||||
func TestClientStorerErr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var cs testClientStorerErr
|
||||
|
||||
csw := clientStoreWrapper{&cs}
|
||||
if _, err := csw.GetErr("hello"); err == nil {
|
||||
t.Error("Expected an error")
|
||||
}
|
||||
|
||||
cs = "hello"
|
||||
if str, err := csw.GetErr("hello"); err != nil {
|
||||
t.Error(err)
|
||||
} else if str != "hello" {
|
||||
t.Error("Wrong value:", str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlashClearer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
session := mockClientStore{FlashSuccessKey: "success", FlashErrorKey: "error"}
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(session)
|
||||
|
||||
if msg := ab.FlashSuccess(nil, nil); msg != "success" {
|
||||
t.Error("Unexpected flash success:", msg)
|
||||
}
|
||||
if msg, ok := session.Get(FlashSuccessKey); ok {
|
||||
t.Error("Unexpected success flash:", msg)
|
||||
}
|
||||
|
||||
if msg := ab.FlashError(nil, nil); msg != "error" {
|
||||
t.Error("Unexpected flash error:", msg)
|
||||
}
|
||||
if msg, ok := session.Get(FlashErrorKey); ok {
|
||||
t.Error("Unexpected error flash:", msg)
|
||||
}
|
||||
|
||||
}
|
18
config.go
18
config.go
@ -97,14 +97,16 @@ type Config struct {
|
||||
// Storer is the interface through which Authboss accesses the web apps database.
|
||||
StoreLoader StoreLoader
|
||||
|
||||
// CookieStoreMaker must be defined to provide an interface capapable of storing cookies
|
||||
// for the given response, and reading them from the request.
|
||||
CookieStoreMaker ClientStoreMaker
|
||||
// SessionStoreMaker must be defined to provide an interface capable of storing session-only
|
||||
// values for the given response, and reading them from the request.
|
||||
SessionStoreMaker ClientStoreMaker
|
||||
// LogWriter is written to when errors occur, as well as on startup to show which modules are loaded
|
||||
// and which routes they registered. By default writes to io.Discard.
|
||||
// CookieStateStorer must be defined to provide an interface capapable of
|
||||
// storing cookies for the given response, and reading them from the request.
|
||||
CookieStateStorer ClientStateReadWriter
|
||||
// SessionStateStorer must be defined to provide an interface capable of
|
||||
// storing session-only values for the given response, and reading them
|
||||
// from the request.
|
||||
SessionStateStorer ClientStateReadWriter
|
||||
// LogWriter is written to when errors occur, as well as on startup to show
|
||||
// which modules are loaded and which routes they registered. By default
|
||||
// writes to io.Discard.
|
||||
LogWriter io.Writer
|
||||
// Mailer is the mailer being used to send e-mails out. Authboss defines two loggers for use
|
||||
// LogMailer and SMTPMailer, the default is a LogMailer to io.Discard.
|
||||
|
@ -10,6 +10,9 @@ type contextKey string
|
||||
const (
|
||||
ctxKeyPID contextKey = "pid"
|
||||
ctxKeyUser contextKey = "user"
|
||||
|
||||
ctxKeySessionState contextKey = "session"
|
||||
ctxKeyCookieState contextKey = "cookie"
|
||||
)
|
||||
|
||||
func (c contextKey) String() string {
|
||||
@ -27,8 +30,7 @@ func (a *Authboss) CurrentUserID(w http.ResponseWriter, r *http.Request) (string
|
||||
return "", err
|
||||
}
|
||||
|
||||
session := a.SessionStoreMaker.Make(w, r)
|
||||
pid, _ := session.Get(SessionKey)
|
||||
pid, _ := GetSession(r, SessionKey)
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,6 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
/* TODO(aarondl): Re-enable
|
||||
|
||||
func TestCurrentUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
@ -233,3 +229,4 @@ func TestLoadCurrentUserP(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
_ = ab.LoadCurrentUserP(nil, &req)
|
||||
}
|
||||
*/
|
||||
|
36
expire.go
36
expire.go
@ -9,10 +9,12 @@ var nowTime = time.Now
|
||||
|
||||
// TimeToExpiry returns zero if the user session is expired else the time until expiry.
|
||||
func (a *Authboss) TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration {
|
||||
return a.timeToExpiry(a.SessionStoreMaker.Make(w, r))
|
||||
//TODO(aarondl): Rewrite this so it makes sense with new ClientStorer idioms
|
||||
//return a.timeToExpiry(state.(ClientState))
|
||||
return 0
|
||||
}
|
||||
|
||||
func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration {
|
||||
func (a *Authboss) timeToExpiry(session ClientState) time.Duration {
|
||||
dateStr, ok := session.Get(SessionLastAction)
|
||||
if !ok {
|
||||
return a.ExpireAfter
|
||||
@ -33,12 +35,13 @@ func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration {
|
||||
|
||||
// RefreshExpiry updates the last action for the user, so he doesn't become expired.
|
||||
func (a *Authboss) RefreshExpiry(w http.ResponseWriter, r *http.Request) {
|
||||
session := a.SessionStoreMaker.Make(w, r)
|
||||
a.refreshExpiry(session)
|
||||
//TODO(aarondl): Fix
|
||||
//a.refreshExpiry(session)
|
||||
}
|
||||
|
||||
func (a *Authboss) refreshExpiry(session ClientStorer) {
|
||||
session.Put(SessionLastAction, nowTime().UTC().Format(time.RFC3339))
|
||||
func (a *Authboss) refreshExpiry(session ClientState) {
|
||||
//TODO(aarondl): Fix
|
||||
PutSession(nil, SessionLastAction, nowTime().UTC().Format(time.RFC3339))
|
||||
}
|
||||
|
||||
type expireMiddleware struct {
|
||||
@ -57,16 +60,17 @@ func (a *Authboss) ExpireMiddleware(next http.Handler) http.Handler {
|
||||
|
||||
// ServeHTTP removes the session if it's passed the expire time.
|
||||
func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
session := m.ab.SessionStoreMaker.Make(w, r)
|
||||
if _, ok := session.Get(SessionKey); ok {
|
||||
ttl := m.ab.timeToExpiry(session)
|
||||
if ttl == 0 {
|
||||
session.Del(SessionKey)
|
||||
session.Del(SessionLastAction)
|
||||
} else {
|
||||
m.ab.refreshExpiry(session)
|
||||
//TODO(aarondl): Fix
|
||||
/*
|
||||
if _, ok := GetSession(r, SessionKey); ok {
|
||||
ttl := m.ab.timeToExpiry(session)
|
||||
if ttl == 0 {
|
||||
session.Del(SessionKey)
|
||||
session.Del(SessionLastAction)
|
||||
} else {
|
||||
m.ab.refreshExpiry(session)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.next.ServeHTTP(w, r)
|
||||
m.next.ServeHTTP(w, r)*/
|
||||
}
|
||||
|
@ -1,11 +1,6 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
/* TODO(aarondl): Re-enable
|
||||
|
||||
// These tests use the global variable nowTime so cannot be parallelized
|
||||
|
||||
@ -80,3 +75,4 @@ func TestDudeIsNotExpired(t *testing.T) {
|
||||
t.Error("Expected session key:", key)
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
@ -1,6 +1,7 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
@ -78,35 +79,59 @@ func (m mockStoredUser) GetPassword(ctx context.Context) (password string, err e
|
||||
return m.Password, nil
|
||||
}
|
||||
|
||||
type mockClientStoreMaker struct {
|
||||
store mockClientStore
|
||||
type mockClientStateReadWriter struct {
|
||||
state mockClientState
|
||||
}
|
||||
type mockClientStore map[string]string
|
||||
|
||||
func newMockClientStoreMaker(store mockClientStore) mockClientStoreMaker {
|
||||
return mockClientStoreMaker{
|
||||
store: store,
|
||||
type mockClientState map[string]string
|
||||
|
||||
func newMockClientStateRW(keyValue ...string) mockClientStateReadWriter {
|
||||
state := mockClientState{}
|
||||
for i := 0; i < len(keyValue); i += 2 {
|
||||
key, value := keyValue[i], keyValue[i+1]
|
||||
state[key] = value
|
||||
}
|
||||
}
|
||||
func (m mockClientStoreMaker) Make(w http.ResponseWriter, r *http.Request) ClientStorer {
|
||||
return m.store
|
||||
|
||||
return mockClientStateReadWriter{state}
|
||||
}
|
||||
|
||||
func (m mockClientStore) Get(key string) (string, bool) {
|
||||
v, ok := m[key]
|
||||
return v, ok
|
||||
func (m mockClientStateReadWriter) ReadState(w http.ResponseWriter, r *http.Request) (ClientState, error) {
|
||||
return m.state, nil
|
||||
}
|
||||
func (m mockClientStore) GetErr(key string) (string, error) {
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return v, ClientDataErr{key}
|
||||
|
||||
func (m mockClientStateReadWriter) WriteState(w http.ResponseWriter, cs ClientState, evs []ClientStateEvent) error {
|
||||
var state mockClientState
|
||||
|
||||
if cs != nil {
|
||||
state = cs.(mockClientState)
|
||||
} else {
|
||||
state = mockClientState{}
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
func (m mockClientStore) Put(key, val string) { m[key] = val }
|
||||
func (m mockClientStore) Del(key string) { delete(m, key) }
|
||||
|
||||
func mockRequest(postKeyValues ...string) *http.Request {
|
||||
for _, ev := range evs {
|
||||
switch ev.Kind {
|
||||
case ClientStateEventPut:
|
||||
state[ev.Key] = ev.Value
|
||||
case ClientStateEventDel:
|
||||
delete(state, ev.Key)
|
||||
}
|
||||
}
|
||||
|
||||
b, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.Header().Set("test_session", string(b))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockClientState) Get(key string) (string, bool) {
|
||||
val, ok := m[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func newMockRequest(postKeyValues ...string) *http.Request {
|
||||
urlValues := make(url.Values)
|
||||
for i := 0; i < len(postKeyValues); i += 2 {
|
||||
urlValues.Set(postKeyValues[i], postKeyValues[i+1])
|
||||
@ -121,6 +146,27 @@ func mockRequest(postKeyValues ...string) *http.Request {
|
||||
return req
|
||||
}
|
||||
|
||||
func newMockAPIRequest(postKeyValues ...string) *http.Request {
|
||||
kv := map[string]string{}
|
||||
for i := 0; i < len(postKeyValues); i += 2 {
|
||||
key, value := postKeyValues[i], postKeyValues[i+1]
|
||||
kv[key] = value
|
||||
}
|
||||
|
||||
b, err := json.Marshal(kv)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", "http://localhost", bytes.NewReader(b))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
type mockValidator struct {
|
||||
FieldName string
|
||||
Errs ErrorList
|
||||
|
68
response.go
68
response.go
@ -7,30 +7,8 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// RedirectOptions packages up all the pieces a module needs to write out a
|
||||
// response.
|
||||
type RedirectOptions struct {
|
||||
// Success & Failure are used to set Flash messages / JSON messages
|
||||
// if set. They should be mutually exclusive.
|
||||
Success string
|
||||
Failure string
|
||||
|
||||
// Code is used when it's an API request instead of 200.
|
||||
Code int
|
||||
|
||||
// When a request should redirect a user somewhere on completion, these
|
||||
// should be set. RedirectURL tells it where to go. And optionally set
|
||||
// FollowRedirParam to override the RedirectURL if the form parameter defined
|
||||
// by FormValueRedirect is passed in the request.
|
||||
//
|
||||
// Redirecting works differently whether it's an API request or not.
|
||||
// If it's an API request, then it will leave the URL in a "redirect"
|
||||
// parameter.
|
||||
RedirectPath string
|
||||
FollowRedirParam bool
|
||||
}
|
||||
|
||||
// Respond to an HTTP request.
|
||||
// Respond to an HTTP request. Renders templates, flash messages, does XSRF
|
||||
// and writes the headers out.
|
||||
func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, templateName string, data HTMLData) error {
|
||||
data.MergeKV(
|
||||
"xsrfName", template.HTML(a.XSRFName),
|
||||
@ -41,14 +19,13 @@ func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, tem
|
||||
data.Merge(a.LayoutDataMaker(w, r))
|
||||
}
|
||||
|
||||
session := a.SessionStoreMaker.Make(w, r)
|
||||
if flash, ok := session.Get(FlashSuccessKey); ok {
|
||||
session.Del(FlashSuccessKey)
|
||||
data.MergeKV(FlashSuccessKey, flash)
|
||||
flashSuccess := FlashSuccess(w, r)
|
||||
flashError := FlashError(w, r)
|
||||
if len(flashSuccess) != 0 {
|
||||
data.MergeKV(FlashSuccessKey, flashSuccess)
|
||||
}
|
||||
if flash, ok := session.Get(FlashErrorKey); ok {
|
||||
session.Del(FlashErrorKey)
|
||||
data.MergeKV(FlashErrorKey, flash)
|
||||
if len(flashError) != 0 {
|
||||
data.MergeKV(FlashErrorKey, flashError)
|
||||
}
|
||||
|
||||
rendered, mime, err := a.renderer.Render(r.Context(), templateName, data)
|
||||
@ -93,6 +70,29 @@ func (a *Authboss) Email(w http.ResponseWriter, r *http.Request, email Email, ro
|
||||
return a.Mailer.Send(ctx, email)
|
||||
}
|
||||
|
||||
// RedirectOptions packages up all the pieces a module needs to write out a
|
||||
// response.
|
||||
type RedirectOptions struct {
|
||||
// Success & Failure are used to set Flash messages / JSON messages
|
||||
// if set. They should be mutually exclusive.
|
||||
Success string
|
||||
Failure string
|
||||
|
||||
// Code is used when it's an API request instead of 200.
|
||||
Code int
|
||||
|
||||
// When a request should redirect a user somewhere on completion, these
|
||||
// should be set. RedirectURL tells it where to go. And optionally set
|
||||
// FollowRedirParam to override the RedirectURL if the form parameter defined
|
||||
// by FormValueRedirect is passed in the request.
|
||||
//
|
||||
// Redirecting works differently whether it's an API request or not.
|
||||
// If it's an API request, then it will leave the URL in a "redirect"
|
||||
// parameter.
|
||||
RedirectPath string
|
||||
FollowRedirParam bool
|
||||
}
|
||||
|
||||
// Redirect the client elsewhere. If it's an API request it will simply render
|
||||
// a JSON response with information that should help a client to decide what
|
||||
// to do.
|
||||
@ -155,12 +155,10 @@ func (a *Authboss) redirectNonAPI(w http.ResponseWriter, r *http.Request, ro Red
|
||||
}
|
||||
|
||||
if len(ro.Success) != 0 {
|
||||
session := a.SessionStoreMaker.Make(w, r)
|
||||
session.Put(FlashSuccessKey, ro.Success)
|
||||
PutSession(w, FlashSuccessKey, ro.Success)
|
||||
}
|
||||
if len(ro.Failure) != 0 {
|
||||
session := a.SessionStoreMaker.Make(w, r)
|
||||
session.Put(FlashErrorKey, ro.Failure)
|
||||
PutSession(w, FlashErrorKey, ro.Failure)
|
||||
}
|
||||
|
||||
http.Redirect(w, r, path, http.StatusFound)
|
||||
|
@ -31,8 +31,8 @@ func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
|
||||
ab.ViewLoader = mockRenderLoader{}
|
||||
ab.Init(testRouterModName)
|
||||
ab.MountPath = "/prefix"
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
//ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
//ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
|
||||
logger.Reset() // Clear out the module load messages
|
||||
|
||||
|
@ -61,6 +61,10 @@ type Storer interface {
|
||||
Load(ctx context.Context) error
|
||||
}
|
||||
|
||||
// TODO(aarondl): Document & move to Register module
|
||||
// ArbitraryStorer allows arbitrary data from the web form through. You should
|
||||
// definitely only pull the keys you want from the map, since this is unfiltered
|
||||
// input from a web request and is an attack vector.
|
||||
type ArbitraryStorer interface {
|
||||
Storer
|
||||
|
||||
|
@ -65,7 +65,7 @@ func TestErrorList_Map(t *testing.T) {
|
||||
func TestValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := mockRequest(StoreUsername, "john", StoreEmail, "john@john.com")
|
||||
req := newMockRequest(StoreUsername, "john", StoreEmail, "john@john.com")
|
||||
|
||||
errList := Validate(req, []Validator{
|
||||
mockValidator{
|
||||
@ -96,19 +96,19 @@ func TestValidate(t *testing.T) {
|
||||
func TestValidate_Confirm(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := mockRequest(StoreUsername, "john", "confirmUsername", "johnny")
|
||||
req := newMockRequest(StoreUsername, "john", "confirmUsername", "johnny")
|
||||
errs := Validate(req, nil, StoreUsername, "confirmUsername").Map()
|
||||
if errs["confirmUsername"][0] != "Does not match username" {
|
||||
t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0])
|
||||
}
|
||||
|
||||
req = mockRequest(StoreUsername, "john", "confirmUsername", "john")
|
||||
req = newMockRequest(StoreUsername, "john", "confirmUsername", "john")
|
||||
errs = Validate(req, nil, StoreUsername, "confirmUsername").Map()
|
||||
if len(errs) != 0 {
|
||||
t.Error("Expected no errors:", errs)
|
||||
}
|
||||
|
||||
req = mockRequest(StoreUsername, "john", "confirmUsername", "john")
|
||||
req = newMockRequest(StoreUsername, "john", "confirmUsername", "john")
|
||||
errs = Validate(req, nil, StoreUsername).Map()
|
||||
if len(errs) != 0 {
|
||||
t.Error("Expected no errors:", errs)
|
||||
|
Reference in New Issue
Block a user