package middleware import ( "errors" "fmt" "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/golang-jwt/jwt/v4" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) { // This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running keyFunc := func(t *jwt.Token) (interface{}, error) { if t.Method.Alg() != signingMethod { return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) } return signingKey, nil } return func(c echo.Context, auth string) (interface{}, error) { token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) if err != nil { return nil, err } if !token.Valid { return nil, errors.New("invalid token") } return token, nil } } // jwtCustomInfo defines some custom types we're going to use within our tokens. type jwtCustomInfo struct { Name string `json:"name"` Admin bool `json:"admin"` } // jwtCustomClaims are custom claims expanding default ones. type jwtCustomClaims struct { *jwt.StandardClaims jwtCustomInfo } func TestJWT_combinations(t *testing.T) { e := echo.New() handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" validKey := []byte("secret") invalidKey := []byte("invalid-key") validAuth := "Bearer " + token var testCases = []struct { name string config JWTConfig reqURL string // "/" if empty hdrAuth string hdrCookie string // test.Request doesn't provide SetCookie(); use name=val formValues map[string]string expPanic bool expErrCode int // 0 for Success }{ { expPanic: true, name: "No signing key provided", }, { expErrCode: http.StatusUnauthorized, hdrAuth: validAuth, config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey), }, name: "Unexpected signing method", }, { expErrCode: http.StatusUnauthorized, hdrAuth: validAuth, config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, invalidKey), }, name: "Invalid key", }, { hdrAuth: validAuth, config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), }, name: "Valid JWT", }, { hdrAuth: "Token" + " " + token, config: JWTConfig{ TokenLookup: "header:" + echo.HeaderAuthorization + ":Token ", ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), }, name: "Valid JWT with custom AuthScheme", }, { hdrAuth: validAuth, config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), }, name: "Valid JWT with custom claims", }, { hdrAuth: "invalid-auth", expErrCode: http.StatusUnauthorized, config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), }, name: "Invalid Authorization header", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), }, expErrCode: http.StatusUnauthorized, name: "Empty header auth field", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwt=" + token, name: "Valid query method", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwtxyz=" + token, expErrCode: http.StatusUnauthorized, name: "Invalid query param name", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwt=invalid-token", expErrCode: http.StatusUnauthorized, name: "Invalid query param value", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "query:jwt", }, reqURL: "/?a=b", expErrCode: http.StatusUnauthorized, name: "Empty query", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "param:jwt", }, reqURL: "/" + token, name: "Valid param method", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "cookie:jwt", }, hdrCookie: "jwt=" + token, name: "Valid cookie method", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "query:jwt,cookie:jwt", }, hdrCookie: "jwt=" + token, name: "Multiple jwt lookuop", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "cookie:jwt", }, expErrCode: http.StatusUnauthorized, hdrCookie: "jwt=invalid", name: "Invalid token with cookie method", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "cookie:jwt", }, expErrCode: http.StatusUnauthorized, name: "Empty cookie", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "form:jwt", }, formValues: map[string]string{"jwt": token}, name: "Valid form method", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "form:jwt", }, expErrCode: http.StatusUnauthorized, formValues: map[string]string{"jwt": "invalid"}, name: "Invalid token with form method", }, { config: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), TokenLookup: "form:jwt", }, expErrCode: http.StatusUnauthorized, name: "Empty form field", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { if tc.reqURL == "" { tc.reqURL = "/" } var req *http.Request if len(tc.formValues) > 0 { form := url.Values{} for k, v := range tc.formValues { form.Set(k, v) } req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") req.ParseForm() } else { req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) } res := httptest.NewRecorder() req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) req.Header.Set(echo.HeaderCookie, tc.hdrCookie) c := e.NewContext(req, res) if tc.reqURL == "/"+token { cc := c.(echo.EditableContext) cc.SetPathParams(echo.PathParams{ {Name: "jwt", Value: token}, }) } if tc.expPanic { assert.Panics(t, func() { JWTWithConfig(tc.config) }, tc.name) return } if tc.expErrCode != 0 { h := JWTWithConfig(tc.config)(handler) he := h(c).(*echo.HTTPError) assert.Equal(t, tc.expErrCode, he.Code) return } h := JWTWithConfig(tc.config)(handler) if assert.NoError(t, h(c), tc.name) { user := c.Get("user").(*jwt.Token) switch claims := user.Claims.(type) { case jwt.MapClaims: assert.Equal(t, claims["name"], "John Doe") case *jwtCustomClaims: assert.Equal(t, claims.Name, "John Doe") assert.Equal(t, claims.Admin, true) default: panic("unexpected type of claims") } } }) } } func TestJWTConfig_skipper(t *testing.T) { e := echo.New() e.Use(JWTWithConfig(JWTConfig{ Skipper: func(context echo.Context) bool { return true // skip everything }, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), })) isCalled := false e.GET("/", func(c echo.Context) error { isCalled = true return c.String(http.StatusTeapot, "test") }) req := httptest.NewRequest(http.MethodGet, "/", nil) res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, http.StatusTeapot, res.Code) assert.True(t, isCalled) } func TestJWTConfig_BeforeFunc(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { return c.String(http.StatusTeapot, "test") }) isCalled := false e.Use(JWTWithConfig(JWTConfig{ BeforeFunc: func(context echo.Context) { isCalled = true }, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, http.StatusTeapot, res.Code) assert.True(t, isCalled) } func TestJWTConfig_extractorErrorHandling(t *testing.T) { var testCases = []struct { name string given JWTConfig expectStatusCode int }{ { name: "ok, ErrorHandler is executed", given: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), ErrorHandler: func(c echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "custom_error") }, }, expectStatusCode: http.StatusTeapot, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { return c.String(http.StatusNotImplemented, "should not end up here") }) e.Use(JWTWithConfig(tc.given)) req := httptest.NewRequest(http.MethodGet, "/", nil) res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, tc.expectStatusCode, res.Code) }) } } func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { var testCases = []struct { name string given JWTConfig expectErr string }{ { name: "ok, ErrorHandler is executed", given: JWTConfig{ ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), ErrorHandler: func(c echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) }, }, expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() //e.Debug = true e.GET("/", func(c echo.Context) error { return c.String(http.StatusNotImplemented, "should not end up here") }) config := tc.given parseTokenCalled := false config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) { parseTokenCalled = true return nil, errors.New("parsing failed") } e.Use(JWTWithConfig(config)) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, http.StatusTeapot, res.Code) assert.Equal(t, tc.expectErr, res.Body.String()) assert.True(t, parseTokenCalled) }) } } func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { return c.String(http.StatusTeapot, "test") }) // example of minimal custom ParseTokenFunc implementation. Allows you to use different versions of `github.com/golang-jwt/jwt` // with current JWT middleware signingKey := []byte("secret") config := JWTConfig{ ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { keyFunc := func(t *jwt.Token) (interface{}, error) { if t.Method.Alg() != "HS256" { return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) } return signingKey, nil } // claims are of type `jwt.MapClaims` when token is created with `jwt.Parse` token, err := jwt.Parse(auth, keyFunc) if err != nil { return nil, err } if !token.Valid { return nil, errors.New("invalid token") } return token, nil }, } e.Use(JWTWithConfig(config)) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, http.StatusTeapot, res.Code) } func TestMustJWTWithConfig_SuccessHandler(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { success := c.Get("success").(string) user := c.Get("user").(string) return c.String(http.StatusTeapot, fmt.Sprintf("%v:%v", success, user)) }) mw, err := JWTConfig{ ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { return auth, nil }, SuccessHandler: func(c echo.Context) { c.Set("success", "yes") }, }.ToMiddleware() assert.NoError(t, err) e.Use(mw) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Add(echo.HeaderAuthorization, "Bearer valid_token_base64") res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, "yes:valid_token_base64", res.Body.String()) assert.Equal(t, http.StatusTeapot, res.Code) } func TestJWTWithConfig_CallNextOnNilErrorHandlerResult(t *testing.T) { var testCases = []struct { name string givenCallNext bool givenErrorHandler JWTErrorHandlerWithContext givenTokenLookup string whenAuthHeaders []string whenCookies []string whenParseReturn string whenParseError error expectHandlerCalled bool expect string expectCode int }{ { name: "ok, with valid JWT from auth header", givenCallNext: true, givenErrorHandler: func(c echo.Context, err error) error { return nil }, whenAuthHeaders: []string{"Bearer valid_token_base64"}, whenParseReturn: "valid_token", expectCode: http.StatusTeapot, expect: "valid_token", }, { name: "ok, missing header, callNext and set public_token from error handler", givenCallNext: true, givenErrorHandler: func(c echo.Context, err error) error { if err != ErrJWTMissing { panic("must get ErrJWTMissing") } c.Set("user", "public_token") return nil }, whenAuthHeaders: []string{}, // no JWT header expectCode: http.StatusTeapot, expect: "public_token", }, { name: "ok, invalid token, callNext and set public_token from error handler", givenCallNext: true, givenErrorHandler: func(c echo.Context, err error) error { // this is probably not realistic usecase. on parse error you probably want to return error if err.Error() != "parser_error" { panic("must get parser_error") } c.Set("user", "public_token") return nil }, whenAuthHeaders: []string{"Bearer invalid_header"}, whenParseError: errors.New("parser_error"), expectCode: http.StatusTeapot, expect: "public_token", }, { name: "nok, invalid token, return error from error handler", givenCallNext: true, givenErrorHandler: func(c echo.Context, err error) error { if err.Error() != "parser_error" { panic("must get parser_error") } return err }, whenAuthHeaders: []string{"Bearer invalid_header"}, whenParseError: errors.New("parser_error"), expectCode: http.StatusInternalServerError, expect: "{\"message\":\"Internal Server Error\"}\n", }, { name: "nok, callNext but return error from error handler", givenCallNext: true, givenErrorHandler: func(c echo.Context, err error) error { return err }, whenAuthHeaders: []string{}, // no JWT header expectCode: http.StatusUnauthorized, expect: "{\"message\":\"missing or malformed jwt\"}\n", }, { name: "nok, callNext=false", givenCallNext: false, givenErrorHandler: func(c echo.Context, err error) error { return err }, whenAuthHeaders: []string{}, // no JWT header expectCode: http.StatusUnauthorized, expect: "{\"message\":\"missing or malformed jwt\"}\n", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { token := c.Get("user").(string) return c.String(http.StatusTeapot, token) }) mw, err := JWTConfig{ TokenLookup: tc.givenTokenLookup, ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { return tc.whenParseReturn, tc.whenParseError }, ErrorHandler: tc.givenErrorHandler, }.ToMiddleware() assert.NoError(t, err) e.Use(mw) req := httptest.NewRequest(http.MethodGet, "/", nil) for _, a := range tc.whenAuthHeaders { req.Header.Add(echo.HeaderAuthorization, a) } res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, tc.expect, res.Body.String()) assert.Equal(t, tc.expectCode, res.Code) }) } }