1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-02-03 13:21:22 +02:00

Load and verify user logged in middleware

This commit is contained in:
Aaron L 2018-04-30 18:17:07 -07:00
parent 6dee0259e1
commit 4aa961f758
2 changed files with 87 additions and 1 deletions

View File

@ -8,6 +8,7 @@ package authboss
import (
"context"
"net/http"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
@ -41,7 +42,7 @@ func (a *Authboss) Init(modulesToLoad ...string) error {
for _, name := range modulesToLoad {
if err := a.loadModule(name); err != nil {
return errors.Errorf("module %s failed to load", name)
return errors.Errorf("module %s failed to load: %+v", name, err)
}
}
@ -75,3 +76,24 @@ func (a *Authboss) UpdatePassword(ctx context.Context, user AuthableUser, newPas
return rmStorer.DelRememberTokens(user.GetPID())
}
// Middleware prevents someone from accessing a route by returning a 404 if they are not logged in.
// This middleware also loads the current user.
func Middleware(ab *Authboss) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := ab.RequestLogger(r)
if u, err := ab.LoadCurrentUser(&r); err != nil {
log.Errorf("error fetching current user: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
} else if u == nil {
log.Infof("providing not found for unauthorized user at: %s", r.URL.Path)
w.WriteHeader(http.StatusNotFound)
return
} else {
next.ServeHTTP(w, r)
}
})
}
}

View File

@ -2,6 +2,8 @@ package authboss
import (
"context"
"net/http"
"net/http/httptest"
"testing"
)
@ -32,3 +34,65 @@ func TestAuthbossUpdatePassword(t *testing.T) {
t.Error("password was not updated")
}
}
func TestAuthbossMiddleware(t *testing.T) {
t.Parallel()
ab := New()
ab.Core.Logger = mockLogger{}
mid := Middleware(ab)
r := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
w := ab.NewResponse(rec)
called := false
hadUser := false
server := mid(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
hadUser = r.Context().Value(CTXKeyUser) != nil
w.WriteHeader(http.StatusOK)
}))
var err error
r, err = ab.LoadClientState(w, r)
if err != nil {
t.Fatal(err)
}
server.ServeHTTP(w, r)
if called || hadUser {
t.Error("should not be called or have a user when no session variables have been provided")
}
if rec.Code != http.StatusNotFound {
t.Error("want a not found code")
}
ab.Storage.SessionState = mockClientStateReadWriter{
state: mockClientState{SessionKey: "test@test.com"},
}
ab.Storage.Server = &mockServerStorer{
Users: map[string]*mockUser{
"test@test.com": &mockUser{},
},
}
r = httptest.NewRequest("GET", "/", nil)
rec = httptest.NewRecorder()
w = ab.NewResponse(rec)
r, err = ab.LoadClientState(w, r)
if err != nil {
t.Fatal(err)
}
server.ServeHTTP(w, r)
if !called {
t.Error("it should have been called")
}
if !hadUser {
t.Error("it should have had a user loaded")
}
if rec.Code != http.StatusOK {
t.Error("want a not found code")
}
}