1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-11-23 22:55:37 +02:00

moved ValidateTokenSignature to jwk and added tests

This commit is contained in:
Gani Georgiev
2025-10-19 18:18:24 +03:00
parent 0b6157e1cc
commit 0bd712752f
4 changed files with 158 additions and 42 deletions

View File

@@ -4,8 +4,10 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/golang-jwt/jwt/v5"
"github.com/pocketbase/pocketbase/tools/auth/internal/jwk"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
"golang.org/x/oauth2"
@@ -108,10 +110,10 @@ func (p *Apple) parseAndVerifyIdToken(idToken string) (jwt.MapClaims, error) {
return nil, errors.New("empty id_token")
}
// extract the token header params and claims
// extract the token claims
// ---
claims := jwt.MapClaims{}
t, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
_, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
if err != nil {
return nil, err
}
@@ -136,10 +138,9 @@ func (p *Apple) parseAndVerifyIdToken(idToken string) (jwt.MapClaims, error) {
// the token which is a result of direct TLS communication with the provider
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
// ---
kid, _ := t.Header["kid"].(string)
err = validateIdTokenSignature(p.ctx, idToken, p.jwksURL, kid)
err = jwk.ValidateTokenSignature(p.ctx, idToken, p.jwksURL)
if err != nil {
return nil, err
return nil, fmt.Errorf("id_token validation failed: %w", err)
}
return claims, nil

View File

@@ -8,11 +8,14 @@ import (
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math/big"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v5"
)
type JWK struct {
@@ -94,7 +97,7 @@ func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) {
// http.Client.Get doesn't treat non 2xx responses as error
if res.StatusCode >= 400 {
return nil, fmt.Errorf(
"failed to JSON Web Key Set from %s (%d):\n%s",
"failed to fetch JSON Web Key Set from %s (%d):\n%s",
jwksURL,
res.StatusCode,
string(rawBody),
@@ -116,5 +119,45 @@ func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) {
}
}
return nil, fmt.Errorf("jwk with kid %q was not found", kid)
return nil, fmt.Errorf("JWK with kid %q was not found", kid)
}
// ValidateTokenSignature validates the signature of a token with the
// public key retrievied from a remote JWKS.
func ValidateTokenSignature(ctx context.Context, token string, jwksURL string) error {
// extract the kid token header
// ---
t, _, err := jwt.NewParser().ParseUnverified(token, jwt.MapClaims{})
if err != nil {
return err
}
kid, _ := t.Header["kid"].(string)
if kid == "" {
return errors.New("missing kid header value")
}
// fetch the public key set
// ---
key, err := Fetch(ctx, jwksURL, kid)
if err != nil {
return err
}
// verify the signature
// ---
parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg}))
parsedToken, err := parser.Parse(token, func(t *jwt.Token) (any, error) {
return key.PublicKey()
})
if err != nil {
return err
}
if !parsedToken.Valid {
return errors.New("the parsed token is invalid")
}
return nil
}

View File

@@ -15,6 +15,7 @@ import (
"strings"
"testing"
"github.com/golang-jwt/jwt/v5"
"github.com/pocketbase/pocketbase/tools/auth/internal/jwk"
)
@@ -27,7 +28,7 @@ func TestJWK_PublicKey(t *testing.T) {
rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatalf("failed to generate test RSA private key: %v", err)
t.Fatal(err)
}
scenarios := []struct {
@@ -209,3 +210,105 @@ func TestFetch(t *testing.T) {
})
}
}
func TestValidateTokenSignature(t *testing.T) {
t.Parallel()
rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatal(err)
}
ed25519Public, ed25519Private, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
nonmatchingKidToken := jwt.New(&jwt.SigningMethodEd25519{})
nonmatchingKidToken.Header["kid"] = "missing"
nonmatchingKidTokenStr, err := nonmatchingKidToken.SignedString(ed25519Private)
if err != nil {
t.Fatal(err)
}
key1Token := jwt.New(&jwt.SigningMethodEd25519{})
key1Token.Header["kid"] = "key1"
key1TokenStr, err := key1Token.SignedString(ed25519Private)
if err != nil {
t.Fatal(err)
}
key2Token := jwt.New(jwt.SigningMethodRS256)
key2Token.Header["kid"] = "key2"
key2TokenStr, err := key2Token.SignedString(rsaPrivate)
if err != nil {
t.Fatal(err)
}
server := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
_ = json.NewEncoder(res).Encode(map[string]any{"keys": []*jwk.JWK{
{
Kid: "key1",
Kty: "OKP",
Alg: "EdDSA",
Crv: "Ed25519",
X: base64.RawURLEncoding.EncodeToString(ed25519Public),
},
{
Kid: "key2",
Kty: "RSA",
Alg: "RS256",
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaPrivate.E)).Bytes()),
N: base64.RawURLEncoding.EncodeToString(rsaPrivate.N.Bytes()),
},
}})
}))
defer server.Close()
scenarios := []struct {
name string
token string
expectError bool
}{
{
"empty token",
"",
true,
},
{
"invlaid token",
"abc",
true,
},
{
"no matching kid",
nonmatchingKidTokenStr,
true,
},
{
"valid Ed25519 token",
key1TokenStr,
false,
},
{
"valid RSA token",
key2TokenStr,
false,
},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
err := jwk.ValidateTokenSignature(
context.Background(),
s.token,
server.URL,
)
hasErr := err != nil
if hasErr != s.expectError {
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
}
})
}
}

View File

@@ -135,7 +135,7 @@ func (p *OIDC) parseIdToken(token *oauth2.Token) (jwt.MapClaims, error) {
}
claims := jwt.MapClaims{}
t, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
_, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
if err != nil {
return nil, err
}
@@ -176,42 +176,11 @@ func (p *OIDC) parseIdToken(token *oauth2.Token) (jwt.MapClaims, error) {
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
jwksURL := cast.ToString(p.Extra()["jwksURL"])
if jwksURL != "" {
kid, _ := t.Header["kid"].(string)
err = validateIdTokenSignature(p.ctx, idToken, jwksURL, kid)
err = jwk.ValidateTokenSignature(p.ctx, idToken, jwksURL)
if err != nil {
return nil, err
return nil, fmt.Errorf("id_token validation failed: %w", err)
}
}
return claims, nil
}
func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL string, kid string) error {
// fetch the public key set
// ---
if kid == "" {
return errors.New("missing kid header value")
}
key, err := jwk.Fetch(ctx, jwksURL, kid)
if err != nil {
return err
}
// verify the signiture
// ---
parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg}))
parsedToken, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) {
return key.PublicKey()
})
if err != nil {
return err
}
if !parsedToken.Valid {
return errors.New("the parsed id_token is invalid")
}
return nil
}