1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Changes to jwt and csrf middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-05-26 14:06:30 -07:00
parent 7a66f226f2
commit 7e52ad4dd5
3 changed files with 68 additions and 40 deletions

View File

@ -20,14 +20,19 @@ type (
// Key to create CSRF token.
Secret []byte `json:"secret"`
// Lookup is a string in the form of "<source>:<key>" that is used to extract
// token from the request.
// Optional. Default value "header:X-CSRF-Token".
// Possible values:
// - "header:<name>"
// - "form:<name>"
// - "header:<name>"
Lookup string `json:"lookup"`
// Context key to store generated CSRF token into context.
// Optional. Default value "csrf".
ContextKey string `json:"context_key"`
// Extractor is a function that extracts token from the request.
// Optional. Default value CSRFTokenFromHeader(echo.HeaderXCSRFToken).
Extractor CSRFTokenExtractor
// Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf".
CookieName string `json:"cookie_name"`
@ -53,16 +58,16 @@ type (
CookieHTTPOnly bool `json:"cookie_http_only"`
}
// CSRFTokenExtractor defines a function that takes `echo.Context` and returns
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
// either a token or an error.
CSRFTokenExtractor func(echo.Context) (string, error)
csrfTokenExtractor func(echo.Context) (string, error)
)
var (
// DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{
Lookup: "header:" + echo.HeaderXCSRFToken,
ContextKey: "csrf",
Extractor: CSRFTokenFromHeader(echo.HeaderXCSRFToken),
CookieName: "csrf",
CookieExpires: time.Now().Add(24 * time.Hour),
}
@ -83,12 +88,12 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.Secret == nil {
panic("csrf secret must be provided")
}
if config.Lookup == "" {
config.Lookup = DefaultCSRFConfig.Lookup
}
if config.ContextKey == "" {
config.ContextKey = DefaultCSRFConfig.ContextKey
}
if config.Extractor == nil {
config.Extractor = DefaultCSRFConfig.Extractor
}
if config.CookieName == "" {
config.CookieName = DefaultCSRFConfig.CookieName
}
@ -96,6 +101,16 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
config.CookieExpires = DefaultCSRFConfig.CookieExpires
}
// Initialize
parts := strings.Split(config.Lookup, ":")
extractor := csrfTokenFromHeader(parts[1])
switch parts[0] {
case "form":
extractor = csrfTokenFromForm(parts[1])
case "query":
extractor = csrfTokenFromQuery(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
@ -124,7 +139,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
switch req.Method() {
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
default:
token, err := config.Extractor(c)
token, err := extractor(c)
if err != nil {
return err
}
@ -141,17 +156,17 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
}
}
// CSRFTokenFromHeader returns a `CSRFTokenExtractor` that extracts token from the
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided request header.
func CSRFTokenFromHeader(header string) CSRFTokenExtractor {
func csrfTokenFromHeader(header string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
return c.Request().Header().Get(header), nil
}
}
// CSRFTokenFromForm returns a `CSRFTokenExtractor` that extracts token from the
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided form parameter.
func CSRFTokenFromForm(param string) CSRFTokenExtractor {
func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.FormValue(param)
if token == "" {
@ -161,9 +176,9 @@ func CSRFTokenFromForm(param string) CSRFTokenExtractor {
}
}
// CSRFTokenFromQuery returns a `CSRFTokenExtractor` that extracts token from the
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
// provided query parameter.
func CSRFTokenFromQuery(param string) CSRFTokenExtractor {
func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/dgrijalva/jwt-go"
"github.com/labstack/echo"
@ -24,14 +25,16 @@ type (
// Optional. Default value "user".
ContextKey string `json:"context_key"`
// Extractor is a function that extracts token from the request.
// Optional. Default value JWTFromHeader.
Extractor JWTExtractor
// Lookup is a string in the form of "<source>:<key>" that is used to extract
// token from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "form:<name>"
Lookup string `json:"lookup"`
}
// JWTExtractor defines a function that takes `echo.Context` and returns either
// a token or an error.
JWTExtractor func(echo.Context) (string, error)
jwtExtractor func(echo.Context) (string, error)
)
const (
@ -48,7 +51,7 @@ var (
DefaultJWTConfig = JWTConfig{
SigningMethod: AlgorithmHS256,
ContextKey: "user",
Extractor: JWTFromHeader,
Lookup: "header:" + echo.HeaderAuthorization,
}
)
@ -78,13 +81,21 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.ContextKey == "" {
config.ContextKey = DefaultJWTConfig.ContextKey
}
if config.Extractor == nil {
config.Extractor = DefaultJWTConfig.Extractor
if config.Lookup == "" {
config.Lookup = DefaultJWTConfig.Lookup
}
// Initialize
parts := strings.Split(config.Lookup, ":")
extractor := jwtFromHeader(parts[1])
switch parts[0] {
case "form":
extractor = jwtFromQuery(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
auth, err := config.Extractor(c)
auth, err := extractor(c)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
@ -106,20 +117,22 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
}
}
// JWTFromHeader is a `JWTExtractor` that extracts token from the `Authorization` request
// header.
func JWTFromHeader(c echo.Context) (string, error) {
auth := c.Request().Header().Get(echo.HeaderAuthorization)
l := len(bearer)
if len(auth) > l+1 && auth[:l] == bearer {
return auth[l+1:], nil
// jwtFromHeader returns a `jwtExtractor` that extracts token from the provided
// request header.
func jwtFromHeader(header string) jwtExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header().Get(header)
l := len(bearer)
if len(auth) > l+1 && auth[:l] == bearer {
return auth[l+1:], nil
}
return "", errors.New("empty or invalid jwt in authorization header")
}
return "", errors.New("empty or invalid jwt in authorization header")
}
// JWTFromQuery returns a `JWTExtractor` that extracts token from the provided query
// jwtFromQuery returns a `jwtExtractor` that extracts token from the provided query
// parameter.
func JWTFromQuery(param string) JWTExtractor {
func jwtFromQuery(param string) jwtExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {

View File

@ -26,9 +26,9 @@ type (
// clickjacking.
// Optional. Default value "SAMEORIGIN".
// Possible values:
// `SAMEORIGIN` - The page can only be displayed in a frame on the same origin as the page itself.
// `DENY` - The page cannot be displayed in a frame, regardless of the site attempting to do so.
// `ALLOW-FROM uri` - The page can only be displayed in a frame on the specified origin.
// - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself.
// - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so.
// - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin.
XFrameOptions string `json:"x_frame_options"`
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how