mirror of
https://github.com/go-kratos/kratos.git
synced 2025-02-21 19:19:32 +02:00
feat(middleware/auth/jwt): add customer header (#1752)
This commit is contained in:
parent
1c3185f9e5
commit
76ab0baa56
@ -22,21 +22,24 @@ const (
|
||||
// bearerFormat authorization token format
|
||||
bearerFormat string = "Bearer %s"
|
||||
|
||||
// authorizationKey holds the key used to store the JWT Token in the request header.
|
||||
// authorizationKey holds the key used to store the JWT Token in the request tokenHeader.
|
||||
authorizationKey string = "Authorization"
|
||||
|
||||
// reason holds the error reason.
|
||||
reason string = "UNAUTHORIZED"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingJwtToken = errors.Unauthorized("UNAUTHORIZED", "JWT token is missing")
|
||||
ErrMissingKeyFunc = errors.Unauthorized("UNAUTHORIZED", "keyFunc is missing")
|
||||
ErrTokenInvalid = errors.Unauthorized("UNAUTHORIZED", "Token is invalid")
|
||||
ErrTokenExpired = errors.Unauthorized("UNAUTHORIZED", "JWT token has expired")
|
||||
ErrTokenParseFail = errors.Unauthorized("UNAUTHORIZED", "Fail to parse JWT token ")
|
||||
ErrUnSupportSigningMethod = errors.Unauthorized("UNAUTHORIZED", "Wrong signing method")
|
||||
ErrWrongContext = errors.Unauthorized("UNAUTHORIZED", "Wrong context for middleware")
|
||||
ErrNeedTokenProvider = errors.Unauthorized("UNAUTHORIZED", "Token provider is missing")
|
||||
ErrSignToken = errors.Unauthorized("UNAUTHORIZED", "Can not sign token.Is the key correct?")
|
||||
ErrGetKey = errors.Unauthorized("UNAUTHORIZED", "Can not get key while signing token")
|
||||
ErrMissingJwtToken = errors.Unauthorized(reason, "JWT token is missing")
|
||||
ErrMissingKeyFunc = errors.Unauthorized(reason, "keyFunc is missing")
|
||||
ErrTokenInvalid = errors.Unauthorized(reason, "Token is invalid")
|
||||
ErrTokenExpired = errors.Unauthorized(reason, "JWT token has expired")
|
||||
ErrTokenParseFail = errors.Unauthorized(reason, "Fail to parse JWT token ")
|
||||
ErrUnSupportSigningMethod = errors.Unauthorized(reason, "Wrong signing method")
|
||||
ErrWrongContext = errors.Unauthorized(reason, "Wrong context for middleware")
|
||||
ErrNeedTokenProvider = errors.Unauthorized(reason, "Token provider is missing")
|
||||
ErrSignToken = errors.Unauthorized(reason, "Can not sign token.Is the key correct?")
|
||||
ErrGetKey = errors.Unauthorized(reason, "Can not get key while signing token")
|
||||
)
|
||||
|
||||
// Option is jwt option.
|
||||
@ -46,6 +49,7 @@ type Option func(*options)
|
||||
type options struct {
|
||||
signingMethod jwt.SigningMethod
|
||||
claims jwt.Claims
|
||||
tokenHeader map[string]interface{}
|
||||
}
|
||||
|
||||
// WithSigningMethod with signing method option.
|
||||
@ -62,6 +66,13 @@ func WithClaims(claims jwt.Claims) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithTokenHeader withe customer tokenHeader for client side
|
||||
func WithTokenHeader(header map[string]interface{}) Option {
|
||||
return func(o *options) {
|
||||
o.tokenHeader = header
|
||||
}
|
||||
}
|
||||
|
||||
// Server is a server auth middleware. Check the token and extract the info from token.
|
||||
func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
|
||||
o := &options{
|
||||
@ -93,7 +104,7 @@ func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
|
||||
return nil, ErrTokenParseFail
|
||||
}
|
||||
}
|
||||
return nil, errors.Unauthorized("UNAUTHORIZED", err.Error())
|
||||
return nil, errors.Unauthorized(reason, err.Error())
|
||||
} else if !tokenInfo.Valid {
|
||||
return nil, ErrTokenInvalid
|
||||
} else if tokenInfo.Method != o.signingMethod {
|
||||
@ -122,6 +133,11 @@ func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware {
|
||||
return nil, ErrNeedTokenProvider
|
||||
}
|
||||
token := jwt.NewWithClaims(o.signingMethod, o.claims)
|
||||
if o.tokenHeader != nil {
|
||||
for k, v := range o.tokenHeader {
|
||||
token.Header[k] = v
|
||||
}
|
||||
}
|
||||
key, err := keyProvider(token)
|
||||
if err != nil {
|
||||
return nil, ErrGetKey
|
||||
|
@ -294,6 +294,34 @@ func TestClientWithClaims(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientWithHeader(t *testing.T) {
|
||||
testKey := "testKey"
|
||||
mapClaims := jwt.MapClaims{}
|
||||
mapClaims["name"] = "xiaoli"
|
||||
tokenHeader := map[string]interface{}{
|
||||
"test": "test",
|
||||
}
|
||||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
|
||||
for k, v := range tokenHeader {
|
||||
claims.Header[k] = v
|
||||
}
|
||||
token, err := claims.SignedString([]byte(testKey))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tProvider := func(*jwt.Token) (interface{}, error) {
|
||||
return []byte(testKey), nil
|
||||
}
|
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return "reply", nil
|
||||
}
|
||||
handler := Client(tProvider, WithClaims(mapClaims), WithTokenHeader(tokenHeader))(next)
|
||||
header := &headerCarrier{}
|
||||
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
|
||||
assert.Equal(t, nil, err2)
|
||||
assert.Equal(t, fmt.Sprintf(bearerFormat, token), header.Get(authorizationKey))
|
||||
}
|
||||
|
||||
func TestClientMissKey(t *testing.T) {
|
||||
testKey := "testKey"
|
||||
mapClaims := jwt.MapClaims{}
|
||||
|
Loading…
x
Reference in New Issue
Block a user