1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-02-15 14:03:17 +02:00

Rewrite module loading to be per-instance

This commit is contained in:
Aaron 2015-03-31 15:08:43 -07:00
parent d6c0eb8684
commit 9ff0b65629
7 changed files with 154 additions and 67 deletions

View File

@ -21,6 +21,10 @@ import (
type Authboss struct { type Authboss struct {
Config Config
Callbacks *Callbacks Callbacks *Callbacks
loadedModules map[string]Modularizer
moduleAttributes AttributeMeta
mux *http.ServeMux
} }
// New makes a new instance of authboss with a default // New makes a new instance of authboss with a default
@ -28,20 +32,33 @@ type Authboss struct {
func New() *Authboss { func New() *Authboss {
ab := &Authboss{ ab := &Authboss{
Callbacks: NewCallbacks(), Callbacks: NewCallbacks(),
loadedModules: make(map[string]Modularizer),
moduleAttributes: make(AttributeMeta),
} }
ab.Defaults() ab.Config.Defaults()
return ab return ab
} }
// Init authboss and it's loaded modules. // Init authboss and the requested modules. modulesToLoad is left empty
func (a *Authboss) Init() error { // all registered modules will be loaded.
for name, mod := range modules { func (a *Authboss) Init(modulesToLoad ...string) error {
fmt.Fprintf(a.LogWriter, "%-10s Initializing\n", "["+name+"]") if len(modulesToLoad) == 0 {
if err := mod.Initialize(a); err != nil { modulesToLoad = RegisteredModules()
}
for _, name := range modulesToLoad {
fmt.Fprintf(a.LogWriter, "%-10s Loading\n", "["+name+"]")
if err := a.loadModule(name); err != nil {
return fmt.Errorf("[%s] Error Initializing: %v", name, err) return fmt.Errorf("[%s] Error Initializing: %v", name, err)
} }
} }
for _, mod := range a.loadedModules {
for k, v := range mod.Storage() {
a.moduleAttributes[k] = v
}
}
return nil return nil
} }

View File

@ -279,7 +279,7 @@ func (m *MockClientStorer) Put(key, val string) { m.Values[key] = val }
func (m *MockClientStorer) Del(key string) { delete(m.Values, key) } func (m *MockClientStorer) Del(key string) { delete(m.Values, key) }
// MockRequestContext returns a new context as if it came from POST request. // MockRequestContext returns a new context as if it came from POST request.
func MockRequestContext(postKeyValues ...string) *authboss.Context { func MockRequestContext(ab authboss.Authboss, postKeyValues ...string) *authboss.Context {
keyValues := &bytes.Buffer{} keyValues := &bytes.Buffer{}
for i := 0; i < len(postKeyValues); i += 2 { for i := 0; i < len(postKeyValues); i += 2 {
if i != 0 { if i != 0 {
@ -294,7 +294,7 @@ func MockRequestContext(postKeyValues ...string) *authboss.Context {
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := authboss.ContextFromRequest(req) ctx, err := ab.ContextFromRequest(req)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -21,30 +21,30 @@ import (
var ( var (
// ErrTemplateNotFound should be returned from Get when the view is not found // ErrTemplateNotFound should be returned from Get when the view is not found
ErrTemplateNotFound = errors.New("Template not found") ErrTemplateNotFound = errors.New("Template not found")
funcMap = template.FuncMap{
"title": strings.Title,
"mountpathed": func(location string) string {
if authboss.a.MountPath == "/" {
return location
}
return path.Join(authboss.a.MountPath, location)
},
}
) )
// Templates is a map depicting the forms a template needs wrapped within the specified layout // Templates is a map depicting the forms a template needs wrapped within the specified layout
type Templates map[string]*template.Template type Templates map[string]*template.Template
// LoadTemplates parses all specified files located in path. Each template is wrapped // LoadTemplates parses all specified files located in fpath. Each template is wrapped
// in a unique clone of layout. All templates are expecting {{authboss}} handlebars // in a unique clone of layout. All templates are expecting {{authboss}} handlebars
// for parsing. It will check the override directory specified in the config, replacing any // for parsing. It will check the override directory specified in the config, replacing any
// templates as necessary. // templates as necessary.
func LoadTemplates(layout *template.Template, path string, files ...string) (Templates, error) { func LoadTemplates(ab *authboss.Authboss, layout *template.Template, fpath string, files ...string) (Templates, error) {
m := make(Templates) m := make(Templates)
funcMap := template.FuncMap{
"title": strings.Title,
"mountpathed": func(location string) string {
if ab.MountPath == "/" {
return location
}
return path.Join(ab.MountPath, location)
},
}
for _, file := range files { for _, file := range files {
b, err := ioutil.ReadFile(filepath.Join(path, file)) b, err := ioutil.ReadFile(filepath.Join(fpath, file))
if exists := !os.IsNotExist(err); err != nil && exists { if exists := !os.IsNotExist(err); err != nil && exists {
return nil, err return nil, err
} else if !exists { } else if !exists {
@ -77,10 +77,13 @@ func (t Templates) Render(ctx *authboss.Context, w http.ResponseWriter, r *http.
return authboss.RenderErr{tpl.Name(), data, ErrTemplateNotFound} return authboss.RenderErr{tpl.Name(), data, ErrTemplateNotFound}
} }
data.MergeKV("xsrfName", template.HTML(authboss.a.XSRFName), "xsrfToken", template.HTML(authboss.a.XSRFMaker(w, r))) data.MergeKV(
"xsrfName", template.HTML(ctx.XSRFName),
"xsrfToken", template.HTML(ctx.XSRFMaker(w, r)),
)
if authboss.a.LayoutDataMaker != nil { if ctx.LayoutDataMaker != nil {
data.Merge(authboss.a.LayoutDataMaker(w, r)) data.Merge(ctx.LayoutDataMaker(w, r))
} }
if flash, ok := ctx.SessionStorer.Get(authboss.FlashSuccessKey); ok { if flash, ok := ctx.SessionStorer.Get(authboss.FlashSuccessKey); ok {
@ -107,7 +110,7 @@ func (t Templates) Render(ctx *authboss.Context, w http.ResponseWriter, r *http.
} }
// RenderEmail renders the html and plaintext views for an email and sends it // RenderEmail renders the html and plaintext views for an email and sends it
func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls Templates, namePlain string, data interface{}) error { func Email(mailer authboss.Mailer, email authboss.Email, htmlTpls Templates, nameHTML string, textTpls Templates, namePlain string, data interface{}) error {
tplHTML, ok := htmlTpls[nameHTML] tplHTML, ok := htmlTpls[nameHTML]
if !ok { if !ok {
return authboss.RenderErr{tplHTML.Name(), data, ErrTemplateNotFound} return authboss.RenderErr{tplHTML.Name(), data, ErrTemplateNotFound}
@ -130,7 +133,7 @@ func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls T
} }
email.TextBody = plainBuffer.String() email.TextBody = plainBuffer.String()
if err := authboss.a.Mailer.Send(email); err != nil { if err := mailer.Send(email); err != nil {
return err return err
} }
@ -138,8 +141,11 @@ func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls T
} }
// Redirect sets any flash messages given and redirects the user. // Redirect sets any flash messages given and redirects the user.
func Redirect(ctx *authboss.Context, w http.ResponseWriter, r *http.Request, path, flashSuccess, flashError string, overrideableRedir bool) { // If flashSuccess or flashError are set they will be set in the session.
if redir := r.FormValue("redir"); redir != "" && overrideableRedir { // If followRedir is set to true, it will attempt to grab the redirect path from the
// query string.
func Redirect(ctx *authboss.Context, w http.ResponseWriter, r *http.Request, path, flashSuccess, flashError string, followRedir bool) {
if redir := r.FormValue(authboss.FormValueRedirect); redir != "" && followRedir {
path = redir path = redir
} }

View File

@ -1,11 +1,8 @@
package authboss package authboss
var modules = make(map[string]Modularizer) import "reflect"
// ModuleAttributes is the list of attributes required by all the loaded modules. var registeredModules = make(map[string]Modularizer)
// Authboss implementers can use this at runtime to determine what data is necessary
// to store.
var ModuleAttributes = make(AttributeMeta)
// Modularizer should be implemented by all the authboss modules. // Modularizer should be implemented by all the authboss modules.
type Modularizer interface { type Modularizer interface {
@ -17,18 +14,56 @@ type Modularizer interface {
// RegisterModule with the core providing all the necessary information to // RegisterModule with the core providing all the necessary information to
// integrate into authboss. // integrate into authboss.
func RegisterModule(name string, m Modularizer) { func RegisterModule(name string, m Modularizer) {
modules[name] = m registeredModules[name] = m
}
for k, v := range m.Storage() { // RegisteredModules returns a list of modules that are currently registered.
ModuleAttributes[k] = v func RegisteredModules() []string {
mods := make([]string, len(registeredModules))
i := 0
for k := range registeredModules {
mods[i] = k
i++
} }
return mods
}
// loadModule loads a particular module. It uses reflection to create a new
// instance of the module type. The original value is copied, but not deep copied
// so care should be taken to make sure most initialization happens inside the Initialize()
// method of the module.
func (a *Authboss) loadModule(name string) error {
module, ok := registeredModules[name]
if !ok {
panic("Could not find module: " + name)
}
var wasPtr bool
modVal := reflect.ValueOf(module)
if modVal.Kind() == reflect.Ptr {
wasPtr = true
modVal = modVal.Elem()
}
modType := modVal.Type()
value := reflect.New(modType)
if !wasPtr {
value = value.Elem()
value.Set(modVal)
} else {
value.Elem().Set(modVal)
}
mod, ok := value.Interface().(Modularizer)
a.loadedModules[name] = mod
return mod.Initialize(a)
} }
// LoadedModules returns a list of modules that are currently loaded. // LoadedModules returns a list of modules that are currently loaded.
func LoadedModules() []string { func (a *Authboss) LoadedModules() []string {
mods := make([]string, len(modules)) mods := make([]string, len(a.loadedModules))
i := 0 i := 0
for k := range modules { for k := range a.loadedModules {
mods[i] = k mods[i] = k
i++ i++
} }
@ -37,7 +72,7 @@ func LoadedModules() []string {
} }
// IsLoaded checks if a specific module is loaded. // IsLoaded checks if a specific module is loaded.
func IsLoaded(mod string) bool { func (a *Authboss) IsLoaded(mod string) bool {
_, ok := modules[mod] _, ok := a.loadedModules[mod]
return ok return ok
} }

View File

@ -1,12 +1,17 @@
package authboss package authboss
import ( import (
"io/ioutil"
"net/http" "net/http"
"testing" "testing"
) )
const testModName = "testmodule" const testModName = "testmodule"
func init() {
RegisterModule(testModName, testMod)
}
type testModule struct { type testModule struct {
s StorageOptions s StorageOptions
r RouteTable r RouteTable
@ -28,26 +33,39 @@ func (t *testModule) Routes() RouteTable { return t.r }
func (t *testModule) Storage() StorageOptions { return t.s } func (t *testModule) Storage() StorageOptions { return t.s }
func TestRegister(t *testing.T) { func TestRegister(t *testing.T) {
modules = make(map[string]Modularizer) // RegisterModule called by init()
RegisterModule("testmodule", testMod) if _, ok := registeredModules[testModName]; !ok {
if _, ok := modules["testmodule"]; !ok {
t.Error("Expected module to be saved.") t.Error("Expected module to be saved.")
} }
if !IsLoaded("testmodule") {
t.Error("Expected module to be loaded.")
}
} }
func TestLoadedModules(t *testing.T) { func TestLoadedModules(t *testing.T) {
modules = make(map[string]Modularizer) // RegisterModule called by init()
RegisterModule("testmodule", testMod) registered := RegisteredModules()
if len(registered) != 2 { // There is another test module loaded from router
loadedMods := LoadedModules()
if len(loadedMods) != 1 {
t.Error("Expected only a single module to be loaded.") t.Error("Expected only a single module to be loaded.")
} else if loadedMods[0] != "testmodule" { } else {
t.Error("Expected testmodule to be loaded.") found := false
for _, name := range registered {
if name == testModName {
found = true
break
}
}
if !found {
t.Error("It should have found the module:", registered)
}
}
}
func TestIsLoaded(t *testing.T) {
ab := New()
ab.LogWriter = ioutil.Discard
if err := ab.Init(testModName); err != nil {
t.Error(err)
}
if loaded := ab.LoadedModules(); len(loaded) == 0 || loaded[0] != testModName {
t.Error("Loaded modules wrong:", loaded)
} }
} }

View File

@ -15,16 +15,19 @@ type RouteTable map[string]HandlerFunc
// NewRouter returns a router to be mounted at some mountpoint. // NewRouter returns a router to be mounted at some mountpoint.
func (a *Authboss) NewRouter() http.Handler { func (a *Authboss) NewRouter() http.Handler {
mux := http.NewServeMux() if a.mux != nil {
return a.mux
}
a.mux = http.NewServeMux()
for name, mod := range modules { for name, mod := range a.loadedModules {
for route, handler := range mod.Routes() { for route, handler := range mod.Routes() {
fmt.Fprintf(a.LogWriter, "%-10s Route: %s\n", "["+name+"]", path.Join(a.MountPath, route)) fmt.Fprintf(a.LogWriter, "%-10s Route: %s\n", "["+name+"]", path.Join(a.MountPath, route))
mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler}) a.mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler})
} }
} }
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { a.mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if a.NotFoundHandler != nil { if a.NotFoundHandler != nil {
a.NotFoundHandler.ServeHTTP(w, r) a.NotFoundHandler.ServeHTTP(w, r)
} else { } else {
@ -33,7 +36,7 @@ func (a *Authboss) NewRouter() http.Handler {
} }
}) })
return mux return a.mux
} }
type contextRoute struct { type contextRoute struct {

View File

@ -9,6 +9,12 @@ import (
"testing" "testing"
) )
const testRouterModName = "testrouter"
func init() {
RegisterModule(testRouterModName, testRouterModule{})
}
type testRouterModule struct { type testRouterModule struct {
routes RouteTable routes RouteTable
} }
@ -19,21 +25,23 @@ func (t testRouterModule) Storage() StorageOptions { return nil }
func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) { func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
ab := New() ab := New()
logger := &bytes.Buffer{}
ab.LogWriter = logger
ab.Init(testRouterModName)
ab.MountPath = "/prefix" ab.MountPath = "/prefix"
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} } ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
ab.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} } ab.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
logger := &bytes.Buffer{}
ab.LogWriter = logger logger.Reset() // Clear out the module load messages
return ab, ab.NewRouter(), logger return ab, ab.NewRouter(), logger
} }
// testRouterCallbackSetup is NOT safe for use by multiple goroutines, don't use parallel // testRouterCallbackSetup is NOT safe for use by multiple goroutines, don't use parallel
func testRouterCallbackSetup(path string, h HandlerFunc) (w *httptest.ResponseRecorder, r *http.Request) { func testRouterCallbackSetup(path string, h HandlerFunc) (w *httptest.ResponseRecorder, r *http.Request) {
modules = map[string]Modularizer{} registeredModules[testRouterModName] = testRouterModule{
RegisterModule("testrouter", testRouterModule{
routes: map[string]HandlerFunc{path: h}, routes: map[string]HandlerFunc{path: h},
}) }
w = httptest.NewRecorder() w = httptest.NewRecorder()
r, _ = http.NewRequest("GET", "http://localhost/prefix"+path, nil) r, _ = http.NewRequest("GET", "http://localhost/prefix"+path, nil)