You've already forked pocketbase
mirror of
https://github.com/pocketbase/pocketbase.git
synced 2025-11-23 22:55:37 +02:00
moved ValidateTokenSignature to jwk and added tests
This commit is contained in:
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pocketbase/pocketbase/tools/auth/internal/jwk"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -108,10 +110,10 @@ func (p *Apple) parseAndVerifyIdToken(idToken string) (jwt.MapClaims, error) {
|
||||
return nil, errors.New("empty id_token")
|
||||
}
|
||||
|
||||
// extract the token header params and claims
|
||||
// extract the token claims
|
||||
// ---
|
||||
claims := jwt.MapClaims{}
|
||||
t, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
|
||||
_, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -136,10 +138,9 @@ func (p *Apple) parseAndVerifyIdToken(idToken string) (jwt.MapClaims, error) {
|
||||
// the token which is a result of direct TLS communication with the provider
|
||||
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
|
||||
// ---
|
||||
kid, _ := t.Header["kid"].(string)
|
||||
err = validateIdTokenSignature(p.ctx, idToken, p.jwksURL, kid)
|
||||
err = jwk.ValidateTokenSignature(p.ctx, idToken, p.jwksURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("id_token validation failed: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
|
||||
@@ -8,11 +8,14 @@ import (
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type JWK struct {
|
||||
@@ -94,7 +97,7 @@ func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) {
|
||||
// http.Client.Get doesn't treat non 2xx responses as error
|
||||
if res.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to JSON Web Key Set from %s (%d):\n%s",
|
||||
"failed to fetch JSON Web Key Set from %s (%d):\n%s",
|
||||
jwksURL,
|
||||
res.StatusCode,
|
||||
string(rawBody),
|
||||
@@ -116,5 +119,45 @@ func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("jwk with kid %q was not found", kid)
|
||||
return nil, fmt.Errorf("JWK with kid %q was not found", kid)
|
||||
}
|
||||
|
||||
// ValidateTokenSignature validates the signature of a token with the
|
||||
// public key retrievied from a remote JWKS.
|
||||
func ValidateTokenSignature(ctx context.Context, token string, jwksURL string) error {
|
||||
// extract the kid token header
|
||||
// ---
|
||||
t, _, err := jwt.NewParser().ParseUnverified(token, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
kid, _ := t.Header["kid"].(string)
|
||||
if kid == "" {
|
||||
return errors.New("missing kid header value")
|
||||
}
|
||||
|
||||
// fetch the public key set
|
||||
// ---
|
||||
key, err := Fetch(ctx, jwksURL, kid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// verify the signature
|
||||
// ---
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg}))
|
||||
|
||||
parsedToken, err := parser.Parse(token, func(t *jwt.Token) (any, error) {
|
||||
return key.PublicKey()
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !parsedToken.Valid {
|
||||
return errors.New("the parsed token is invalid")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/pocketbase/pocketbase/tools/auth/internal/jwk"
|
||||
)
|
||||
|
||||
@@ -27,7 +28,7 @@ func TestJWK_PublicKey(t *testing.T) {
|
||||
|
||||
rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate test RSA private key: %v", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
@@ -209,3 +210,105 @@ func TestFetch(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ed25519Public, ed25519Private, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nonmatchingKidToken := jwt.New(&jwt.SigningMethodEd25519{})
|
||||
nonmatchingKidToken.Header["kid"] = "missing"
|
||||
nonmatchingKidTokenStr, err := nonmatchingKidToken.SignedString(ed25519Private)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key1Token := jwt.New(&jwt.SigningMethodEd25519{})
|
||||
key1Token.Header["kid"] = "key1"
|
||||
key1TokenStr, err := key1Token.SignedString(ed25519Private)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key2Token := jwt.New(jwt.SigningMethodRS256)
|
||||
key2Token.Header["kid"] = "key2"
|
||||
key2TokenStr, err := key2Token.SignedString(rsaPrivate)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
||||
_ = json.NewEncoder(res).Encode(map[string]any{"keys": []*jwk.JWK{
|
||||
{
|
||||
Kid: "key1",
|
||||
Kty: "OKP",
|
||||
Alg: "EdDSA",
|
||||
Crv: "Ed25519",
|
||||
X: base64.RawURLEncoding.EncodeToString(ed25519Public),
|
||||
},
|
||||
{
|
||||
Kid: "key2",
|
||||
Kty: "RSA",
|
||||
Alg: "RS256",
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaPrivate.E)).Bytes()),
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPrivate.N.Bytes()),
|
||||
},
|
||||
}})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
token string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"empty token",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invlaid token",
|
||||
"abc",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"no matching kid",
|
||||
nonmatchingKidTokenStr,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"valid Ed25519 token",
|
||||
key1TokenStr,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"valid RSA token",
|
||||
key2TokenStr,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range scenarios {
|
||||
t.Run(s.name, func(t *testing.T) {
|
||||
err := jwk.ValidateTokenSignature(
|
||||
context.Background(),
|
||||
s.token,
|
||||
server.URL,
|
||||
)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func (p *OIDC) parseIdToken(token *oauth2.Token) (jwt.MapClaims, error) {
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{}
|
||||
t, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
|
||||
_, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -176,42 +176,11 @@ func (p *OIDC) parseIdToken(token *oauth2.Token) (jwt.MapClaims, error) {
|
||||
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
|
||||
jwksURL := cast.ToString(p.Extra()["jwksURL"])
|
||||
if jwksURL != "" {
|
||||
kid, _ := t.Header["kid"].(string)
|
||||
err = validateIdTokenSignature(p.ctx, idToken, jwksURL, kid)
|
||||
err = jwk.ValidateTokenSignature(p.ctx, idToken, jwksURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("id_token validation failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL string, kid string) error {
|
||||
// fetch the public key set
|
||||
// ---
|
||||
if kid == "" {
|
||||
return errors.New("missing kid header value")
|
||||
}
|
||||
|
||||
key, err := jwk.Fetch(ctx, jwksURL, kid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// verify the signiture
|
||||
// ---
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg}))
|
||||
|
||||
parsedToken, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) {
|
||||
return key.PublicKey()
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !parsedToken.Valid {
|
||||
return errors.New("the parsed id_token is invalid")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user