mirror of
https://github.com/labstack/echo.git
synced 2025-07-13 01:30:31 +02:00
Changes to jwt and csrf middleware
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
@ -20,14 +20,19 @@ type (
|
|||||||
// Key to create CSRF token.
|
// Key to create CSRF token.
|
||||||
Secret []byte `json:"secret"`
|
Secret []byte `json:"secret"`
|
||||||
|
|
||||||
|
// Lookup is a string in the form of "<source>:<key>" that is used to extract
|
||||||
|
// token from the request.
|
||||||
|
// Optional. Default value "header:X-CSRF-Token".
|
||||||
|
// Possible values:
|
||||||
|
// - "header:<name>"
|
||||||
|
// - "form:<name>"
|
||||||
|
// - "header:<name>"
|
||||||
|
Lookup string `json:"lookup"`
|
||||||
|
|
||||||
// Context key to store generated CSRF token into context.
|
// Context key to store generated CSRF token into context.
|
||||||
// Optional. Default value "csrf".
|
// Optional. Default value "csrf".
|
||||||
ContextKey string `json:"context_key"`
|
ContextKey string `json:"context_key"`
|
||||||
|
|
||||||
// Extractor is a function that extracts token from the request.
|
|
||||||
// Optional. Default value CSRFTokenFromHeader(echo.HeaderXCSRFToken).
|
|
||||||
Extractor CSRFTokenExtractor
|
|
||||||
|
|
||||||
// Name of the CSRF cookie. This cookie will store CSRF token.
|
// Name of the CSRF cookie. This cookie will store CSRF token.
|
||||||
// Optional. Default value "csrf".
|
// Optional. Default value "csrf".
|
||||||
CookieName string `json:"cookie_name"`
|
CookieName string `json:"cookie_name"`
|
||||||
@ -53,16 +58,16 @@ type (
|
|||||||
CookieHTTPOnly bool `json:"cookie_http_only"`
|
CookieHTTPOnly bool `json:"cookie_http_only"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSRFTokenExtractor defines a function that takes `echo.Context` and returns
|
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
|
||||||
// either a token or an error.
|
// either a token or an error.
|
||||||
CSRFTokenExtractor func(echo.Context) (string, error)
|
csrfTokenExtractor func(echo.Context) (string, error)
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// DefaultCSRFConfig is the default CSRF middleware config.
|
// DefaultCSRFConfig is the default CSRF middleware config.
|
||||||
DefaultCSRFConfig = CSRFConfig{
|
DefaultCSRFConfig = CSRFConfig{
|
||||||
|
Lookup: "header:" + echo.HeaderXCSRFToken,
|
||||||
ContextKey: "csrf",
|
ContextKey: "csrf",
|
||||||
Extractor: CSRFTokenFromHeader(echo.HeaderXCSRFToken),
|
|
||||||
CookieName: "csrf",
|
CookieName: "csrf",
|
||||||
CookieExpires: time.Now().Add(24 * time.Hour),
|
CookieExpires: time.Now().Add(24 * time.Hour),
|
||||||
}
|
}
|
||||||
@ -83,12 +88,12 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|||||||
if config.Secret == nil {
|
if config.Secret == nil {
|
||||||
panic("csrf secret must be provided")
|
panic("csrf secret must be provided")
|
||||||
}
|
}
|
||||||
|
if config.Lookup == "" {
|
||||||
|
config.Lookup = DefaultCSRFConfig.Lookup
|
||||||
|
}
|
||||||
if config.ContextKey == "" {
|
if config.ContextKey == "" {
|
||||||
config.ContextKey = DefaultCSRFConfig.ContextKey
|
config.ContextKey = DefaultCSRFConfig.ContextKey
|
||||||
}
|
}
|
||||||
if config.Extractor == nil {
|
|
||||||
config.Extractor = DefaultCSRFConfig.Extractor
|
|
||||||
}
|
|
||||||
if config.CookieName == "" {
|
if config.CookieName == "" {
|
||||||
config.CookieName = DefaultCSRFConfig.CookieName
|
config.CookieName = DefaultCSRFConfig.CookieName
|
||||||
}
|
}
|
||||||
@ -96,6 +101,16 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|||||||
config.CookieExpires = DefaultCSRFConfig.CookieExpires
|
config.CookieExpires = DefaultCSRFConfig.CookieExpires
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
parts := strings.Split(config.Lookup, ":")
|
||||||
|
extractor := csrfTokenFromHeader(parts[1])
|
||||||
|
switch parts[0] {
|
||||||
|
case "form":
|
||||||
|
extractor = csrfTokenFromForm(parts[1])
|
||||||
|
case "query":
|
||||||
|
extractor = csrfTokenFromQuery(parts[1])
|
||||||
|
}
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
req := c.Request()
|
req := c.Request()
|
||||||
@ -124,7 +139,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|||||||
switch req.Method() {
|
switch req.Method() {
|
||||||
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
|
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
|
||||||
default:
|
default:
|
||||||
token, err := config.Extractor(c)
|
token, err := extractor(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -141,17 +156,17 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSRFTokenFromHeader returns a `CSRFTokenExtractor` that extracts token from the
|
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
||||||
// provided request header.
|
// provided request header.
|
||||||
func CSRFTokenFromHeader(header string) CSRFTokenExtractor {
|
func csrfTokenFromHeader(header string) csrfTokenExtractor {
|
||||||
return func(c echo.Context) (string, error) {
|
return func(c echo.Context) (string, error) {
|
||||||
return c.Request().Header().Get(header), nil
|
return c.Request().Header().Get(header), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSRFTokenFromForm returns a `CSRFTokenExtractor` that extracts token from the
|
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
||||||
// provided form parameter.
|
// provided form parameter.
|
||||||
func CSRFTokenFromForm(param string) CSRFTokenExtractor {
|
func csrfTokenFromForm(param string) csrfTokenExtractor {
|
||||||
return func(c echo.Context) (string, error) {
|
return func(c echo.Context) (string, error) {
|
||||||
token := c.FormValue(param)
|
token := c.FormValue(param)
|
||||||
if token == "" {
|
if token == "" {
|
||||||
@ -161,9 +176,9 @@ func CSRFTokenFromForm(param string) CSRFTokenExtractor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSRFTokenFromQuery returns a `CSRFTokenExtractor` that extracts token from the
|
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
|
||||||
// provided query parameter.
|
// provided query parameter.
|
||||||
func CSRFTokenFromQuery(param string) CSRFTokenExtractor {
|
func csrfTokenFromQuery(param string) csrfTokenExtractor {
|
||||||
return func(c echo.Context) (string, error) {
|
return func(c echo.Context) (string, error) {
|
||||||
token := c.QueryParam(param)
|
token := c.QueryParam(param)
|
||||||
if token == "" {
|
if token == "" {
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/dgrijalva/jwt-go"
|
"github.com/dgrijalva/jwt-go"
|
||||||
"github.com/labstack/echo"
|
"github.com/labstack/echo"
|
||||||
@ -24,14 +25,16 @@ type (
|
|||||||
// Optional. Default value "user".
|
// Optional. Default value "user".
|
||||||
ContextKey string `json:"context_key"`
|
ContextKey string `json:"context_key"`
|
||||||
|
|
||||||
// Extractor is a function that extracts token from the request.
|
// Lookup is a string in the form of "<source>:<key>" that is used to extract
|
||||||
// Optional. Default value JWTFromHeader.
|
// token from the request.
|
||||||
Extractor JWTExtractor
|
// Optional. Default value "header:Authorization".
|
||||||
|
// Possible values:
|
||||||
|
// - "header:<name>"
|
||||||
|
// - "form:<name>"
|
||||||
|
Lookup string `json:"lookup"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTExtractor defines a function that takes `echo.Context` and returns either
|
jwtExtractor func(echo.Context) (string, error)
|
||||||
// a token or an error.
|
|
||||||
JWTExtractor func(echo.Context) (string, error)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -48,7 +51,7 @@ var (
|
|||||||
DefaultJWTConfig = JWTConfig{
|
DefaultJWTConfig = JWTConfig{
|
||||||
SigningMethod: AlgorithmHS256,
|
SigningMethod: AlgorithmHS256,
|
||||||
ContextKey: "user",
|
ContextKey: "user",
|
||||||
Extractor: JWTFromHeader,
|
Lookup: "header:" + echo.HeaderAuthorization,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -78,13 +81,21 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
|||||||
if config.ContextKey == "" {
|
if config.ContextKey == "" {
|
||||||
config.ContextKey = DefaultJWTConfig.ContextKey
|
config.ContextKey = DefaultJWTConfig.ContextKey
|
||||||
}
|
}
|
||||||
if config.Extractor == nil {
|
if config.Lookup == "" {
|
||||||
config.Extractor = DefaultJWTConfig.Extractor
|
config.Lookup = DefaultJWTConfig.Lookup
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
parts := strings.Split(config.Lookup, ":")
|
||||||
|
extractor := jwtFromHeader(parts[1])
|
||||||
|
switch parts[0] {
|
||||||
|
case "form":
|
||||||
|
extractor = jwtFromQuery(parts[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
auth, err := config.Extractor(c)
|
auth, err := extractor(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||||
}
|
}
|
||||||
@ -106,20 +117,22 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTFromHeader is a `JWTExtractor` that extracts token from the `Authorization` request
|
// jwtFromHeader returns a `jwtExtractor` that extracts token from the provided
|
||||||
// header.
|
// request header.
|
||||||
func JWTFromHeader(c echo.Context) (string, error) {
|
func jwtFromHeader(header string) jwtExtractor {
|
||||||
auth := c.Request().Header().Get(echo.HeaderAuthorization)
|
return func(c echo.Context) (string, error) {
|
||||||
|
auth := c.Request().Header().Get(header)
|
||||||
l := len(bearer)
|
l := len(bearer)
|
||||||
if len(auth) > l+1 && auth[:l] == bearer {
|
if len(auth) > l+1 && auth[:l] == bearer {
|
||||||
return auth[l+1:], nil
|
return auth[l+1:], nil
|
||||||
}
|
}
|
||||||
return "", errors.New("empty or invalid jwt in authorization header")
|
return "", errors.New("empty or invalid jwt in authorization header")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// JWTFromQuery returns a `JWTExtractor` that extracts token from the provided query
|
// jwtFromQuery returns a `jwtExtractor` that extracts token from the provided query
|
||||||
// parameter.
|
// parameter.
|
||||||
func JWTFromQuery(param string) JWTExtractor {
|
func jwtFromQuery(param string) jwtExtractor {
|
||||||
return func(c echo.Context) (string, error) {
|
return func(c echo.Context) (string, error) {
|
||||||
token := c.QueryParam(param)
|
token := c.QueryParam(param)
|
||||||
if token == "" {
|
if token == "" {
|
||||||
|
@ -26,9 +26,9 @@ type (
|
|||||||
// clickjacking.
|
// clickjacking.
|
||||||
// Optional. Default value "SAMEORIGIN".
|
// Optional. Default value "SAMEORIGIN".
|
||||||
// Possible values:
|
// Possible values:
|
||||||
// `SAMEORIGIN` - The page can only be displayed in a frame on the same origin as the page itself.
|
// - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself.
|
||||||
// `DENY` - The page cannot be displayed in a frame, regardless of the site attempting to do so.
|
// - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so.
|
||||||
// `ALLOW-FROM uri` - The page can only be displayed in a frame on the specified origin.
|
// - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin.
|
||||||
XFrameOptions string `json:"x_frame_options"`
|
XFrameOptions string `json:"x_frame_options"`
|
||||||
|
|
||||||
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how
|
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how
|
||||||
|
Reference in New Issue
Block a user