mirror of
https://github.com/volatiletech/authboss.git
synced 2025-03-05 15:15:45 +02:00
Add a few convenience pieces
- Add helper to directly merge data into a request (common use case) - Allow parsing of OAuth2PID without panic - Add oauth2.* strings to the modules list in case people want to be able to switch on which oauth2 providers are available in their views.
This commit is contained in:
parent
2399b4c089
commit
48b33b0217
20
html_data.go
20
html_data.go
@ -1,5 +1,10 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Keys for use in HTMLData that are meaningful
|
||||
const (
|
||||
// DataErr is for one off errors that don't really belong to
|
||||
@ -66,3 +71,18 @@ func (h HTMLData) MergeKV(data ...interface{}) HTMLData {
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// MergeDataInRequest edits the request pointer to point to a new request with
|
||||
// a modified context that contains the merged data.
|
||||
func MergeDataInRequest(r **http.Request, other HTMLData) {
|
||||
ctx := (*r).Context()
|
||||
currentIntf := ctx.Value(CTXKeyData)
|
||||
if currentIntf == nil {
|
||||
*r = (*r).WithContext(context.WithValue(ctx, CTXKeyData, other))
|
||||
return
|
||||
}
|
||||
|
||||
current := currentIntf.(HTMLData)
|
||||
merged := current.Merge(other)
|
||||
*r = (*r).WithContext(context.WithValue(ctx, CTXKeyData, merged))
|
||||
}
|
||||
|
@ -1,6 +1,10 @@
|
||||
package authboss
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHTMLData(t *testing.T) {
|
||||
t.Parallel()
|
||||
@ -51,3 +55,29 @@ func TestHTMLData_Panics(t *testing.T) {
|
||||
t.Error("They all should have paniced.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTMLDataMergeDataInRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
MergeDataInRequest(&r, HTMLData{"hello": "world"})
|
||||
|
||||
val := r.Context().Value(CTXKeyData).(HTMLData)["hello"].(string)
|
||||
if val != "world" {
|
||||
t.Error("expected world, got:", val)
|
||||
}
|
||||
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
r = r.WithContext(context.WithValue(context.Background(), CTXKeyData, HTMLData{"first": "here"}))
|
||||
MergeDataInRequest(&r, HTMLData{"hello": "world"})
|
||||
|
||||
val = r.Context().Value(CTXKeyData).(HTMLData)["hello"].(string)
|
||||
if val != "world" {
|
||||
t.Error("expected world, got:", val)
|
||||
}
|
||||
|
||||
val = r.Context().Value(CTXKeyData).(HTMLData)["first"].(string)
|
||||
if val != "here" {
|
||||
t.Error("expected world, got:", val)
|
||||
}
|
||||
}
|
||||
|
@ -91,6 +91,11 @@ func (a *Authboss) loadModule(name string) error {
|
||||
// of wether or not the module is loaded.
|
||||
// Data looks like:
|
||||
// map[modulename] = true
|
||||
//
|
||||
// oauth2 providers are also listed here using the syntax:
|
||||
// oauth2.google for an example. Be careful since this doesn't actually mean
|
||||
// that the oauth2 module has been loaded so you should do a conditional that checks
|
||||
// for both.
|
||||
func ModuleListMiddleware(ab *Authboss) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@ -109,6 +114,10 @@ func ModuleListMiddleware(ab *Authboss) func(http.Handler) http.Handler {
|
||||
loaded[k] = true
|
||||
}
|
||||
|
||||
for provider := range ab.Config.Modules.OAuth2Providers {
|
||||
loaded["oauth2."+provider] = true
|
||||
}
|
||||
|
||||
data[DataModules] = loaded
|
||||
r = r.WithContext(context.WithValue(ctx, CTXKeyData, data))
|
||||
next.ServeHTTP(w, r)
|
||||
|
@ -78,6 +78,10 @@ func TestModuleLoadedMiddleware(t *testing.T) {
|
||||
ab.loadedModules = map[string]Moduler{
|
||||
"recover": nil,
|
||||
"auth": nil,
|
||||
"oauth2": nil,
|
||||
}
|
||||
ab.Config.Modules.OAuth2Providers = map[string]OAuth2Provider{
|
||||
"google": OAuth2Provider{},
|
||||
}
|
||||
|
||||
var mods map[string]bool
|
||||
@ -88,8 +92,8 @@ func TestModuleLoadedMiddleware(t *testing.T) {
|
||||
|
||||
server.ServeHTTP(nil, httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
if len(mods) != 2 {
|
||||
t.Error("want two modules, got:", len(mods))
|
||||
if len(mods) != 4 {
|
||||
t.Error("want 4 modules, got:", len(mods))
|
||||
}
|
||||
|
||||
if _, ok := mods["auth"]; !ok {
|
||||
@ -98,4 +102,10 @@ func TestModuleLoadedMiddleware(t *testing.T) {
|
||||
if _, ok := mods["recover"]; !ok {
|
||||
t.Error("recover should be loaded")
|
||||
}
|
||||
if _, ok := mods["oauth2"]; !ok {
|
||||
t.Error("modules should include oauth2.google")
|
||||
}
|
||||
if _, ok := mods["oauth2.google"]; !ok {
|
||||
t.Error("modules should include oauth2.google")
|
||||
}
|
||||
}
|
||||
|
21
user.go
21
user.go
@ -4,6 +4,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// User has functions for each piece of data it requires.
|
||||
@ -154,14 +156,25 @@ func MakeOAuth2PID(provider, uid string) string {
|
||||
}
|
||||
|
||||
// ParseOAuth2PID returns the uid and provider for a given OAuth2 pid
|
||||
func ParseOAuth2PID(pid string) (provider, uid string) {
|
||||
func ParseOAuth2PID(pid string) (provider, uid string, err error) {
|
||||
splits := strings.Split(pid, ";;")
|
||||
if len(splits) != 3 {
|
||||
panic(fmt.Sprintf("failed to parse oauth2 pid, too many segments: %s", pid))
|
||||
return "", "", errors.Errorf("failed to parse oauth2 pid, too many segments: %s", pid)
|
||||
}
|
||||
if splits[0] != "oauth2" {
|
||||
panic(fmt.Sprintf("invalid oauth2 pid, did not start with oauth2: %s", pid))
|
||||
return "", "", errors.Errorf("invalid oauth2 pid, did not start with oauth2: %s", pid)
|
||||
}
|
||||
|
||||
return splits[1], splits[2]
|
||||
return splits[1], splits[2], nil
|
||||
}
|
||||
|
||||
// ParseOAuth2PIDP returns the uid and provider for a given OAuth2 pid
|
||||
func ParseOAuth2PIDP(pid string) (provider, uid string) {
|
||||
var err error
|
||||
provider, uid, err = ParseOAuth2PID(pid)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return provider, uid
|
||||
}
|
||||
|
32
user_test.go
32
user_test.go
@ -13,11 +13,41 @@ func TestOAuth2PIDs(t *testing.T) {
|
||||
t.Error("pid was wrong:", pid)
|
||||
}
|
||||
|
||||
gotProvider, gotUID := ParseOAuth2PID(pid)
|
||||
gotProvider, gotUID := ParseOAuth2PIDP(pid)
|
||||
if gotUID != uid {
|
||||
t.Error("uid was wrong:", gotUID)
|
||||
}
|
||||
if gotProvider != provider {
|
||||
t.Error("provider was wrong:", gotProvider)
|
||||
}
|
||||
|
||||
notEnoughSegments, didntStartWithOAuth2 := false, false
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
notEnoughSegments = true
|
||||
}
|
||||
}()
|
||||
|
||||
_, _ = ParseOAuth2PIDP("nope")
|
||||
}()
|
||||
|
||||
if !notEnoughSegments {
|
||||
t.Error("expected a panic when there's not enough segments")
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
didntStartWithOAuth2 = true
|
||||
}
|
||||
}()
|
||||
|
||||
_, _ = ParseOAuth2PIDP("notoauth2;;but;;restisgood")
|
||||
}()
|
||||
|
||||
if !didntStartWithOAuth2 {
|
||||
t.Error("expected a panic when the pid doesn't start with oauth2")
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user