diff --git a/tools/auth/apple.go b/tools/auth/apple.go index 9783c2f2..8ee92763 100644 --- a/tools/auth/apple.go +++ b/tools/auth/apple.go @@ -4,8 +4,10 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/golang-jwt/jwt/v5" + "github.com/pocketbase/pocketbase/tools/auth/internal/jwk" "github.com/pocketbase/pocketbase/tools/types" "github.com/spf13/cast" "golang.org/x/oauth2" @@ -108,10 +110,10 @@ func (p *Apple) parseAndVerifyIdToken(idToken string) (jwt.MapClaims, error) { return nil, errors.New("empty id_token") } - // extract the token header params and claims + // extract the token claims // --- claims := jwt.MapClaims{} - t, _, err := jwt.NewParser().ParseUnverified(idToken, claims) + _, _, err := jwt.NewParser().ParseUnverified(idToken, claims) if err != nil { return nil, err } @@ -136,10 +138,9 @@ func (p *Apple) parseAndVerifyIdToken(idToken string) (jwt.MapClaims, error) { // the token which is a result of direct TLS communication with the provider // (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation) // --- - kid, _ := t.Header["kid"].(string) - err = validateIdTokenSignature(p.ctx, idToken, p.jwksURL, kid) + err = jwk.ValidateTokenSignature(p.ctx, idToken, p.jwksURL) if err != nil { - return nil, err + return nil, fmt.Errorf("id_token validation failed: %w", err) } return claims, nil diff --git a/tools/auth/internal/jwk/jwk.go b/tools/auth/internal/jwk/jwk.go index b81bf129..c2bbfe8c 100644 --- a/tools/auth/internal/jwk/jwk.go +++ b/tools/auth/internal/jwk/jwk.go @@ -8,11 +8,14 @@ import ( "crypto/rsa" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "math/big" "net/http" "strings" + + "github.com/golang-jwt/jwt/v5" ) type JWK struct { @@ -94,7 +97,7 @@ func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) { // http.Client.Get doesn't treat non 2xx responses as error if res.StatusCode >= 400 { return nil, fmt.Errorf( - "failed to JSON Web Key Set from %s (%d):\n%s", + "failed to fetch JSON Web Key Set from %s (%d):\n%s", jwksURL, res.StatusCode, string(rawBody), @@ -116,5 +119,45 @@ func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) { } } - return nil, fmt.Errorf("jwk with kid %q was not found", kid) + return nil, fmt.Errorf("JWK with kid %q was not found", kid) +} + +// ValidateTokenSignature validates the signature of a token with the +// public key retrievied from a remote JWKS. +func ValidateTokenSignature(ctx context.Context, token string, jwksURL string) error { + // extract the kid token header + // --- + t, _, err := jwt.NewParser().ParseUnverified(token, jwt.MapClaims{}) + if err != nil { + return err + } + + kid, _ := t.Header["kid"].(string) + if kid == "" { + return errors.New("missing kid header value") + } + + // fetch the public key set + // --- + key, err := Fetch(ctx, jwksURL, kid) + if err != nil { + return err + } + + // verify the signature + // --- + parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg})) + + parsedToken, err := parser.Parse(token, func(t *jwt.Token) (any, error) { + return key.PublicKey() + }) + if err != nil { + return err + } + + if !parsedToken.Valid { + return errors.New("the parsed token is invalid") + } + + return nil } diff --git a/tools/auth/internal/jwk/jwk_test.go b/tools/auth/internal/jwk/jwk_test.go index 07e780f9..df4a48ed 100644 --- a/tools/auth/internal/jwk/jwk_test.go +++ b/tools/auth/internal/jwk/jwk_test.go @@ -15,6 +15,7 @@ import ( "strings" "testing" + "github.com/golang-jwt/jwt/v5" "github.com/pocketbase/pocketbase/tools/auth/internal/jwk" ) @@ -27,7 +28,7 @@ func TestJWK_PublicKey(t *testing.T) { rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { - t.Fatalf("failed to generate test RSA private key: %v", err) + t.Fatal(err) } scenarios := []struct { @@ -209,3 +210,105 @@ func TestFetch(t *testing.T) { }) } } + +func TestValidateTokenSignature(t *testing.T) { + t.Parallel() + + rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + + ed25519Public, ed25519Private, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + nonmatchingKidToken := jwt.New(&jwt.SigningMethodEd25519{}) + nonmatchingKidToken.Header["kid"] = "missing" + nonmatchingKidTokenStr, err := nonmatchingKidToken.SignedString(ed25519Private) + if err != nil { + t.Fatal(err) + } + + key1Token := jwt.New(&jwt.SigningMethodEd25519{}) + key1Token.Header["kid"] = "key1" + key1TokenStr, err := key1Token.SignedString(ed25519Private) + if err != nil { + t.Fatal(err) + } + + key2Token := jwt.New(jwt.SigningMethodRS256) + key2Token.Header["kid"] = "key2" + key2TokenStr, err := key2Token.SignedString(rsaPrivate) + if err != nil { + t.Fatal(err) + } + + server := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + _ = json.NewEncoder(res).Encode(map[string]any{"keys": []*jwk.JWK{ + { + Kid: "key1", + Kty: "OKP", + Alg: "EdDSA", + Crv: "Ed25519", + X: base64.RawURLEncoding.EncodeToString(ed25519Public), + }, + { + Kid: "key2", + Kty: "RSA", + Alg: "RS256", + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaPrivate.E)).Bytes()), + N: base64.RawURLEncoding.EncodeToString(rsaPrivate.N.Bytes()), + }, + }}) + })) + defer server.Close() + + scenarios := []struct { + name string + token string + expectError bool + }{ + { + "empty token", + "", + true, + }, + { + "invlaid token", + "abc", + true, + }, + { + "no matching kid", + nonmatchingKidTokenStr, + true, + }, + { + "valid Ed25519 token", + key1TokenStr, + false, + }, + { + "valid RSA token", + key2TokenStr, + false, + }, + } + + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + err := jwk.ValidateTokenSignature( + context.Background(), + s.token, + server.URL, + ) + + hasErr := err != nil + if hasErr != s.expectError { + t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err) + } + }) + } +} diff --git a/tools/auth/oidc.go b/tools/auth/oidc.go index 43e093e3..ddfdffbb 100644 --- a/tools/auth/oidc.go +++ b/tools/auth/oidc.go @@ -135,7 +135,7 @@ func (p *OIDC) parseIdToken(token *oauth2.Token) (jwt.MapClaims, error) { } claims := jwt.MapClaims{} - t, _, err := jwt.NewParser().ParseUnverified(idToken, claims) + _, _, err := jwt.NewParser().ParseUnverified(idToken, claims) if err != nil { return nil, err } @@ -176,42 +176,11 @@ func (p *OIDC) parseIdToken(token *oauth2.Token) (jwt.MapClaims, error) { // (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation) jwksURL := cast.ToString(p.Extra()["jwksURL"]) if jwksURL != "" { - kid, _ := t.Header["kid"].(string) - err = validateIdTokenSignature(p.ctx, idToken, jwksURL, kid) + err = jwk.ValidateTokenSignature(p.ctx, idToken, jwksURL) if err != nil { - return nil, err + return nil, fmt.Errorf("id_token validation failed: %w", err) } } return claims, nil } - -func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL string, kid string) error { - // fetch the public key set - // --- - if kid == "" { - return errors.New("missing kid header value") - } - - key, err := jwk.Fetch(ctx, jwksURL, kid) - if err != nil { - return err - } - - // verify the signiture - // --- - parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg})) - - parsedToken, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) { - return key.PublicKey() - }) - if err != nil { - return err - } - - if !parsedToken.Valid { - return errors.New("the parsed id_token is invalid") - } - - return nil -}