mirror of
https://github.com/volatiletech/authboss.git
synced 2025-02-07 13:41:55 +02:00
Rewrite module loading to be per-instance
This commit is contained in:
parent
d6c0eb8684
commit
9ff0b65629
31
authboss.go
31
authboss.go
@ -21,27 +21,44 @@ import (
|
||||
type Authboss struct {
|
||||
Config
|
||||
Callbacks *Callbacks
|
||||
|
||||
loadedModules map[string]Modularizer
|
||||
moduleAttributes AttributeMeta
|
||||
mux *http.ServeMux
|
||||
}
|
||||
|
||||
// New makes a new instance of authboss with a default
|
||||
// configuration.
|
||||
func New() *Authboss {
|
||||
ab := &Authboss{
|
||||
Callbacks: NewCallbacks(),
|
||||
Callbacks: NewCallbacks(),
|
||||
loadedModules: make(map[string]Modularizer),
|
||||
moduleAttributes: make(AttributeMeta),
|
||||
}
|
||||
ab.Defaults()
|
||||
ab.Config.Defaults()
|
||||
return ab
|
||||
}
|
||||
|
||||
// Init authboss and it's loaded modules.
|
||||
func (a *Authboss) Init() error {
|
||||
for name, mod := range modules {
|
||||
fmt.Fprintf(a.LogWriter, "%-10s Initializing\n", "["+name+"]")
|
||||
if err := mod.Initialize(a); err != nil {
|
||||
// Init authboss and the requested modules. modulesToLoad is left empty
|
||||
// all registered modules will be loaded.
|
||||
func (a *Authboss) Init(modulesToLoad ...string) error {
|
||||
if len(modulesToLoad) == 0 {
|
||||
modulesToLoad = RegisteredModules()
|
||||
}
|
||||
|
||||
for _, name := range modulesToLoad {
|
||||
fmt.Fprintf(a.LogWriter, "%-10s Loading\n", "["+name+"]")
|
||||
if err := a.loadModule(name); err != nil {
|
||||
return fmt.Errorf("[%s] Error Initializing: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, mod := range a.loadedModules {
|
||||
for k, v := range mod.Storage() {
|
||||
a.moduleAttributes[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -279,7 +279,7 @@ func (m *MockClientStorer) Put(key, val string) { m.Values[key] = val }
|
||||
func (m *MockClientStorer) Del(key string) { delete(m.Values, key) }
|
||||
|
||||
// MockRequestContext returns a new context as if it came from POST request.
|
||||
func MockRequestContext(postKeyValues ...string) *authboss.Context {
|
||||
func MockRequestContext(ab authboss.Authboss, postKeyValues ...string) *authboss.Context {
|
||||
keyValues := &bytes.Buffer{}
|
||||
for i := 0; i < len(postKeyValues); i += 2 {
|
||||
if i != 0 {
|
||||
@ -294,7 +294,7 @@ func MockRequestContext(postKeyValues ...string) *authboss.Context {
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
ctx, err := authboss.ContextFromRequest(req)
|
||||
ctx, err := ab.ContextFromRequest(req)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -21,30 +21,30 @@ import (
|
||||
var (
|
||||
// ErrTemplateNotFound should be returned from Get when the view is not found
|
||||
ErrTemplateNotFound = errors.New("Template not found")
|
||||
|
||||
funcMap = template.FuncMap{
|
||||
"title": strings.Title,
|
||||
"mountpathed": func(location string) string {
|
||||
if authboss.a.MountPath == "/" {
|
||||
return location
|
||||
}
|
||||
return path.Join(authboss.a.MountPath, location)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// Templates is a map depicting the forms a template needs wrapped within the specified layout
|
||||
type Templates map[string]*template.Template
|
||||
|
||||
// LoadTemplates parses all specified files located in path. Each template is wrapped
|
||||
// LoadTemplates parses all specified files located in fpath. Each template is wrapped
|
||||
// in a unique clone of layout. All templates are expecting {{authboss}} handlebars
|
||||
// for parsing. It will check the override directory specified in the config, replacing any
|
||||
// templates as necessary.
|
||||
func LoadTemplates(layout *template.Template, path string, files ...string) (Templates, error) {
|
||||
func LoadTemplates(ab *authboss.Authboss, layout *template.Template, fpath string, files ...string) (Templates, error) {
|
||||
m := make(Templates)
|
||||
|
||||
funcMap := template.FuncMap{
|
||||
"title": strings.Title,
|
||||
"mountpathed": func(location string) string {
|
||||
if ab.MountPath == "/" {
|
||||
return location
|
||||
}
|
||||
return path.Join(ab.MountPath, location)
|
||||
},
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
b, err := ioutil.ReadFile(filepath.Join(path, file))
|
||||
b, err := ioutil.ReadFile(filepath.Join(fpath, file))
|
||||
if exists := !os.IsNotExist(err); err != nil && exists {
|
||||
return nil, err
|
||||
} else if !exists {
|
||||
@ -77,10 +77,13 @@ func (t Templates) Render(ctx *authboss.Context, w http.ResponseWriter, r *http.
|
||||
return authboss.RenderErr{tpl.Name(), data, ErrTemplateNotFound}
|
||||
}
|
||||
|
||||
data.MergeKV("xsrfName", template.HTML(authboss.a.XSRFName), "xsrfToken", template.HTML(authboss.a.XSRFMaker(w, r)))
|
||||
data.MergeKV(
|
||||
"xsrfName", template.HTML(ctx.XSRFName),
|
||||
"xsrfToken", template.HTML(ctx.XSRFMaker(w, r)),
|
||||
)
|
||||
|
||||
if authboss.a.LayoutDataMaker != nil {
|
||||
data.Merge(authboss.a.LayoutDataMaker(w, r))
|
||||
if ctx.LayoutDataMaker != nil {
|
||||
data.Merge(ctx.LayoutDataMaker(w, r))
|
||||
}
|
||||
|
||||
if flash, ok := ctx.SessionStorer.Get(authboss.FlashSuccessKey); ok {
|
||||
@ -107,7 +110,7 @@ func (t Templates) Render(ctx *authboss.Context, w http.ResponseWriter, r *http.
|
||||
}
|
||||
|
||||
// RenderEmail renders the html and plaintext views for an email and sends it
|
||||
func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls Templates, namePlain string, data interface{}) error {
|
||||
func Email(mailer authboss.Mailer, email authboss.Email, htmlTpls Templates, nameHTML string, textTpls Templates, namePlain string, data interface{}) error {
|
||||
tplHTML, ok := htmlTpls[nameHTML]
|
||||
if !ok {
|
||||
return authboss.RenderErr{tplHTML.Name(), data, ErrTemplateNotFound}
|
||||
@ -130,7 +133,7 @@ func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls T
|
||||
}
|
||||
email.TextBody = plainBuffer.String()
|
||||
|
||||
if err := authboss.a.Mailer.Send(email); err != nil {
|
||||
if err := mailer.Send(email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -138,8 +141,11 @@ func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls T
|
||||
}
|
||||
|
||||
// Redirect sets any flash messages given and redirects the user.
|
||||
func Redirect(ctx *authboss.Context, w http.ResponseWriter, r *http.Request, path, flashSuccess, flashError string, overrideableRedir bool) {
|
||||
if redir := r.FormValue("redir"); redir != "" && overrideableRedir {
|
||||
// If flashSuccess or flashError are set they will be set in the session.
|
||||
// If followRedir is set to true, it will attempt to grab the redirect path from the
|
||||
// query string.
|
||||
func Redirect(ctx *authboss.Context, w http.ResponseWriter, r *http.Request, path, flashSuccess, flashError string, followRedir bool) {
|
||||
if redir := r.FormValue(authboss.FormValueRedirect); redir != "" && followRedir {
|
||||
path = redir
|
||||
}
|
||||
|
||||
|
61
module.go
61
module.go
@ -1,11 +1,8 @@
|
||||
package authboss
|
||||
|
||||
var modules = make(map[string]Modularizer)
|
||||
import "reflect"
|
||||
|
||||
// ModuleAttributes is the list of attributes required by all the loaded modules.
|
||||
// Authboss implementers can use this at runtime to determine what data is necessary
|
||||
// to store.
|
||||
var ModuleAttributes = make(AttributeMeta)
|
||||
var registeredModules = make(map[string]Modularizer)
|
||||
|
||||
// Modularizer should be implemented by all the authboss modules.
|
||||
type Modularizer interface {
|
||||
@ -17,18 +14,56 @@ type Modularizer interface {
|
||||
// RegisterModule with the core providing all the necessary information to
|
||||
// integrate into authboss.
|
||||
func RegisterModule(name string, m Modularizer) {
|
||||
modules[name] = m
|
||||
registeredModules[name] = m
|
||||
}
|
||||
|
||||
for k, v := range m.Storage() {
|
||||
ModuleAttributes[k] = v
|
||||
// RegisteredModules returns a list of modules that are currently registered.
|
||||
func RegisteredModules() []string {
|
||||
mods := make([]string, len(registeredModules))
|
||||
i := 0
|
||||
for k := range registeredModules {
|
||||
mods[i] = k
|
||||
i++
|
||||
}
|
||||
|
||||
return mods
|
||||
}
|
||||
|
||||
// loadModule loads a particular module. It uses reflection to create a new
|
||||
// instance of the module type. The original value is copied, but not deep copied
|
||||
// so care should be taken to make sure most initialization happens inside the Initialize()
|
||||
// method of the module.
|
||||
func (a *Authboss) loadModule(name string) error {
|
||||
module, ok := registeredModules[name]
|
||||
if !ok {
|
||||
panic("Could not find module: " + name)
|
||||
}
|
||||
|
||||
var wasPtr bool
|
||||
modVal := reflect.ValueOf(module)
|
||||
if modVal.Kind() == reflect.Ptr {
|
||||
wasPtr = true
|
||||
modVal = modVal.Elem()
|
||||
}
|
||||
|
||||
modType := modVal.Type()
|
||||
value := reflect.New(modType)
|
||||
if !wasPtr {
|
||||
value = value.Elem()
|
||||
value.Set(modVal)
|
||||
} else {
|
||||
value.Elem().Set(modVal)
|
||||
}
|
||||
mod, ok := value.Interface().(Modularizer)
|
||||
a.loadedModules[name] = mod
|
||||
return mod.Initialize(a)
|
||||
}
|
||||
|
||||
// LoadedModules returns a list of modules that are currently loaded.
|
||||
func LoadedModules() []string {
|
||||
mods := make([]string, len(modules))
|
||||
func (a *Authboss) LoadedModules() []string {
|
||||
mods := make([]string, len(a.loadedModules))
|
||||
i := 0
|
||||
for k := range modules {
|
||||
for k := range a.loadedModules {
|
||||
mods[i] = k
|
||||
i++
|
||||
}
|
||||
@ -37,7 +72,7 @@ func LoadedModules() []string {
|
||||
}
|
||||
|
||||
// IsLoaded checks if a specific module is loaded.
|
||||
func IsLoaded(mod string) bool {
|
||||
_, ok := modules[mod]
|
||||
func (a *Authboss) IsLoaded(mod string) bool {
|
||||
_, ok := a.loadedModules[mod]
|
||||
return ok
|
||||
}
|
||||
|
@ -1,12 +1,17 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const testModName = "testmodule"
|
||||
|
||||
func init() {
|
||||
RegisterModule(testModName, testMod)
|
||||
}
|
||||
|
||||
type testModule struct {
|
||||
s StorageOptions
|
||||
r RouteTable
|
||||
@ -28,26 +33,39 @@ func (t *testModule) Routes() RouteTable { return t.r }
|
||||
func (t *testModule) Storage() StorageOptions { return t.s }
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
modules = make(map[string]Modularizer)
|
||||
RegisterModule("testmodule", testMod)
|
||||
|
||||
if _, ok := modules["testmodule"]; !ok {
|
||||
// RegisterModule called by init()
|
||||
if _, ok := registeredModules[testModName]; !ok {
|
||||
t.Error("Expected module to be saved.")
|
||||
}
|
||||
|
||||
if !IsLoaded("testmodule") {
|
||||
t.Error("Expected module to be loaded.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadedModules(t *testing.T) {
|
||||
modules = make(map[string]Modularizer)
|
||||
RegisterModule("testmodule", testMod)
|
||||
|
||||
loadedMods := LoadedModules()
|
||||
if len(loadedMods) != 1 {
|
||||
// RegisterModule called by init()
|
||||
registered := RegisteredModules()
|
||||
if len(registered) != 2 { // There is another test module loaded from router
|
||||
t.Error("Expected only a single module to be loaded.")
|
||||
} else if loadedMods[0] != "testmodule" {
|
||||
t.Error("Expected testmodule to be loaded.")
|
||||
} else {
|
||||
found := false
|
||||
for _, name := range registered {
|
||||
if name == testModName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("It should have found the module:", registered)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLoaded(t *testing.T) {
|
||||
ab := New()
|
||||
ab.LogWriter = ioutil.Discard
|
||||
if err := ab.Init(testModName); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if loaded := ab.LoadedModules(); len(loaded) == 0 || loaded[0] != testModName {
|
||||
t.Error("Loaded modules wrong:", loaded)
|
||||
}
|
||||
}
|
||||
|
13
router.go
13
router.go
@ -15,16 +15,19 @@ type RouteTable map[string]HandlerFunc
|
||||
|
||||
// NewRouter returns a router to be mounted at some mountpoint.
|
||||
func (a *Authboss) NewRouter() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
if a.mux != nil {
|
||||
return a.mux
|
||||
}
|
||||
a.mux = http.NewServeMux()
|
||||
|
||||
for name, mod := range modules {
|
||||
for name, mod := range a.loadedModules {
|
||||
for route, handler := range mod.Routes() {
|
||||
fmt.Fprintf(a.LogWriter, "%-10s Route: %s\n", "["+name+"]", path.Join(a.MountPath, route))
|
||||
mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler})
|
||||
a.mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler})
|
||||
}
|
||||
}
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
a.mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if a.NotFoundHandler != nil {
|
||||
a.NotFoundHandler.ServeHTTP(w, r)
|
||||
} else {
|
||||
@ -33,7 +36,7 @@ func (a *Authboss) NewRouter() http.Handler {
|
||||
}
|
||||
})
|
||||
|
||||
return mux
|
||||
return a.mux
|
||||
}
|
||||
|
||||
type contextRoute struct {
|
||||
|
@ -9,6 +9,12 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
const testRouterModName = "testrouter"
|
||||
|
||||
func init() {
|
||||
RegisterModule(testRouterModName, testRouterModule{})
|
||||
}
|
||||
|
||||
type testRouterModule struct {
|
||||
routes RouteTable
|
||||
}
|
||||
@ -19,21 +25,23 @@ func (t testRouterModule) Storage() StorageOptions { return nil }
|
||||
|
||||
func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
|
||||
ab := New()
|
||||
logger := &bytes.Buffer{}
|
||||
ab.LogWriter = logger
|
||||
ab.Init(testRouterModName)
|
||||
ab.MountPath = "/prefix"
|
||||
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
|
||||
ab.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
|
||||
logger := &bytes.Buffer{}
|
||||
ab.LogWriter = logger
|
||||
|
||||
logger.Reset() // Clear out the module load messages
|
||||
|
||||
return ab, ab.NewRouter(), logger
|
||||
}
|
||||
|
||||
// testRouterCallbackSetup is NOT safe for use by multiple goroutines, don't use parallel
|
||||
func testRouterCallbackSetup(path string, h HandlerFunc) (w *httptest.ResponseRecorder, r *http.Request) {
|
||||
modules = map[string]Modularizer{}
|
||||
RegisterModule("testrouter", testRouterModule{
|
||||
registeredModules[testRouterModName] = testRouterModule{
|
||||
routes: map[string]HandlerFunc{path: h},
|
||||
})
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
r, _ = http.NewRequest("GET", "http://localhost/prefix"+path, nil)
|
||||
|
Loading…
x
Reference in New Issue
Block a user