1
0
mirror of https://github.com/labstack/echo.git synced 2025-07-13 01:30:31 +02:00

Changes to jwt and csrf middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana
2016-05-26 14:06:30 -07:00
parent 7a66f226f2
commit 7e52ad4dd5
3 changed files with 68 additions and 40 deletions

View File

@ -20,14 +20,19 @@ type (
// Key to create CSRF token. // Key to create CSRF token.
Secret []byte `json:"secret"` Secret []byte `json:"secret"`
// Lookup is a string in the form of "<source>:<key>" that is used to extract
// token from the request.
// Optional. Default value "header:X-CSRF-Token".
// Possible values:
// - "header:<name>"
// - "form:<name>"
// - "header:<name>"
Lookup string `json:"lookup"`
// Context key to store generated CSRF token into context. // Context key to store generated CSRF token into context.
// Optional. Default value "csrf". // Optional. Default value "csrf".
ContextKey string `json:"context_key"` ContextKey string `json:"context_key"`
// Extractor is a function that extracts token from the request.
// Optional. Default value CSRFTokenFromHeader(echo.HeaderXCSRFToken).
Extractor CSRFTokenExtractor
// Name of the CSRF cookie. This cookie will store CSRF token. // Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf". // Optional. Default value "csrf".
CookieName string `json:"cookie_name"` CookieName string `json:"cookie_name"`
@ -53,16 +58,16 @@ type (
CookieHTTPOnly bool `json:"cookie_http_only"` CookieHTTPOnly bool `json:"cookie_http_only"`
} }
// CSRFTokenExtractor defines a function that takes `echo.Context` and returns // csrfTokenExtractor defines a function that takes `echo.Context` and returns
// either a token or an error. // either a token or an error.
CSRFTokenExtractor func(echo.Context) (string, error) csrfTokenExtractor func(echo.Context) (string, error)
) )
var ( var (
// DefaultCSRFConfig is the default CSRF middleware config. // DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{ DefaultCSRFConfig = CSRFConfig{
Lookup: "header:" + echo.HeaderXCSRFToken,
ContextKey: "csrf", ContextKey: "csrf",
Extractor: CSRFTokenFromHeader(echo.HeaderXCSRFToken),
CookieName: "csrf", CookieName: "csrf",
CookieExpires: time.Now().Add(24 * time.Hour), CookieExpires: time.Now().Add(24 * time.Hour),
} }
@ -83,12 +88,12 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.Secret == nil { if config.Secret == nil {
panic("csrf secret must be provided") panic("csrf secret must be provided")
} }
if config.Lookup == "" {
config.Lookup = DefaultCSRFConfig.Lookup
}
if config.ContextKey == "" { if config.ContextKey == "" {
config.ContextKey = DefaultCSRFConfig.ContextKey config.ContextKey = DefaultCSRFConfig.ContextKey
} }
if config.Extractor == nil {
config.Extractor = DefaultCSRFConfig.Extractor
}
if config.CookieName == "" { if config.CookieName == "" {
config.CookieName = DefaultCSRFConfig.CookieName config.CookieName = DefaultCSRFConfig.CookieName
} }
@ -96,6 +101,16 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
config.CookieExpires = DefaultCSRFConfig.CookieExpires config.CookieExpires = DefaultCSRFConfig.CookieExpires
} }
// Initialize
parts := strings.Split(config.Lookup, ":")
extractor := csrfTokenFromHeader(parts[1])
switch parts[0] {
case "form":
extractor = csrfTokenFromForm(parts[1])
case "query":
extractor = csrfTokenFromQuery(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
req := c.Request() req := c.Request()
@ -124,7 +139,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
switch req.Method() { switch req.Method() {
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE: case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
default: default:
token, err := config.Extractor(c) token, err := extractor(c)
if err != nil { if err != nil {
return err return err
} }
@ -141,17 +156,17 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
} }
} }
// CSRFTokenFromHeader returns a `CSRFTokenExtractor` that extracts token from the // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided request header. // provided request header.
func CSRFTokenFromHeader(header string) CSRFTokenExtractor { func csrfTokenFromHeader(header string) csrfTokenExtractor {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
return c.Request().Header().Get(header), nil return c.Request().Header().Get(header), nil
} }
} }
// CSRFTokenFromForm returns a `CSRFTokenExtractor` that extracts token from the // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided form parameter. // provided form parameter.
func CSRFTokenFromForm(param string) CSRFTokenExtractor { func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
token := c.FormValue(param) token := c.FormValue(param)
if token == "" { if token == "" {
@ -161,9 +176,9 @@ func CSRFTokenFromForm(param string) CSRFTokenExtractor {
} }
} }
// CSRFTokenFromQuery returns a `CSRFTokenExtractor` that extracts token from the // csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
// provided query parameter. // provided query parameter.
func CSRFTokenFromQuery(param string) CSRFTokenExtractor { func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
token := c.QueryParam(param) token := c.QueryParam(param)
if token == "" { if token == "" {

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/labstack/echo" "github.com/labstack/echo"
@ -24,14 +25,16 @@ type (
// Optional. Default value "user". // Optional. Default value "user".
ContextKey string `json:"context_key"` ContextKey string `json:"context_key"`
// Extractor is a function that extracts token from the request. // Lookup is a string in the form of "<source>:<key>" that is used to extract
// Optional. Default value JWTFromHeader. // token from the request.
Extractor JWTExtractor // Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "form:<name>"
Lookup string `json:"lookup"`
} }
// JWTExtractor defines a function that takes `echo.Context` and returns either jwtExtractor func(echo.Context) (string, error)
// a token or an error.
JWTExtractor func(echo.Context) (string, error)
) )
const ( const (
@ -48,7 +51,7 @@ var (
DefaultJWTConfig = JWTConfig{ DefaultJWTConfig = JWTConfig{
SigningMethod: AlgorithmHS256, SigningMethod: AlgorithmHS256,
ContextKey: "user", ContextKey: "user",
Extractor: JWTFromHeader, Lookup: "header:" + echo.HeaderAuthorization,
} }
) )
@ -78,13 +81,21 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.ContextKey == "" { if config.ContextKey == "" {
config.ContextKey = DefaultJWTConfig.ContextKey config.ContextKey = DefaultJWTConfig.ContextKey
} }
if config.Extractor == nil { if config.Lookup == "" {
config.Extractor = DefaultJWTConfig.Extractor config.Lookup = DefaultJWTConfig.Lookup
}
// Initialize
parts := strings.Split(config.Lookup, ":")
extractor := jwtFromHeader(parts[1])
switch parts[0] {
case "form":
extractor = jwtFromQuery(parts[1])
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
auth, err := config.Extractor(c) auth, err := extractor(c)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error()) return echo.NewHTTPError(http.StatusBadRequest, err.Error())
} }
@ -106,20 +117,22 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
} }
} }
// JWTFromHeader is a `JWTExtractor` that extracts token from the `Authorization` request // jwtFromHeader returns a `jwtExtractor` that extracts token from the provided
// header. // request header.
func JWTFromHeader(c echo.Context) (string, error) { func jwtFromHeader(header string) jwtExtractor {
auth := c.Request().Header().Get(echo.HeaderAuthorization) return func(c echo.Context) (string, error) {
auth := c.Request().Header().Get(header)
l := len(bearer) l := len(bearer)
if len(auth) > l+1 && auth[:l] == bearer { if len(auth) > l+1 && auth[:l] == bearer {
return auth[l+1:], nil return auth[l+1:], nil
} }
return "", errors.New("empty or invalid jwt in authorization header") return "", errors.New("empty or invalid jwt in authorization header")
} }
}
// JWTFromQuery returns a `JWTExtractor` that extracts token from the provided query // jwtFromQuery returns a `jwtExtractor` that extracts token from the provided query
// parameter. // parameter.
func JWTFromQuery(param string) JWTExtractor { func jwtFromQuery(param string) jwtExtractor {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
token := c.QueryParam(param) token := c.QueryParam(param)
if token == "" { if token == "" {

View File

@ -26,9 +26,9 @@ type (
// clickjacking. // clickjacking.
// Optional. Default value "SAMEORIGIN". // Optional. Default value "SAMEORIGIN".
// Possible values: // Possible values:
// `SAMEORIGIN` - The page can only be displayed in a frame on the same origin as the page itself. // - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself.
// `DENY` - The page cannot be displayed in a frame, regardless of the site attempting to do so. // - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so.
// `ALLOW-FROM uri` - The page can only be displayed in a frame on the specified origin. // - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin.
XFrameOptions string `json:"x_frame_options"` XFrameOptions string `json:"x_frame_options"`
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how // HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how