1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-02-07 13:41:55 +02:00

Rewrite module loading to be per-instance

This commit is contained in:
Aaron 2015-03-31 15:08:43 -07:00
parent d6c0eb8684
commit 9ff0b65629
7 changed files with 154 additions and 67 deletions

View File

@ -21,27 +21,44 @@ import (
type Authboss struct {
Config
Callbacks *Callbacks
loadedModules map[string]Modularizer
moduleAttributes AttributeMeta
mux *http.ServeMux
}
// New makes a new instance of authboss with a default
// configuration.
func New() *Authboss {
ab := &Authboss{
Callbacks: NewCallbacks(),
Callbacks: NewCallbacks(),
loadedModules: make(map[string]Modularizer),
moduleAttributes: make(AttributeMeta),
}
ab.Defaults()
ab.Config.Defaults()
return ab
}
// Init authboss and it's loaded modules.
func (a *Authboss) Init() error {
for name, mod := range modules {
fmt.Fprintf(a.LogWriter, "%-10s Initializing\n", "["+name+"]")
if err := mod.Initialize(a); err != nil {
// Init authboss and the requested modules. modulesToLoad is left empty
// all registered modules will be loaded.
func (a *Authboss) Init(modulesToLoad ...string) error {
if len(modulesToLoad) == 0 {
modulesToLoad = RegisteredModules()
}
for _, name := range modulesToLoad {
fmt.Fprintf(a.LogWriter, "%-10s Loading\n", "["+name+"]")
if err := a.loadModule(name); err != nil {
return fmt.Errorf("[%s] Error Initializing: %v", name, err)
}
}
for _, mod := range a.loadedModules {
for k, v := range mod.Storage() {
a.moduleAttributes[k] = v
}
}
return nil
}

View File

@ -279,7 +279,7 @@ func (m *MockClientStorer) Put(key, val string) { m.Values[key] = val }
func (m *MockClientStorer) Del(key string) { delete(m.Values, key) }
// MockRequestContext returns a new context as if it came from POST request.
func MockRequestContext(postKeyValues ...string) *authboss.Context {
func MockRequestContext(ab authboss.Authboss, postKeyValues ...string) *authboss.Context {
keyValues := &bytes.Buffer{}
for i := 0; i < len(postKeyValues); i += 2 {
if i != 0 {
@ -294,7 +294,7 @@ func MockRequestContext(postKeyValues ...string) *authboss.Context {
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := authboss.ContextFromRequest(req)
ctx, err := ab.ContextFromRequest(req)
if err != nil {
panic(err)
}

View File

@ -21,30 +21,30 @@ import (
var (
// ErrTemplateNotFound should be returned from Get when the view is not found
ErrTemplateNotFound = errors.New("Template not found")
funcMap = template.FuncMap{
"title": strings.Title,
"mountpathed": func(location string) string {
if authboss.a.MountPath == "/" {
return location
}
return path.Join(authboss.a.MountPath, location)
},
}
)
// Templates is a map depicting the forms a template needs wrapped within the specified layout
type Templates map[string]*template.Template
// LoadTemplates parses all specified files located in path. Each template is wrapped
// LoadTemplates parses all specified files located in fpath. Each template is wrapped
// in a unique clone of layout. All templates are expecting {{authboss}} handlebars
// for parsing. It will check the override directory specified in the config, replacing any
// templates as necessary.
func LoadTemplates(layout *template.Template, path string, files ...string) (Templates, error) {
func LoadTemplates(ab *authboss.Authboss, layout *template.Template, fpath string, files ...string) (Templates, error) {
m := make(Templates)
funcMap := template.FuncMap{
"title": strings.Title,
"mountpathed": func(location string) string {
if ab.MountPath == "/" {
return location
}
return path.Join(ab.MountPath, location)
},
}
for _, file := range files {
b, err := ioutil.ReadFile(filepath.Join(path, file))
b, err := ioutil.ReadFile(filepath.Join(fpath, file))
if exists := !os.IsNotExist(err); err != nil && exists {
return nil, err
} else if !exists {
@ -77,10 +77,13 @@ func (t Templates) Render(ctx *authboss.Context, w http.ResponseWriter, r *http.
return authboss.RenderErr{tpl.Name(), data, ErrTemplateNotFound}
}
data.MergeKV("xsrfName", template.HTML(authboss.a.XSRFName), "xsrfToken", template.HTML(authboss.a.XSRFMaker(w, r)))
data.MergeKV(
"xsrfName", template.HTML(ctx.XSRFName),
"xsrfToken", template.HTML(ctx.XSRFMaker(w, r)),
)
if authboss.a.LayoutDataMaker != nil {
data.Merge(authboss.a.LayoutDataMaker(w, r))
if ctx.LayoutDataMaker != nil {
data.Merge(ctx.LayoutDataMaker(w, r))
}
if flash, ok := ctx.SessionStorer.Get(authboss.FlashSuccessKey); ok {
@ -107,7 +110,7 @@ func (t Templates) Render(ctx *authboss.Context, w http.ResponseWriter, r *http.
}
// RenderEmail renders the html and plaintext views for an email and sends it
func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls Templates, namePlain string, data interface{}) error {
func Email(mailer authboss.Mailer, email authboss.Email, htmlTpls Templates, nameHTML string, textTpls Templates, namePlain string, data interface{}) error {
tplHTML, ok := htmlTpls[nameHTML]
if !ok {
return authboss.RenderErr{tplHTML.Name(), data, ErrTemplateNotFound}
@ -130,7 +133,7 @@ func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls T
}
email.TextBody = plainBuffer.String()
if err := authboss.a.Mailer.Send(email); err != nil {
if err := mailer.Send(email); err != nil {
return err
}
@ -138,8 +141,11 @@ func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls T
}
// Redirect sets any flash messages given and redirects the user.
func Redirect(ctx *authboss.Context, w http.ResponseWriter, r *http.Request, path, flashSuccess, flashError string, overrideableRedir bool) {
if redir := r.FormValue("redir"); redir != "" && overrideableRedir {
// If flashSuccess or flashError are set they will be set in the session.
// If followRedir is set to true, it will attempt to grab the redirect path from the
// query string.
func Redirect(ctx *authboss.Context, w http.ResponseWriter, r *http.Request, path, flashSuccess, flashError string, followRedir bool) {
if redir := r.FormValue(authboss.FormValueRedirect); redir != "" && followRedir {
path = redir
}

View File

@ -1,11 +1,8 @@
package authboss
var modules = make(map[string]Modularizer)
import "reflect"
// ModuleAttributes is the list of attributes required by all the loaded modules.
// Authboss implementers can use this at runtime to determine what data is necessary
// to store.
var ModuleAttributes = make(AttributeMeta)
var registeredModules = make(map[string]Modularizer)
// Modularizer should be implemented by all the authboss modules.
type Modularizer interface {
@ -17,18 +14,56 @@ type Modularizer interface {
// RegisterModule with the core providing all the necessary information to
// integrate into authboss.
func RegisterModule(name string, m Modularizer) {
modules[name] = m
registeredModules[name] = m
}
for k, v := range m.Storage() {
ModuleAttributes[k] = v
// RegisteredModules returns a list of modules that are currently registered.
func RegisteredModules() []string {
mods := make([]string, len(registeredModules))
i := 0
for k := range registeredModules {
mods[i] = k
i++
}
return mods
}
// loadModule loads a particular module. It uses reflection to create a new
// instance of the module type. The original value is copied, but not deep copied
// so care should be taken to make sure most initialization happens inside the Initialize()
// method of the module.
func (a *Authboss) loadModule(name string) error {
module, ok := registeredModules[name]
if !ok {
panic("Could not find module: " + name)
}
var wasPtr bool
modVal := reflect.ValueOf(module)
if modVal.Kind() == reflect.Ptr {
wasPtr = true
modVal = modVal.Elem()
}
modType := modVal.Type()
value := reflect.New(modType)
if !wasPtr {
value = value.Elem()
value.Set(modVal)
} else {
value.Elem().Set(modVal)
}
mod, ok := value.Interface().(Modularizer)
a.loadedModules[name] = mod
return mod.Initialize(a)
}
// LoadedModules returns a list of modules that are currently loaded.
func LoadedModules() []string {
mods := make([]string, len(modules))
func (a *Authboss) LoadedModules() []string {
mods := make([]string, len(a.loadedModules))
i := 0
for k := range modules {
for k := range a.loadedModules {
mods[i] = k
i++
}
@ -37,7 +72,7 @@ func LoadedModules() []string {
}
// IsLoaded checks if a specific module is loaded.
func IsLoaded(mod string) bool {
_, ok := modules[mod]
func (a *Authboss) IsLoaded(mod string) bool {
_, ok := a.loadedModules[mod]
return ok
}

View File

@ -1,12 +1,17 @@
package authboss
import (
"io/ioutil"
"net/http"
"testing"
)
const testModName = "testmodule"
func init() {
RegisterModule(testModName, testMod)
}
type testModule struct {
s StorageOptions
r RouteTable
@ -28,26 +33,39 @@ func (t *testModule) Routes() RouteTable { return t.r }
func (t *testModule) Storage() StorageOptions { return t.s }
func TestRegister(t *testing.T) {
modules = make(map[string]Modularizer)
RegisterModule("testmodule", testMod)
if _, ok := modules["testmodule"]; !ok {
// RegisterModule called by init()
if _, ok := registeredModules[testModName]; !ok {
t.Error("Expected module to be saved.")
}
if !IsLoaded("testmodule") {
t.Error("Expected module to be loaded.")
}
}
func TestLoadedModules(t *testing.T) {
modules = make(map[string]Modularizer)
RegisterModule("testmodule", testMod)
loadedMods := LoadedModules()
if len(loadedMods) != 1 {
// RegisterModule called by init()
registered := RegisteredModules()
if len(registered) != 2 { // There is another test module loaded from router
t.Error("Expected only a single module to be loaded.")
} else if loadedMods[0] != "testmodule" {
t.Error("Expected testmodule to be loaded.")
} else {
found := false
for _, name := range registered {
if name == testModName {
found = true
break
}
}
if !found {
t.Error("It should have found the module:", registered)
}
}
}
func TestIsLoaded(t *testing.T) {
ab := New()
ab.LogWriter = ioutil.Discard
if err := ab.Init(testModName); err != nil {
t.Error(err)
}
if loaded := ab.LoadedModules(); len(loaded) == 0 || loaded[0] != testModName {
t.Error("Loaded modules wrong:", loaded)
}
}

View File

@ -15,16 +15,19 @@ type RouteTable map[string]HandlerFunc
// NewRouter returns a router to be mounted at some mountpoint.
func (a *Authboss) NewRouter() http.Handler {
mux := http.NewServeMux()
if a.mux != nil {
return a.mux
}
a.mux = http.NewServeMux()
for name, mod := range modules {
for name, mod := range a.loadedModules {
for route, handler := range mod.Routes() {
fmt.Fprintf(a.LogWriter, "%-10s Route: %s\n", "["+name+"]", path.Join(a.MountPath, route))
mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler})
a.mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler})
}
}
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
a.mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if a.NotFoundHandler != nil {
a.NotFoundHandler.ServeHTTP(w, r)
} else {
@ -33,7 +36,7 @@ func (a *Authboss) NewRouter() http.Handler {
}
})
return mux
return a.mux
}
type contextRoute struct {

View File

@ -9,6 +9,12 @@ import (
"testing"
)
const testRouterModName = "testrouter"
func init() {
RegisterModule(testRouterModName, testRouterModule{})
}
type testRouterModule struct {
routes RouteTable
}
@ -19,21 +25,23 @@ func (t testRouterModule) Storage() StorageOptions { return nil }
func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
ab := New()
logger := &bytes.Buffer{}
ab.LogWriter = logger
ab.Init(testRouterModName)
ab.MountPath = "/prefix"
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
ab.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
logger := &bytes.Buffer{}
ab.LogWriter = logger
logger.Reset() // Clear out the module load messages
return ab, ab.NewRouter(), logger
}
// testRouterCallbackSetup is NOT safe for use by multiple goroutines, don't use parallel
func testRouterCallbackSetup(path string, h HandlerFunc) (w *httptest.ResponseRecorder, r *http.Request) {
modules = map[string]Modularizer{}
RegisterModule("testrouter", testRouterModule{
registeredModules[testRouterModName] = testRouterModule{
routes: map[string]HandlerFunc{path: h},
})
}
w = httptest.NewRecorder()
r, _ = http.NewRequest("GET", "http://localhost/prefix"+path, nil)