You've already forked pocketbase
mirror of
https://github.com/pocketbase/pocketbase.git
synced 2025-11-23 22:55:37 +02:00
moved ValidateTokenSignature to jwk and added tests
This commit is contained in:
@@ -4,8 +4,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/pocketbase/pocketbase/tools/auth/internal/jwk"
|
||||||
"github.com/pocketbase/pocketbase/tools/types"
|
"github.com/pocketbase/pocketbase/tools/types"
|
||||||
"github.com/spf13/cast"
|
"github.com/spf13/cast"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@@ -108,10 +110,10 @@ func (p *Apple) parseAndVerifyIdToken(idToken string) (jwt.MapClaims, error) {
|
|||||||
return nil, errors.New("empty id_token")
|
return nil, errors.New("empty id_token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract the token header params and claims
|
// extract the token claims
|
||||||
// ---
|
// ---
|
||||||
claims := jwt.MapClaims{}
|
claims := jwt.MapClaims{}
|
||||||
t, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
|
_, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -136,10 +138,9 @@ func (p *Apple) parseAndVerifyIdToken(idToken string) (jwt.MapClaims, error) {
|
|||||||
// the token which is a result of direct TLS communication with the provider
|
// the token which is a result of direct TLS communication with the provider
|
||||||
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
|
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
|
||||||
// ---
|
// ---
|
||||||
kid, _ := t.Header["kid"].(string)
|
err = jwk.ValidateTokenSignature(p.ctx, idToken, p.jwksURL)
|
||||||
err = validateIdTokenSignature(p.ctx, idToken, p.jwksURL, kid)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("id_token validation failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return claims, nil
|
return claims, nil
|
||||||
|
|||||||
@@ -8,11 +8,14 @@ import (
|
|||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JWK struct {
|
type JWK struct {
|
||||||
@@ -94,7 +97,7 @@ func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) {
|
|||||||
// http.Client.Get doesn't treat non 2xx responses as error
|
// http.Client.Get doesn't treat non 2xx responses as error
|
||||||
if res.StatusCode >= 400 {
|
if res.StatusCode >= 400 {
|
||||||
return nil, fmt.Errorf(
|
return nil, fmt.Errorf(
|
||||||
"failed to JSON Web Key Set from %s (%d):\n%s",
|
"failed to fetch JSON Web Key Set from %s (%d):\n%s",
|
||||||
jwksURL,
|
jwksURL,
|
||||||
res.StatusCode,
|
res.StatusCode,
|
||||||
string(rawBody),
|
string(rawBody),
|
||||||
@@ -116,5 +119,45 @@ func Fetch(ctx context.Context, jwksURL string, kid string) (*JWK, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("jwk with kid %q was not found", kid)
|
return nil, fmt.Errorf("JWK with kid %q was not found", kid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateTokenSignature validates the signature of a token with the
|
||||||
|
// public key retrievied from a remote JWKS.
|
||||||
|
func ValidateTokenSignature(ctx context.Context, token string, jwksURL string) error {
|
||||||
|
// extract the kid token header
|
||||||
|
// ---
|
||||||
|
t, _, err := jwt.NewParser().ParseUnverified(token, jwt.MapClaims{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
kid, _ := t.Header["kid"].(string)
|
||||||
|
if kid == "" {
|
||||||
|
return errors.New("missing kid header value")
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetch the public key set
|
||||||
|
// ---
|
||||||
|
key, err := Fetch(ctx, jwksURL, kid)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the signature
|
||||||
|
// ---
|
||||||
|
parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg}))
|
||||||
|
|
||||||
|
parsedToken, err := parser.Parse(token, func(t *jwt.Token) (any, error) {
|
||||||
|
return key.PublicKey()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !parsedToken.Valid {
|
||||||
|
return errors.New("the parsed token is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/pocketbase/pocketbase/tools/auth/internal/jwk"
|
"github.com/pocketbase/pocketbase/tools/auth/internal/jwk"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,7 +28,7 @@ func TestJWK_PublicKey(t *testing.T) {
|
|||||||
|
|
||||||
rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024)
|
rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to generate test RSA private key: %v", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
scenarios := []struct {
|
scenarios := []struct {
|
||||||
@@ -209,3 +210,105 @@ func TestFetch(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateTokenSignature(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
rsaPrivate, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ed25519Public, ed25519Private, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonmatchingKidToken := jwt.New(&jwt.SigningMethodEd25519{})
|
||||||
|
nonmatchingKidToken.Header["kid"] = "missing"
|
||||||
|
nonmatchingKidTokenStr, err := nonmatchingKidToken.SignedString(ed25519Private)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key1Token := jwt.New(&jwt.SigningMethodEd25519{})
|
||||||
|
key1Token.Header["kid"] = "key1"
|
||||||
|
key1TokenStr, err := key1Token.SignedString(ed25519Private)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key2Token := jwt.New(jwt.SigningMethodRS256)
|
||||||
|
key2Token.Header["kid"] = "key2"
|
||||||
|
key2TokenStr, err := key2Token.SignedString(rsaPrivate)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
|
||||||
|
_ = json.NewEncoder(res).Encode(map[string]any{"keys": []*jwk.JWK{
|
||||||
|
{
|
||||||
|
Kid: "key1",
|
||||||
|
Kty: "OKP",
|
||||||
|
Alg: "EdDSA",
|
||||||
|
Crv: "Ed25519",
|
||||||
|
X: base64.RawURLEncoding.EncodeToString(ed25519Public),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Kid: "key2",
|
||||||
|
Kty: "RSA",
|
||||||
|
Alg: "RS256",
|
||||||
|
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaPrivate.E)).Bytes()),
|
||||||
|
N: base64.RawURLEncoding.EncodeToString(rsaPrivate.N.Bytes()),
|
||||||
|
},
|
||||||
|
}})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"empty token",
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invlaid token",
|
||||||
|
"abc",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"no matching kid",
|
||||||
|
nonmatchingKidTokenStr,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"valid Ed25519 token",
|
||||||
|
key1TokenStr,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"valid RSA token",
|
||||||
|
key2TokenStr,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range scenarios {
|
||||||
|
t.Run(s.name, func(t *testing.T) {
|
||||||
|
err := jwk.ValidateTokenSignature(
|
||||||
|
context.Background(),
|
||||||
|
s.token,
|
||||||
|
server.URL,
|
||||||
|
)
|
||||||
|
|
||||||
|
hasErr := err != nil
|
||||||
|
if hasErr != s.expectError {
|
||||||
|
t.Fatalf("Expected hasErr %v, got %v (%v)", s.expectError, hasErr, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func (p *OIDC) parseIdToken(token *oauth2.Token) (jwt.MapClaims, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := jwt.MapClaims{}
|
claims := jwt.MapClaims{}
|
||||||
t, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
|
_, _, err := jwt.NewParser().ParseUnverified(idToken, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -176,42 +176,11 @@ func (p *OIDC) parseIdToken(token *oauth2.Token) (jwt.MapClaims, error) {
|
|||||||
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
|
// (see also https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
|
||||||
jwksURL := cast.ToString(p.Extra()["jwksURL"])
|
jwksURL := cast.ToString(p.Extra()["jwksURL"])
|
||||||
if jwksURL != "" {
|
if jwksURL != "" {
|
||||||
kid, _ := t.Header["kid"].(string)
|
err = jwk.ValidateTokenSignature(p.ctx, idToken, jwksURL)
|
||||||
err = validateIdTokenSignature(p.ctx, idToken, jwksURL, kid)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("id_token validation failed: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateIdTokenSignature(ctx context.Context, idToken string, jwksURL string, kid string) error {
|
|
||||||
// fetch the public key set
|
|
||||||
// ---
|
|
||||||
if kid == "" {
|
|
||||||
return errors.New("missing kid header value")
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := jwk.Fetch(ctx, jwksURL, kid)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// verify the signiture
|
|
||||||
// ---
|
|
||||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{key.Alg}))
|
|
||||||
|
|
||||||
parsedToken, err := parser.Parse(idToken, func(t *jwt.Token) (any, error) {
|
|
||||||
return key.PublicKey()
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !parsedToken.Valid {
|
|
||||||
return errors.New("the parsed id_token is invalid")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user