diff --git a/authboss.go b/authboss.go index ce77038..bf3a4c7 100644 --- a/authboss.go +++ b/authboss.go @@ -21,27 +21,44 @@ import ( type Authboss struct { Config Callbacks *Callbacks + + loadedModules map[string]Modularizer + moduleAttributes AttributeMeta + mux *http.ServeMux } // New makes a new instance of authboss with a default // configuration. func New() *Authboss { ab := &Authboss{ - Callbacks: NewCallbacks(), + Callbacks: NewCallbacks(), + loadedModules: make(map[string]Modularizer), + moduleAttributes: make(AttributeMeta), } - ab.Defaults() + ab.Config.Defaults() return ab } -// Init authboss and it's loaded modules. -func (a *Authboss) Init() error { - for name, mod := range modules { - fmt.Fprintf(a.LogWriter, "%-10s Initializing\n", "["+name+"]") - if err := mod.Initialize(a); err != nil { +// Init authboss and the requested modules. modulesToLoad is left empty +// all registered modules will be loaded. +func (a *Authboss) Init(modulesToLoad ...string) error { + if len(modulesToLoad) == 0 { + modulesToLoad = RegisteredModules() + } + + for _, name := range modulesToLoad { + fmt.Fprintf(a.LogWriter, "%-10s Loading\n", "["+name+"]") + if err := a.loadModule(name); err != nil { return fmt.Errorf("[%s] Error Initializing: %v", name, err) } } + for _, mod := range a.loadedModules { + for k, v := range mod.Storage() { + a.moduleAttributes[k] = v + } + } + return nil } diff --git a/internal/mocks/mocks.go b/internal/mocks/mocks.go index bfe172a..bae91d6 100644 --- a/internal/mocks/mocks.go +++ b/internal/mocks/mocks.go @@ -279,7 +279,7 @@ func (m *MockClientStorer) Put(key, val string) { m.Values[key] = val } func (m *MockClientStorer) Del(key string) { delete(m.Values, key) } // MockRequestContext returns a new context as if it came from POST request. -func MockRequestContext(postKeyValues ...string) *authboss.Context { +func MockRequestContext(ab authboss.Authboss, postKeyValues ...string) *authboss.Context { keyValues := &bytes.Buffer{} for i := 0; i < len(postKeyValues); i += 2 { if i != 0 { @@ -294,7 +294,7 @@ func MockRequestContext(postKeyValues ...string) *authboss.Context { } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - ctx, err := authboss.ContextFromRequest(req) + ctx, err := ab.ContextFromRequest(req) if err != nil { panic(err) } diff --git a/internal/response/response.go b/internal/response/response.go index d007c90..ce43f0e 100644 --- a/internal/response/response.go +++ b/internal/response/response.go @@ -21,30 +21,30 @@ import ( var ( // ErrTemplateNotFound should be returned from Get when the view is not found ErrTemplateNotFound = errors.New("Template not found") - - funcMap = template.FuncMap{ - "title": strings.Title, - "mountpathed": func(location string) string { - if authboss.a.MountPath == "/" { - return location - } - return path.Join(authboss.a.MountPath, location) - }, - } ) // Templates is a map depicting the forms a template needs wrapped within the specified layout type Templates map[string]*template.Template -// LoadTemplates parses all specified files located in path. Each template is wrapped +// LoadTemplates parses all specified files located in fpath. Each template is wrapped // in a unique clone of layout. All templates are expecting {{authboss}} handlebars // for parsing. It will check the override directory specified in the config, replacing any // templates as necessary. -func LoadTemplates(layout *template.Template, path string, files ...string) (Templates, error) { +func LoadTemplates(ab *authboss.Authboss, layout *template.Template, fpath string, files ...string) (Templates, error) { m := make(Templates) + funcMap := template.FuncMap{ + "title": strings.Title, + "mountpathed": func(location string) string { + if ab.MountPath == "/" { + return location + } + return path.Join(ab.MountPath, location) + }, + } + for _, file := range files { - b, err := ioutil.ReadFile(filepath.Join(path, file)) + b, err := ioutil.ReadFile(filepath.Join(fpath, file)) if exists := !os.IsNotExist(err); err != nil && exists { return nil, err } else if !exists { @@ -77,10 +77,13 @@ func (t Templates) Render(ctx *authboss.Context, w http.ResponseWriter, r *http. return authboss.RenderErr{tpl.Name(), data, ErrTemplateNotFound} } - data.MergeKV("xsrfName", template.HTML(authboss.a.XSRFName), "xsrfToken", template.HTML(authboss.a.XSRFMaker(w, r))) + data.MergeKV( + "xsrfName", template.HTML(ctx.XSRFName), + "xsrfToken", template.HTML(ctx.XSRFMaker(w, r)), + ) - if authboss.a.LayoutDataMaker != nil { - data.Merge(authboss.a.LayoutDataMaker(w, r)) + if ctx.LayoutDataMaker != nil { + data.Merge(ctx.LayoutDataMaker(w, r)) } if flash, ok := ctx.SessionStorer.Get(authboss.FlashSuccessKey); ok { @@ -107,7 +110,7 @@ func (t Templates) Render(ctx *authboss.Context, w http.ResponseWriter, r *http. } // RenderEmail renders the html and plaintext views for an email and sends it -func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls Templates, namePlain string, data interface{}) error { +func Email(mailer authboss.Mailer, email authboss.Email, htmlTpls Templates, nameHTML string, textTpls Templates, namePlain string, data interface{}) error { tplHTML, ok := htmlTpls[nameHTML] if !ok { return authboss.RenderErr{tplHTML.Name(), data, ErrTemplateNotFound} @@ -130,7 +133,7 @@ func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls T } email.TextBody = plainBuffer.String() - if err := authboss.a.Mailer.Send(email); err != nil { + if err := mailer.Send(email); err != nil { return err } @@ -138,8 +141,11 @@ func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls T } // Redirect sets any flash messages given and redirects the user. -func Redirect(ctx *authboss.Context, w http.ResponseWriter, r *http.Request, path, flashSuccess, flashError string, overrideableRedir bool) { - if redir := r.FormValue("redir"); redir != "" && overrideableRedir { +// If flashSuccess or flashError are set they will be set in the session. +// If followRedir is set to true, it will attempt to grab the redirect path from the +// query string. +func Redirect(ctx *authboss.Context, w http.ResponseWriter, r *http.Request, path, flashSuccess, flashError string, followRedir bool) { + if redir := r.FormValue(authboss.FormValueRedirect); redir != "" && followRedir { path = redir } diff --git a/module.go b/module.go index 5e11502..a2fc613 100644 --- a/module.go +++ b/module.go @@ -1,11 +1,8 @@ package authboss -var modules = make(map[string]Modularizer) +import "reflect" -// ModuleAttributes is the list of attributes required by all the loaded modules. -// Authboss implementers can use this at runtime to determine what data is necessary -// to store. -var ModuleAttributes = make(AttributeMeta) +var registeredModules = make(map[string]Modularizer) // Modularizer should be implemented by all the authboss modules. type Modularizer interface { @@ -17,18 +14,56 @@ type Modularizer interface { // RegisterModule with the core providing all the necessary information to // integrate into authboss. func RegisterModule(name string, m Modularizer) { - modules[name] = m + registeredModules[name] = m +} - for k, v := range m.Storage() { - ModuleAttributes[k] = v +// RegisteredModules returns a list of modules that are currently registered. +func RegisteredModules() []string { + mods := make([]string, len(registeredModules)) + i := 0 + for k := range registeredModules { + mods[i] = k + i++ } + + return mods +} + +// loadModule loads a particular module. It uses reflection to create a new +// instance of the module type. The original value is copied, but not deep copied +// so care should be taken to make sure most initialization happens inside the Initialize() +// method of the module. +func (a *Authboss) loadModule(name string) error { + module, ok := registeredModules[name] + if !ok { + panic("Could not find module: " + name) + } + + var wasPtr bool + modVal := reflect.ValueOf(module) + if modVal.Kind() == reflect.Ptr { + wasPtr = true + modVal = modVal.Elem() + } + + modType := modVal.Type() + value := reflect.New(modType) + if !wasPtr { + value = value.Elem() + value.Set(modVal) + } else { + value.Elem().Set(modVal) + } + mod, ok := value.Interface().(Modularizer) + a.loadedModules[name] = mod + return mod.Initialize(a) } // LoadedModules returns a list of modules that are currently loaded. -func LoadedModules() []string { - mods := make([]string, len(modules)) +func (a *Authboss) LoadedModules() []string { + mods := make([]string, len(a.loadedModules)) i := 0 - for k := range modules { + for k := range a.loadedModules { mods[i] = k i++ } @@ -37,7 +72,7 @@ func LoadedModules() []string { } // IsLoaded checks if a specific module is loaded. -func IsLoaded(mod string) bool { - _, ok := modules[mod] +func (a *Authboss) IsLoaded(mod string) bool { + _, ok := a.loadedModules[mod] return ok } diff --git a/module_test.go b/module_test.go index 4e39b99..85a89c5 100644 --- a/module_test.go +++ b/module_test.go @@ -1,12 +1,17 @@ package authboss import ( + "io/ioutil" "net/http" "testing" ) const testModName = "testmodule" +func init() { + RegisterModule(testModName, testMod) +} + type testModule struct { s StorageOptions r RouteTable @@ -28,26 +33,39 @@ func (t *testModule) Routes() RouteTable { return t.r } func (t *testModule) Storage() StorageOptions { return t.s } func TestRegister(t *testing.T) { - modules = make(map[string]Modularizer) - RegisterModule("testmodule", testMod) - - if _, ok := modules["testmodule"]; !ok { + // RegisterModule called by init() + if _, ok := registeredModules[testModName]; !ok { t.Error("Expected module to be saved.") } - - if !IsLoaded("testmodule") { - t.Error("Expected module to be loaded.") - } } func TestLoadedModules(t *testing.T) { - modules = make(map[string]Modularizer) - RegisterModule("testmodule", testMod) - - loadedMods := LoadedModules() - if len(loadedMods) != 1 { + // RegisterModule called by init() + registered := RegisteredModules() + if len(registered) != 2 { // There is another test module loaded from router t.Error("Expected only a single module to be loaded.") - } else if loadedMods[0] != "testmodule" { - t.Error("Expected testmodule to be loaded.") + } else { + found := false + for _, name := range registered { + if name == testModName { + found = true + break + } + } + if !found { + t.Error("It should have found the module:", registered) + } + } +} + +func TestIsLoaded(t *testing.T) { + ab := New() + ab.LogWriter = ioutil.Discard + if err := ab.Init(testModName); err != nil { + t.Error(err) + } + + if loaded := ab.LoadedModules(); len(loaded) == 0 || loaded[0] != testModName { + t.Error("Loaded modules wrong:", loaded) } } diff --git a/router.go b/router.go index ee4eb48..89d46e8 100644 --- a/router.go +++ b/router.go @@ -15,16 +15,19 @@ type RouteTable map[string]HandlerFunc // NewRouter returns a router to be mounted at some mountpoint. func (a *Authboss) NewRouter() http.Handler { - mux := http.NewServeMux() + if a.mux != nil { + return a.mux + } + a.mux = http.NewServeMux() - for name, mod := range modules { + for name, mod := range a.loadedModules { for route, handler := range mod.Routes() { fmt.Fprintf(a.LogWriter, "%-10s Route: %s\n", "["+name+"]", path.Join(a.MountPath, route)) - mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler}) + a.mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler}) } } - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + a.mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if a.NotFoundHandler != nil { a.NotFoundHandler.ServeHTTP(w, r) } else { @@ -33,7 +36,7 @@ func (a *Authboss) NewRouter() http.Handler { } }) - return mux + return a.mux } type contextRoute struct { diff --git a/router_test.go b/router_test.go index 56dae43..ca1a426 100644 --- a/router_test.go +++ b/router_test.go @@ -9,6 +9,12 @@ import ( "testing" ) +const testRouterModName = "testrouter" + +func init() { + RegisterModule(testRouterModName, testRouterModule{}) +} + type testRouterModule struct { routes RouteTable } @@ -19,21 +25,23 @@ func (t testRouterModule) Storage() StorageOptions { return nil } func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) { ab := New() + logger := &bytes.Buffer{} + ab.LogWriter = logger + ab.Init(testRouterModName) ab.MountPath = "/prefix" ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} } ab.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} } - logger := &bytes.Buffer{} - ab.LogWriter = logger + + logger.Reset() // Clear out the module load messages return ab, ab.NewRouter(), logger } // testRouterCallbackSetup is NOT safe for use by multiple goroutines, don't use parallel func testRouterCallbackSetup(path string, h HandlerFunc) (w *httptest.ResponseRecorder, r *http.Request) { - modules = map[string]Modularizer{} - RegisterModule("testrouter", testRouterModule{ + registeredModules[testRouterModName] = testRouterModule{ routes: map[string]HandlerFunc{path: h}, - }) + } w = httptest.NewRecorder() r, _ = http.NewRequest("GET", "http://localhost/prefix"+path, nil)