mirror of
https://github.com/volatiletech/authboss.git
synced 2024-11-24 08:42:17 +02:00
91403280ec
Also fixes behavior for sql.NullInt64, sql.NullFloat64, sql.NullBool.
394 lines
9.5 KiB
Go
394 lines
9.5 KiB
Go
package authboss
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
)
|
|
|
|
// Data store constants for attribute names.
|
|
const (
|
|
StoreEmail = "email"
|
|
StoreUsername = "username"
|
|
StorePassword = "password"
|
|
)
|
|
|
|
// Data store constants for OAuth2 attribute names.
|
|
const (
|
|
StoreOAuth2UID = "oauth2_uid"
|
|
StoreOAuth2Provider = "oauth2_provider"
|
|
StoreOAuth2Token = "oauth2_token"
|
|
StoreOAuth2Refresh = "oauth2_refresh"
|
|
StoreOAuth2Expiry = "oauth2_expiry"
|
|
)
|
|
|
|
var (
|
|
// ErrUserNotFound should be returned from Get when the record is not found.
|
|
ErrUserNotFound = errors.New("User not found")
|
|
// ErrTokenNotFound should be returned from UseToken when the record is not found.
|
|
ErrTokenNotFound = errors.New("Token not found")
|
|
// ErrUserFound should be retruned from Create when the primaryID of the record is found.
|
|
ErrUserFound = errors.New("User found")
|
|
)
|
|
|
|
// StorageOptions is a map depicting the things a module must be able to store.
|
|
type StorageOptions map[string]DataType
|
|
|
|
// Storer must be implemented in order to store the user's attributes somewhere.
|
|
// The type of store is up to the developer implementing it, and all it has to
|
|
// do is be able to store several simple types.
|
|
type Storer interface {
|
|
// Put is for storing the attributes passed in using the key. This is an
|
|
// update only method and should not store if it does not find the key.
|
|
Put(key string, attr Attributes) error
|
|
// Get is for retrieving attributes for a given key. The return value
|
|
// must be a struct that contains all fields with the correct types as shown
|
|
// by attrMeta. If the key is not found in the data store simply
|
|
// return nil, ErrUserNotFound.
|
|
Get(key string) (interface{}, error)
|
|
}
|
|
|
|
// OAuth2Storer is a replacement (or addition) to the Storer interface.
|
|
// It allows users to be stored and fetched via a uid/provider combination.
|
|
type OAuth2Storer interface {
|
|
// PutOAuth creates or updates an existing record (unlike Storer.Put)
|
|
// because in the OAuth flow there is no separate create/update.
|
|
PutOAuth(uid, provider string, attr Attributes) error
|
|
GetOAuth(uid, provider string) (interface{}, error)
|
|
}
|
|
|
|
// DataType represents the various types that clients must be able to store.
|
|
type DataType int
|
|
|
|
// DataType constants
|
|
const (
|
|
Integer DataType = iota
|
|
String
|
|
Bool
|
|
DateTime
|
|
)
|
|
|
|
var (
|
|
dateTimeType = reflect.TypeOf(time.Time{})
|
|
)
|
|
|
|
// String returns a string for the DataType representation.
|
|
func (d DataType) String() string {
|
|
switch d {
|
|
case Integer:
|
|
return "Integer"
|
|
case String:
|
|
return "String"
|
|
case Bool:
|
|
return "Bool"
|
|
case DateTime:
|
|
return "DateTime"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// AttributeMeta stores type information for attributes.
|
|
type AttributeMeta map[string]DataType
|
|
|
|
// Names returns the names of all the attributes.
|
|
func (a AttributeMeta) Names() []string {
|
|
names := make([]string, len(a))
|
|
i := 0
|
|
for n := range a {
|
|
names[i] = n
|
|
i++
|
|
}
|
|
return names
|
|
}
|
|
|
|
// Attributes is just a key-value mapping of data.
|
|
type Attributes map[string]interface{}
|
|
|
|
// Attributes converts the post form values into an attributes map.
|
|
func AttributesFromRequest(r *http.Request) (Attributes, error) {
|
|
attr := make(Attributes)
|
|
|
|
if err := r.ParseForm(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for name, values := range r.Form {
|
|
if len(values) == 0 {
|
|
continue
|
|
}
|
|
|
|
val := values[0]
|
|
if len(val) == 0 {
|
|
continue
|
|
}
|
|
|
|
switch {
|
|
case strings.HasSuffix(name, "_int"):
|
|
integer, err := strconv.Atoi(val)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%q (%q): could not be converted to an integer: %v", name, val, err)
|
|
}
|
|
attr[strings.TrimRight(name, "_int")] = integer
|
|
case strings.HasSuffix(name, "_date"):
|
|
date, err := time.Parse(time.RFC3339, val)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%q (%q): could not be converted to a datetime: %v", name, val, err)
|
|
}
|
|
attr[strings.TrimRight(name, "_date")] = date.UTC()
|
|
default:
|
|
attr[name] = val
|
|
}
|
|
}
|
|
|
|
return attr, nil
|
|
}
|
|
|
|
// Names returns the names of all the attributes.
|
|
func (a Attributes) Names() []string {
|
|
names := make([]string, len(a))
|
|
i := 0
|
|
for n := range a {
|
|
names[i] = n
|
|
i++
|
|
}
|
|
return names
|
|
}
|
|
|
|
// String returns a single value as a string
|
|
func (a Attributes) String(key string) (string, bool) {
|
|
inter, ok := a[key]
|
|
if !ok {
|
|
return "", false
|
|
}
|
|
val, ok := inter.(string)
|
|
return val, ok
|
|
}
|
|
|
|
// Int64 returns a single value as a int64
|
|
func (a Attributes) Int64(key string) (int64, bool) {
|
|
inter, ok := a[key]
|
|
if !ok {
|
|
return 0, false
|
|
}
|
|
val, ok := inter.(int64)
|
|
return val, ok
|
|
}
|
|
|
|
// Bool returns a single value as a bool.
|
|
func (a Attributes) Bool(key string) (val bool, ok bool) {
|
|
var inter interface{}
|
|
inter, ok = a[key]
|
|
if !ok {
|
|
return val, ok
|
|
}
|
|
|
|
val, ok = inter.(bool)
|
|
return val, ok
|
|
}
|
|
|
|
// DateTime returns a single value as a time.Time
|
|
func (a Attributes) DateTime(key string) (time.Time, bool) {
|
|
inter, ok := a[key]
|
|
if !ok {
|
|
var time time.Time
|
|
return time, false
|
|
}
|
|
val, ok := inter.(time.Time)
|
|
return val, ok
|
|
}
|
|
|
|
// StringErr returns a single value as a string
|
|
func (a Attributes) StringErr(key string) (val string, err error) {
|
|
inter, ok := a[key]
|
|
if !ok {
|
|
return "", AttributeErr{Name: key}
|
|
}
|
|
val, ok = inter.(string)
|
|
if !ok {
|
|
return val, NewAttributeErr(key, String, inter)
|
|
}
|
|
return val, nil
|
|
}
|
|
|
|
// Int64Err returns a single value as a int
|
|
func (a Attributes) Int64Err(key string) (val int64, err error) {
|
|
inter, ok := a[key]
|
|
if !ok {
|
|
return val, AttributeErr{Name: key}
|
|
}
|
|
val, ok = inter.(int64)
|
|
if !ok {
|
|
return val, NewAttributeErr(key, Integer, inter)
|
|
}
|
|
return val, nil
|
|
}
|
|
|
|
// BoolErr returns a single value as a bool.
|
|
func (a Attributes) BoolErr(key string) (val bool, err error) {
|
|
inter, ok := a[key]
|
|
if !ok {
|
|
return val, AttributeErr{Name: key}
|
|
}
|
|
val, ok = inter.(bool)
|
|
if !ok {
|
|
return val, NewAttributeErr(key, Integer, inter)
|
|
}
|
|
return val, nil
|
|
}
|
|
|
|
// DateTimeErr returns a single value as a time.Time
|
|
func (a Attributes) DateTimeErr(key string) (val time.Time, err error) {
|
|
inter, ok := a[key]
|
|
if !ok {
|
|
return val, AttributeErr{Name: key}
|
|
}
|
|
val, ok = inter.(time.Time)
|
|
if !ok {
|
|
return val, NewAttributeErr(key, DateTime, inter)
|
|
}
|
|
return val, nil
|
|
}
|
|
|
|
// Bind the data in the attributes to the given struct. This means the
|
|
// struct creator must have read the documentation and decided what fields
|
|
// will be needed ahead of time. Ignore missing ignores attributes for
|
|
// which a struct attribute equivalent can not be found.
|
|
func (a Attributes) Bind(strct interface{}, ignoreMissing bool) error {
|
|
structType := reflect.TypeOf(strct)
|
|
if structType.Kind() != reflect.Ptr {
|
|
return errors.New("Bind: Must pass in a struct pointer.")
|
|
}
|
|
|
|
structVal := reflect.ValueOf(strct).Elem()
|
|
structType = structVal.Type()
|
|
for k, v := range a {
|
|
|
|
k = underToCamel(k)
|
|
|
|
if _, has := structType.FieldByName(k); !has && ignoreMissing {
|
|
continue
|
|
} else if !has {
|
|
return fmt.Errorf("Bind: Struct was missing %s field, type: %v", k, reflect.TypeOf(v).String())
|
|
}
|
|
|
|
field := structVal.FieldByName(k)
|
|
if !field.CanSet() {
|
|
return fmt.Errorf("Bind: Found field %s, but was not writeable", k)
|
|
}
|
|
|
|
fieldType := field.Type()
|
|
fieldPtr := field.Addr()
|
|
|
|
if _, ok := fieldPtr.Interface().(sql.Scanner); ok {
|
|
method := fieldPtr.MethodByName("Scan")
|
|
if !method.IsValid() {
|
|
return errors.New("Bind: Was a scanner without a Scan method")
|
|
}
|
|
|
|
if v != nil {
|
|
rvals := method.Call([]reflect.Value{reflect.ValueOf(v)})
|
|
if err, ok := rvals[0].Interface().(error); ok && err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
if valType := reflect.TypeOf(v); fieldType != valType {
|
|
return fmt.Errorf("Bind: Field %s's type should be %s but was %s", k, valType, fieldType)
|
|
}
|
|
field.Set(reflect.ValueOf(v))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// StoreMaker is used to create a storer from an http request.
|
|
type StoreMaker func(http.ResponseWriter, *http.Request) Storer
|
|
|
|
// OAuth2StoreMaker is used to create an oauth2 storer from an http request.
|
|
type OAuth2StoreMaker func(http.ResponseWriter, *http.Request) OAuth2Storer
|
|
|
|
// Unbind is the opposite of Bind, taking a struct in and producing a list of attributes.
|
|
func Unbind(intf interface{}) Attributes {
|
|
structValue := reflect.ValueOf(intf)
|
|
if structValue.Kind() == reflect.Ptr {
|
|
structValue = structValue.Elem()
|
|
}
|
|
|
|
structType := structValue.Type()
|
|
attr := make(Attributes)
|
|
for i := 0; i < structValue.NumField(); i++ {
|
|
field := structValue.Field(i)
|
|
|
|
name := structType.Field(i).Name
|
|
if unicode.IsLower(rune(name[0])) {
|
|
continue // Unexported
|
|
}
|
|
|
|
name = camelToUnder(name)
|
|
|
|
fieldPtr := field.Addr()
|
|
if _, ok := fieldPtr.Interface().(driver.Valuer); ok {
|
|
method := fieldPtr.MethodByName("Value")
|
|
if !method.IsValid() {
|
|
panic("Unbind: Was a valuer without a Value method")
|
|
}
|
|
|
|
rvals := method.Call([]reflect.Value{})
|
|
if err, ok := rvals[1].Interface().(error); ok && err != nil {
|
|
panic(fmt.Errorf("Unbind: Failed to get value out of Valuer: %s, %v", name, err))
|
|
}
|
|
attr[name] = rvals[0].Interface()
|
|
|
|
continue
|
|
}
|
|
|
|
attr[name] = field.Interface()
|
|
}
|
|
|
|
return attr
|
|
}
|
|
|
|
func camelToUnder(in string) string {
|
|
out := bytes.Buffer{}
|
|
for i := 0; i < len(in); i++ {
|
|
chr := in[i]
|
|
if chr >= 'A' && chr <= 'Z' {
|
|
if i > 0 {
|
|
out.WriteByte('_')
|
|
}
|
|
out.WriteByte(chr + 'a' - 'A')
|
|
} else {
|
|
out.WriteByte(chr)
|
|
}
|
|
}
|
|
return out.String()
|
|
}
|
|
|
|
func underToCamel(in string) string {
|
|
out := bytes.Buffer{}
|
|
for i := 0; i < len(in); i++ {
|
|
chr := in[i]
|
|
|
|
if first := i == 0; first || chr == '_' {
|
|
if !first {
|
|
i++
|
|
}
|
|
out.WriteByte(in[i] - 'a' + 'A')
|
|
} else {
|
|
out.WriteByte(chr)
|
|
}
|
|
}
|
|
return out.String()
|
|
}
|