1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-03-05 15:15:45 +02:00

Add a few convenience pieces

- Add helper to directly merge data into a request (common use case)
- Allow parsing of OAuth2PID without panic
- Add oauth2.* strings to the modules list in case people want to be
  able to switch on which oauth2 providers are available in their
  views.
This commit is contained in:
Aaron L 2018-05-08 20:39:39 -07:00
parent 2399b4c089
commit 48b33b0217
6 changed files with 120 additions and 8 deletions

View File

@ -1,5 +1,10 @@
package authboss
import (
"context"
"net/http"
)
// Keys for use in HTMLData that are meaningful
const (
// DataErr is for one off errors that don't really belong to
@ -66,3 +71,18 @@ func (h HTMLData) MergeKV(data ...interface{}) HTMLData {
return h
}
// MergeDataInRequest edits the request pointer to point to a new request with
// a modified context that contains the merged data.
func MergeDataInRequest(r **http.Request, other HTMLData) {
ctx := (*r).Context()
currentIntf := ctx.Value(CTXKeyData)
if currentIntf == nil {
*r = (*r).WithContext(context.WithValue(ctx, CTXKeyData, other))
return
}
current := currentIntf.(HTMLData)
merged := current.Merge(other)
*r = (*r).WithContext(context.WithValue(ctx, CTXKeyData, merged))
}

View File

@ -1,6 +1,10 @@
package authboss
import "testing"
import (
"context"
"net/http/httptest"
"testing"
)
func TestHTMLData(t *testing.T) {
t.Parallel()
@ -51,3 +55,29 @@ func TestHTMLData_Panics(t *testing.T) {
t.Error("They all should have paniced.")
}
}
func TestHTMLDataMergeDataInRequest(t *testing.T) {
t.Parallel()
r := httptest.NewRequest("GET", "/", nil)
MergeDataInRequest(&r, HTMLData{"hello": "world"})
val := r.Context().Value(CTXKeyData).(HTMLData)["hello"].(string)
if val != "world" {
t.Error("expected world, got:", val)
}
r = httptest.NewRequest("GET", "/", nil)
r = r.WithContext(context.WithValue(context.Background(), CTXKeyData, HTMLData{"first": "here"}))
MergeDataInRequest(&r, HTMLData{"hello": "world"})
val = r.Context().Value(CTXKeyData).(HTMLData)["hello"].(string)
if val != "world" {
t.Error("expected world, got:", val)
}
val = r.Context().Value(CTXKeyData).(HTMLData)["first"].(string)
if val != "here" {
t.Error("expected world, got:", val)
}
}

View File

@ -91,6 +91,11 @@ func (a *Authboss) loadModule(name string) error {
// of wether or not the module is loaded.
// Data looks like:
// map[modulename] = true
//
// oauth2 providers are also listed here using the syntax:
// oauth2.google for an example. Be careful since this doesn't actually mean
// that the oauth2 module has been loaded so you should do a conditional that checks
// for both.
func ModuleListMiddleware(ab *Authboss) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -109,6 +114,10 @@ func ModuleListMiddleware(ab *Authboss) func(http.Handler) http.Handler {
loaded[k] = true
}
for provider := range ab.Config.Modules.OAuth2Providers {
loaded["oauth2."+provider] = true
}
data[DataModules] = loaded
r = r.WithContext(context.WithValue(ctx, CTXKeyData, data))
next.ServeHTTP(w, r)

View File

@ -78,6 +78,10 @@ func TestModuleLoadedMiddleware(t *testing.T) {
ab.loadedModules = map[string]Moduler{
"recover": nil,
"auth": nil,
"oauth2": nil,
}
ab.Config.Modules.OAuth2Providers = map[string]OAuth2Provider{
"google": OAuth2Provider{},
}
var mods map[string]bool
@ -88,8 +92,8 @@ func TestModuleLoadedMiddleware(t *testing.T) {
server.ServeHTTP(nil, httptest.NewRequest("GET", "/", nil))
if len(mods) != 2 {
t.Error("want two modules, got:", len(mods))
if len(mods) != 4 {
t.Error("want 4 modules, got:", len(mods))
}
if _, ok := mods["auth"]; !ok {
@ -98,4 +102,10 @@ func TestModuleLoadedMiddleware(t *testing.T) {
if _, ok := mods["recover"]; !ok {
t.Error("recover should be loaded")
}
if _, ok := mods["oauth2"]; !ok {
t.Error("modules should include oauth2.google")
}
if _, ok := mods["oauth2.google"]; !ok {
t.Error("modules should include oauth2.google")
}
}

21
user.go
View File

@ -4,6 +4,8 @@ import (
"fmt"
"strings"
"time"
"github.com/pkg/errors"
)
// User has functions for each piece of data it requires.
@ -154,14 +156,25 @@ func MakeOAuth2PID(provider, uid string) string {
}
// ParseOAuth2PID returns the uid and provider for a given OAuth2 pid
func ParseOAuth2PID(pid string) (provider, uid string) {
func ParseOAuth2PID(pid string) (provider, uid string, err error) {
splits := strings.Split(pid, ";;")
if len(splits) != 3 {
panic(fmt.Sprintf("failed to parse oauth2 pid, too many segments: %s", pid))
return "", "", errors.Errorf("failed to parse oauth2 pid, too many segments: %s", pid)
}
if splits[0] != "oauth2" {
panic(fmt.Sprintf("invalid oauth2 pid, did not start with oauth2: %s", pid))
return "", "", errors.Errorf("invalid oauth2 pid, did not start with oauth2: %s", pid)
}
return splits[1], splits[2]
return splits[1], splits[2], nil
}
// ParseOAuth2PIDP returns the uid and provider for a given OAuth2 pid
func ParseOAuth2PIDP(pid string) (provider, uid string) {
var err error
provider, uid, err = ParseOAuth2PID(pid)
if err != nil {
panic(err)
}
return provider, uid
}

View File

@ -13,11 +13,41 @@ func TestOAuth2PIDs(t *testing.T) {
t.Error("pid was wrong:", pid)
}
gotProvider, gotUID := ParseOAuth2PID(pid)
gotProvider, gotUID := ParseOAuth2PIDP(pid)
if gotUID != uid {
t.Error("uid was wrong:", gotUID)
}
if gotProvider != provider {
t.Error("provider was wrong:", gotProvider)
}
notEnoughSegments, didntStartWithOAuth2 := false, false
func() {
defer func() {
if r := recover(); r != nil {
notEnoughSegments = true
}
}()
_, _ = ParseOAuth2PIDP("nope")
}()
if !notEnoughSegments {
t.Error("expected a panic when there's not enough segments")
}
func() {
defer func() {
if r := recover(); r != nil {
didntStartWithOAuth2 = true
}
}()
_, _ = ParseOAuth2PIDP("notoauth2;;but;;restisgood")
}()
if !didntStartWithOAuth2 {
t.Error("expected a panic when the pid doesn't start with oauth2")
}
}