From 7d1819e5b1bce5a4e73c68d339ef09c9826a5e85 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Fri, 13 May 2016 08:18:00 -0700 Subject: [PATCH] Extractor for csrf token Signed-off-by: Vishal Rana --- middleware/csrf.go | 66 ++++++++++++++++++++++++++++++++++++++++------ middleware/jwt.go | 9 +++++-- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/middleware/csrf.go b/middleware/csrf.go index 70e20528..5d1a96f9 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/sha1" "encoding/hex" + "errors" "fmt" "net/http" "strings" @@ -19,8 +20,13 @@ type ( // Key to create CSRF token. Secret []byte - // Name of the request header to extract CSRF token. - HeaderName string + // Context key to store generated CSRF token into context. + // Optional. Default value "csrf". + ContextKey string + + // Extractor is a function that extracts token from the request. + // Optional. Default value CSRFTokenFromHeader(echo.HeaderXCSRFToken). + Extractor CSRFTokenExtractor // Name of the CSRF cookie. This cookie will store CSRF token. // Optional. Default value "csrf". @@ -46,12 +52,17 @@ type ( // Optional. Default value false. CookieHTTPOnly bool } + + // CSRFTokenExtractor defines a function that takes `echo.Context` and returns + // either a token or an error. + CSRFTokenExtractor func(echo.Context) (string, error) ) var ( // DefaultCSRFConfig is the default CSRF middleware config. DefaultCSRFConfig = CSRFConfig{ - HeaderName: echo.HeaderXCSRFToken, + ContextKey: "csrf", + Extractor: CSRFTokenFromHeader(echo.HeaderXCSRFToken), CookieName: "csrf", CookieExpires: time.Now().Add(24 * time.Hour), } @@ -70,10 +81,13 @@ func CSRF(secret []byte) echo.MiddlewareFunc { func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { // Defaults if config.Secret == nil { - panic("csrf: secret must be provided") + panic("csrf secret must be provided") } - if config.HeaderName == "" { - config.HeaderName = DefaultCSRFConfig.HeaderName + if config.ContextKey == "" { + config.ContextKey = DefaultCSRFConfig.ContextKey + } + if config.Extractor == nil { + config.Extractor = DefaultCSRFConfig.Extractor } if config.CookieName == "" { config.CookieName = DefaultCSRFConfig.CookieName @@ -92,6 +106,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return err } token := generateCSRFToken(config.Secret, salt) + c.Set(config.ContextKey, token) cookie := new(echo.Cookie) cookie.SetName(config.CookieName) cookie.SetValue(token) @@ -109,13 +124,16 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { switch req.Method() { case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE: default: - token := req.Header().Get(config.HeaderName) + token, err := config.Extractor(c) + if err != nil { + return err + } ok, err := validateCSRFToken(token, config.Secret) if err != nil { return err } if !ok { - return echo.NewHTTPError(http.StatusForbidden, "csrf: invalid token") + return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") } } return next(c) @@ -123,6 +141,38 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { } } +// CSRFTokenFromHeader returns a `CSRFTokenExtractor` that extracts token from the +// provided request header. +func CSRFTokenFromHeader(header string) CSRFTokenExtractor { + return func(c echo.Context) (string, error) { + return c.Request().Header().Get(header), nil + } +} + +// CSRFTokenFromForm returns a `CSRFTokenExtractor` that extracts token from the +// provided form parameter. +func CSRFTokenFromForm(param string) CSRFTokenExtractor { + return func(c echo.Context) (string, error) { + token := c.FormValue(param) + if token == "" { + return "", errors.New("empty csrf token in form param") + } + return token, nil + } +} + +// CSRFTokenFromQuery returns a `CSRFTokenExtractor` that extracts token from the +// provided query parameter. +func CSRFTokenFromQuery(param string) CSRFTokenExtractor { + return func(c echo.Context) (string, error) { + token := c.QueryParam(param) + if token == "" { + return "", errors.New("empty csrf token in query param") + } + return token, nil + } +} + func generateCSRFToken(secret, salt []byte) string { h := hmac.New(sha1.New, secret) h.Write(salt) diff --git a/middleware/jwt.go b/middleware/jwt.go index 935d660a..feae2bf8 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "fmt" "net/http" @@ -113,13 +114,17 @@ func JWTFromHeader(c echo.Context) (string, error) { if len(auth) > l+1 && auth[:l] == bearer { return auth[l+1:], nil } - return "", echo.NewHTTPError(http.StatusBadRequest, "empty or invalid authorization header="+auth) + return "", errors.New("empty or invalid jwt in authorization header") } // JWTFromQuery returns a `JWTExtractor` that extracts token from the provided query // parameter. func JWTFromQuery(param string) JWTExtractor { return func(c echo.Context) (string, error) { - return c.QueryParam(param), nil + token := c.QueryParam(param) + if token == "" { + return "", errors.New("empty jwt in query param") + } + return token, nil } }