mirror of
https://github.com/go-kratos/kratos.git
synced 2025-04-11 11:42:10 +02:00
fix(jwt): parse server custom claims (#1817)
* fix(jwt): parse server custom claims * fix(jwt): parse server custom claims & use factory pattern Co-authored-by: 王真 <zhen.wang@yo-star.com>
This commit is contained in:
parent
85800cedb9
commit
4dadafff90
@ -48,7 +48,7 @@ type Option func(*options)
|
||||
// Parser is a jwt parser
|
||||
type options struct {
|
||||
signingMethod jwt.SigningMethod
|
||||
claims jwt.Claims
|
||||
claims func() jwt.Claims
|
||||
tokenHeader map[string]interface{}
|
||||
}
|
||||
|
||||
@ -60,9 +60,11 @@ func WithSigningMethod(method jwt.SigningMethod) Option {
|
||||
}
|
||||
|
||||
// WithClaims with customer claim
|
||||
func WithClaims(claims jwt.Claims) Option {
|
||||
// If you use it in Server, f needs to return a new jwt.Claims object each time to avoid concurrent write problems
|
||||
// If you use it in Client, f only needs to return a single object to provide performance
|
||||
func WithClaims(f func() jwt.Claims) Option {
|
||||
return func(o *options) {
|
||||
o.claims = claims
|
||||
o.claims = f
|
||||
}
|
||||
}
|
||||
|
||||
@ -77,7 +79,6 @@ func WithTokenHeader(header map[string]interface{}) Option {
|
||||
func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
|
||||
o := &options{
|
||||
signingMethod: jwt.SigningMethodHS256,
|
||||
claims: jwt.RegisteredClaims{},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
@ -93,7 +94,15 @@ func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
|
||||
return nil, ErrMissingJwtToken
|
||||
}
|
||||
jwtToken := auths[1]
|
||||
tokenInfo, err := jwt.Parse(jwtToken, keyFunc)
|
||||
var (
|
||||
tokenInfo *jwt.Token
|
||||
err error
|
||||
)
|
||||
if o.claims != nil {
|
||||
tokenInfo, err = jwt.ParseWithClaims(jwtToken, o.claims(), keyFunc)
|
||||
} else {
|
||||
tokenInfo, err = jwt.Parse(jwtToken, keyFunc)
|
||||
}
|
||||
if err != nil {
|
||||
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
|
||||
@ -120,9 +129,10 @@ func Server(keyFunc jwt.Keyfunc, opts ...Option) middleware.Middleware {
|
||||
|
||||
// Client is a client jwt middleware.
|
||||
func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware {
|
||||
claims := jwt.RegisteredClaims{}
|
||||
o := &options{
|
||||
signingMethod: jwt.SigningMethodHS256,
|
||||
claims: jwt.RegisteredClaims{},
|
||||
claims: func() jwt.Claims { return claims },
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
@ -132,7 +142,7 @@ func Client(keyProvider jwt.Keyfunc, opts ...Option) middleware.Middleware {
|
||||
if keyProvider == nil {
|
||||
return nil, ErrNeedTokenProvider
|
||||
}
|
||||
token := jwt.NewWithClaims(o.signingMethod, o.claims)
|
||||
token := jwt.NewWithClaims(o.signingMethod, o.claims())
|
||||
if o.tokenHeader != nil {
|
||||
for k, v := range o.tokenHeader {
|
||||
token.Header[k] = v
|
||||
|
@ -4,8 +4,11 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -63,6 +66,107 @@ func (tr *Transport) ReplyHeader() transport.Header {
|
||||
return nil
|
||||
}
|
||||
|
||||
type CustomerClaims struct {
|
||||
Name string `json:"name"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func TestJWTServerParse(t *testing.T) {
|
||||
var (
|
||||
errConcurrentWrite = errors.New("concurrent write claims")
|
||||
errParseClaims = errors.New("bad result, token claims is not CustomerClaims")
|
||||
)
|
||||
|
||||
testKey := "testKey"
|
||||
tests := []struct {
|
||||
name string
|
||||
token func() string
|
||||
claims func() jwt.Claims
|
||||
exceptErr error
|
||||
key string
|
||||
goroutineNum int
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
token: func() string {
|
||||
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &CustomerClaims{}).SignedString([]byte(testKey))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return fmt.Sprintf(bearerFormat, token)
|
||||
},
|
||||
claims: func() jwt.Claims {
|
||||
return &CustomerClaims{}
|
||||
},
|
||||
exceptErr: nil,
|
||||
key: testKey,
|
||||
goroutineNum: 1,
|
||||
},
|
||||
{
|
||||
name: "concurrent request",
|
||||
token: func() string {
|
||||
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &CustomerClaims{
|
||||
Name: strconv.Itoa(rand.Int()),
|
||||
}).SignedString([]byte(testKey))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return fmt.Sprintf(bearerFormat, token)
|
||||
},
|
||||
claims: func() jwt.Claims {
|
||||
return &CustomerClaims{}
|
||||
},
|
||||
exceptErr: nil,
|
||||
key: testKey,
|
||||
goroutineNum: 10000,
|
||||
},
|
||||
}
|
||||
|
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
testToken, _ := FromContext(ctx)
|
||||
var name string
|
||||
if customerClaims, ok := testToken.(*CustomerClaims); ok {
|
||||
name = customerClaims.Name
|
||||
} else {
|
||||
return nil, errParseClaims
|
||||
}
|
||||
|
||||
// mock biz
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if customerClaims, ok := testToken.(*CustomerClaims); ok {
|
||||
if name != customerClaims.Name {
|
||||
return nil, errConcurrentWrite
|
||||
}
|
||||
} else {
|
||||
return nil, errParseClaims
|
||||
}
|
||||
return "reply", nil
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
server := Server(
|
||||
func(token *jwt.Token) (interface{}, error) { return []byte(testKey), nil },
|
||||
WithClaims(test.claims),
|
||||
)(next)
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < test.goroutineNum; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx := transport.NewServerContext(context.Background(), &Transport{reqHeader: newTokenHeader(authorizationKey, test.token())})
|
||||
_, err2 := server(ctx, test.name)
|
||||
if !errors.Is(test.exceptErr, err2) {
|
||||
t.Errorf("except error %v, but got %v", test.exceptErr, err2)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
testKey := "testKey"
|
||||
mapClaims := jwt.MapClaims{}
|
||||
@ -279,6 +383,7 @@ func TestClientWithClaims(t *testing.T) {
|
||||
testKey := "testKey"
|
||||
mapClaims := jwt.MapClaims{}
|
||||
mapClaims["name"] = "xiaoli"
|
||||
mapClaimsFunc := func() jwt.Claims { return mapClaims }
|
||||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
|
||||
token, err := claims.SignedString([]byte(testKey))
|
||||
if err != nil {
|
||||
@ -301,7 +406,7 @@ func TestClientWithClaims(t *testing.T) {
|
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return "reply", nil
|
||||
}
|
||||
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next)
|
||||
handler := Client(test.tokenProvider, WithClaims(mapClaimsFunc))(next)
|
||||
header := &headerCarrier{}
|
||||
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
|
||||
if !errors.Is(test.expectError, err2) {
|
||||
@ -319,6 +424,7 @@ func TestClientWithHeader(t *testing.T) {
|
||||
testKey := "testKey"
|
||||
mapClaims := jwt.MapClaims{}
|
||||
mapClaims["name"] = "xiaoli"
|
||||
mapClaimsFunc := func() jwt.Claims { return mapClaims }
|
||||
tokenHeader := map[string]interface{}{
|
||||
"test": "test",
|
||||
}
|
||||
@ -336,7 +442,7 @@ func TestClientWithHeader(t *testing.T) {
|
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return "reply", nil
|
||||
}
|
||||
handler := Client(tProvider, WithClaims(mapClaims), WithTokenHeader(tokenHeader))(next)
|
||||
handler := Client(tProvider, WithClaims(mapClaimsFunc), WithTokenHeader(tokenHeader))(next)
|
||||
header := &headerCarrier{}
|
||||
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
|
||||
if err2 != nil {
|
||||
@ -351,6 +457,7 @@ func TestClientMissKey(t *testing.T) {
|
||||
testKey := "testKey"
|
||||
mapClaims := jwt.MapClaims{}
|
||||
mapClaims["name"] = "xiaoli"
|
||||
mapClaimsFunc := func() jwt.Claims { return mapClaims }
|
||||
claims := jwt.NewWithClaims(jwt.SigningMethodHS256, mapClaims)
|
||||
token, err := claims.SignedString([]byte(testKey))
|
||||
if err != nil {
|
||||
@ -373,7 +480,7 @@ func TestClientMissKey(t *testing.T) {
|
||||
next := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return "reply", nil
|
||||
}
|
||||
handler := Client(test.tokenProvider, WithClaims(mapClaims))(next)
|
||||
handler := Client(test.tokenProvider, WithClaims(mapClaimsFunc))(next)
|
||||
header := &headerCarrier{}
|
||||
_, err2 := handler(transport.NewClientContext(context.Background(), &Transport{reqHeader: header}), "ok")
|
||||
if !errors.Is(test.expectError, err2) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user