From 48b33b021734550d8e7f5011ce1d27617a61cdaf Mon Sep 17 00:00:00 2001 From: Aaron L Date: Tue, 8 May 2018 20:39:39 -0700 Subject: [PATCH] Add a few convenience pieces - Add helper to directly merge data into a request (common use case) - Allow parsing of OAuth2PID without panic - Add oauth2.* strings to the modules list in case people want to be able to switch on which oauth2 providers are available in their views. --- html_data.go | 20 ++++++++++++++++++++ html_data_test.go | 32 +++++++++++++++++++++++++++++++- module.go | 9 +++++++++ module_test.go | 14 ++++++++++++-- user.go | 21 +++++++++++++++++---- user_test.go | 32 +++++++++++++++++++++++++++++++- 6 files changed, 120 insertions(+), 8 deletions(-) diff --git a/html_data.go b/html_data.go index 9640410..0bfeb06 100644 --- a/html_data.go +++ b/html_data.go @@ -1,5 +1,10 @@ package authboss +import ( + "context" + "net/http" +) + // Keys for use in HTMLData that are meaningful const ( // DataErr is for one off errors that don't really belong to @@ -66,3 +71,18 @@ func (h HTMLData) MergeKV(data ...interface{}) HTMLData { return h } + +// MergeDataInRequest edits the request pointer to point to a new request with +// a modified context that contains the merged data. +func MergeDataInRequest(r **http.Request, other HTMLData) { + ctx := (*r).Context() + currentIntf := ctx.Value(CTXKeyData) + if currentIntf == nil { + *r = (*r).WithContext(context.WithValue(ctx, CTXKeyData, other)) + return + } + + current := currentIntf.(HTMLData) + merged := current.Merge(other) + *r = (*r).WithContext(context.WithValue(ctx, CTXKeyData, merged)) +} diff --git a/html_data_test.go b/html_data_test.go index f1c19a5..d083544 100644 --- a/html_data_test.go +++ b/html_data_test.go @@ -1,6 +1,10 @@ package authboss -import "testing" +import ( + "context" + "net/http/httptest" + "testing" +) func TestHTMLData(t *testing.T) { t.Parallel() @@ -51,3 +55,29 @@ func TestHTMLData_Panics(t *testing.T) { t.Error("They all should have paniced.") } } + +func TestHTMLDataMergeDataInRequest(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest("GET", "/", nil) + MergeDataInRequest(&r, HTMLData{"hello": "world"}) + + val := r.Context().Value(CTXKeyData).(HTMLData)["hello"].(string) + if val != "world" { + t.Error("expected world, got:", val) + } + + r = httptest.NewRequest("GET", "/", nil) + r = r.WithContext(context.WithValue(context.Background(), CTXKeyData, HTMLData{"first": "here"})) + MergeDataInRequest(&r, HTMLData{"hello": "world"}) + + val = r.Context().Value(CTXKeyData).(HTMLData)["hello"].(string) + if val != "world" { + t.Error("expected world, got:", val) + } + + val = r.Context().Value(CTXKeyData).(HTMLData)["first"].(string) + if val != "here" { + t.Error("expected world, got:", val) + } +} diff --git a/module.go b/module.go index b705211..853ea4d 100644 --- a/module.go +++ b/module.go @@ -91,6 +91,11 @@ func (a *Authboss) loadModule(name string) error { // of wether or not the module is loaded. // Data looks like: // map[modulename] = true +// +// oauth2 providers are also listed here using the syntax: +// oauth2.google for an example. Be careful since this doesn't actually mean +// that the oauth2 module has been loaded so you should do a conditional that checks +// for both. func ModuleListMiddleware(ab *Authboss) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -109,6 +114,10 @@ func ModuleListMiddleware(ab *Authboss) func(http.Handler) http.Handler { loaded[k] = true } + for provider := range ab.Config.Modules.OAuth2Providers { + loaded["oauth2."+provider] = true + } + data[DataModules] = loaded r = r.WithContext(context.WithValue(ctx, CTXKeyData, data)) next.ServeHTTP(w, r) diff --git a/module_test.go b/module_test.go index 428b403..2614a8c 100644 --- a/module_test.go +++ b/module_test.go @@ -78,6 +78,10 @@ func TestModuleLoadedMiddleware(t *testing.T) { ab.loadedModules = map[string]Moduler{ "recover": nil, "auth": nil, + "oauth2": nil, + } + ab.Config.Modules.OAuth2Providers = map[string]OAuth2Provider{ + "google": OAuth2Provider{}, } var mods map[string]bool @@ -88,8 +92,8 @@ func TestModuleLoadedMiddleware(t *testing.T) { server.ServeHTTP(nil, httptest.NewRequest("GET", "/", nil)) - if len(mods) != 2 { - t.Error("want two modules, got:", len(mods)) + if len(mods) != 4 { + t.Error("want 4 modules, got:", len(mods)) } if _, ok := mods["auth"]; !ok { @@ -98,4 +102,10 @@ func TestModuleLoadedMiddleware(t *testing.T) { if _, ok := mods["recover"]; !ok { t.Error("recover should be loaded") } + if _, ok := mods["oauth2"]; !ok { + t.Error("modules should include oauth2.google") + } + if _, ok := mods["oauth2.google"]; !ok { + t.Error("modules should include oauth2.google") + } } diff --git a/user.go b/user.go index 948ddc0..f050594 100644 --- a/user.go +++ b/user.go @@ -4,6 +4,8 @@ import ( "fmt" "strings" "time" + + "github.com/pkg/errors" ) // User has functions for each piece of data it requires. @@ -154,14 +156,25 @@ func MakeOAuth2PID(provider, uid string) string { } // ParseOAuth2PID returns the uid and provider for a given OAuth2 pid -func ParseOAuth2PID(pid string) (provider, uid string) { +func ParseOAuth2PID(pid string) (provider, uid string, err error) { splits := strings.Split(pid, ";;") if len(splits) != 3 { - panic(fmt.Sprintf("failed to parse oauth2 pid, too many segments: %s", pid)) + return "", "", errors.Errorf("failed to parse oauth2 pid, too many segments: %s", pid) } if splits[0] != "oauth2" { - panic(fmt.Sprintf("invalid oauth2 pid, did not start with oauth2: %s", pid)) + return "", "", errors.Errorf("invalid oauth2 pid, did not start with oauth2: %s", pid) } - return splits[1], splits[2] + return splits[1], splits[2], nil +} + +// ParseOAuth2PIDP returns the uid and provider for a given OAuth2 pid +func ParseOAuth2PIDP(pid string) (provider, uid string) { + var err error + provider, uid, err = ParseOAuth2PID(pid) + if err != nil { + panic(err) + } + + return provider, uid } diff --git a/user_test.go b/user_test.go index e689bf1..5cda8bf 100644 --- a/user_test.go +++ b/user_test.go @@ -13,11 +13,41 @@ func TestOAuth2PIDs(t *testing.T) { t.Error("pid was wrong:", pid) } - gotProvider, gotUID := ParseOAuth2PID(pid) + gotProvider, gotUID := ParseOAuth2PIDP(pid) if gotUID != uid { t.Error("uid was wrong:", gotUID) } if gotProvider != provider { t.Error("provider was wrong:", gotProvider) } + + notEnoughSegments, didntStartWithOAuth2 := false, false + + func() { + defer func() { + if r := recover(); r != nil { + notEnoughSegments = true + } + }() + + _, _ = ParseOAuth2PIDP("nope") + }() + + if !notEnoughSegments { + t.Error("expected a panic when there's not enough segments") + } + + func() { + defer func() { + if r := recover(); r != nil { + didntStartWithOAuth2 = true + } + }() + + _, _ = ParseOAuth2PIDP("notoauth2;;but;;restisgood") + }() + + if !didntStartWithOAuth2 { + t.Error("expected a panic when the pid doesn't start with oauth2") + } }