diff --git a/echo.go b/echo.go index 24eadaeb..51939bea 100644 --- a/echo.go +++ b/echo.go @@ -162,6 +162,7 @@ const ( HeaderUpgrade = "Upgrade" HeaderVary = "Vary" HeaderWWWAuthenticate = "WWW-Authenticate" + HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" HeaderXForwardedFor = "X-Forwarded-For" HeaderXRealIP = "X-Real-IP" HeaderServer = "Server" diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go new file mode 100644 index 00000000..24afbd2f --- /dev/null +++ b/middleware/basic_auth.go @@ -0,0 +1,69 @@ +package middleware + +import ( + "encoding/base64" + "net/http" + + "github.com/labstack/echo" +) + +type ( + // BasicAuthConfig defines the config for HTTP basic auth middleware. + BasicAuthConfig struct { + // Validator is the function to validate basic auth credentials. + Validator BasicAuthValidator + } + + // BasicAuthValidator defines a function to validate basic auth credentials. + BasicAuthValidator func(string, string) bool +) + +const ( + basic = "Basic" +) + +var ( + // DefaultBasicAuthConfig is the default basic auth middleware config. + DefaultBasicAuthConfig = BasicAuthConfig{} +) + +// BasicAuth returns an HTTP basic auth middleware. +// +// For valid credentials it calls the next handler. +// For invalid credentials, it sends "401 - Unauthorized" response. +// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response. +func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { + c := DefaultBasicAuthConfig + c.Validator = fn + return BasicAuthWithConfig(c) +} + +// BasicAuthWithConfig returns an HTTP basic auth middleware from config. +// See `BasicAuth()`. +func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + auth := c.Request().Header().Get(echo.HeaderAuthorization) + l := len(basic) + + if len(auth) > l+1 && auth[:l] == basic { + b, err := base64.StdEncoding.DecodeString(auth[l+1:]) + if err != nil { + return err + } + cred := string(b) + for i := 0; i < len(cred); i++ { + if cred[i] == ':' { + // Verify credentials + if config.Validator(cred[:i], cred[i+1:]) { + return next(c) + } + c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted") + return echo.ErrUnauthorized + } + } + } + return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth) + } + } +} diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go new file mode 100644 index 00000000..f2da0ecf --- /dev/null +++ b/middleware/basic_auth_test.go @@ -0,0 +1,50 @@ +package middleware + +import ( + "encoding/base64" + "net/http" + "testing" + + "github.com/labstack/echo" + "github.com/labstack/echo/test" + "github.com/stretchr/testify/assert" +) + +func TestBasicAuth(t *testing.T) { + e := echo.New() + req := test.NewRequest(echo.GET, "/", nil) + res := test.NewResponseRecorder() + c := e.NewContext(req, res) + f := func(u, p string) bool { + if u == "joe" && p == "secret" { + return true + } + return false + } + h := BasicAuth(f)(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Valid credentials + auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) + req.Header().Set(echo.HeaderAuthorization, auth) + assert.NoError(t, h(c)) + + // Incorrect password + auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) + req.Header().Set(echo.HeaderAuthorization, auth) + he := h(c).(*echo.HTTPError) + assert.Equal(t, http.StatusUnauthorized, he.Code) + assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate)) + + // Empty Authorization header + req.Header().Set(echo.HeaderAuthorization, "") + he = h(c).(*echo.HTTPError) + assert.Equal(t, http.StatusBadRequest, he.Code) + + // Invalid Authorization header + auth = base64.StdEncoding.EncodeToString([]byte("invalid")) + req.Header().Set(echo.HeaderAuthorization, auth) + he = h(c).(*echo.HTTPError) + assert.Equal(t, http.StatusBadRequest, he.Code) +} diff --git a/middleware/auth.go b/middleware/jwt_auth.go similarity index 65% rename from middleware/auth.go rename to middleware/jwt_auth.go index 53911854..14d1912f 100644 --- a/middleware/auth.go +++ b/middleware/jwt_auth.go @@ -1,7 +1,6 @@ package middleware import ( - "encoding/base64" "fmt" "net/http" @@ -10,15 +9,6 @@ import ( ) type ( - // BasicAuthConfig defines the config for HTTP basic auth middleware. - BasicAuthConfig struct { - // AuthFunc is the function to validate basic auth credentials. - AuthFunc BasicAuthFunc - } - - // BasicAuthFunc defines a function to validate basic auth credentials. - BasicAuthFunc func(string, string) bool - // JWTAuthConfig defines the config for JWT auth middleware. JWTAuthConfig struct { // SigningKey is the key to validate token. @@ -45,7 +35,6 @@ type ( ) const ( - basic = "Basic" bearer = "Bearer" ) @@ -55,9 +44,6 @@ const ( ) var ( - // DefaultBasicAuthConfig is the default basic auth middleware config. - DefaultBasicAuthConfig = BasicAuthConfig{} - // DefaultJWTAuthConfig is the default JWT auth middleware config. DefaultJWTAuthConfig = JWTAuthConfig{ SigningMethod: AlgorithmHS256, @@ -66,47 +52,6 @@ var ( } ) -// BasicAuth returns an HTTP basic auth middleware. -// -// For valid credentials it calls the next handler. -// For invalid credentials, it sends "401 - Unauthorized" response. -// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response. -func BasicAuth(fn BasicAuthFunc) echo.MiddlewareFunc { - c := DefaultBasicAuthConfig - c.AuthFunc = fn - return BasicAuthWithConfig(c) -} - -// BasicAuthWithConfig returns an HTTP basic auth middleware from config. -// See `BasicAuth()`. -func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - auth := c.Request().Header().Get(echo.HeaderAuthorization) - l := len(basic) - - if len(auth) > l+1 && auth[:l] == basic { - b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err != nil { - return err - } - cred := string(b) - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Verify credentials - if config.AuthFunc(cred[:i], cred[i+1:]) { - return next(c) - } - c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted") - return echo.ErrUnauthorized - } - } - } - return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth) - } - } -} - // JWTAuth returns a JSON Web Token (JWT) auth middleware. // // For valid token, it sets the user in context and calls next handler. diff --git a/middleware/auth_test.go b/middleware/jwt_auth_test.go similarity index 55% rename from middleware/auth_test.go rename to middleware/jwt_auth_test.go index 7638fc0e..23ec6635 100644 --- a/middleware/auth_test.go +++ b/middleware/jwt_auth_test.go @@ -1,7 +1,6 @@ package middleware import ( - "encoding/base64" "net/http" "testing" @@ -11,45 +10,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestBasicAuth(t *testing.T) { - e := echo.New() - req := test.NewRequest(echo.GET, "/", nil) - res := test.NewResponseRecorder() - c := e.NewContext(req, res) - f := func(u, p string) bool { - if u == "joe" && p == "secret" { - return true - } - return false - } - h := BasicAuth(f)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - - // Valid credentials - auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header().Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) - - // Incorrect password - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) - req.Header().Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) - assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate)) - - // Empty Authorization header - req.Header().Set(echo.HeaderAuthorization, "") - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusBadRequest, he.Code) - - // Invalid Authorization header - auth = base64.StdEncoding.EncodeToString([]byte("invalid")) - req.Header().Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusBadRequest, he.Code) -} - func TestJWTAuth(t *testing.T) { e := echo.New() req := test.NewRequest(echo.GET, "/", nil) diff --git a/middleware/method_override.go b/middleware/method_override.go index 94379614..a28a1035 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -4,36 +4,70 @@ import ( "github.com/labstack/echo" ) -const ( - HttpMethodOverrideHeader = "X-HTTP-Method-Override" +type ( + // MethodOverrideConfig defines the config for method override middleware. + MethodOverrideConfig struct { + Getter MethodOverrideGetter + } + + // MethodOverrideGetter is a function that gets overridden method from the request + // Optional, with default values as `MethodFromHeader(echo.HeaderXHTTPMethodOverride)`. + MethodOverrideGetter func(echo.Context) string ) -func OverrideMethod() echo.MiddlewareFunc { - return Override() +var ( + // DefaultMethodOverrideConfig is the default method override middleware config. + DefaultMethodOverrideConfig = MethodOverrideConfig{ + Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), + } +) + +// MethodOverride returns a method override middleware. +// Method override middleware checks for the overriden method from the request and +// uses it instead of the original method. +// +// For security reasons, only `POST` method can be overriden. +func MethodOverride() echo.MiddlewareFunc { + return MethodOverrideWithConfig(DefaultMethodOverrideConfig) } -// Override checks for the X-HTTP-Method-Override header -// or the body for parameter, `_method` -// and uses the http method instead of Request.Method. -// It isn't secure to override e.g a GET to a POST, -// so only Request.Method which are POSTs are considered. -func Override() echo.MiddlewareFunc { +// MethodOverrideWithConfig returns a method override middleware from config. +// See `MethodOverride()`. +func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - originalMethod := c.Request().Method() - - if originalMethod == "POST" { - m := c.FormValue("_method") - if m != "" { - c.Request().SetMethod(m) - } - m = c.Request().Header().Get(HttpMethodOverrideHeader) + req := c.Request() + if req.Method() == echo.POST { + m := config.Getter(c) if m != "" { c.Request().SetMethod(m) } } - return next(c) } } } + +// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from +// the request header. +func MethodFromHeader(header string) MethodOverrideGetter { + return func(c echo.Context) string { + return c.Request().Header().Get(header) + } +} + +// MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the +// form parameter. +func MethodFromForm(param string) MethodOverrideGetter { + return func(c echo.Context) string { + return c.FormValue(param) + } +} + +// MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from +// the query parameter. +func MethodFromQuery(param string) MethodOverrideGetter { + return func(c echo.Context) string { + return c.QueryParam(param) + } +} diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 28c00762..0fccd4eb 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -10,34 +10,40 @@ import ( "github.com/stretchr/testify/assert" ) -func TestOverrideMtrhod(t *testing.T) { +func TestMethodOverride(t *testing.T) { e := echo.New() - methodOverride := OverrideMethod() - h := methodOverride(func(c echo.Context) error { - return c.String(http.StatusOK, c.Request().Method()) - }) + m := MethodOverride() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "Okay") + } // Override with http header - rq := test.NewRequest(echo.POST, "/", nil) - rq.Header().Set(HttpMethodOverrideHeader, "DELETE") - rc := test.NewResponseRecorder() - c := e.NewContext(rq, rc) - h(c) - assert.Equal(t, "DELETE", rc.Body.String()) + req := test.NewRequest(echo.POST, "/", nil) + rec := test.NewResponseRecorder() + req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE) + c := e.NewContext(req, rec) + m(h)(c) + assert.Equal(t, echo.DELETE, req.Method()) - // Override with body parameter - rq = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method=DELETE"))) - rq.Header().Set(echo.HeaderContentType, echo.MIMEApplicationForm) - rc = test.NewResponseRecorder() - c = e.NewContext(rq, rc) - h(c) - assert.Equal(t, "DELETE", rc.Body.String()) + // Override with form parameter + m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) + req = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method="+echo.DELETE))) + rec = test.NewResponseRecorder() + req.Header().Set(echo.HeaderContentType, echo.MIMEApplicationForm) + c = e.NewContext(req, rec) + m(h)(c) + assert.Equal(t, echo.DELETE, req.Method()) - // Ignore GET - rq = test.NewRequest(echo.GET, "/", nil) - rq.Header().Set(HttpMethodOverrideHeader, "DELETE") - rc = test.NewResponseRecorder() - c = e.NewContext(rq, rc) - h(c) - assert.Equal(t, "GET", rc.Body.String()) + // Override with query paramter + m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) + req = test.NewRequest(echo.POST, "/?_method="+echo.DELETE, nil) + rec = test.NewResponseRecorder() + c = e.NewContext(req, rec) + m(h)(c) + assert.Equal(t, echo.DELETE, req.Method()) + + // Ignore `GET` + req = test.NewRequest(echo.GET, "/", nil) + req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE) + assert.Equal(t, echo.GET, req.Method()) }