1
0
mirror of https://github.com/labstack/echo.git synced 2025-06-08 23:56:20 +02:00

WIP: logger examples

WIP: make default logger implemented custom writer for jsonlike logs
WIP: improve examples
WIP: defaultErrorHandler use errors.As to unwrap errors. Update readme
WIP: default logger logs json, restore e.Start method
WIP: clean router.Match a bit
WIP: func types/fields have echo.Context has first element
WIP: remove yaml tags as functions etc can not be serialized anyway
WIP: change BindPathParams,BindQueryParams,BindHeaders from methods to functions and reverse arguments to be like DefaultBinder.Bind is
WIP: improved comments, logger now extracts status from error
WIP: go mod tidy
WIP: rebase with 4.5.0
WIP:
* removed todos.
* removed StartAutoTLS and StartH2CServer methods from `StartConfig`
* KeyAuth middleware errorhandler can swallow the error and resume next middleware
WIP: add RouterConfig.UseEscapedPathForMatching to use escaped path for matching request against routes
WIP: FIXMEs
WIP: upgrade golang-jwt/jwt to `v4`
WIP: refactor http methods to return RouteInfo
WIP: refactor static not creating multiple routes
WIP: refactor route and middleware adding functions not to return error directly
WIP: Use 401 for problematic/missing headers for key auth and JWT middleware (#1552, #1402).
> In summary, a 401 Unauthorized response should be used for missing or bad authentication
WIP: replace `HTTPError.SetInternal` with `HTTPError.WithInternal` so we could not mutate global error variables
WIP: add RouteInfo and RouteMatchType into Context what we could know from in middleware what route was matched and/or type of that match (200/404/405)
WIP: make notFoundHandler and methodNotAllowedHandler private. encourage that all errors be handled in Echo.HTTPErrorHandler
WIP: server cleanup ideas
WIP: routable.ForGroup
WIP: note about logger middleware
WIP: bind should not default values on second try. use crypto rand for better randomness
WIP: router add route as interface and returns info as interface
WIP: improve flaky test (remains still flaky)
WIP: add notes about bind default values
WIP: every route can have their own path params names
WIP: routerCreator and different tests
WIP: different things
WIP: remove route implementation
WIP: support custom method types
WIP: extractor tests
WIP: v5.0.x proposal
over v4.4.0
This commit is contained in:
toimtoimtoim 2021-07-15 23:34:01 +03:00
parent c6f0c667f1
commit 6ef5f77bf2
80 changed files with 9216 additions and 4922 deletions

View File

@ -27,7 +27,8 @@ jobs:
os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, macos-latest, windows-latest]
# Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
# Echo tests with last four major releases # Echo tests with last four major releases
go: [1.14, 1.15, 1.16, 1.17] # except v5 starts from 1.16 until there is last four major releases after that
go: [1.16, 1.17]
name: ${{ matrix.os }} @ Go ${{ matrix.go }} name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:

View File

@ -1,21 +0,0 @@
arch:
- amd64
- ppc64le
language: go
go:
- 1.14.x
- 1.15.x
- tip
env:
- GO111MODULE=on
install:
- go get -v golang.org/x/lint/golint
script:
- golint -set_exit_status ./...
- go test -race -coverprofile=coverage.txt -covermode=atomic ./...
after_success:
- bash <(curl -s https://codecov.io/bash)
matrix:
allow_failures:
- go: tip

View File

@ -24,11 +24,11 @@ race: ## Run tests with data race detector
@go test -race ${PKG_LIST} @go test -race ${PKG_LIST}
benchmark: ## Run benchmarks benchmark: ## Run benchmarks
@go test -run="-" -bench=".*" ${PKG_LIST} @go test -run="-" -benchmem -bench=".*" ${PKG_LIST}
help: ## Display this help screen help: ## Display this help screen
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
goversion ?= "1.15" goversion ?= "1.16"
test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.15 test_version: ## Run tests inside Docker with given version (defaults to 1.16 oldest supported). Example: make test_version goversion=1.16
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"

View File

@ -12,6 +12,8 @@
## Supported Go versions ## Supported Go versions
Echo supports last four major releases. `v5` starts from 1.16 until there is last four major releases after that.
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
Therefore a Go version capable of understanding /vN suffixed imports is required: Therefore a Go version capable of understanding /vN suffixed imports is required:
@ -67,8 +69,8 @@ package main
import ( import (
"net/http" "net/http"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v5"
"github.com/labstack/echo/v4/middleware" "github.com/labstack/echo/v5/middleware"
) )
func main() { func main() {
@ -83,7 +85,9 @@ func main() {
e.GET("/", hello) e.GET("/", hello)
// Start server // Start server
e.Logger.Fatal(e.Start(":1323")) if err := e.Start(":1323"); err != http.ErrServerClosed {
log.Fatal(err)
}
} }
// Handler // Handler

70
bind.go
View File

@ -11,42 +11,38 @@ import (
"strings" "strings"
) )
type (
// Binder is the interface that wraps the Bind method. // Binder is the interface that wraps the Bind method.
Binder interface { type Binder interface {
Bind(i interface{}, c Context) error Bind(c Context, i interface{}) error
} }
// DefaultBinder is the default implementation of the Binder interface. // DefaultBinder is the default implementation of the Binder interface.
DefaultBinder struct{} type DefaultBinder struct{}
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. // BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
// Types that don't implement this, but do implement encoding.TextUnmarshaler // Types that don't implement this, but do implement encoding.TextUnmarshaler
// will use that interface instead. // will use that interface instead.
BindUnmarshaler interface { type BindUnmarshaler interface {
// UnmarshalParam decodes and assigns a value from an form or query param. // UnmarshalParam decodes and assigns a value from an form or query param.
UnmarshalParam(param string) error UnmarshalParam(param string) error
} }
)
// BindPathParams binds path params to bindable object // BindPathParams binds path params to bindable object
func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { func BindPathParams(c Context, i interface{}) error {
names := c.ParamNames()
values := c.ParamValues()
params := map[string][]string{} params := map[string][]string{}
for i, name := range names { for _, param := range c.PathParams() {
params[name] = []string{values[i]} params[param.Name] = []string{param.Value}
} }
if err := b.bindData(i, params, "param"); err != nil { if err := bindData(i, params, "param"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
return nil return nil
} }
// BindQueryParams binds query params to bindable object // BindQueryParams binds query params to bindable object
func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { func BindQueryParams(c Context, i interface{}) error {
if err := b.bindData(i, c.QueryParams(), "query"); err != nil { if err := bindData(i, c.QueryParams(), "query"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
return nil return nil
} }
@ -56,7 +52,7 @@ func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error {
// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm // See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm
// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm // See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm
func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { func BindBody(c Context, i interface{}) (err error) {
req := c.Request() req := c.Request()
if req.ContentLength == 0 { if req.ContentLength == 0 {
return return
@ -70,25 +66,25 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
case *HTTPError: case *HTTPError:
return err return err
default: default:
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
} }
case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML): case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML):
if err = xml.NewDecoder(req.Body).Decode(i); err != nil { if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
if ute, ok := err.(*xml.UnsupportedTypeError); ok { if ute, ok := err.(*xml.UnsupportedTypeError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error()))
} else if se, ok := err.(*xml.SyntaxError); ok { } else if se, ok := err.(*xml.SyntaxError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error()))
} }
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
params, err := c.FormParams() params, err := c.FormParams()
if err != nil { if err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
if err = b.bindData(i, params, "form"); err != nil { if err = bindData(i, params, "form"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
default: default:
return ErrUnsupportedMediaType return ErrUnsupportedMediaType
@ -98,17 +94,17 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
// BindHeaders binds HTTP headers to a bindable object // BindHeaders binds HTTP headers to a bindable object
func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error { func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error {
if err := b.bindData(i, c.Request().Header, "header"); err != nil { if err := bindData(i, c.Request().Header, "header"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
return nil return nil
} }
// Bind implements the `Binder#Bind` function. // Bind implements the `Binder#Bind` function.
// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous // Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous
// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. // step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { func (b *DefaultBinder) Bind(c Context, i interface{}) (err error) {
if err := b.BindPathParams(c, i); err != nil { if err := BindPathParams(c, i); err != nil {
return err return err
} }
// Issue #1670 - Query params are binded only for GET/DELETE and NOT for usual request with body (POST/PUT/PATCH) // Issue #1670 - Query params are binded only for GET/DELETE and NOT for usual request with body (POST/PUT/PATCH)
@ -116,15 +112,15 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
// i.e. is `&id=1&lang=en` from URL same as `{"id":100,"lang":"de"}` request body and which one should have priority when binding. // i.e. is `&id=1&lang=en` from URL same as `{"id":100,"lang":"de"}` request body and which one should have priority when binding.
// This HTTP method check restores pre v4.1.11 behavior and avoids different problems when query is mixed with body // This HTTP method check restores pre v4.1.11 behavior and avoids different problems when query is mixed with body
if c.Request().Method == http.MethodGet || c.Request().Method == http.MethodDelete { if c.Request().Method == http.MethodGet || c.Request().Method == http.MethodDelete {
if err = b.BindQueryParams(c, i); err != nil { if err = BindQueryParams(c, i); err != nil {
return err return err
} }
} }
return b.BindBody(c, i) return BindBody(c, i)
} }
// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag
func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error { func bindData(destination interface{}, data map[string][]string, tag string) error {
if destination == nil || len(data) == 0 { if destination == nil || len(data) == 0 {
return nil return nil
} }
@ -170,7 +166,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
// If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags).
// structs that implement BindUnmarshaler are binded only when they have explicit tag // structs that implement BindUnmarshaler are binded only when they have explicit tag
if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { if err := bindData(structField.Addr().Interface(), data, tag); err != nil {
return err return err
} }
} }
@ -297,7 +293,7 @@ func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
func setIntField(value string, bitSize int, field reflect.Value) error { func setIntField(value string, bitSize int, field reflect.Value) error {
if value == "" { if value == "" {
value = "0" return nil
} }
intVal, err := strconv.ParseInt(value, 10, bitSize) intVal, err := strconv.ParseInt(value, 10, bitSize)
if err == nil { if err == nil {
@ -308,7 +304,7 @@ func setIntField(value string, bitSize int, field reflect.Value) error {
func setUintField(value string, bitSize int, field reflect.Value) error { func setUintField(value string, bitSize int, field reflect.Value) error {
if value == "" { if value == "" {
value = "0" return nil
} }
uintVal, err := strconv.ParseUint(value, 10, bitSize) uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err == nil { if err == nil {
@ -319,7 +315,7 @@ func setUintField(value string, bitSize int, field reflect.Value) error {
func setBoolField(value string, field reflect.Value) error { func setBoolField(value string, field reflect.Value) error {
if value == "" { if value == "" {
value = "false" return nil
} }
boolVal, err := strconv.ParseBool(value) boolVal, err := strconv.ParseBool(value)
if err == nil { if err == nil {
@ -330,7 +326,7 @@ func setBoolField(value string, field reflect.Value) error {
func setFloatField(value string, bitSize int, field reflect.Value) error { func setFloatField(value string, bitSize int, field reflect.Value) error {
if value == "" { if value == "" {
value = "0.0" return nil
} }
floatVal, err := strconv.ParseFloat(value, bitSize) floatVal, err := strconv.ParseFloat(value, bitSize)
if err == nil { if err == nil {

View File

@ -300,6 +300,52 @@ func TestBindHeaderParamBadType(t *testing.T) {
} }
} }
func TestBind_CombineQueryWithHeaderParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/products/999?length=50&page=10&language=et", nil)
req.Header.Set("language", "de")
req.Header.Set("length", "99")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.SetPathParams(PathParams{{
Name: "id",
Value: "999",
}})
type SearchOpts struct {
ID int `param:"id"`
Length int `query:"length"`
Page int `query:"page"`
Search string `query:"search"`
Language string `query:"language" header:"language"`
}
opts := SearchOpts{
Length: 100,
Page: 0,
Search: "default value",
Language: "en",
}
err := c.Bind(&opts)
assert.NoError(t, err)
assert.Equal(t, 50, opts.Length) // bind from query
assert.Equal(t, 10, opts.Page) // bind from query
assert.Equal(t, 999, opts.ID) // bind from path param
assert.Equal(t, "et", opts.Language) // bind from query
assert.Equal(t, "default value", opts.Search) // default value stays
// make sure another bind will not mess already set values unless there are new values
err = (&DefaultBinder{}).BindHeaders(c, &opts)
assert.NoError(t, err)
assert.Equal(t, 50, opts.Length) // does not have tag in struct although header exists
assert.Equal(t, 10, opts.Page)
assert.Equal(t, 999, opts.ID)
assert.Equal(t, "de", opts.Language) // header overwrites now this value
assert.Equal(t, "default value", opts.Search)
}
func TestBindUnmarshalParam(t *testing.T) { func TestBindUnmarshalParam(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
@ -330,7 +376,7 @@ func TestBindUnmarshalParam(t *testing.T) {
func TestBindUnmarshalText(t *testing.T) { func TestBindUnmarshalText(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
result := struct { result := struct {
@ -406,7 +452,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) {
func TestBindUnmarshalTextPtr(t *testing.T) { func TestBindUnmarshalTextPtr(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil) req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
result := struct { result := struct {
@ -439,8 +485,7 @@ func TestBindUnsupportedMediaType(t *testing.T) {
func TestBindbindData(t *testing.T) { func TestBindbindData(t *testing.T) {
a := assert.New(t) a := assert.New(t)
ts := new(bindTestStruct) ts := new(bindTestStruct)
b := new(DefaultBinder) err := bindData(ts, values, "form")
err := b.bindData(ts, values, "form")
a.NoError(err) a.NoError(err)
a.Equal(0, ts.I) a.Equal(0, ts.I)
@ -462,12 +507,15 @@ func TestBindbindData(t *testing.T) {
func TestBindParam(t *testing.T) { func TestBindParam(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(GET, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
c.SetPath("/users/:id/:name") cc := c.(EditableContext)
c.SetParamNames("id", "name") cc.SetRouteInfo(routeInfo{path: "/users/:id/:name"})
c.SetParamValues("1", "Jon Snow") cc.SetPathParams(PathParams{
{Name: "id", Value: "1"},
{Name: "name", Value: "Jon Snow"},
})
u := new(user) u := new(user)
err := c.Bind(u) err := c.Bind(u)
@ -478,9 +526,11 @@ func TestBindParam(t *testing.T) {
// Second test for the absence of a param // Second test for the absence of a param
c2 := e.NewContext(req, rec) c2 := e.NewContext(req, rec)
c2.SetPath("/users/:id") cc2 := c2.(EditableContext)
c2.SetParamNames("id") cc2.SetRouteInfo(routeInfo{path: "/users/:id"})
c2.SetParamValues("1") cc2.SetPathParams(PathParams{
{Name: "id", Value: "1"},
})
u = new(user) u = new(user)
err = c2.Bind(u) err = c2.Bind(u)
@ -492,15 +542,17 @@ func TestBindParam(t *testing.T) {
// Bind something with param and post data payload // Bind something with param and post data payload
body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) body := bytes.NewBufferString(`{ "name": "Jon Snow" }`)
e2 := New() e2 := New()
req2 := httptest.NewRequest(POST, "/", body) req2 := httptest.NewRequest(http.MethodPost, "/", body)
req2.Header.Set(HeaderContentType, MIMEApplicationJSON) req2.Header.Set(HeaderContentType, MIMEApplicationJSON)
rec2 := httptest.NewRecorder() rec2 := httptest.NewRecorder()
c3 := e2.NewContext(req2, rec2) c3 := e2.NewContext(req2, rec2)
c3.SetPath("/users/:id") cc3 := c3.(EditableContext)
c3.SetParamNames("id") cc3.SetRouteInfo(routeInfo{path: "/users/:id"})
c3.SetParamValues("1") cc3.SetPathParams(PathParams{
{Name: "id", Value: "1"},
})
u = new(user) u = new(user)
err = c3.Bind(u) err = c3.Bind(u)
@ -556,47 +608,115 @@ func TestBindSetWithProperType(t *testing.T) {
assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
} }
func TestBindSetFields(t *testing.T) { func TestSetIntField(t *testing.T) {
assert := assert.New(t) ts := new(bindTestStruct)
ts.I = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setIntField("", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 100, ts.I)
// second set with value sets the value
err = setIntField("5", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 5, ts.I)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setIntField("", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 5, ts.I)
}
func TestSetUintField(t *testing.T) {
ts := new(bindTestStruct)
ts.UI = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setUintField("", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(100), ts.UI)
// second set with value sets the value
err = setUintField("5", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(5), ts.UI)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setUintField("", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(5), ts.UI)
}
func TestSetFloatField(t *testing.T) {
ts := new(bindTestStruct)
ts.F32 = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setFloatField("", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(100), ts.F32)
// second set with value sets the value
err = setFloatField("15.5", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(15.5), ts.F32)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setFloatField("", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(15.5), ts.F32)
}
func TestSetBoolField(t *testing.T) {
ts := new(bindTestStruct)
ts.B = true
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setBoolField("", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// second set with value sets the value
err = setBoolField("true", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setBoolField("", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// fourth set to false
err = setBoolField("false", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, false, ts.B)
}
func TestUnmarshalFieldNonPtr(t *testing.T) {
ts := new(bindTestStruct) ts := new(bindTestStruct)
val := reflect.ValueOf(ts).Elem() val := reflect.ValueOf(ts).Elem()
// Int
if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) {
assert.Equal(5, ts.I)
}
if assert.NoError(setIntField("", 0, val.FieldByName("I"))) {
assert.Equal(0, ts.I)
}
// Uint
if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) {
assert.Equal(uint(10), ts.UI)
}
if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) {
assert.Equal(uint(0), ts.UI)
}
// Float
if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) {
assert.Equal(float32(15.5), ts.F32)
}
if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) {
assert.Equal(float32(0.0), ts.F32)
}
// Bool
if assert.NoError(setBoolField("true", val.FieldByName("B"))) {
assert.Equal(true, ts.B)
}
if assert.NoError(setBoolField("", val.FieldByName("B"))) {
assert.Equal(false, ts.B)
}
ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T")) ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T"))
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(ok, true) assert.True(t, ok)
assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
} }
} }
@ -604,11 +724,10 @@ func BenchmarkBindbindDataWithTags(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
assert := assert.New(b) assert := assert.New(b)
ts := new(bindTestStructWithTags) ts := new(bindTestStructWithTags)
binder := new(DefaultBinder)
var err error var err error
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
err = binder.bindData(ts, values, "form") err = bindData(ts, values, "form")
} }
assert.NoError(err) assert.NoError(err)
assertBindTestStruct(assert, (*bindTestStruct)(ts)) assertBindTestStruct(assert, (*bindTestStruct)(ts))
@ -840,8 +959,10 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
if !tc.whenNoPathParams { if !tc.whenNoPathParams {
c.SetParamNames("node") cc := c.(EditableContext)
c.SetParamValues("node_from_path") cc.SetPathParams(PathParams{
{Name: "node", Value: "node_from_path"},
})
} }
var bindTarget interface{} var bindTarget interface{}
@ -852,7 +973,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
} }
b := new(DefaultBinder) b := new(DefaultBinder)
err := b.Bind(bindTarget, c) err := b.Bind(c, bindTarget)
if tc.expectError != "" { if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError) assert.EqualError(t, err, tc.expectError)
} else { } else {
@ -1021,8 +1142,10 @@ func TestDefaultBinder_BindBody(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
if !tc.whenNoPathParams { if !tc.whenNoPathParams {
c.SetParamNames("node") cc := c.(EditableContext)
c.SetParamValues("real_node") cc.SetPathParams(PathParams{
{Name: "node", Value: "real_node"},
})
} }
var bindTarget interface{} var bindTarget interface{}
@ -1031,9 +1154,8 @@ func TestDefaultBinder_BindBody(t *testing.T) {
} else { } else {
bindTarget = &Node{} bindTarget = &Node{}
} }
b := new(DefaultBinder)
err := b.BindBody(c, bindTarget) err := BindBody(c, bindTarget)
if tc.expectError != "" { if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError) assert.EqualError(t, err, tc.expectError)
} else { } else {

View File

@ -118,10 +118,10 @@ func QueryParamsBinder(c Context) *ValueBinder {
func PathParamsBinder(c Context) *ValueBinder { func PathParamsBinder(c Context) *ValueBinder {
return &ValueBinder{ return &ValueBinder{
failFast: true, failFast: true,
ValueFunc: c.Param, ValueFunc: c.PathParam,
ValuesFunc: func(sourceParam string) []string { ValuesFunc: func(sourceParam string) []string {
// path parameter should not have multiple values so getting values does not make sense but lets not error out here // path parameter should not have multiple values so getting values does not make sense but lets not error out here
value := c.Param(sourceParam) value := c.PathParam(sourceParam)
if value == "" { if value == "" {
return nil return nil
} }

View File

@ -30,14 +30,15 @@ func createTestContext15(URL string, body io.Reader, pathParams map[string]strin
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
if len(pathParams) > 0 { if len(pathParams) > 0 {
names := make([]string, 0) params := make(PathParams, 0)
values := make([]string, 0)
for name, value := range pathParams { for name, value := range pathParams {
names = append(names, name) params = append(params, PathParam{
values = append(values, value) Name: name,
Value: value,
})
} }
c.SetParamNames(names...) cc := c.(EditableContext)
c.SetParamValues(values...) cc.SetPathParams(params)
} }
return c return c

View File

@ -25,14 +25,15 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string)
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
if len(pathParams) > 0 { if len(pathParams) > 0 {
names := make([]string, 0) params := make(PathParams, 0)
values := make([]string, 0)
for name, value := range pathParams { for name, value := range pathParams {
names = append(names, name) params = append(params, PathParam{
values = append(values, value) Name: name,
Value: value,
})
} }
c.SetParamNames(names...) cc := c.(EditableContext)
c.SetParamValues(values...) cc.SetPathParams(params)
} }
return c return c
@ -2643,7 +2644,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) {
binder := new(DefaultBinder) binder := new(DefaultBinder)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var dest Opts var dest Opts
_ = binder.Bind(&dest, c) _ = binder.Bind(c, &dest)
} }
} }
@ -2710,7 +2711,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
binder := new(DefaultBinder) binder := new(DefaultBinder)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var dest Opts var dest Opts
_ = binder.Bind(&dest, c) _ = binder.Bind(c, &dest)
if dest.Int64 != 1 { if dest.Int64 != 1 {
b.Fatalf("int64!=1") b.Fatalf("int64!=1")
} }

View File

@ -3,22 +3,22 @@ package echo
import ( import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"mime/multipart" "mime/multipart"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
) )
type (
// Context represents the context of the current HTTP request. It holds request and // Context represents the context of the current HTTP request. It holds request and
// response objects, path, path parameters, data and registered handler. // response objects, path, path parameters, data and registered handler.
Context interface { type Context interface {
// Request returns `*http.Request`. // Request returns `*http.Request`.
Request() *http.Request Request() *http.Request
@ -45,26 +45,32 @@ type (
// The behavior can be configured using `Echo#IPExtractor`. // The behavior can be configured using `Echo#IPExtractor`.
RealIP() string RealIP() string
// RouteMatchType returns router match type for current context. This helps middlewares to distinguish which type
// of match router found and how this request context handler chain could end:
// * route match - this path + method had matching route.
// * not found - this path did not match any routes enough to be considered match
// * method not allowed - path had routes registered but for other method types then current request is
// * unknown - initial state for fresh context before router tries to do routing
//
// Note: for pre-middleware (Echo.Pre) this method result is always RouteMatchUnknown as at point router has not tried
// to match request to route.
RouteMatchType() RouteMatchType
// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route.
// In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases.
RouteInfo() RouteInfo
// Path returns the registered path for the handler. // Path returns the registered path for the handler.
Path() string Path() string
// SetPath sets the registered path for the handler. // PathParam returns path parameter by name.
SetPath(p string) PathParam(name string) string
// Param returns path parameter by name. // PathParams returns path parameter values.
Param(name string) string PathParams() PathParams
// ParamNames returns path parameter names. // SetPathParams set path parameter for during current request lifecycle.
ParamNames() []string SetPathParams(params PathParams)
// SetParamNames sets path parameter names.
SetParamNames(names ...string)
// ParamValues returns path parameter values.
ParamValues() []string
// SetParamValues sets path parameter values.
SetParamValues(values ...string)
// QueryParam returns the query param for the provided name. // QueryParam returns the query param for the provided name.
QueryParam(name string) string QueryParam(name string) string
@ -171,23 +177,33 @@ type (
// Redirect redirects the request to a provided URL with status code. // Redirect redirects the request to a provided URL with status code.
Redirect(code int, url string) error Redirect(code int, url string) error
// Error invokes the registered HTTP error handler. Generally used by middleware. // Error invokes the registered HTTP error handler.
// NB: Avoid using this method. It is better to return errors so middlewares up in chain could act on returned error.
Error(err error) Error(err error)
// Handler returns the matched handler by router.
Handler() HandlerFunc
// SetHandler sets the matched handler by router.
SetHandler(h HandlerFunc)
// Logger returns the `Logger` instance.
Logger() Logger
// Set the logger
SetLogger(l Logger)
// Echo returns the `Echo` instance. // Echo returns the `Echo` instance.
Echo() *Echo Echo() *Echo
}
// EditableContext is additional interface that structure implementing Context must implement. Methods inside this
// interface are meant for Echo internal usage (for mainly routing) and should not be used in middlewares.
type EditableContext interface {
Context
// RawPathParams returns raw path pathParams value.
RawPathParams() *PathParams
// SetRawPathParams replaces any existing param values with new values for this context lifetime (request).
SetRawPathParams(params *PathParams)
// SetPath sets the registered path for the handler.
SetPath(p string)
// SetRouteMatchType sets the RouteMatchType of router match for this request.
SetRouteMatchType(t RouteMatchType)
// SetRouteInfo sets the route info of this request to the context.
SetRouteInfo(ri RouteInfo)
// Reset resets the context after request completes. It must be called along // Reset resets the context after request completes. It must be called along
// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
@ -195,20 +211,25 @@ type (
Reset(r *http.Request, w http.ResponseWriter) Reset(r *http.Request, w http.ResponseWriter)
} }
context struct { type context struct {
request *http.Request request *http.Request
response *Response response *Response
matchType RouteMatchType
route RouteInfo
path string path string
pnames []string
pvalues []string // pathParams holds path/uri parameters determined by Router. Lifecycle is handled by Echo to reduce allocations.
pathParams *PathParams
// currentParams hold path parameters set by non-Echo implementation (custom middlewares, handlers) during the lifetime of Request.
// Lifecycle is not handle by Echo and could have excess allocations per served Request
currentParams PathParams
query url.Values query url.Values
handler HandlerFunc
store Map store Map
echo *Echo echo *Echo
logger Logger
lock sync.RWMutex lock sync.RWMutex
} }
)
const ( const (
defaultMemory = 32 << 20 // 32 MB defaultMemory = 32 << 20 // 32 MB
@ -296,52 +317,50 @@ func (c *context) SetPath(p string) {
c.path = p c.path = p
} }
func (c *context) Param(name string) string { func (c *context) RouteMatchType() RouteMatchType {
for i, n := range c.pnames { return c.matchType
if i < len(c.pvalues) {
if n == name {
return c.pvalues[i]
}
}
}
return ""
} }
func (c *context) ParamNames() []string { func (c *context) SetRouteMatchType(t RouteMatchType) {
return c.pnames c.matchType = t
} }
func (c *context) SetParamNames(names ...string) { func (c *context) RouteInfo() RouteInfo {
c.pnames = names return c.route
l := len(names)
if *c.echo.maxParam < l {
*c.echo.maxParam = l
} }
if len(c.pvalues) < l { func (c *context) SetRouteInfo(ri RouteInfo) {
// Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them, c.route = ri
// probably those values will be overriden in a Context#SetParamValues
newPvalues := make([]string, l)
copy(newPvalues, c.pvalues)
c.pvalues = newPvalues
}
} }
func (c *context) ParamValues() []string { func (c *context) RawPathParams() *PathParams {
return c.pvalues[:len(c.pnames)] return c.pathParams
} }
func (c *context) SetParamValues(values ...string) { func (c *context) SetRawPathParams(params *PathParams) {
// NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times c.pathParams = params
// It will brake the Router#Find code
limit := len(values)
if limit > *c.echo.maxParam {
limit = *c.echo.maxParam
} }
for i := 0; i < limit; i++ {
c.pvalues[i] = values[i] func (c *context) PathParam(name string) string {
if c.currentParams != nil {
return c.currentParams.Get(name, "")
} }
return c.pathParams.Get(name, "")
}
func (c *context) PathParams() PathParams {
if c.currentParams != nil {
return c.currentParams
}
result := make(PathParams, len(*c.pathParams))
copy(result, *c.pathParams)
return result
}
func (c *context) SetPathParams(params PathParams) {
c.currentParams = params
} }
func (c *context) QueryParam(name string) string { func (c *context) QueryParam(name string) string {
@ -422,7 +441,7 @@ func (c *context) Set(key string, val interface{}) {
} }
func (c *context) Bind(i interface{}) error { func (c *context) Bind(i interface{}) error {
return c.echo.Binder.Bind(i, c) return c.echo.Binder.Bind(c, i)
} }
func (c *context) Validate(i interface{}) error { func (c *context) Validate(i interface{}) error {
@ -562,27 +581,36 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error)
return return
} }
func (c *context) File(file string) (err error) { func (c *context) File(file string) error {
f, err := os.Open(file) return c.FsFile(file, c.echo.Filesystem)
}
func (c *context) FsFile(file string, filesystem fs.FS) error {
// FIXME: should we add this method into echo.Context interface?
f, err := filesystem.Open(file)
if err != nil { if err != nil {
return NotFoundHandler(c) return ErrNotFound
} }
defer f.Close() defer f.Close()
fi, _ := f.Stat() fi, _ := f.Stat()
if fi.IsDir() { if fi.IsDir() {
file = filepath.Join(file, indexPage) file = filepath.Join(file, indexPage)
f, err = os.Open(file) f, err = filesystem.Open(file)
if err != nil { if err != nil {
return NotFoundHandler(c) return ErrNotFound
} }
defer f.Close() defer f.Close()
if fi, err = f.Stat(); err != nil { if fi, err = f.Stat(); err != nil {
return return err
} }
} }
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) ff, ok := f.(io.ReadSeeker)
return if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
return nil
} }
func (c *context) Attachment(file, name string) error { func (c *context) Attachment(file, name string) error {
@ -613,44 +641,23 @@ func (c *context) Redirect(code int, url string) error {
} }
func (c *context) Error(err error) { func (c *context) Error(err error) {
c.echo.HTTPErrorHandler(err, c) c.echo.HTTPErrorHandler(c, err)
} }
func (c *context) Echo() *Echo { func (c *context) Echo() *Echo {
return c.echo return c.echo
} }
func (c *context) Handler() HandlerFunc {
return c.handler
}
func (c *context) SetHandler(h HandlerFunc) {
c.handler = h
}
func (c *context) Logger() Logger {
res := c.logger
if res != nil {
return res
}
return c.echo.Logger
}
func (c *context) SetLogger(l Logger) {
c.logger = l
}
func (c *context) Reset(r *http.Request, w http.ResponseWriter) { func (c *context) Reset(r *http.Request, w http.ResponseWriter) {
c.request = r c.request = r
c.response.reset(w) c.response.reset(w)
c.query = nil c.query = nil
c.handler = NotFoundHandler
c.store = nil c.store = nil
c.matchType = RouteMatchUnknown
c.route = nil
c.path = "" c.path = ""
c.pnames = nil // NOTE: Don't reset because it has to have length c.echo.contextPathParamAllocSize at all times
c.logger = nil *c.pathParams = (*c.pathParams)[:0]
// NOTE: Don't reset because it has to have length c.echo.maxParam at all times c.currentParams = nil
for i := 0; i < *c.echo.maxParam; i++ {
c.pvalues[i] = ""
}
} }

View File

@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math" "math"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
@ -18,21 +19,19 @@ import (
"text/template" "text/template"
"time" "time"
"github.com/labstack/gommon/log" "github.com/stretchr/testify/assert"
testify "github.com/stretchr/testify/assert"
) )
type ( type Template struct {
Template struct {
templates *template.Template templates *template.Template
} }
)
var testUser = user{1, "Jon Snow"} var testUser = user{1, "Jon Snow"}
func BenchmarkAllocJSONP(b *testing.B) { func BenchmarkAllocJSONP(b *testing.B) {
e := New() e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) e.Logger = &jsonLogger{writer: ioutil.Discard}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
@ -46,7 +45,8 @@ func BenchmarkAllocJSONP(b *testing.B) {
func BenchmarkAllocJSON(b *testing.B) { func BenchmarkAllocJSON(b *testing.B) {
e := New() e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) e.Logger = &jsonLogger{writer: ioutil.Discard}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
@ -60,7 +60,8 @@ func BenchmarkAllocJSON(b *testing.B) {
func BenchmarkAllocXML(b *testing.B) { func BenchmarkAllocXML(b *testing.B) {
e := New() e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) e.Logger = &jsonLogger{writer: ioutil.Discard}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
@ -106,16 +107,14 @@ func TestContext(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
assert := testify.New(t)
// Echo // Echo
assert.Equal(e, c.Echo()) assert.Equal(t, e, c.Echo())
// Request // Request
assert.NotNil(c.Request()) assert.NotNil(t, c.Request())
// Response // Response
assert.NotNil(c.Response()) assert.NotNil(t, c.Response())
//-------- //--------
// Render // Render
@ -126,23 +125,23 @@ func TestContext(t *testing.T) {
} }
c.echo.Renderer = tmpl c.echo.Renderer = tmpl
err := c.Render(http.StatusOK, "hello", "Jon Snow") err := c.Render(http.StatusOK, "hello", "Jon Snow")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal("Hello, Jon Snow!", rec.Body.String()) assert.Equal(t, "Hello, Jon Snow!", rec.Body.String())
} }
c.echo.Renderer = nil c.echo.Renderer = nil
err = c.Render(http.StatusOK, "hello", "Jon Snow") err = c.Render(http.StatusOK, "hello", "Jon Snow")
assert.Error(err) assert.Error(t, err)
// JSON // JSON
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSON+"\n", rec.Body.String()) assert.Equal(t, userJSON+"\n", rec.Body.String())
} }
// JSON with "?pretty" // JSON with "?pretty"
@ -150,10 +149,10 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSONPretty+"\n", rec.Body.String()) assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
} }
req = httptest.NewRequest(http.MethodGet, "/", nil) // reset req = httptest.NewRequest(http.MethodGet, "/", nil) // reset
@ -161,37 +160,37 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSONPretty+"\n", rec.Body.String()) assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
} }
// JSON (error) // JSON (error)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, make(chan bool)) err = c.JSON(http.StatusOK, make(chan bool))
assert.Error(err) assert.Error(t, err)
// JSONP // JSONP
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
callback := "callback" callback := "callback"
err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"})
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String()) assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String())
} }
// XML // XML
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, user{1, "Jon Snow"}) err = c.XML(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXML, rec.Body.String()) assert.Equal(t, xml.Header+userXML, rec.Body.String())
} }
// XML with "?pretty" // XML with "?pretty"
@ -199,10 +198,10 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, user{1, "Jon Snow"}) err = c.XML(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
} }
req = httptest.NewRequest(http.MethodGet, "/", nil) req = httptest.NewRequest(http.MethodGet, "/", nil)
@ -210,22 +209,22 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, make(chan bool)) err = c.XML(http.StatusOK, make(chan bool))
assert.Error(err) assert.Error(t, err)
// XML response write error // XML response write error
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
c.response.Writer = responseWriterErr{} c.response.Writer = responseWriterErr{}
err = c.XML(0, 0) err = c.XML(0, 0)
testify.Error(t, err) assert.Error(t, err)
// XMLPretty // XMLPretty
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
} }
t.Run("empty indent", func(t *testing.T) { t.Run("empty indent", func(t *testing.T) {
@ -237,7 +236,6 @@ func TestContext(t *testing.T) {
t.Run("json", func(t *testing.T) { t.Run("json", func(t *testing.T) {
buf.Reset() buf.Reset()
assert := testify.New(t)
// New JSONBlob with empty indent // New JSONBlob with empty indent
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
@ -246,16 +244,15 @@ func TestContext(t *testing.T) {
enc.SetIndent(emptyIndent, emptyIndent) enc.SetIndent(emptyIndent, emptyIndent)
err = enc.Encode(u) err = enc.Encode(u)
err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(buf.String(), rec.Body.String()) assert.Equal(t, buf.String(), rec.Body.String())
} }
}) })
t.Run("xml", func(t *testing.T) { t.Run("xml", func(t *testing.T) {
buf.Reset() buf.Reset()
assert := testify.New(t)
// New XMLBlob with empty indent // New XMLBlob with empty indent
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
@ -264,10 +261,10 @@ func TestContext(t *testing.T) {
enc.Indent(emptyIndent, emptyIndent) enc.Indent(emptyIndent, emptyIndent)
err = enc.Encode(u) err = enc.Encode(u)
err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+buf.String(), rec.Body.String()) assert.Equal(t, xml.Header+buf.String(), rec.Body.String())
} }
}) })
}) })
@ -276,12 +273,12 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
data, err := json.Marshal(user{1, "Jon Snow"}) data, err := json.Marshal(user{1, "Jon Snow"})
assert.NoError(err) assert.NoError(t, err)
err = c.JSONBlob(http.StatusOK, data) err = c.JSONBlob(http.StatusOK, data)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSON, rec.Body.String()) assert.Equal(t, userJSON, rec.Body.String())
} }
// Legacy JSONPBlob // Legacy JSONPBlob
@ -289,44 +286,44 @@ func TestContext(t *testing.T) {
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
callback = "callback" callback = "callback"
data, err = json.Marshal(user{1, "Jon Snow"}) data, err = json.Marshal(user{1, "Jon Snow"})
assert.NoError(err) assert.NoError(t, err)
err = c.JSONPBlob(http.StatusOK, callback, data) err = c.JSONPBlob(http.StatusOK, callback, data)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(callback+"("+userJSON+");", rec.Body.String()) assert.Equal(t, callback+"("+userJSON+");", rec.Body.String())
} }
// Legacy XMLBlob // Legacy XMLBlob
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
data, err = xml.Marshal(user{1, "Jon Snow"}) data, err = xml.Marshal(user{1, "Jon Snow"})
assert.NoError(err) assert.NoError(t, err)
err = c.XMLBlob(http.StatusOK, data) err = c.XMLBlob(http.StatusOK, data)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXML, rec.Body.String()) assert.Equal(t, xml.Header+userXML, rec.Body.String())
} }
// String // String
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.String(http.StatusOK, "Hello, World!") err = c.String(http.StatusOK, "Hello, World!")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal("Hello, World!", rec.Body.String()) assert.Equal(t, "Hello, World!", rec.Body.String())
} }
// HTML // HTML
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.HTML(http.StatusOK, "Hello, <strong>World!</strong>") err = c.HTML(http.StatusOK, "Hello, <strong>World!</strong>")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal("Hello, <strong>World!</strong>", rec.Body.String()) assert.Equal(t, "Hello, <strong>World!</strong>", rec.Body.String())
} }
// Stream // Stream
@ -334,55 +331,55 @@ func TestContext(t *testing.T) {
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
r := strings.NewReader("response from a stream") r := strings.NewReader("response from a stream")
err = c.Stream(http.StatusOK, "application/octet-stream", r) err = c.Stream(http.StatusOK, "application/octet-stream", r)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType)) assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType))
assert.Equal("response from a stream", rec.Body.String()) assert.Equal(t, "response from a stream", rec.Body.String())
} }
// Attachment // Attachment
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.Attachment("_fixture/images/walle.png", "walle.png") err = c.Attachment("_fixture/images/walle.png", "walle.png")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(219885, rec.Body.Len()) assert.Equal(t, 219885, rec.Body.Len())
} }
// Inline // Inline
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.Inline("_fixture/images/walle.png", "walle.png") err = c.Inline("_fixture/images/walle.png", "walle.png")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(219885, rec.Body.Len()) assert.Equal(t, 219885, rec.Body.Len())
} }
// NoContent // NoContent
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
c.NoContent(http.StatusOK) c.NoContent(http.StatusOK)
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
// Error // Error
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
c.Error(errors.New("error")) c.Error(errors.New("error"))
assert.Equal(http.StatusInternalServerError, rec.Code) assert.Equal(t, http.StatusInternalServerError, rec.Code)
// Reset // Reset
c.SetParamNames("foo") c.pathParams = &PathParams{
c.SetParamValues("bar") {Name: "foo", Value: "bar"},
}
c.Set("foe", "ban") c.Set("foe", "ban")
c.query = url.Values(map[string][]string{"fon": {"baz"}}) c.query = url.Values(map[string][]string{"fon": {"baz"}})
c.Reset(req, httptest.NewRecorder()) c.Reset(req, httptest.NewRecorder())
assert.Equal(0, len(c.ParamValues())) assert.Equal(t, 0, len(c.PathParams()))
assert.Equal(0, len(c.ParamNames())) assert.Equal(t, 0, len(c.store))
assert.Equal(0, len(c.store)) assert.Equal(t, nil, c.RouteInfo())
assert.Equal("", c.Path()) assert.Equal(t, 0, len(c.QueryParams()))
assert.Equal(0, len(c.QueryParams()))
} }
func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
@ -392,11 +389,10 @@ func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
err := c.JSON(http.StatusCreated, user{1, "Jon Snow"}) err := c.JSON(http.StatusCreated, user{1, "Jon Snow"})
assert := testify.New(t) if assert.NoError(t, err) {
if assert.NoError(err) { assert.Equal(t, http.StatusCreated, rec.Code)
assert.Equal(http.StatusCreated, rec.Code) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSON+"\n", rec.Body.String())
assert.Equal(userJSON+"\n", rec.Body.String())
} }
} }
@ -407,9 +403,8 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) {
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()}) err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()})
assert := testify.New(t) if assert.Error(t, err) {
if assert.Error(err) { assert.False(t, c.response.Committed)
assert.False(c.response.Committed)
} }
} }
@ -423,22 +418,20 @@ func TestContextCookie(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
assert := testify.New(t)
// Read single // Read single
cookie, err := c.Cookie("theme") cookie, err := c.Cookie("theme")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal("theme", cookie.Name) assert.Equal(t, "theme", cookie.Name)
assert.Equal("light", cookie.Value) assert.Equal(t, "light", cookie.Value)
} }
// Read multiple // Read multiple
for _, cookie := range c.Cookies() { for _, cookie := range c.Cookies() {
switch cookie.Name { switch cookie.Name {
case "theme": case "theme":
assert.Equal("light", cookie.Value) assert.Equal(t, "light", cookie.Value)
case "user": case "user":
assert.Equal("Jon Snow", cookie.Value) assert.Equal(t, "Jon Snow", cookie.Value)
} }
} }
@ -453,104 +446,95 @@ func TestContextCookie(t *testing.T) {
HttpOnly: true, HttpOnly: true,
} }
c.SetCookie(cookie) c.SetCookie(cookie)
assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID")
assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com")
assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure")
assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
}
func TestContextPath(t *testing.T) {
e := New()
r := e.Router()
handler := func(c Context) error { return c.String(http.StatusOK, "OK") }
r.Add(http.MethodGet, "/users/:id", handler)
c := e.NewContext(nil, nil)
r.Find(http.MethodGet, "/users/1", c)
assert := testify.New(t)
assert.Equal("/users/:id", c.Path())
r.Add(http.MethodGet, "/users/:uid/files/:fid", handler)
c = e.NewContext(nil, nil)
r.Find(http.MethodGet, "/users/1/files/1", c)
assert.Equal("/users/:uid/files/:fid", c.Path())
} }
func TestContextPathParam(t *testing.T) { func TestContextPathParam(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, nil) c := e.NewContext(req, nil).(*context)
params := &PathParams{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
}
// ParamNames // ParamNames
c.SetParamNames("uid", "fid") c.pathParams = params
testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) assert.EqualValues(t, *params, c.PathParams())
// ParamValues
c.SetParamValues("101", "501")
testify.EqualValues(t, []string{"101", "501"}, c.ParamValues())
// Param // Param
testify.Equal(t, "501", c.Param("fid")) assert.Equal(t, "501", c.PathParam("fid"))
testify.Equal(t, "", c.Param("undefined")) assert.Equal(t, "", c.PathParam("undefined"))
} }
func TestContextGetAndSetParam(t *testing.T) { func TestContextGetAndSetParam(t *testing.T) {
e := New() e := New()
r := e.Router() r := e.Router()
r.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) _, err := r.Add(Route{
Method: http.MethodGet,
Path: "/:foo",
Name: "",
Handler: func(Context) error { return nil },
Middlewares: nil,
})
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/:foo", nil) req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
c.SetParamNames("foo")
params := &PathParams{{Name: "foo", Value: "101"}}
// ParamNames
c.(*context).pathParams = params
// round-trip param values with modification // round-trip param values with modification
paramVals := c.ParamValues() paramVals := c.PathParams()
testify.EqualValues(t, []string{""}, c.ParamValues()) assert.Equal(t, *params, c.PathParams())
paramVals[0] = "bar"
c.SetParamValues(paramVals...) paramVals[0] = PathParam{Name: "xxx", Value: "yyy"} // PathParams() returns copy and modifying it does nothing to context
testify.EqualValues(t, []string{"bar"}, c.ParamValues()) assert.Equal(t, PathParams{{Name: "foo", Value: "101"}}, c.PathParams())
pathParams := PathParams{
{Name: "aaa", Value: "bbb"},
{Name: "ccc", Value: "ddd"},
}
c.SetPathParams(pathParams)
assert.Equal(t, pathParams, c.PathParams())
// shouldn't explode during Reset() afterwards! // shouldn't explode during Reset() afterwards!
testify.NotPanics(t, func() { assert.NotPanics(t, func() {
c.Reset(nil, nil) c.(EditableContext).Reset(nil, nil)
}) })
assert.Equal(t, PathParams{}, c.PathParams())
assert.Len(t, *c.(*context).pathParams, 0)
assert.Equal(t, cap(*c.(*context).pathParams), 1)
} }
// Issue #1655 // Issue #1655
func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) { func TestContext_SetParamNamesShouldNotModifyPathParams(t *testing.T) {
assert := testify.New(t)
e := New() e := New()
assert.Equal(0, *e.maxParam) c := e.NewContext(nil, nil).(*context)
expectedOneParam := []string{"one"} assert.Equal(t, 0, e.contextPathParamAllocSize)
expectedTwoParams := []string{"one", "two"} expectedTwoParams := &PathParams{
expectedThreeParams := []string{"one", "two", ""} {Name: "1", Value: "one"},
expectedABCParams := []string{"A", "B", "C"} {Name: "2", Value: "two"},
}
c.SetRawPathParams(expectedTwoParams)
assert.Equal(t, 0, e.contextPathParamAllocSize)
assert.Equal(t, *expectedTwoParams, c.PathParams())
c := e.NewContext(nil, nil) expectedThreeParams := PathParams{
c.SetParamNames("1", "2") {Name: "1", Value: "one"},
c.SetParamValues(expectedTwoParams...) {Name: "2", Value: "two"},
assert.Equal(2, *e.maxParam) {Name: "3", Value: "three"},
assert.EqualValues(expectedTwoParams, c.ParamValues()) }
c.SetPathParams(expectedThreeParams)
c.SetParamNames("1") assert.Equal(t, 0, e.contextPathParamAllocSize)
assert.Equal(2, *e.maxParam) assert.Equal(t, expectedThreeParams, c.PathParams())
// Here for backward compatibility the ParamValues remains as they are
assert.EqualValues(expectedOneParam, c.ParamValues())
c.SetParamNames("1", "2", "3")
assert.Equal(3, *e.maxParam)
// Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam
assert.EqualValues(expectedThreeParams, c.ParamValues())
c.SetParamValues("A", "B", "C", "D")
assert.Equal(3, *e.maxParam)
// Here D shouldn't be returned
assert.EqualValues(expectedABCParams, c.ParamValues())
} }
func TestContextFormValue(t *testing.T) { func TestContextFormValue(t *testing.T) {
@ -564,13 +548,13 @@ func TestContextFormValue(t *testing.T) {
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
// FormValue // FormValue
testify.Equal(t, "Jon Snow", c.FormValue("name")) assert.Equal(t, "Jon Snow", c.FormValue("name"))
testify.Equal(t, "jon@labstack.com", c.FormValue("email")) assert.Equal(t, "jon@labstack.com", c.FormValue("email"))
// FormParams // FormParams
params, err := c.FormParams() params, err := c.FormParams()
if testify.NoError(t, err) { if assert.NoError(t, err) {
testify.Equal(t, url.Values{ assert.Equal(t, url.Values{
"name": []string{"Jon Snow"}, "name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"}, "email": []string{"jon@labstack.com"},
}, params) }, params)
@ -581,8 +565,8 @@ func TestContextFormValue(t *testing.T) {
req.Header.Add(HeaderContentType, MIMEMultipartForm) req.Header.Add(HeaderContentType, MIMEMultipartForm)
c = e.NewContext(req, nil) c = e.NewContext(req, nil)
params, err = c.FormParams() params, err = c.FormParams()
testify.Nil(t, params) assert.Nil(t, params)
testify.Error(t, err) assert.Error(t, err)
} }
func TestContextQueryParam(t *testing.T) { func TestContextQueryParam(t *testing.T) {
@ -594,11 +578,11 @@ func TestContextQueryParam(t *testing.T) {
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
// QueryParam // QueryParam
testify.Equal(t, "Jon Snow", c.QueryParam("name")) assert.Equal(t, "Jon Snow", c.QueryParam("name"))
testify.Equal(t, "jon@labstack.com", c.QueryParam("email")) assert.Equal(t, "jon@labstack.com", c.QueryParam("email"))
// QueryParams // QueryParams
testify.Equal(t, url.Values{ assert.Equal(t, url.Values{
"name": []string{"Jon Snow"}, "name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"}, "email": []string{"jon@labstack.com"},
}, c.QueryParams()) }, c.QueryParams())
@ -609,7 +593,7 @@ func TestContextFormFile(t *testing.T) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
mr := multipart.NewWriter(buf) mr := multipart.NewWriter(buf)
w, err := mr.CreateFormFile("file", "test") w, err := mr.CreateFormFile("file", "test")
if testify.NoError(t, err) { if assert.NoError(t, err) {
w.Write([]byte("test")) w.Write([]byte("test"))
} }
mr.Close() mr.Close()
@ -618,8 +602,8 @@ func TestContextFormFile(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
f, err := c.FormFile("file") f, err := c.FormFile("file")
if testify.NoError(t, err) { if assert.NoError(t, err) {
testify.Equal(t, "test", f.Filename) assert.Equal(t, "test", f.Filename)
} }
} }
@ -634,8 +618,8 @@ func TestContextMultipartForm(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
f, err := c.MultipartForm() f, err := c.MultipartForm()
if testify.NoError(t, err) { if assert.NoError(t, err) {
testify.NotNil(t, f) assert.NotNil(t, f)
} }
} }
@ -644,16 +628,16 @@ func TestContextRedirect(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
testify.Equal(t, http.StatusMovedPermanently, rec.Code) assert.Equal(t, http.StatusMovedPermanently, rec.Code)
testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
} }
func TestContextStore(t *testing.T) { func TestContextStore(t *testing.T) {
var c Context = new(context) var c Context = new(context)
c.Set("name", "Jon Snow") c.Set("name", "Jon Snow")
testify.Equal(t, "Jon Snow", c.Get("name")) assert.Equal(t, "Jon Snow", c.Get("name"))
} }
func BenchmarkContext_Store(b *testing.B) { func BenchmarkContext_Store(b *testing.B) {
@ -671,42 +655,6 @@ func BenchmarkContext_Store(b *testing.B) {
} }
} }
func TestContextHandler(t *testing.T) {
e := New()
r := e.Router()
b := new(bytes.Buffer)
r.Add(http.MethodGet, "/handler", func(Context) error {
_, err := b.Write([]byte("handler"))
return err
})
c := e.NewContext(nil, nil)
r.Find(http.MethodGet, "/handler", c)
err := c.Handler()(c)
testify.Equal(t, "handler", b.String())
testify.NoError(t, err)
}
func TestContext_SetHandler(t *testing.T) {
var c Context = new(context)
testify.Nil(t, c.Handler())
c.SetHandler(func(c Context) error {
return nil
})
testify.NotNil(t, c.Handler())
}
func TestContext_Path(t *testing.T) {
path := "/pa/th"
var c Context = new(context)
c.SetPath(path)
testify.Equal(t, path, c.Path())
}
type validator struct{} type validator struct{}
func (*validator) Validate(i interface{}) error { func (*validator) Validate(i interface{}) error {
@ -717,10 +665,10 @@ func TestContext_Validate(t *testing.T) {
e := New() e := New()
c := e.NewContext(nil, nil) c := e.NewContext(nil, nil)
testify.Error(t, c.Validate(struct{}{})) assert.Error(t, c.Validate(struct{}{}))
e.Validator = &validator{} e.Validator = &validator{}
testify.NoError(t, c.Validate(struct{}{})) assert.NoError(t, c.Validate(struct{}{}))
} }
func TestContext_QueryString(t *testing.T) { func TestContext_QueryString(t *testing.T) {
@ -728,21 +676,21 @@ func TestContext_QueryString(t *testing.T) {
queryString := "query=string&var=val" queryString := "query=string&var=val"
req := httptest.NewRequest(GET, "/?"+queryString, nil) req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil)
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
testify.Equal(t, queryString, c.QueryString()) assert.Equal(t, queryString, c.QueryString())
} }
func TestContext_Request(t *testing.T) { func TestContext_Request(t *testing.T) {
var c Context = new(context) var c Context = new(context)
testify.Nil(t, c.Request()) assert.Nil(t, c.Request())
req := httptest.NewRequest(GET, "/path", nil) req := httptest.NewRequest(http.MethodGet, "/path", nil)
c.SetRequest(req) c.SetRequest(req)
testify.Equal(t, req, c.Request()) assert.Equal(t, req, c.Request())
} }
func TestContext_Scheme(t *testing.T) { func TestContext_Scheme(t *testing.T) {
@ -799,14 +747,14 @@ func TestContext_Scheme(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
testify.Equal(t, tt.s, tt.c.Scheme()) assert.Equal(t, tt.s, tt.c.Scheme())
} }
} }
func TestContext_IsWebSocket(t *testing.T) { func TestContext_IsWebSocket(t *testing.T) {
tests := []struct { tests := []struct {
c Context c Context
ws testify.BoolAssertionFunc ws assert.BoolAssertionFunc
}{ }{
{ {
&context{ &context{
@ -814,7 +762,7 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"websocket"}}, Header: http.Header{HeaderUpgrade: []string{"websocket"}},
}, },
}, },
testify.True, assert.True,
}, },
{ {
&context{ &context{
@ -822,13 +770,13 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, Header: http.Header{HeaderUpgrade: []string{"Websocket"}},
}, },
}, },
testify.True, assert.True,
}, },
{ {
&context{ &context{
request: &http.Request{}, request: &http.Request{},
}, },
testify.False, assert.False,
}, },
{ {
&context{ &context{
@ -836,7 +784,7 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"other"}}, Header: http.Header{HeaderUpgrade: []string{"other"}},
}, },
}, },
testify.False, assert.False,
}, },
} }
@ -849,30 +797,14 @@ func TestContext_IsWebSocket(t *testing.T) {
func TestContext_Bind(t *testing.T) { func TestContext_Bind(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
u := new(user) u := new(user)
req.Header.Add(HeaderContentType, MIMEApplicationJSON) req.Header.Add(HeaderContentType, MIMEApplicationJSON)
err := c.Bind(u) err := c.Bind(u)
testify.NoError(t, err) assert.NoError(t, err)
testify.Equal(t, &user{1, "Jon Snow"}, u) assert.Equal(t, &user{1, "Jon Snow"}, u)
}
func TestContext_Logger(t *testing.T) {
e := New()
c := e.NewContext(nil, nil)
log1 := c.Logger()
testify.NotNil(t, log1)
log2 := log.New("echo2")
c.SetLogger(log2)
testify.Equal(t, log2, c.Logger())
// Resetting the context returns the initial logger
c.Reset(nil, nil)
testify.Equal(t, log1, c.Logger())
} }
func TestContext_RealIP(t *testing.T) { func TestContext_RealIP(t *testing.T) {
@ -925,6 +857,6 @@ func TestContext_RealIP(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
testify.Equal(t, tt.s, tt.c.RealIP()) assert.Equal(t, tt.s, tt.c.RealIP())
} }
} }

890
echo.go

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

16
go.mod
View File

@ -1,17 +1,13 @@
module github.com/labstack/echo/v4 module github.com/labstack/echo/v4
go 1.15 go 1.16
require ( require (
github.com/golang-jwt/jwt v3.2.2+incompatible github.com/davecgh/go-spew v1.1.1 // indirect
github.com/labstack/gommon v0.3.0 github.com/golang-jwt/jwt/v4 v4.0.0
github.com/mattn/go-colorable v0.1.8 // indirect github.com/stretchr/testify v1.7.0
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/stretchr/testify v1.4.0
github.com/valyala/fasttemplate v1.2.1 github.com/valyala/fasttemplate v1.2.1
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
golang.org/x/net v0.0.0-20210913180222-943fd674d43e
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 golang.org/x/time v0.0.0-20201208040808-7e3f01d25324
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
) )

48
go.sum
View File

@ -1,51 +1,29 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o=
github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4=
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210913180222-943fd674d43e h1:+b/22bPvDYt4NPDcy4xAGCmON713ONAWFeY3Z7I3tR8=
golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 h1:xrCZDmdtoloIiooiA9q0OQb9r8HejIHYoHGhGCe1pGg=
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

150
group.go
View File

@ -4,95 +4,117 @@ import (
"net/http" "net/http"
) )
type (
// Group is a set of sub-routes for a specified route. It can be used for inner // Group is a set of sub-routes for a specified route. It can be used for inner
// routes that share a common middleware or functionality that should be separate // routes that share a common middleware or functionality that should be separate
// from the parent echo instance while still inheriting from it. // from the parent echo instance while still inheriting from it.
Group struct { type Group struct {
common
host string host string
prefix string prefix string
middleware []MiddlewareFunc middleware []MiddlewareFunc
echo *Echo echo *Echo
} }
)
// Use implements `Echo#Use()` for sub-routes within the Group. // Use implements `Echo#Use()` for sub-routes within the Group.
// Group middlewares are not executed on request when there is no matching route found.
func (g *Group) Use(middleware ...MiddlewareFunc) { func (g *Group) Use(middleware ...MiddlewareFunc) {
g.middleware = append(g.middleware, middleware...) g.middleware = append(g.middleware, middleware...)
if len(g.middleware) == 0 {
return
}
// Allow all requests to reach the group as they might get dropped if router
// doesn't find a match, making none of the group middleware process.
g.Any("", NotFoundHandler)
g.Any("/*", NotFoundHandler)
} }
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. // CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error.
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodConnect, path, h, m...) return g.Add(http.MethodConnect, path, h, m...)
} }
// DELETE implements `Echo#DELETE()` for sub-routes within the Group. // DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error.
func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodDelete, path, h, m...) return g.Add(http.MethodDelete, path, h, m...)
} }
// GET implements `Echo#GET()` for sub-routes within the Group. // GET implements `Echo#GET()` for sub-routes within the Group. Panics on error.
func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodGet, path, h, m...) return g.Add(http.MethodGet, path, h, m...)
} }
// HEAD implements `Echo#HEAD()` for sub-routes within the Group. // HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error.
func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodHead, path, h, m...) return g.Add(http.MethodHead, path, h, m...)
} }
// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. // OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error.
func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodOptions, path, h, m...) return g.Add(http.MethodOptions, path, h, m...)
} }
// PATCH implements `Echo#PATCH()` for sub-routes within the Group. // PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error.
func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPatch, path, h, m...) return g.Add(http.MethodPatch, path, h, m...)
} }
// POST implements `Echo#POST()` for sub-routes within the Group. // POST implements `Echo#POST()` for sub-routes within the Group. Panics on error.
func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPost, path, h, m...) return g.Add(http.MethodPost, path, h, m...)
} }
// PUT implements `Echo#PUT()` for sub-routes within the Group. // PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error.
func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPut, path, h, m...) return g.Add(http.MethodPut, path, h, m...)
} }
// TRACE implements `Echo#TRACE()` for sub-routes within the Group. // TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error.
func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodTrace, path, h, m...) return g.Add(http.MethodTrace, path, h, m...)
} }
// Any implements `Echo#Any()` for sub-routes within the Group. // Any implements `Echo#Any()` for sub-routes within the Group. Panics on error.
func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
routes := make([]*Route, len(methods)) errs := make([]error, 0)
for i, m := range methods { ris := make(Routes, 0)
routes[i] = g.Add(m, path, handler, middleware...) for _, m := range methods {
ri, err := g.AddRoute(Route{
Method: m,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
errs = append(errs, err)
continue
} }
return routes ris = append(ris, ri)
}
if len(errs) > 0 {
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ris
} }
// Match implements `Echo#Match()` for sub-routes within the Group. // Match implements `Echo#Match()` for sub-routes within the Group. Panics on error.
func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
routes := make([]*Route, len(methods)) errs := make([]error, 0)
for i, m := range methods { ris := make(Routes, 0)
routes[i] = g.Add(m, path, handler, middleware...) for _, m := range methods {
ri, err := g.AddRoute(Route{
Method: m,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
errs = append(errs, err)
continue
} }
return routes ris = append(ris, ri)
}
if len(errs) > 0 {
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ris
} }
// Group creates a new sub-group with prefix and optional sub-group-level middleware. // Group creates a new sub-group with prefix and optional sub-group-level middleware.
// Important! Group middlewares are only executed in case there was exact route match and not
// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add
// a catch-all route `/*` for the group which handler returns always 404
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...) m = append(m, g.middleware...)
@ -102,23 +124,43 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
return return
} }
// Static implements `Echo#Static()` for sub-routes within the Group. // Static implements `Echo#Static()` for sub-routes within the Group. Panics on error.
func (g *Group) Static(prefix, root string) { func (g *Group) Static(prefix, root string, middleware ...MiddlewareFunc) RouteInfo {
g.static(prefix, root, g.GET) return g.Add(
http.MethodGet,
prefix+"*",
StaticDirectoryHandler(root, false),
middleware...,
)
} }
// File implements `Echo#File()` for sub-routes within the Group. // File implements `Echo#File()` for sub-routes within the Group. Panics on error.
func (g *Group) File(path, file string) { func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
g.file(path, file, g.GET) handler := func(c Context) error {
return c.File(file)
}
return g.Add(http.MethodGet, path, handler, middleware...)
} }
// Add implements `Echo#Add()` for sub-routes within the Group. // Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
// Combine into a new slice to avoid accidentally passing the same slice for ri, err := g.AddRoute(Route{
Method: method,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ri
}
// AddRoute registers a new Routable with Router
func (g *Group) AddRoute(route Routable) (RouteInfo, error) {
// Combine middleware into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the // multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls. // middleware from earlier calls.
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) groupRoute := route.ForGroup(g.prefix, append([]MiddlewareFunc{}, g.middleware...))
m = append(m, g.middleware...) return g.echo.add(g.host, groupRoute)
m = append(m, middleware...)
return g.echo.add(g.host, method, g.prefix+path, handler, m...)
} }

View File

@ -1,31 +1,68 @@
package echo package echo
import ( import (
"github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
// TODO: Fix me func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) {
func TestGroup(t *testing.T) { e := New()
g := New().Group("/group")
h := func(Context) error { return nil } called := false
g.CONNECT("/", h) mw := func(next HandlerFunc) HandlerFunc {
g.DELETE("/", h) return func(c Context) error {
g.GET("/", h) called = true
g.HEAD("/", h) return c.NoContent(http.StatusTeapot)
g.OPTIONS("/", h) }
g.PATCH("/", h) }
g.POST("/", h) // even though group has middleware it will not be executed when there are no routes under that group
g.PUT("/", h) _ = e.Group("/group", mw)
g.TRACE("/", h)
g.Any("/", h) status, body := request(http.MethodGet, "/group/nope", e)
g.Match([]string{http.MethodGet, http.MethodPost}, "/", h) assert.Equal(t, http.StatusNotFound, status)
g.Static("/static", "/tmp") assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
g.File("/walle", "_fixture/images//walle.png")
assert.False(t, called)
}
func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) {
e := New()
called := false
mw := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
called = true
return c.NoContent(http.StatusTeapot)
}
}
// even though group has middleware and routes when we have no match on some route the middlewares for that
// group will not be executed
g := e.Group("/group", mw)
g.GET("/yes", handlerFunc)
status, body := request(http.MethodGet, "/group/nope", e)
assert.Equal(t, http.StatusNotFound, status)
assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
assert.False(t, called)
}
func TestGroup_multiLevelGroup(t *testing.T) {
e := New()
api := e.Group("/api")
users := api.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
status, body := request(http.MethodGet, "/api/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
} }
func TestGroupFile(t *testing.T) { func TestGroupFile(t *testing.T) {
@ -92,11 +129,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
} }
m2 := func(next HandlerFunc) HandlerFunc { m2 := func(next HandlerFunc) HandlerFunc {
return func(c Context) error { return func(c Context) error {
return c.String(http.StatusOK, c.Path()) return c.String(http.StatusOK, c.RouteInfo().Path())
} }
} }
h := func(c Context) error { h := func(c Context) error {
return c.String(http.StatusOK, c.Path()) return c.String(http.StatusOK, c.RouteInfo().Path())
} }
g.Use(m1) g.Use(m1)
g.GET("/help", h, m2) g.GET("/help", h, m2)
@ -119,3 +156,442 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
assert.Equal(t, "/*", m) assert.Equal(t, "/*", m)
} }
func TestGroup_CONNECT(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.CONNECT("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodConnect, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodConnect, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_DELETE(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.DELETE("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodDelete, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodDelete, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_HEAD(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.HEAD("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodHead, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodHead+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodHead, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_OPTIONS(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.OPTIONS("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodOptions, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodOptions, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_PATCH(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.PATCH("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPatch, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPatch, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_POST(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.POST("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPost, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPost+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPost, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_PUT(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.PUT("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPut, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPut+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPut, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_TRACE(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.TRACE("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodTrace, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodTrace, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_Any(t *testing.T) {
e := New()
users := e.Group("/users")
ris := users.Any("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Len(t, ris, 11)
for _, m := range methods {
status, body := request(m, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_AnyWithErrors(t *testing.T) {
e := New()
users := e.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
errs := func() (errs []error) {
defer func() {
if r := recover(); r != nil {
if tmpErr, ok := r.([]error); ok {
errs = tmpErr
return
}
panic(r)
}
}()
users.Any("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
return nil
}()
assert.Len(t, errs, 1)
assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
for _, m := range methods {
status, body := request(m, "/users/activate", e)
expect := http.StatusTeapot
if m == http.MethodGet {
expect = http.StatusOK
}
assert.Equal(t, expect, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_Match(t *testing.T) {
e := New()
myMethods := []string{http.MethodGet, http.MethodPost}
users := e.Group("/users")
ris := users.Match(myMethods, "/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Len(t, ris, 2)
for _, m := range myMethods {
status, body := request(m, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_MatchWithErrors(t *testing.T) {
e := New()
users := e.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
myMethods := []string{http.MethodGet, http.MethodPost}
errs := func() (errs []error) {
defer func() {
if r := recover(); r != nil {
if tmpErr, ok := r.([]error); ok {
errs = tmpErr
return
}
panic(r)
}
}()
users.Match(myMethods, "/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
return nil
}()
assert.Len(t, errs, 1)
assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
for _, m := range myMethods {
status, body := request(m, "/users/activate", e)
expect := http.StatusTeapot
if m == http.MethodGet {
expect = http.StatusOK
}
assert.Equal(t, expect, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_Static(t *testing.T) {
e := New()
g := e.Group("/books")
ri := g.Static("/download", "_fixture")
assert.Equal(t, http.MethodGet, ri.Method())
assert.Equal(t, "/books/download*", ri.Path())
assert.Equal(t, "GET:/books/download*", ri.Name())
assert.Equal(t, []string{"*"}, ri.Params())
req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
body := rec.Body.String()
assert.True(t, strings.HasPrefix(body, "<!doctype html>"))
}
func TestGroup_StaticMultiTest(t *testing.T) {
var testCases = []struct {
name string
givenPrefix string
givenRoot string
whenURL string
expectStatus int
expectHeaderLocation string
expectBodyStartsWith string
}{
{
name: "ok",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/test/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "ok, without prefix",
givenPrefix: "",
givenRoot: "_fixture/images",
whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png`
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "nok, without prefix does not serve dir index",
givenPrefix: "",
givenRoot: "_fixture/images",
whenURL: "/test/", // `/test` + `*` creates route `/test*`
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "No file",
givenPrefix: "/images",
givenRoot: "_fixture/scripts",
whenURL: "/test/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/test/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory Redirect",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory Redirect with non-root path",
givenPrefix: "/static",
givenRoot: "_fixture",
whenURL: "/test/static",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/static/",
expectBodyStartsWith: "",
},
{
name: "Prefixed directory 404 (request URL without slash)",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
givenRoot: "_fixture",
whenURL: "/test/folder", // no trailing slash
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Prefixed directory redirect (without slash redirect to slash)",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
givenRoot: "_fixture",
whenURL: "/test/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/",
givenRoot: "_fixture",
whenURL: "/test/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets",
givenRoot: "_fixture",
whenURL: "/test/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Sub-directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/test/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/test/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
g := e.Group("/test")
g.Static(tc.givenPrefix, tc.givenRoot)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String()
if tc.expectBodyStartsWith != "" {
assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
} else {
assert.Equal(t, "", body)
}
if tc.expectHeaderLocation != "" {
assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
} else {
_, ok := rec.Result().Header["Location"]
assert.False(t, ok)
}
})
}
}

74
httperror.go Normal file
View File

@ -0,0 +1,74 @@
package echo
import (
"errors"
"fmt"
"net/http"
)
// Errors
var (
ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType)
ErrNotFound = NewHTTPError(http.StatusNotFound)
ErrUnauthorized = NewHTTPError(http.StatusUnauthorized)
ErrForbidden = NewHTTPError(http.StatusForbidden)
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests)
ErrBadRequest = NewHTTPError(http.StatusBadRequest)
ErrBadGateway = NewHTTPError(http.StatusBadGateway)
ErrInternalServerError = NewHTTPError(http.StatusInternalServerError)
ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout)
ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable)
ErrValidatorNotRegistered = errors.New("validator not registered")
ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
ErrInvalidListenerNetwork = errors.New("invalid listener network")
)
// HTTPError represents an error that occurred while handling a request.
type HTTPError struct {
Code int `json:"-"`
Message interface{} `json:"message"`
Internal error `json:"-"` // Stores the error returned by an external dependency
}
// NewHTTPError creates a new HTTPError instance.
func NewHTTPError(code int, message ...interface{}) *HTTPError { // FIXME: this need cleanup - why vararg if [0] is only used?
he := &HTTPError{Code: code, Message: http.StatusText(code)}
if len(message) > 0 {
he.Message = message[0]
}
return he
}
// NewHTTPErrorWithInternal creates a new HTTPError instance with internal error set.
func NewHTTPErrorWithInternal(code int, internalError error, message ...interface{}) *HTTPError {
he := NewHTTPError(code, message...)
he.Internal = internalError
return he
}
// Error makes it compatible with `error` interface.
func (he *HTTPError) Error() string {
if he.Internal == nil {
return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message)
}
return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal)
}
// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field
func (he *HTTPError) WithInternal(err error) *HTTPError {
return &HTTPError{
Code: he.Code,
Message: he.Message,
Internal: err,
}
}
// Unwrap satisfies the Go 1.13 error wrapper interface.
func (he *HTTPError) Unwrap() error {
return he.Internal
}

52
httperror_test.go Normal file
View File

@ -0,0 +1,52 @@
package echo
import (
"errors"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)
func TestHTTPError(t *testing.T) {
t.Run("non-internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
assert.Equal(t, "code=400, message=map[code:12]", err.Error())
})
t.Run("internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
err = err.WithInternal(errors.New("internal error"))
assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error())
})
}
func TestNewHTTPErrorWithInternal(t *testing.T) {
he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"), "test message")
assert.Equal(t, "code=400, message=test message, internal=test", he.Error())
}
func TestNewHTTPErrorWithInternal_noCustomMessage(t *testing.T) {
he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"))
assert.Equal(t, "code=400, message=Bad Request, internal=test", he.Error())
}
func TestHTTPError_Unwrap(t *testing.T) {
t.Run("non-internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
assert.Nil(t, errors.Unwrap(err))
})
t.Run("internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
err = err.WithInternal(errors.New("internal error"))
assert.Equal(t, "internal error", errors.Unwrap(err).Error())
})
}

11
json.go
View File

@ -23,9 +23,16 @@ func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string
func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error { func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error {
err := json.NewDecoder(c.Request().Body).Decode(i) err := json.NewDecoder(c.Request().Body).Decode(i)
if ute, ok := err.(*json.UnmarshalTypeError); ok { if ute, ok := err.(*json.UnmarshalTypeError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) return NewHTTPErrorWithInternal(
http.StatusBadRequest,
err,
fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset),
)
} else if se, ok := err.(*json.SyntaxError); ok { } else if se, ok := err.(*json.SyntaxError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest,
err,
fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error()),
)
} }
return err return err
} }

168
log.go
View File

@ -1,41 +1,141 @@
package echo package echo
import ( import (
"bytes"
"io" "io"
"strconv"
"github.com/labstack/gommon/log" "sync"
"time"
) )
type ( //-----------------------------------------------------------------------------
// Logger defines the logging interface. // Example for Zap (https://github.com/uber-go/zap)
Logger interface { //func main() {
Output() io.Writer // e := echo.New()
SetOutput(w io.Writer) // logger, _ := zap.NewProduction()
Prefix() string // e.Logger = &ZapLogger{logger: logger}
SetPrefix(p string) //}
Level() log.Lvl //type ZapLogger struct {
SetLevel(v log.Lvl) // logger *zap.Logger
SetHeader(h string) //}
Print(i ...interface{}) //
Printf(format string, args ...interface{}) //func (l *ZapLogger) Write(p []byte) (n int, err error) {
Printj(j log.JSON) // // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
Debug(i ...interface{}) // l.logger.Info(string(p), zap.String("subsystem", "echo")) // naively log everything as string message.
Debugf(format string, args ...interface{}) // return len(p), nil
Debugj(j log.JSON) //}
Info(i ...interface{}) //
Infof(format string, args ...interface{}) //func (l *ZapLogger) Error(err error) {
Infoj(j log.JSON) // l.logger.Error(err.Error(), zap.Error(err), zap.String("subsystem", "echo"))
Warn(i ...interface{}) //}
Warnf(format string, args ...interface{})
Warnj(j log.JSON) //-----------------------------------------------------------------------------
Error(i ...interface{}) // Example for Zerolog (https://github.com/rs/zerolog)
Errorf(format string, args ...interface{}) //func main() {
Errorj(j log.JSON) // e := echo.New()
Fatal(i ...interface{}) // logger := zerolog.New(os.Stdout)
Fatalj(j log.JSON) // e.Logger = &ZeroLogger{logger: &logger}
Fatalf(format string, args ...interface{}) //}
Panic(i ...interface{}) //
Panicj(j log.JSON) //type ZeroLogger struct {
Panicf(format string, args ...interface{}) // logger *zerolog.Logger
//}
//
//func (l *ZeroLogger) Write(p []byte) (n int, err error) {
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
// l.logger.Info().Str("subsystem", "echo").Msg(string(p)) // naively log everything as string message.
// return len(p), nil
//}
//
//func (l *ZeroLogger) Error(err error) {
// l.logger.Error().Str("subsystem", "echo").Err(err).Msg(err.Error())
//}
//-----------------------------------------------------------------------------
// Example for Logrus (https://github.com/sirupsen/logrus)
//func main() {
// e := echo.New()
// e.Logger = &LogrusLogger{logger: logrus.New()}
//}
//
//type LogrusLogger struct {
// logger *logrus.Logger
//}
//
//func (l *LogrusLogger) Write(p []byte) (n int, err error) {
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Info(string(p)) // naively log everything as string message.
// return len(p), nil
//}
//
//func (l *LogrusLogger) Error(err error) {
// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Error(err)
//}
// Logger defines the logging interface that Echo uses internally in few places.
// For logging in handlers use your own logger instance (dependency injected or package/public variable) from logging framework of your choice.
type Logger interface {
// Write provides writer interface for http.Server `ErrorLog` and for logging startup messages.
// `http.Server.ErrorLog` logs errors from accepting connections, unexpected behavior from handlers,
// and underlying FileSystem errors.
// `logger` middleware will use this method to write its JSON payload.
Write(p []byte) (n int, err error)
// Error logs the error
Error(err error)
}
// jsonLogger is similar logger formatting implementation as `v4` had. It is not particularly fast or efficient. Only
// goal it to exist is to have somewhat backwards compatibility with `v4` for Echo internals logging formatting.
// It is not meant for logging in handlers/middlewares. Use some real logging library for those cases.
type jsonLogger struct {
writer io.Writer
bufferPool sync.Pool
timeNow func() time.Time
}
func newJSONLogger(writer io.Writer) *jsonLogger {
return &jsonLogger{
writer: writer,
bufferPool: sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 256))
},
},
timeNow: time.Now,
}
}
func (l *jsonLogger) Write(p []byte) (n int, err error) {
pLen := len(p)
if pLen >= 2 && // naively try to avoid JSON values to be wrapped into message
(p[0] == '{' && p[pLen-2] == '}' && p[pLen-1] == '\n') ||
(p[0] == '{' && p[pLen-1] == '}') {
return l.writer.Write(p)
}
// we log with WARN level as we have no idea what that message level should be. From Echo perspective this method is
// called when we pass Echo logger to http.Server.ErrorLog and there are problems inside http.Server - which probably
// deserves at least WARN level.
return l.printf("WARN", string(p))
}
func (l *jsonLogger) Error(err error) {
_, _ = l.printf("ERROR", err.Error())
}
func (l *jsonLogger) printf(level string, message string) (n int, err error) {
buf := l.bufferPool.Get().(*bytes.Buffer)
buf.Reset()
defer l.bufferPool.Put(buf)
buf.WriteString(`{"time":"`)
buf.WriteString(l.timeNow().Format(time.RFC3339Nano))
buf.WriteString(`","level":"`)
buf.WriteString(level)
buf.WriteString(`","prefix":"echo","message":`)
buf.WriteString(strconv.Quote(message))
buf.WriteString("}\n")
return l.writer.Write(buf.Bytes())
} }
)

77
log_test.go Normal file
View File

@ -0,0 +1,77 @@
package echo
import (
"bytes"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestJsonLogger_Write(t *testing.T) {
var testCases = []struct {
name string
when []byte
expect string
}{
{
name: "ok, write non JSONlike message",
when: []byte("version: %v, build: %v"),
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: %v, build: %v"}` + "\n",
},
{
name: "ok, write quoted message",
when: []byte(`version: "%v"`),
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: \"%v\""}` + "\n",
},
{
name: "ok, write JSON",
when: []byte(`{"version": 123}` + "\n"),
expect: `{"version": 123}` + "\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
logger := newJSONLogger(buf)
logger.timeNow = func() time.Time {
return time.Unix(1631045377, 0)
}
_, err := logger.Write(tc.when)
result := buf.String()
assert.Equal(t, tc.expect, result)
assert.NoError(t, err)
})
}
}
func TestJsonLogger_Error(t *testing.T) {
var testCases = []struct {
name string
whenError error
expect string
}{
{
name: "ok",
whenError: ErrForbidden,
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"ERROR","prefix":"echo","message":"code=403, message=Forbidden"}` + "\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
logger := newJSONLogger(buf)
logger.timeNow = func() time.Time {
return time.Unix(1631045377, 0)
}
logger.Error(tc.whenError)
result := buf.String()
assert.Equal(t, tc.expect, result)
})
}
}

13
middleware/DEVELOPMENT.md Normal file
View File

@ -0,0 +1,13 @@
# Development Guidelines for middlewares
// FIXME: add info about `MiddlewareConfigurator` interface
## Best practices:
* Do not use `panic` in middleware creator functions in case of invalid configuration.
* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead
because previous middlewares up in call chain could have logic for dealing with returned errors.
* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they
want to create middleware with panics or with returning errors on configuration errors.
* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same.

View File

@ -1,64 +1,59 @@
package middleware package middleware
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // BasicAuthConfig defines the config for BasicAuthWithConfig middleware.
// BasicAuthConfig defines the config for BasicAuth middleware. type BasicAuthConfig struct {
BasicAuthConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// Validator is a function to validate BasicAuth credentials. // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers
// this function would be called once for each header until first valid result is returned
// Required. // Required.
Validator BasicAuthValidator Validator BasicAuthValidator
// Realm is a string to define realm attribute of BasicAuth. // Realm is a string to define realm attribute of BasicAuthWithConfig.
// Default value "Restricted". // Default value "Restricted".
Realm string Realm string
} }
// BasicAuthValidator defines a function to validate BasicAuth credentials. // BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials.
BasicAuthValidator func(string, string, echo.Context) (bool, error) type BasicAuthValidator func(c echo.Context, user string, password string) (bool, error)
)
const ( const (
basic = "basic" basic = "basic"
defaultRealm = "Restricted" defaultRealm = "Restricted"
) )
var (
// DefaultBasicAuthConfig is the default BasicAuth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{
Skipper: DefaultSkipper,
Realm: defaultRealm,
}
)
// BasicAuth returns an BasicAuth middleware. // BasicAuth returns an BasicAuth middleware.
// //
// For valid credentials it calls the next handler. // For valid credentials it calls the next handler.
// For missing or invalid credentials, it sends "401 - Unauthorized" response. // For missing or invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig return BasicAuthWithConfig(BasicAuthConfig{Validator: fn})
c.Validator = fn
return BasicAuthWithConfig(c)
} }
// BasicAuthWithConfig returns an BasicAuth middleware with config. // BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config.
// See `BasicAuth()`.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration
func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Validator == nil { if config.Validator == nil {
panic("echo: basic-auth middleware requires a validator function") return nil, errors.New("echo basic-auth middleware requires a validator function")
} }
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultBasicAuthConfig.Skipper config.Skipper = DefaultSkipper
} }
if config.Realm == "" { if config.Realm == "" {
config.Realm = defaultRealm config.Realm = defaultRealm
@ -70,27 +65,31 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
auth := c.Request().Header.Get(echo.HeaderAuthorization) var lastError error
l := len(basic) l := len(basic)
for _, auth := range c.Request().Header[echo.HeaderAuthorization] {
if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
b, err := base64.StdEncoding.DecodeString(auth[l+1:]) continue
if err != nil {
return err
} }
cred := string(b)
for i := 0; i < len(cred); i++ { b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
if cred[i] == ':' { if errDecode != nil {
// Verify credentials lastError = fmt.Errorf("invalid basic auth value: %w", errDecode)
valid, err := config.Validator(cred[:i], cred[i+1:], c) continue
if err != nil { }
return err idx := bytes.IndexByte(b, ':')
if idx >= 0 {
valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:]))
if errValidate != nil {
lastError = errValidate
} else if valid { } else if valid {
return next(c) return next(c)
} }
break
} }
} }
if lastError != nil {
return lastError
} }
realm := defaultRealm realm := defaultRealm
@ -102,5 +101,5 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
return echo.ErrUnauthorized return echo.ErrUnauthorized
} }
} }, nil
} }

View File

@ -2,6 +2,7 @@ package middleware
import ( import (
"encoding/base64" "encoding/base64"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -12,60 +13,146 @@ import (
) )
func TestBasicAuth(t *testing.T) { func TestBasicAuth(t *testing.T) {
validatorFunc := func(c echo.Context, u, p string) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
}
if u == "error" {
return false, errors.New(p)
}
return false, nil
}
defaultConfig := BasicAuthConfig{Validator: validatorFunc}
var testCases = []struct {
name string
givenConfig BasicAuthConfig
whenAuth []string
expectHeader string
expectErr string
}{
{
name: "ok",
givenConfig: defaultConfig,
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, multiple",
givenConfig: defaultConfig,
whenAuth: []string{
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
basic + " NOT_BASE64",
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
},
},
{
name: "nok, invalid Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, not base64 Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
expectErr: "invalid basic auth value: illegal base64 data at input byte 3",
},
{
name: "nok, missing Authorization header",
givenConfig: defaultConfig,
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "ok, realm",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, realm, case-insensitive header scheme",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "nok, realm, invalid Authorization header",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm="someRealm"`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, validator func returns an error",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
expectErr: "my_error",
},
{
name: "ok, skipped",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool {
return true
}},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder() res := httptest.NewRecorder()
c := e.NewContext(req, res) c := e.NewContext(req, res)
f := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" { config := tc.givenConfig
mw, err := config.ToMiddleware()
assert.NoError(t, err)
h := mw(func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
if len(tc.whenAuth) != 0 {
for _, a := range tc.whenAuth {
req.Header.Add(echo.HeaderAuthorization, a)
}
}
err = h(c)
if tc.expectErr != "" {
assert.Equal(t, http.StatusOK, res.Code)
assert.EqualError(t, err, tc.expectErr)
} else {
assert.Equal(t, http.StatusTeapot, res.Code)
assert.NoError(t, err)
}
if tc.expectHeader != "" {
assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
}
})
}
}
func TestBasicAuth_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BasicAuth(nil)
assert.NotNil(t, mw)
})
mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) {
return true, nil return true, nil
})
assert.NotNil(t, mw)
} }
return false, nil
} func TestBasicAuthWithConfig_panic(t *testing.T) {
h := BasicAuth(f)(func(c echo.Context) error { assert.Panics(t, func() {
return c.String(http.StatusOK, "test") mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
assert.NotNil(t, mw)
}) })
assert := assert.New(t) mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) {
return true, nil
// Valid credentials }})
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) assert.NotNil(t, mw)
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
h = BasicAuthWithConfig(BasicAuthConfig{
Skipper: nil,
Validator: f,
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
// Invalid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
} }

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
@ -11,9 +12,8 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type (
// BodyDumpConfig defines the config for BodyDump middleware. // BodyDumpConfig defines the config for BodyDump middleware.
BodyDumpConfig struct { type BodyDumpConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
@ -23,51 +23,45 @@ type (
} }
// BodyDumpHandler receives the request and response payload. // BodyDumpHandler receives the request and response payload.
BodyDumpHandler func(echo.Context, []byte, []byte) type BodyDumpHandler func(c echo.Context, reqBody []byte, resBody []byte)
bodyDumpResponseWriter struct { type bodyDumpResponseWriter struct {
io.Writer io.Writer
http.ResponseWriter http.ResponseWriter
} }
)
var (
// DefaultBodyDumpConfig is the default BodyDump middleware config.
DefaultBodyDumpConfig = BodyDumpConfig{
Skipper: DefaultSkipper,
}
)
// BodyDump returns a BodyDump middleware. // BodyDump returns a BodyDump middleware.
// //
// BodyDump middleware captures the request and response payload and calls the // BodyDump middleware captures the request and response payload and calls the
// registered handler. // registered handler.
func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc { func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
c := DefaultBodyDumpConfig return BodyDumpWithConfig(BodyDumpConfig{Handler: handler})
c.Handler = handler
return BodyDumpWithConfig(c)
} }
// BodyDumpWithConfig returns a BodyDump middleware with config. // BodyDumpWithConfig returns a BodyDump middleware with config.
// See: `BodyDump()`. // See: `BodyDump()`.
func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration
func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Handler == nil { if config.Handler == nil {
panic("echo: body-dump middleware requires a handler function") return nil, errors.New("echo body-dump middleware requires a handler function")
} }
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper config.Skipper = DefaultSkipper
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) error {
if config.Skipper(c) { if config.Skipper(c) {
return next(c) return next(c)
} }
// Request // Request
reqBody := []byte{} reqBody := []byte{}
if c.Request().Body != nil { // Read if c.Request().Body != nil {
reqBody, _ = ioutil.ReadAll(c.Request().Body) reqBody, _ = ioutil.ReadAll(c.Request().Body)
} }
c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset
@ -78,16 +72,14 @@ func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer} writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer}
c.Response().Writer = writer c.Response().Writer = writer
if err = next(c); err != nil { err := next(c)
c.Error(err)
}
// Callback // Callback
config.Handler(c, reqBody, resBody.Bytes()) config.Handler(c, reqBody, resBody.Bytes())
return return err
}
} }
}, nil
} }
func (w *bodyDumpResponseWriter) WriteHeader(code int) { func (w *bodyDumpResponseWriter) WriteHeader(code int) {

View File

@ -28,31 +28,48 @@ func TestBodyDump(t *testing.T) {
requestBody := "" requestBody := ""
responseBody := "" responseBody := ""
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody) requestBody = string(reqBody)
responseBody = string(resBody) responseBody = string(resBody)
}) }}.ToMiddleware()
assert.NoError(t, err)
assert := assert.New(t) if assert.NoError(t, mw(h)(c)) {
assert.Equal(t, requestBody, hw)
if assert.NoError(mw(h)(c)) { assert.Equal(t, responseBody, hw)
assert.Equal(requestBody, hw) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(responseBody, hw) assert.Equal(t, hw, rec.Body.String())
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.String())
} }
// Must set default skipper }
BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil, func TestBodyDump_skipper(t *testing.T) {
Handler: func(c echo.Context, reqBody, resBody []byte) { e := echo.New()
requestBody = string(reqBody)
responseBody = string(resBody) isCalled := false
mw, err := BodyDumpConfig{
Skipper: func(c echo.Context) bool {
return true
}, },
}) Handler: func(c echo.Context, reqBody, resBody []byte) {
isCalled = true
},
}.ToMiddleware()
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
return errors.New("some error")
} }
func TestBodyDumpFails(t *testing.T) { err = mw(h)(c)
assert.EqualError(t, err, "some error")
assert.False(t, isCalled)
}
func TestBodyDump_fails(t *testing.T) {
e := echo.New() e := echo.New()
hw := "Hello, World!" hw := "Hello, World!"
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
@ -62,30 +79,37 @@ func TestBodyDumpFails(t *testing.T) {
return errors.New("some error") return errors.New("some error")
} }
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.EqualError(t, err, "some error")
assert.Equal(t, http.StatusOK, rec.Code)
if !assert.Error(t, mw(h)(c)) {
t.FailNow()
} }
func TestBodyDumpWithConfig_panic(t *testing.T) {
assert.Panics(t, func() { assert.Panics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{ mw := BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil, Skipper: nil,
Handler: nil, Handler: nil,
}) })
assert.NotNil(t, mw)
}) })
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{ mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}})
Skipper: func(c echo.Context) bool { assert.NotNil(t, mw)
return true })
}, }
Handler: func(c echo.Context, reqBody, resBody []byte) {
}, func TestBodyDump_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BodyDump(nil)
assert.NotNil(t, mw)
}) })
if !assert.Error(t, mw(h)(c)) { assert.NotPanics(t, func() {
t.FailNow() BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
}
}) })
} }

View File

@ -1,98 +1,83 @@
package middleware package middleware
import ( import (
"fmt"
"io" "io"
"sync" "sync"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/bytes"
) )
type ( // BodyLimitConfig defines the config for BodyLimitWithConfig middleware.
// BodyLimitConfig defines the config for BodyLimit middleware. type BodyLimitConfig struct {
BodyLimitConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// Maximum allowed size for a request body, it can be specified // LimitBytes is maximum allowed size in bytes for a request body
// as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. LimitBytes int64
Limit string `yaml:"limit"`
limit int64
} }
limitedReader struct { type limitedReader struct {
BodyLimitConfig BodyLimitConfig
reader io.ReadCloser reader io.ReadCloser
read int64 read int64
context echo.Context context echo.Context
} }
)
var (
// DefaultBodyLimitConfig is the default BodyLimit middleware config.
DefaultBodyLimitConfig = BodyLimitConfig{
Skipper: DefaultSkipper,
}
)
// BodyLimit returns a BodyLimit middleware. // BodyLimit returns a BodyLimit middleware.
// //
// BodyLimit middleware sets the maximum allowed size for a request body, if the // BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it
// size exceeds the configured limit, it sends "413 - Request Entity Too Large" // sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request
// response. The BodyLimit is determined based on both `Content-Length` request
// header and actual content read, which makes it super secure. // header and actual content read, which makes it super secure.
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, func BodyLimit(limitBytes int64) echo.MiddlewareFunc {
// G, T or P. return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes})
func BodyLimit(limit string) echo.MiddlewareFunc {
c := DefaultBodyLimitConfig
c.Limit = limit
return BodyLimitWithConfig(c)
} }
// BodyLimitWithConfig returns a BodyLimit middleware with config. // BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for
// See: `BodyLimit()`. // a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response.
// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which
// makes it super secure.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
if config.Skipper == nil {
config.Skipper = DefaultBodyLimitConfig.Skipper
} }
limit, err := bytes.Parse(config.Limit) // ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration
if err != nil { func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit)) if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
pool := sync.Pool{
New: func() interface{} {
return &limitedReader{BodyLimitConfig: config}
},
} }
config.limit = limit
pool := limitedReaderPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
if config.Skipper(c) { if config.Skipper(c) {
return next(c) return next(c)
} }
req := c.Request() req := c.Request()
// Based on content length // Based on content length
if req.ContentLength > config.limit { if req.ContentLength > config.LimitBytes {
return echo.ErrStatusRequestEntityTooLarge return echo.ErrStatusRequestEntityTooLarge
} }
// Based on content read // Based on content read
r := pool.Get().(*limitedReader) r := pool.Get().(*limitedReader)
r.Reset(req.Body, c) r.Reset(c, req.Body)
defer pool.Put(r) defer pool.Put(r)
req.Body = r req.Body = r
return next(c) return next(c)
} }
} }, nil
} }
func (r *limitedReader) Read(b []byte) (n int, err error) { func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b) n, err = r.reader.Read(b)
r.read += int64(n) r.read += int64(n)
if r.read > r.limit { if r.read > r.LimitBytes {
return n, echo.ErrStatusRequestEntityTooLarge return n, echo.ErrStatusRequestEntityTooLarge
} }
return return
@ -102,16 +87,8 @@ func (r *limitedReader) Close() error {
return r.reader.Close() return r.reader.Close()
} }
func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) { func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) {
r.reader = reader r.reader = reader
r.context = context r.context = context
r.read = 0 r.read = 0
} }
func limitedReaderPool(c BodyLimitConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
return &limitedReader{BodyLimitConfig: c}
},
}
}

View File

@ -11,6 +11,137 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
// Based on content length (within limit)
mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
// Based on content read (overlimit)
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
assert.NoError(t, err)
he := mw(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Hello, World!", rec.Body.String())
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
assert.NoError(t, err)
he = mw(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
}
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
config := BodyLimitConfig{
Skipper: DefaultSkipper,
LimitBytes: 2,
}
reader := &limitedReader{
BodyLimitConfig: config,
reader: ioutil.NopCloser(bytes.NewReader(hw)),
context: e.NewContext(req, rec),
}
// read all should return ErrStatusRequestEntityTooLarge
_, err := ioutil.ReadAll(reader)
he := err.(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
reader.Reset(e.NewContext(req, rec), ioutil.NopCloser(bytes.NewReader(hw)))
n, err := reader.Read(bt)
assert.Equal(t, 2, n)
assert.Equal(t, nil, err)
}
func TestBodyLimit_skipper(t *testing.T) {
e := echo.New()
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
mw, err := BodyLimitConfig{
Skipper: func(c echo.Context) bool {
return true
},
LimitBytes: 2,
}.ToMiddleware()
assert.NoError(t, err)
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
func TestBodyLimitWithConfig(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB})
err := mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
func TestBodyLimit(t *testing.T) { func TestBodyLimit(t *testing.T) {
e := echo.New() e := echo.New()
hw := []byte("Hello, World!") hw := []byte("Hello, World!")
@ -25,61 +156,10 @@ func TestBodyLimit(t *testing.T) {
return c.String(http.StatusOK, string(body)) return c.String(http.StatusOK, string(body))
} }
assert := assert.New(t) mw := BodyLimit(2 * MB)
// Based on content length (within limit) err := mw(h)(c)
if assert.NoError(BodyLimit("2M")(h)(c)) { assert.NoError(t, err)
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.Bytes()) assert.Equal(t, hw, rec.Body.Bytes())
}
// Based on content read (overlimit)
he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
if assert.NoError(BodyLimit("2M")(h)(c)) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal("Hello, World!", rec.Body.String())
}
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
}
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
config := BodyLimitConfig{
Skipper: DefaultSkipper,
Limit: "2B",
limit: 2,
}
reader := &limitedReader{
BodyLimitConfig: config,
reader: ioutil.NopCloser(bytes.NewReader(hw)),
context: e.NewContext(req, rec),
}
// read all should return ErrStatusRequestEntityTooLarge
_, err := ioutil.ReadAll(reader)
he := err.(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec))
n, err := reader.Read(bt)
assert.Equal(t, 2, n)
assert.Equal(t, nil, err)
} }

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"bufio" "bufio"
"compress/gzip" "compress/gzip"
"errors"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
@ -13,50 +14,45 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( const (
gzipScheme = "gzip"
)
// GzipConfig defines the config for Gzip middleware. // GzipConfig defines the config for Gzip middleware.
GzipConfig struct { type GzipConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// Gzip compression level. // Gzip compression level.
// Optional. Default value -1. // Optional. Default value -1.
Level int `yaml:"level"` Level int
} }
gzipResponseWriter struct { type gzipResponseWriter struct {
io.Writer io.Writer
http.ResponseWriter http.ResponseWriter
} }
)
const ( // Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
gzipScheme = "gzip"
)
var (
// DefaultGzipConfig is the default Gzip middleware config.
DefaultGzipConfig = GzipConfig{
Skipper: DefaultSkipper,
Level: -1,
}
)
// Gzip returns a middleware which compresses HTTP response using gzip compression
// scheme.
func Gzip() echo.MiddlewareFunc { func Gzip() echo.MiddlewareFunc {
return GzipWithConfig(DefaultGzipConfig) return GzipWithConfig(GzipConfig{})
} }
// GzipWithConfig return Gzip middleware with config. // GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
// See: `Gzip()`.
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration
func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper config.Skipper = DefaultSkipper
}
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
return nil, errors.New("invalid gzip level")
} }
if config.Level == 0 { if config.Level == 0 {
config.Level = DefaultGzipConfig.Level config.Level = -1
} }
pool := gzipCompressPool(config) pool := gzipCompressPool(config)
@ -97,7 +93,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
} }, nil
} }
func (w *gzipResponseWriter) WriteHeader(code int) { func (w *gzipResponseWriter) WriteHeader(code int) {

View File

@ -3,94 +3,128 @@ package middleware
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"testing" "testing"
"time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestGzip(t *testing.T) { func TestGzip_NoAcceptEncodingHeader(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Skip if no Accept-Encoding header // Skip if no Accept-Encoding header
h := Gzip()(func(c echo.Context) error { h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil return nil
}) })
h(c)
assert := assert.New(t) e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
assert.Equal("test", rec.Body.String()) err := h(c)
assert.NoError(t, err)
// Gzip assert.Equal(t, "test", rec.Body.String())
req = httptest.NewRequest(http.MethodGet, "/", nil) }
func TestMustGzipWithConfig_panics(t *testing.T) {
assert.Panics(t, func() {
GzipWithConfig(GzipConfig{Level: 999})
})
}
func TestGzip_AcceptEncodingHeader(t *testing.T) {
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec) rec := httptest.NewRecorder()
h(c) c := e.NewContext(req, rec)
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) err := h(c)
assert.NoError(t, err)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
r, err := gzip.NewReader(rec.Body) r, err := gzip.NewReader(rec.Body)
if assert.NoError(err) { assert.NoError(t, err)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
defer r.Close() defer r.Close()
buf.ReadFrom(r) buf.ReadFrom(r)
assert.Equal("test", buf.String()) assert.Equal(t, "test", buf.String())
} }
chunkBuf := make([]byte, 5) func TestGzip_chunked(t *testing.T) {
e := echo.New()
// Gzip chunked req := httptest.NewRequest(http.MethodGet, "/", nil)
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec = httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c = e.NewContext(req, rec) chunkChan := make(chan struct{})
Gzip()(func(c echo.Context) error { waitChan := make(chan struct{})
h := Gzip()(func(c echo.Context) error {
c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Transfer-Encoding", "chunked") c.Response().Header().Set("Transfer-Encoding", "chunked")
// Write and flush the first part of the data // Write and flush the first part of the data
c.Response().Write([]byte("test\n")) c.Response().Write([]byte("first\n"))
c.Response().Flush() c.Response().Flush()
// Read the first part of the data chunkChan <- struct{}{}
assert.True(rec.Flushed) <-waitChan
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r.Reset(rec.Body)
_, err = io.ReadFull(r, chunkBuf)
assert.NoError(err)
assert.Equal("test\n", string(chunkBuf))
// Write and flush the second part of the data // Write and flush the second part of the data
c.Response().Write([]byte("test\n")) c.Response().Write([]byte("second\n"))
c.Response().Flush() c.Response().Flush()
_, err = io.ReadFull(r, chunkBuf) chunkChan <- struct{}{}
assert.NoError(err) <-waitChan
assert.Equal("test\n", string(chunkBuf))
// Write the final part of the data and return // Write the final part of the data and return
c.Response().Write([]byte("test")) c.Response().Write([]byte("third"))
return nil
})(c)
chunkChan <- struct{}{}
return nil
})
go func() {
err := h(c)
chunkChan <- struct{}{}
assert.NoError(t, err)
}()
<-chunkChan // wait for first write
waitChan <- struct{}{}
<-chunkChan // wait for second write
waitChan <- struct{}{}
<-chunkChan // wait for final write in handler
<-chunkChan // wait for return from handler
time.Sleep(5 * time.Millisecond) // to have time for flushing
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r, err := gzip.NewReader(rec.Body)
assert.NoError(t, err)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r) buf.ReadFrom(r)
assert.Equal("test", buf.String()) assert.Equal(t, "first\nsecond\nthird", buf.String())
} }
func TestGzipNoContent(t *testing.T) { func TestGzip_NoContent(t *testing.T) {
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
@ -106,7 +140,7 @@ func TestGzipNoContent(t *testing.T) {
} }
} }
func TestGzipErrorReturned(t *testing.T) { func TestGzip_ErrorReturned(t *testing.T) {
e := echo.New() e := echo.New()
e.Use(Gzip()) e.Use(Gzip())
e.GET("/", func(c echo.Context) error { e.GET("/", func(c echo.Context) error {
@ -120,31 +154,25 @@ func TestGzipErrorReturned(t *testing.T) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
} }
func TestGzipErrorReturnedInvalidConfig(t *testing.T) { func TestGzipWithConfig_invalidLevel(t *testing.T) {
e := echo.New() mw, err := GzipConfig{Level: 12}.ToMiddleware()
// Invalid level assert.EqualError(t, err, "invalid gzip level")
e.Use(GzipWithConfig(GzipConfig{Level: 12})) assert.Nil(t, mw)
e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("test"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "gzip")
} }
// Issue #806 // Issue #806
func TestGzipWithStatic(t *testing.T) { func TestGzipWithStatic(t *testing.T) {
e := echo.New() e := echo.New()
e.Filesystem = os.DirFS("../")
e.Use(Gzip()) e.Use(Gzip())
e.Static("/test", "../_fixture/images") e.Static("/test", "_fixture/images")
req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
// Data is written out in chunks when Content-Length == "", so only // Data is written out in chunks when Content-Length == "", so only
// validate the content length if it's not set. // validate the content length if it's not set.

View File

@ -9,60 +9,56 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type (
// CORSConfig defines the config for CORS middleware. // CORSConfig defines the config for CORS middleware.
CORSConfig struct { type CORSConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// AllowOrigin defines a list of origins that may access the resource. // AllowOrigin defines a list of origins that may access the resource.
// Optional. Default value []string{"*"}. // Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"` AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It takes the // AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If // origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is // an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored. // set, AllowOrigins is ignored.
// Optional. // Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` AllowOriginFunc func(origin string) (bool, error)
// AllowMethods defines a list methods allowed when accessing the resource. // AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request. // This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods. // Optional. Default value DefaultCORSConfig.AllowMethods.
AllowMethods []string `yaml:"allow_methods"` AllowMethods []string
// AllowHeaders defines a list of request headers that can be used when // AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request. // making the actual request. This is in response to a preflight request.
// Optional. Default value []string{}. // Optional. Default value []string{}.
AllowHeaders []string `yaml:"allow_headers"` AllowHeaders []string
// AllowCredentials indicates whether or not the response to the request // AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of // can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the // a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials. // actual request can be made using credentials.
// Optional. Default value false. // Optional. Default value false.
AllowCredentials bool `yaml:"allow_credentials"` AllowCredentials bool
// ExposeHeaders defines a whitelist headers that clients are allowed to // ExposeHeaders defines a whitelist headers that clients are allowed to
// access. // access.
// Optional. Default value []string{}. // Optional. Default value []string{}.
ExposeHeaders []string `yaml:"expose_headers"` ExposeHeaders []string
// MaxAge indicates how long (in seconds) the results of a preflight request // MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached. // can be cached.
// Optional. Default value 0. // Optional. Default value 0.
MaxAge int `yaml:"max_age"` MaxAge int
} }
)
var (
// DefaultCORSConfig is the default CORS middleware config. // DefaultCORSConfig is the default CORS middleware config.
DefaultCORSConfig = CORSConfig{ var DefaultCORSConfig = CORSConfig{
Skipper: DefaultSkipper, Skipper: DefaultSkipper,
AllowOrigins: []string{"*"}, AllowOrigins: []string{"*"},
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
} }
)
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS // See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
@ -70,9 +66,14 @@ func CORS() echo.MiddlewareFunc {
return CORSWithConfig(DefaultCORSConfig) return CORSWithConfig(DefaultCORSConfig)
} }
// CORSWithConfig returns a CORS middleware with config. // CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
// See: `CORS()`. // See: `CORS()`.
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration
func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultCORSConfig.Skipper config.Skipper = DefaultCORSConfig.Skipper
@ -207,5 +208,5 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
} }
return c.NoContent(http.StatusNoContent) return c.NoContent(http.StatusNoContent)
} }
} }, nil
} }

View File

@ -17,7 +17,7 @@ func TestCORS(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := CORS()(echo.NotFoundHandler) h := CORS()(func(c echo.Context) error { return echo.ErrNotFound })
req.Header.Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderOrigin, "localhost")
h(c) h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -26,7 +26,7 @@ func TestCORS(t *testing.T) {
req = httptest.NewRequest(http.MethodGet, "/", nil) req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h = CORS()(echo.NotFoundHandler) h = CORS()(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
@ -38,7 +38,7 @@ func TestCORS(t *testing.T) {
AllowOrigins: []string{"localhost"}, AllowOrigins: []string{"localhost"},
AllowCredentials: true, AllowCredentials: true,
MaxAge: 3600, MaxAge: 3600,
})(echo.NotFoundHandler) })(func(c echo.Context) error { return echo.ErrNotFound })
req.Header.Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderOrigin, "localhost")
h(c) h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -55,7 +55,7 @@ func TestCORS(t *testing.T) {
AllowCredentials: true, AllowCredentials: true,
MaxAge: 3600, MaxAge: 3600,
}) })
h = cors(echo.NotFoundHandler) h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
@ -73,7 +73,7 @@ func TestCORS(t *testing.T) {
AllowCredentials: true, AllowCredentials: true,
MaxAge: 3600, MaxAge: 3600,
}) })
h = cors(echo.NotFoundHandler) h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
@ -90,7 +90,7 @@ func TestCORS(t *testing.T) {
cors = CORSWithConfig(CORSConfig{ cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"}, AllowOrigins: []string{"*"},
}) })
h = cors(echo.NotFoundHandler) h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
@ -104,7 +104,7 @@ func TestCORS(t *testing.T) {
cors = CORSWithConfig(CORSConfig{ cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"http://*.example.com"}, AllowOrigins: []string{"http://*.example.com"},
}) })
h = cors(echo.NotFoundHandler) h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -149,7 +149,7 @@ func Test_allowOriginScheme(t *testing.T) {
cors := CORSWithConfig(CORSConfig{ cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern}, AllowOrigins: []string{tt.pattern},
}) })
h := cors(echo.NotFoundHandler) h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
if tt.expected { if tt.expected {
@ -240,7 +240,7 @@ func Test_allowOriginSubdomain(t *testing.T) {
cors := CORSWithConfig(CORSConfig{ cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern}, AllowOrigins: []string{tt.pattern},
}) })
h := cors(echo.NotFoundHandler) h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
if tt.expected { if tt.expected {
@ -331,7 +331,7 @@ func TestCorsHeaders(t *testing.T) {
//AllowCredentials: true, //AllowCredentials: true,
//MaxAge: 3600, //MaxAge: 3600,
}) })
h := cors(echo.NotFoundHandler) h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
@ -387,11 +387,11 @@ func Test_allowOriginFunc(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin) req.Header.Set(echo.HeaderOrigin, origin)
cors := CORSWithConfig(CORSConfig{ cors, err := CORSConfig{AllowOriginFunc: allowOriginFunc}.ToMiddleware()
AllowOriginFunc: allowOriginFunc, assert.NoError(t, err)
})
h := cors(echo.NotFoundHandler) h := cors(func(c echo.Context) error { return echo.ErrNotFound })
err := h(c) err = h(c)
expected, expectedErr := allowOriginFunc(origin) expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil { if expectedErr != nil {

View File

@ -8,19 +8,21 @@ import (
"time" "time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
) )
type (
// CSRFConfig defines the config for CSRF middleware. // CSRFConfig defines the config for CSRF middleware.
CSRFConfig struct { type CSRFConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// TokenLength is the length of the generated token. // TokenLength is the length of the generated token.
TokenLength uint8 `yaml:"token_length"` TokenLength uint8
// Optional. Default value 32. // Optional. Default value 32.
// Generator defines a function to generate token.
// Optional. Defaults tp randomString(TokenLength).
Generator func() string
// TokenLookup is a string in the form of "<source>:<key>" that is used // TokenLookup is a string in the form of "<source>:<key>" that is used
// to extract token from the request. // to extract token from the request.
// Optional. Default value "header:X-CSRF-Token". // Optional. Default value "header:X-CSRF-Token".
@ -28,49 +30,46 @@ type (
// - "header:<name>" // - "header:<name>"
// - "form:<name>" // - "form:<name>"
// - "query:<name>" // - "query:<name>"
TokenLookup string `yaml:"token_lookup"` TokenLookup string
// Context key to store generated CSRF token into context. // Context key to store generated CSRF token into context.
// Optional. Default value "csrf". // Optional. Default value "csrf".
ContextKey string `yaml:"context_key"` ContextKey string
// Name of the CSRF cookie. This cookie will store CSRF token. // Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf". // Optional. Default value "csrf".
CookieName string `yaml:"cookie_name"` CookieName string
// Domain of the CSRF cookie. // Domain of the CSRF cookie.
// Optional. Default value none. // Optional. Default value none.
CookieDomain string `yaml:"cookie_domain"` CookieDomain string
// Path of the CSRF cookie. // Path of the CSRF cookie.
// Optional. Default value none. // Optional. Default value none.
CookiePath string `yaml:"cookie_path"` CookiePath string
// Max age (in seconds) of the CSRF cookie. // Max age (in seconds) of the CSRF cookie.
// Optional. Default value 86400 (24hr). // Optional. Default value 86400 (24hr).
CookieMaxAge int `yaml:"cookie_max_age"` CookieMaxAge int
// Indicates if CSRF cookie is secure. // Indicates if CSRF cookie is secure.
// Optional. Default value false. // Optional. Default value false.
CookieSecure bool `yaml:"cookie_secure"` CookieSecure bool
// Indicates if CSRF cookie is HTTP only. // Indicates if CSRF cookie is HTTP only.
// Optional. Default value false. // Optional. Default value false.
CookieHTTPOnly bool `yaml:"cookie_http_only"` CookieHTTPOnly bool
// Indicates SameSite mode of the CSRF cookie. // Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode. // Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"` CookieSameSite http.SameSite
} }
// csrfTokenExtractor defines a function that takes `echo.Context` and returns // csrfTokenExtractor defines a function that takes `echo.Context` and returns either a token or an error.
// either a token or an error. type csrfTokenExtractor func(echo.Context) (string, error)
csrfTokenExtractor func(echo.Context) (string, error)
)
var (
// DefaultCSRFConfig is the default CSRF middleware config. // DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{ var DefaultCSRFConfig = CSRFConfig{
Skipper: DefaultSkipper, Skipper: DefaultSkipper,
TokenLength: 32, TokenLength: 32,
TokenLookup: "header:" + echo.HeaderXCSRFToken, TokenLookup: "header:" + echo.HeaderXCSRFToken,
@ -79,18 +78,20 @@ var (
CookieMaxAge: 86400, CookieMaxAge: 86400,
CookieSameSite: http.SameSiteDefaultMode, CookieSameSite: http.SameSiteDefaultMode,
} }
)
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
func CSRF() echo.MiddlewareFunc { func CSRF() echo.MiddlewareFunc {
c := DefaultCSRFConfig return CSRFWithConfig(DefaultCSRFConfig)
return CSRFWithConfig(c)
} }
// CSRFWithConfig returns a CSRF middleware with config. // CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration.
// See `CSRF()`.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration
func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultCSRFConfig.Skipper config.Skipper = DefaultCSRFConfig.Skipper
@ -98,6 +99,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.TokenLength == 0 { if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength config.TokenLength = DefaultCSRFConfig.TokenLength
} }
if config.Generator == nil {
config.Generator = createRandomStringGenerator(config.TokenLength)
}
if config.TokenLookup == "" { if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup config.TokenLookup = DefaultCSRFConfig.TokenLookup
} }
@ -136,7 +140,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
// Generate token // Generate token
if err != nil { if err != nil {
token = random.String(config.TokenLength) token = config.Generator()
} else { } else {
// Reuse token // Reuse token
token = k.Value token = k.Value
@ -181,7 +185,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
} }, nil
} }
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the

View File

@ -9,11 +9,26 @@ import (
"testing" "testing"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestCSRF(t *testing.T) { func TestCSRF(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRF()
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Generate CSRF token
h(c)
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
}
func TestMustCSRFWithConfig(t *testing.T) {
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@ -43,7 +58,7 @@ func TestCSRF(t *testing.T) {
assert.Error(t, h(c)) assert.Error(t, h(c))
// Valid CSRF token // Valid CSRF token
token := random.String(16) token := randomString(16)
req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderCookie, "_csrf="+token)
req.Header.Set(echo.HeaderXCSRFToken, token) req.Header.Set(echo.HeaderXCSRFToken, token)
if assert.NoError(t, h(c)) { if assert.NoError(t, h(c)) {
@ -145,9 +160,10 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{ csrf, err := CSRFConfig{
CookieSameSite: http.SameSiteNoneMode, CookieSameSite: http.SameSiteNoneMode,
}) }.ToMiddleware()
assert.NoError(t, err)
h := csrf(func(c echo.Context) error { h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")

View File

@ -11,16 +11,14 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type (
// DecompressConfig defines the config for Decompress middleware. // DecompressConfig defines the config for Decompress middleware.
DecompressConfig struct { type DecompressConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor GzipDecompressPool Decompressor
} }
)
// GZIPEncoding content-encoding header if set to "gzip", decompress body contents. // GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip" const GZIPEncoding string = "gzip"
@ -30,14 +28,6 @@ type Decompressor interface {
gzipDecompressPool() sync.Pool gzipDecompressPool() sync.Pool
} }
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
)
// DefaultGzipDecompressPool is the default implementation of Decompressor interface // DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct { type DefaultGzipDecompressPool struct {
} }
@ -67,17 +57,21 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
// Decompress decompresses request body based if content encoding type is set to "gzip" with default config // Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc { func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig) return DecompressWithConfig(DecompressConfig{})
} }
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config // DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration.
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration
func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper config.Skipper = DefaultSkipper
} }
if config.GzipDecompressPool == nil { if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool config.GzipDecompressPool = &DefaultGzipDecompressPool{}
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -116,5 +110,5 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
} }, nil
} }

View File

@ -17,6 +17,31 @@ import (
func TestDecompress(t *testing.T) { func TestDecompress(t *testing.T) {
e := echo.New() e := echo.New()
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
// Decompress request body
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}
func TestDecompress_skippedIfNoHeader(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
@ -26,39 +51,42 @@ func TestDecompress(t *testing.T) {
c.Response().Write([]byte("test")) // For Content-Type sniffing c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil return nil
}) })
h(c)
assert := assert.New(t) err := h(c)
assert.Equal("test", rec.Body.String()) assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
} }
func TestDecompressDefaultConfig(t *testing.T) { func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) {
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { h, err := DecompressConfig{}.ToMiddleware()
assert.NoError(t, err)
err = h(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})(c)
assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
}
func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil return nil
}) })
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress // Decompress
body := `{"name": "echo"}` body := `{"name": "echo"}`
@ -67,11 +95,14 @@ func TestDecompressDefaultConfig(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) err := h(c)
assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body) b, err := ioutil.ReadAll(req.Body)
assert.NoError(err) assert.NoError(t, err)
assert.Equal(body, string(b)) assert.Equal(t, body, string(b))
} }
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
@ -82,7 +113,9 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.NewContext(req, rec) e.NewContext(req, rec)
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body) b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err) assert.NoError(t, err)
@ -99,7 +132,10 @@ func TestDecompressNoContent(t *testing.T) {
h := Decompress()(func(c echo.Context) error { h := Decompress()(func(c echo.Context) error {
return c.NoContent(http.StatusNoContent) return c.NoContent(http.StatusNoContent)
}) })
if assert.NoError(t, h(c)) {
err := h(c)
if assert.NoError(t, err) {
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes())) assert.Equal(t, 0, len(rec.Body.Bytes()))
@ -115,7 +151,9 @@ func TestDecompressErrorReturned(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code) assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
} }
@ -132,7 +170,9 @@ func TestDecompressSkipper(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
reqBody, err := ioutil.ReadAll(c.Request().Body) reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err) assert.NoError(t, err)
@ -161,7 +201,9 @@ func TestDecompressPoolError(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := ioutil.ReadAll(c.Request().Body) reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err) assert.NoError(t, err)

148
middleware/extractor.go Normal file
View File

@ -0,0 +1,148 @@
package middleware
import (
"fmt"
"github.com/labstack/echo/v4"
"net/http"
"net/textproto"
"strings"
)
// ErrExtractionValueMissing denotes an error raised when value could not be extracted from request
var ErrExtractionValueMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed value")
// ExtractorType is enum type for where extractor will take its data
type ExtractorType string
const (
// HeaderExtractor tells extractor to take values from request header
HeaderExtractor ExtractorType = "header"
// QueryExtractor tells extractor to take values from request query parameters
QueryExtractor ExtractorType = "query"
// ParamExtractor tells extractor to take values from request route parameters
ParamExtractor ExtractorType = "param"
// CookieExtractor tells extractor to take values from request cookie
CookieExtractor ExtractorType = "cookie"
// FormExtractor tells extractor to take values from request form fields
FormExtractor ExtractorType = "form"
)
func createExtractors(lookups string) ([]valuesExtractor, error) {
sources := strings.Split(lookups, ",")
var extractors []valuesExtractor
for _, source := range sources {
parts := strings.Split(source, ":")
if len(parts) < 2 {
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
}
switch ExtractorType(parts[0]) {
case QueryExtractor:
extractors = append(extractors, valuesFromQuery(parts[1]))
case ParamExtractor:
extractors = append(extractors, valuesFromParam(parts[1]))
case CookieExtractor:
extractors = append(extractors, valuesFromCookie(parts[1]))
case FormExtractor:
extractors = append(extractors, valuesFromForm(parts[1]))
case HeaderExtractor:
prefix := ""
if len(parts) > 2 {
prefix = parts[2]
}
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
}
}
return extractors, nil
}
// valuesFromHeader returns a functions that extracts values from the request header.
func valuesFromHeader(header string, valuePrefix string) valuesExtractor {
prefixLen := len(valuePrefix)
return func(c echo.Context) ([]string, ExtractorType, error) {
values := textproto.MIMEHeader(c.Request().Header).Values(header)
if len(values) == 0 {
return nil, HeaderExtractor, ErrExtractionValueMissing
}
result := make([]string, 0)
for _, value := range values {
if prefixLen == 0 {
result = append(result, value)
continue
}
if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
result = append(result, value[prefixLen:])
}
}
if len(result) == 0 {
return nil, HeaderExtractor, ErrExtractionValueMissing
}
return result, HeaderExtractor, nil
}
}
// valuesFromQuery returns a function that extracts values from the query string.
func valuesFromQuery(param string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
result := c.QueryParams()[param]
if len(result) == 0 {
return nil, QueryExtractor, ErrExtractionValueMissing
}
return result, QueryExtractor, nil
}
}
// valuesFromParam returns a function that extracts values from the url param string.
func valuesFromParam(param string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
result := make([]string, 0)
for _, p := range c.PathParams() {
if param == p.Name {
result = append(result, p.Value)
}
}
if len(result) == 0 {
return nil, ParamExtractor, ErrExtractionValueMissing
}
return result, ParamExtractor, nil
}
}
// valuesFromCookie returns a function that extracts values from the named cookie.
func valuesFromCookie(name string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
cookies := c.Cookies()
if len(cookies) == 0 {
return nil, CookieExtractor, ErrExtractionValueMissing
}
result := make([]string, 0)
for _, cookie := range cookies {
if name == cookie.Name {
result = append(result, cookie.Value)
}
}
if len(result) == 0 {
return nil, CookieExtractor, ErrExtractionValueMissing
}
return result, CookieExtractor, nil
}
}
// valuesFromForm returns a function that extracts values from the form field.
func valuesFromForm(name string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
if err := c.Request().ParseForm(); err != nil {
return nil, FormExtractor, fmt.Errorf("valuesFromForm parse form failed: %w", err)
}
values := c.Request().Form[name]
if len(values) == 0 {
return nil, FormExtractor, ErrExtractionValueMissing
}
result := append([]string{}, values...)
return result, FormExtractor, nil
}
}

View File

@ -0,0 +1,498 @@
package middleware
import (
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestCreateExtractors(t *testing.T) {
var testCases = []struct {
name string
givenRequest func() *http.Request
givenPathParams echo.PathParams
whenLoopups string
expectValues []string
expectExtractorType ExtractorType
expectCreateError string
expectError string
}{
{
name: "ok, header",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer token")
return req
},
whenLoopups: "header:Authorization:Bearer ",
expectValues: []string{"token"},
expectExtractorType: HeaderExtractor,
},
{
name: "ok, form",
givenRequest: func() *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenLoopups: "form:name",
expectValues: []string{"Jon Snow"},
expectExtractorType: FormExtractor,
},
{
name: "ok, cookie",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderCookie, "_csrf=token")
return req
},
whenLoopups: "cookie:_csrf",
expectValues: []string{"token"},
expectExtractorType: CookieExtractor,
},
{
name: "ok, param",
givenPathParams: echo.PathParams{
{Name: "id", Value: "123"},
},
whenLoopups: "param:id",
expectValues: []string{"123"},
expectExtractorType: ParamExtractor,
},
{
name: "ok, query",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
return req
},
whenLoopups: "query:id",
expectValues: []string{"999"},
expectExtractorType: QueryExtractor,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
req = tc.givenRequest()
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(echo.EditableContext)
if tc.givenPathParams != nil {
c.SetRawPathParams(&tc.givenPathParams)
}
extractors, err := createExtractors(tc.whenLoopups)
if tc.expectCreateError != "" {
assert.EqualError(t, err, tc.expectCreateError)
return
}
assert.NoError(t, err)
for _, e := range extractors {
values, eType, eErr := e(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, tc.expectExtractorType, eType)
if tc.expectError != "" {
assert.EqualError(t, eErr, tc.expectError)
return
}
assert.NoError(t, eErr)
}
})
}
}
func TestValuesFromHeader(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
whenValuePrefix string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, single value, case insensitive",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
},
{
name: "ok, empty prefix",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "",
expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="},
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Bearer ",
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderWWWAuthenticate,
whenValuePrefix: "",
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no headers",
givenRequest: nil,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, HeaderExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromQuery(t *testing.T) {
var testCases = []struct {
name string
givenQueryPart string
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenQueryPart: "?id=123&name=test",
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenQueryPart: "?id=123&id=456&name=test",
whenName: "id",
expectValues: []string{"123", "456"},
},
{
name: "nok, missing value",
givenQueryPart: "?id=123&name=test",
whenName: "nope",
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromQuery(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, QueryExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromParam(t *testing.T) {
examplePathParams := echo.PathParams{
{Name: "id", Value: "123"},
{Name: "gid", Value: "456"},
{Name: "gid", Value: "789"},
}
var testCases = []struct {
name string
givenPathParams echo.PathParams
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenPathParams: examplePathParams,
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenPathParams: examplePathParams,
whenName: "gid",
expectValues: []string{"456", "789"},
},
{
name: "nok, no values",
givenPathParams: nil,
whenName: "nope",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no matching value",
givenPathParams: examplePathParams,
whenName: "nope",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(echo.EditableContext)
if tc.givenPathParams != nil {
c.SetRawPathParams(&tc.givenPathParams)
}
extractor := valuesFromParam(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ParamExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromCookie(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderCookie, "_csrf=token")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: "_csrf",
expectValues: []string{"token"},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Add(echo.HeaderCookie, "_csrf=token")
req.Header.Add(echo.HeaderCookie, "_csrf=token2")
},
whenName: "_csrf",
expectValues: []string{"token", "token2"},
},
{
name: "nok, no matching cookie",
givenRequest: exampleRequest,
whenName: "xxx",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no cookies at all",
givenRequest: nil,
whenName: "xxx",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromCookie(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, CookieExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromForm(t *testing.T) {
examplePostFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
}
exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil)
return req
}
var testCases = []struct {
name string
givenRequest *http.Request
whenName string
expectValues []string
expectError string
}{
{
name: "ok, POST form, single value",
givenRequest: examplePostFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, POST form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "ok, GET form, single value",
givenRequest: exampleGetFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, GET form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "nok, POST form, value missing",
givenRequest: examplePostFormRequest(nil),
whenName: "nope",
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, POST form, form parsing error",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Body = nil
return req
}(),
whenName: "name",
expectError: "valuesFromForm parse form failed: missing form body",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := tc.givenRequest
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromForm(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, FormExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@ -1,21 +1,14 @@
// +build go1.15
package middleware package middleware
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/http"
"reflect"
"strings"
"github.com/golang-jwt/jwt"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"net/http"
) )
type (
// JWTConfig defines the config for JWT middleware. // JWTConfig defines the config for JWT middleware.
JWTConfig struct { type JWTConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
@ -25,110 +18,67 @@ type (
// SuccessHandler defines a function which is executed for a valid token. // SuccessHandler defines a function which is executed for a valid token.
SuccessHandler JWTSuccessHandler SuccessHandler JWTSuccessHandler
// ErrorHandler defines a function which is executed for an invalid token. // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
// function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
// It may be used to define a custom JWT error. // It may be used to define a custom JWT error.
ErrorHandler JWTErrorHandler //
// Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
ErrorHandlerWithContext JWTErrorHandlerWithContext // In that case you can use ErrorHandler to set default public JWT token value to request and continue with handler chain.
ErrorHandler JWTErrorHandlerWithContext
// Signing key to validate token.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKeys is provided.
SigningKey interface{}
// Map of signing keys to validate token with kid field usage.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKey is provided.
SigningKeys map[string]interface{}
// Signing method used to check the token's signing algorithm.
// Optional. Default value HS256.
SigningMethod string
// Context key to store user information from the token into context. // Context key to store user information from the token into context.
// Optional. Default value "user". // Optional. Default value "user".
ContextKey string ContextKey string
// Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation.
// Not used if custom ParseTokenFunc is set.
// Optional. Default value jwt.MapClaims
Claims jwt.Claims
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used // TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract token from the request. // to extract token(s) from the request.
// Optional. Default value "header:Authorization". // Optional. Default value "header:Authorization:Bearer ".
// Possible values: // Possible values:
// - "header:<name>" // - "header:<name>"
// - "query:<name>" // - "query:<name>"
// - "param:<name>" // - "param:<name>"
// - "cookie:<name>" // - "cookie:<name>"
// - "form:<name>" // - "form:<name>"
// Multiply sources example: // Multiple sources example:
// - "header:Authorization,cookie:myowncookie" // - "header:Authorization,cookie:myowncookie"
TokenLookup string TokenLookup string
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// KeyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
// Used by default ParseTokenFunc implementation.
//
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither SigningKeys nor SigningKey is provided.
// Not used if custom ParseTokenFunc is set.
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
KeyFunc jwt.Keyfunc
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
// parsing fails or parsed token is invalid. // parsing fails or parsed token is invalid.
// Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library // NB: could be called multiple times per request when token lookup is able to extract multiple token values (i.e. multiple Authorization headers)
ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) // See `jwt_external_test.go` for example implementation using `github.com/golang-jwt/jwt` as JWT implementation library
ParseTokenFunc func(c echo.Context, auth string) (interface{}, error)
} }
// JWTSuccessHandler defines a function which is executed for a valid token. // JWTSuccessHandler defines a function which is executed for a valid token.
JWTSuccessHandler func(echo.Context) type JWTSuccessHandler func(c echo.Context)
// JWTErrorHandler defines a function which is executed for an invalid token. // JWTErrorHandler defines a function which is executed for an invalid token.
JWTErrorHandler func(error) error type JWTErrorHandler func(err error) error
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
JWTErrorHandlerWithContext func(error, echo.Context) error type JWTErrorHandlerWithContext func(c echo.Context, err error) error
jwtExtractor func(echo.Context) (string, error) type valuesExtractor func(c echo.Context) ([]string, ExtractorType, error)
)
// Algorithms
const ( const (
// AlgorithmHS256 is token signing algorithm
AlgorithmHS256 = "HS256" AlgorithmHS256 = "HS256"
) )
// Errors // ErrJWTMissing denotes an error raised when JWT token value could not be extracted from request
var ( var ErrJWTMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing or malformed jwt")
ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") // ErrJWTInvalid denotes an error raised when JWT token value is invalid or expired
) var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
var (
// DefaultJWTConfig is the default JWT auth middleware config. // DefaultJWTConfig is the default JWT auth middleware config.
DefaultJWTConfig = JWTConfig{ var DefaultJWTConfig = JWTConfig{
Skipper: DefaultSkipper, Skipper: DefaultSkipper,
SigningMethod: AlgorithmHS256,
ContextKey: "user", ContextKey: "user",
TokenLookup: "header:" + echo.HeaderAuthorization, TokenLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
AuthScheme: "Bearer",
Claims: jwt.MapClaims{},
KeyFunc: nil,
} }
)
// JWT returns a JSON Web Token (JWT) auth middleware. // JWT returns a JSON Web Token (JWT) auth middleware.
// //
@ -137,64 +87,43 @@ var (
// For missing token, it returns "400 - Bad Request" error. // For missing token, it returns "400 - Bad Request" error.
// //
// See: https://jwt.io/introduction // See: https://jwt.io/introduction
// See `JWTConfig.TokenLookup` func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc {
func JWT(key interface{}) echo.MiddlewareFunc {
c := DefaultJWTConfig c := DefaultJWTConfig
c.SigningKey = key c.ParseTokenFunc = parseTokenFunc
return JWTWithConfig(c) return JWTWithConfig(c)
} }
// JWTWithConfig returns a JWT auth middleware with config. // JWTWithConfig returns a JSON Web Token (JWT) auth middleware or panics if configuration is invalid.
// See: `JWT()`. //
// For valid token, it sets the user in context and calls next handler.
// For invalid token, it returns "401 - Unauthorized" error.
// For missing token, it returns "400 - Bad Request" error.
//
// See: https://jwt.io/introduction
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts JWTConfig to middleware or returns an error for invalid configuration
func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultJWTConfig.Skipper config.Skipper = DefaultJWTConfig.Skipper
} }
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil { if config.ParseTokenFunc == nil {
panic("echo: jwt middleware requires signing key") return nil, errors.New("echo jwt middleware requires parse token function")
}
if config.SigningMethod == "" {
config.SigningMethod = DefaultJWTConfig.SigningMethod
} }
if config.ContextKey == "" { if config.ContextKey == "" {
config.ContextKey = DefaultJWTConfig.ContextKey config.ContextKey = DefaultJWTConfig.ContextKey
} }
if config.Claims == nil {
config.Claims = DefaultJWTConfig.Claims
}
if config.TokenLookup == "" { if config.TokenLookup == "" {
config.TokenLookup = DefaultJWTConfig.TokenLookup config.TokenLookup = DefaultJWTConfig.TokenLookup
} }
if config.AuthScheme == "" { extractors, err := createExtractors(config.TokenLookup)
config.AuthScheme = DefaultJWTConfig.AuthScheme if err != nil {
} return nil, fmt.Errorf("echo jwt middleware could not create token extractor: %w", err)
if config.KeyFunc == nil {
config.KeyFunc = config.defaultKeyFunc
}
if config.ParseTokenFunc == nil {
config.ParseTokenFunc = config.defaultParseToken
}
// Initialize
// Split sources
sources := strings.Split(config.TokenLookup, ",")
var extractors []jwtExtractor
for _, source := range sources {
parts := strings.Split(source, ":")
switch parts[0] {
case "query":
extractors = append(extractors, jwtFromQuery(parts[1]))
case "param":
extractors = append(extractors, jwtFromParam(parts[1]))
case "cookie":
extractors = append(extractors, jwtFromCookie(parts[1]))
case "form":
extractors = append(extractors, jwtFromForm(parts[1]))
case "header":
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
} }
if len(extractors) == 0 {
return nil, errors.New("echo jwt middleware could not create extractors from TokenLookup string")
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -206,30 +135,20 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.BeforeFunc != nil { if config.BeforeFunc != nil {
config.BeforeFunc(c) config.BeforeFunc(c)
} }
var auth string var lastExtractorErr error
var err error var lastTokenErr error
for _, extractor := range extractors { for _, extractor := range extractors {
// Extract token from extractor, if it's not fail break the loop and auths, _, extrErr := extractor(c)
// set auth if extrErr != nil {
auth, err = extractor(c) lastExtractorErr = extrErr
if err == nil { continue
break
} }
} for _, auth := range auths {
// If none of extractor has a token, handle error token, err := config.ParseTokenFunc(c, auth)
if err != nil { if err != nil {
if config.ErrorHandler != nil { lastTokenErr = err
return config.ErrorHandler(err) continue
} }
if config.ErrorHandlerWithContext != nil {
return config.ErrorHandlerWithContext(err, c)
}
return err
}
token, err := config.ParseTokenFunc(auth, c)
if err == nil {
// Store user information from token into context. // Store user information from token into context.
c.Set(config.ContextKey, token) c.Set(config.ContextKey, token)
if config.SuccessHandler != nil { if config.SuccessHandler != nil {
@ -237,111 +156,34 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
}
// prioritize token errors over extracting errors
err := lastTokenErr
if err == nil {
err = lastExtractorErr
}
if config.ErrorHandler != nil { if config.ErrorHandler != nil {
return config.ErrorHandler(err) if err == ErrExtractionValueMissing {
err = ErrJWTMissing
} }
if config.ErrorHandlerWithContext != nil { // Allow error handler to swallow the error and continue handler chain execution
return config.ErrorHandlerWithContext(err, c) // Useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public token to request and continue with handler chain
if handledErr := config.ErrorHandler(c, err); handledErr != nil {
return handledErr
} }
return next(c)
}
if err == ErrExtractionValueMissing {
return ErrJWTMissing
}
// everything else goes under http.StatusUnauthorized to avoid exposing JWT internals with generic error
return &echo.HTTPError{ return &echo.HTTPError{
Code: ErrJWTInvalid.Code, Code: ErrJWTInvalid.Code,
Message: ErrJWTInvalid.Message, Message: ErrJWTInvalid.Message,
Internal: err, Internal: err,
} }
} }
} }, nil
}
func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) {
token := new(jwt.Token)
var err error
// Issue #647, #656
if _, ok := config.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, config.KeyFunc)
} else {
t := reflect.ValueOf(config.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
}
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
// defaultKeyFunc returns a signing key of the given token.
func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(config.SigningKeys) > 0 {
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := config.SigningKeys[kid]; ok {
return key, nil
}
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}
return config.SigningKey, nil
}
// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) jwtExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
l := len(authScheme)
if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) {
return auth[l+1:], nil
}
return "", ErrJWTMissing
}
}
// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string.
func jwtFromQuery(param string) jwtExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string.
func jwtFromParam(param string) jwtExtractor {
return func(c echo.Context) (string, error) {
token := c.Param(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie.
func jwtFromCookie(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
cookie, err := c.Cookie(name)
if err != nil {
return "", ErrJWTMissing
}
return cookie.Value, nil
}
}
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
} }

View File

@ -0,0 +1,76 @@
package middleware_test
import (
"errors"
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"net/http"
"net/http/httptest"
)
// CreateJWTGoParseTokenFunc creates JWTGo implementation for ParseTokenFunc
//
// signingKey is signing key to validate token.
// This is one of the options to provide a token validation key.
// The order of precedence is a user-defined SigningKeys and SigningKey.
// Required if signingKeys is not provided.
//
// signingKeys is Map of signing keys to validate token with kid field usage.
// This is one of the options to provide a token validation key.
// The order of precedence is a user-defined SigningKeys and SigningKey.
// Required if signingKey is not provided
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) {
// keyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != middleware.AlgorithmHS256 {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(signingKeys) == 0 {
return signingKey, nil
}
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := signingKeys[kid]; ok {
return key, nil
}
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}
return func(c echo.Context, auth string) (interface{}, error) {
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
}
func ExampleJWTConfig_withJWTGoAsTokenParser() {
mw := middleware.JWTWithConfig(middleware.JWTConfig{
ParseTokenFunc: CreateJWTGoParseTokenFunc([]byte("secret"), nil),
})
e := echo.New()
e.Use(mw)
e.GET("/", func(c echo.Context) error {
user := c.Get("user").(*jwt.Token)
return c.JSON(http.StatusTeapot, user.Claims)
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
fmt.Printf("status: %v, body: %v", res.Code, res.Body.String())
// Output: status: 418, body: {"admin":true,"name":"John Doe","sub":"1234567890"}
}

View File

@ -1,5 +1,3 @@
// +build go1.15
package middleware package middleware
import ( import (
@ -11,11 +9,32 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) {
// This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != signingMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
return signingKey, nil
}
return func(c echo.Context, auth string) (interface{}, error) {
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc)
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
}
// jwtCustomInfo defines some custom types we're going to use within our tokens. // jwtCustomInfo defines some custom types we're going to use within our tokens.
type jwtCustomInfo struct { type jwtCustomInfo struct {
Name string `json:"name"` Name string `json:"name"`
@ -28,43 +47,7 @@ type jwtCustomClaims struct {
jwtCustomInfo jwtCustomInfo
} }
func TestJWTRace(t *testing.T) { func TestJWT_combinations(t *testing.T) {
e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss"
validKey := []byte("secret")
h := JWTWithConfig(JWTConfig{
Claims: &jwtCustomClaims{},
SigningKey: validKey,
})(handler)
makeReq := func(token string) echo.Context {
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token)
c := e.NewContext(req, res)
assert.NoError(t, h(c))
return c
}
c := makeReq(initialToken)
user := c.Get("user").(*jwt.Token)
claims := user.Claims.(*jwtCustomClaims)
assert.Equal(t, claims.Name, "John Doe")
makeReq(raceToken)
user = c.Get("user").(*jwt.Token)
claims = user.Claims.(*jwtCustomClaims)
// Initial context should still be "John Doe", not "Race Condition"
assert.Equal(t, claims.Name, "John Doe")
assert.Equal(t, claims.Admin, true)
}
func TestJWT(t *testing.T) {
e := echo.New() e := echo.New()
handler := func(c echo.Context) error { handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
@ -72,201 +55,180 @@ func TestJWT(t *testing.T) {
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
validKey := []byte("secret") validKey := []byte("secret")
invalidKey := []byte("invalid-key") invalidKey := []byte("invalid-key")
validAuth := DefaultJWTConfig.AuthScheme + " " + token validAuth := "Bearer " + token
for _, tc := range []struct { var testCases = []struct {
expPanic bool name string
expErrCode int // 0 for Success
config JWTConfig config JWTConfig
reqURL string // "/" if empty reqURL string // "/" if empty
hdrAuth string hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string formValues map[string]string
info string expPanic bool
expErrCode int // 0 for Success
}{ }{
{ {
expPanic: true, expPanic: true,
info: "No signing key provided", name: "No signing key provided",
},
{
expErrCode: http.StatusBadRequest,
config: JWTConfig{
SigningKey: validKey,
SigningMethod: "RS256",
},
info: "Unexpected signing method",
}, },
{ {
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{SigningKey: invalidKey}, config: JWTConfig{
info: "Invalid key", ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey),
},
name: "Unexpected signing method",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: validAuth,
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, invalidKey),
},
name: "Invalid key",
}, },
{ {
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{SigningKey: validKey}, config: JWTConfig{
info: "Valid JWT", ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
name: "Valid JWT",
}, },
{ {
hdrAuth: "Token" + " " + token, hdrAuth: "Token" + " " + token,
config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, config: JWTConfig{
info: "Valid JWT with custom AuthScheme", TokenLookup: "header:" + echo.HeaderAuthorization + ":Token ",
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
name: "Valid JWT with custom AuthScheme",
}, },
{ {
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{ config: JWTConfig{
Claims: &jwtCustomClaims{}, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
SigningKey: []byte("secret"),
}, },
info: "Valid JWT with custom claims", name: "Valid JWT with custom claims",
}, },
{ {
hdrAuth: "invalid-auth", hdrAuth: "invalid-auth",
expErrCode: http.StatusBadRequest, expErrCode: http.StatusUnauthorized,
config: JWTConfig{SigningKey: validKey}, config: JWTConfig{
info: "Invalid Authorization header", ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
}, },
{ name: "Invalid Authorization header",
config: JWTConfig{SigningKey: validKey},
expErrCode: http.StatusBadRequest,
info: "Empty header auth field",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
expErrCode: http.StatusUnauthorized,
name: "Empty header auth field",
},
{
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt", TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b&jwt=" + token, reqURL: "/?a=b&jwt=" + token,
info: "Valid query method", name: "Valid query method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt", TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b&jwtxyz=" + token, reqURL: "/?a=b&jwtxyz=" + token,
expErrCode: http.StatusBadRequest, expErrCode: http.StatusUnauthorized,
info: "Invalid query param name", name: "Invalid query param name",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt", TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b&jwt=invalid-token", reqURL: "/?a=b&jwt=invalid-token",
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
info: "Invalid query param value", name: "Invalid query param value",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt", TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b", reqURL: "/?a=b",
expErrCode: http.StatusBadRequest, expErrCode: http.StatusUnauthorized,
info: "Empty query", name: "Empty query",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "param:jwt", TokenLookup: "param:jwt",
}, },
reqURL: "/" + token, reqURL: "/" + token,
info: "Valid param method", name: "Valid param method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt", TokenLookup: "cookie:jwt",
}, },
hdrCookie: "jwt=" + token, hdrCookie: "jwt=" + token,
info: "Valid cookie method", name: "Valid cookie method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt,cookie:jwt", TokenLookup: "query:jwt,cookie:jwt",
}, },
hdrCookie: "jwt=" + token, hdrCookie: "jwt=" + token,
info: "Multiple jwt lookuop", name: "Multiple jwt lookuop",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt", TokenLookup: "cookie:jwt",
}, },
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
hdrCookie: "jwt=invalid", hdrCookie: "jwt=invalid",
info: "Invalid token with cookie method", name: "Invalid token with cookie method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt", TokenLookup: "cookie:jwt",
}, },
expErrCode: http.StatusBadRequest, expErrCode: http.StatusUnauthorized,
info: "Empty cookie", name: "Empty cookie",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt", TokenLookup: "form:jwt",
}, },
formValues: map[string]string{"jwt": token}, formValues: map[string]string{"jwt": token},
info: "Valid form method", name: "Valid form method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt", TokenLookup: "form:jwt",
}, },
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"}, formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method", name: "Invalid token with form method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt", TokenLookup: "form:jwt",
}, },
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
{
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return validKey, nil
},
},
info: "Valid JWT with a valid key using a user-defined KeyFunc",
},
{
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return invalidKey, nil
},
},
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
info: "Valid JWT with an invalid key using a user-defined KeyFunc", name: "Empty form field",
}, },
{ }
hdrAuth: validAuth,
config: JWTConfig{ for _, tc := range testCases {
KeyFunc: func(*jwt.Token) (interface{}, error) { t.Run(tc.name, func(t *testing.T) {
return nil, errors.New("faulty KeyFunc")
},
},
expErrCode: http.StatusUnauthorized,
info: "Token verification does not pass using a user-defined KeyFunc",
},
{
hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token,
config: JWTConfig{SigningKey: validKey},
info: "Valid JWT with lower case AuthScheme",
},
} {
if tc.reqURL == "" { if tc.reqURL == "" {
tc.reqURL = "/" tc.reqURL = "/"
} }
@ -289,127 +251,40 @@ func TestJWT(t *testing.T) {
c := e.NewContext(req, res) c := e.NewContext(req, res)
if tc.reqURL == "/"+token { if tc.reqURL == "/"+token {
c.SetParamNames("jwt") cc := c.(echo.EditableContext)
c.SetParamValues(token) cc.SetPathParams(echo.PathParams{
{Name: "jwt", Value: token},
})
} }
if tc.expPanic { if tc.expPanic {
assert.Panics(t, func() { assert.Panics(t, func() {
JWTWithConfig(tc.config) JWTWithConfig(tc.config)
}, tc.info) }, tc.name)
continue return
} }
if tc.expErrCode != 0 { if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler) h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError) he := h(c).(*echo.HTTPError)
assert.Equal(t, tc.expErrCode, he.Code, tc.info) assert.Equal(t, tc.expErrCode, he.Code)
continue return
} }
h := JWTWithConfig(tc.config)(handler) h := JWTWithConfig(tc.config)(handler)
if assert.NoError(t, h(c), tc.info) { if assert.NoError(t, h(c), tc.name) {
user := c.Get("user").(*jwt.Token) user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) { switch claims := user.Claims.(type) {
case jwt.MapClaims: case jwt.MapClaims:
assert.Equal(t, claims["name"], "John Doe", tc.info) assert.Equal(t, claims["name"], "John Doe")
case *jwtCustomClaims: case *jwtCustomClaims:
assert.Equal(t, claims.Name, "John Doe", tc.info) assert.Equal(t, claims.Name, "John Doe")
assert.Equal(t, claims.Admin, true, tc.info) assert.Equal(t, claims.Admin, true)
default:
panic("unexpected type of claims")
}
}
}
}
func TestJWTwithKID(t *testing.T) {
test := assert.New(t)
e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk"
secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM"
wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90"
staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o"
validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")}
invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")}
staticSecret := []byte("static_secret")
invalidStaticSecret := []byte("invalid_secret")
for _, tc := range []struct {
expErrCode int // 0 for Success
config JWTConfig
hdrAuth string
info string
}{
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
config: JWTConfig{SigningKeys: validKeys},
info: "First token valid",
},
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
config: JWTConfig{SigningKeys: validKeys},
info: "Second token valid",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken,
config: JWTConfig{SigningKeys: validKeys},
info: "Wrong key id token",
},
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
config: JWTConfig{SigningKey: staticSecret},
info: "Valid static secret token",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
config: JWTConfig{SigningKey: invalidStaticSecret},
info: "Invalid static secret",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
config: JWTConfig{SigningKeys: invalidKeys},
info: "Invalid keys first token",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
config: JWTConfig{SigningKeys: invalidKeys},
info: "Invalid keys second token",
},
} {
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
c := e.NewContext(req, res)
if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
test.Equal(tc.expErrCode, he.Code, tc.info)
continue
}
h := JWTWithConfig(tc.config)(handler)
if test.NoError(h(c), tc.info) {
user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) {
case jwt.MapClaims:
test.Equal(claims["name"], "John Doe", tc.info)
case *jwtCustomClaims:
test.Equal(claims.Name, "John Doe", tc.info)
test.Equal(claims.Admin, true, tc.info)
default: default:
panic("unexpected type of claims") panic("unexpected type of claims")
} }
} }
})
} }
} }
@ -420,7 +295,7 @@ func TestJWTConfig_skipper(t *testing.T) {
Skipper: func(context echo.Context) bool { Skipper: func(context echo.Context) bool {
return true // skip everything return true // skip everything
}, },
SigningKey: []byte("secret"), ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
})) }))
isCalled := false isCalled := false
@ -448,11 +323,11 @@ func TestJWTConfig_BeforeFunc(t *testing.T) {
BeforeFunc: func(context echo.Context) { BeforeFunc: func(context echo.Context) {
isCalled = true isCalled = true
}, },
SigningKey: []byte("secret"), ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
})) }))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder() res := httptest.NewRecorder()
e.ServeHTTP(res, req) e.ServeHTTP(res, req)
@ -469,18 +344,8 @@ func TestJWTConfig_extractorErrorHandling(t *testing.T) {
{ {
name: "ok, ErrorHandler is executed", name: "ok, ErrorHandler is executed",
given: JWTConfig{ given: JWTConfig{
SigningKey: []byte("secret"), ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
ErrorHandler: func(err error) error { ErrorHandler: func(c echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
},
},
expectStatusCode: http.StatusTeapot,
},
{
name: "ok, ErrorHandlerWithContext is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, context echo.Context) error {
return echo.NewHTTPError(http.StatusTeapot, "custom_error") return echo.NewHTTPError(http.StatusTeapot, "custom_error")
}, },
}, },
@ -515,23 +380,13 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
{ {
name: "ok, ErrorHandler is executed", name: "ok, ErrorHandler is executed",
given: JWTConfig{ given: JWTConfig{
SigningKey: []byte("secret"), ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
ErrorHandler: func(err error) error { ErrorHandler: func(c echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error())
}, },
}, },
expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n", expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n",
}, },
{
name: "ok, ErrorHandlerWithContext is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, context echo.Context) error {
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error())
},
},
expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n",
},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -544,14 +399,14 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
config := tc.given config := tc.given
parseTokenCalled := false parseTokenCalled := false
config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) { config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) {
parseTokenCalled = true parseTokenCalled = true
return nil, errors.New("parsing failed") return nil, errors.New("parsing failed")
} }
e.Use(JWTWithConfig(config)) e.Use(JWTWithConfig(config))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder() res := httptest.NewRecorder()
e.ServeHTTP(res, req) e.ServeHTTP(res, req)
@ -574,7 +429,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
signingKey := []byte("secret") signingKey := []byte("secret")
config := JWTConfig{ config := JWTConfig{
ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) { ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
keyFunc := func(t *jwt.Token) (interface{}, error) { keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != "HS256" { if t.Method.Alg() != "HS256" {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
@ -597,9 +452,161 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
e.Use(JWTWithConfig(config)) e.Use(JWTWithConfig(config))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder() res := httptest.NewRecorder()
e.ServeHTTP(res, req) e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code) assert.Equal(t, http.StatusTeapot, res.Code)
} }
func TestMustJWTWithConfig_SuccessHandler(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
success := c.Get("success").(string)
user := c.Get("user").(string)
return c.String(http.StatusTeapot, fmt.Sprintf("%v:%v", success, user))
})
mw, err := JWTConfig{
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
return auth, nil
},
SuccessHandler: func(c echo.Context) {
c.Set("success", "yes")
},
}.ToMiddleware()
assert.NoError(t, err)
e.Use(mw)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderAuthorization, "Bearer valid_token_base64")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, "yes:valid_token_base64", res.Body.String())
assert.Equal(t, http.StatusTeapot, res.Code)
}
func TestJWTWithConfig_CallNextOnNilErrorHandlerResult(t *testing.T) {
var testCases = []struct {
name string
givenCallNext bool
givenErrorHandler JWTErrorHandlerWithContext
givenTokenLookup string
whenAuthHeaders []string
whenCookies []string
whenParseReturn string
whenParseError error
expectHandlerCalled bool
expect string
expectCode int
}{
{
name: "ok, with valid JWT from auth header",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
return nil
},
whenAuthHeaders: []string{"Bearer valid_token_base64"},
whenParseReturn: "valid_token",
expectCode: http.StatusTeapot,
expect: "valid_token",
},
{
name: "ok, missing header, callNext and set public_token from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
if err != ErrJWTMissing {
panic("must get ErrJWTMissing")
}
c.Set("user", "public_token")
return nil
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusTeapot,
expect: "public_token",
},
{
name: "ok, invalid token, callNext and set public_token from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
// this is probably not realistic usecase. on parse error you probably want to return error
if err.Error() != "parser_error" {
panic("must get parser_error")
}
c.Set("user", "public_token")
return nil
},
whenAuthHeaders: []string{"Bearer invalid_header"},
whenParseError: errors.New("parser_error"),
expectCode: http.StatusTeapot,
expect: "public_token",
},
{
name: "nok, invalid token, return error from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
if err.Error() != "parser_error" {
panic("must get parser_error")
}
return err
},
whenAuthHeaders: []string{"Bearer invalid_header"},
whenParseError: errors.New("parser_error"),
expectCode: http.StatusInternalServerError,
expect: "{\"message\":\"Internal Server Error\"}\n",
},
{
name: "nok, callNext but return error from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
return err
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusUnauthorized,
expect: "{\"message\":\"missing or malformed jwt\"}\n",
},
{
name: "nok, callNext=false",
givenCallNext: false,
givenErrorHandler: func(c echo.Context, err error) error {
return err
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusUnauthorized,
expect: "{\"message\":\"missing or malformed jwt\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
token := c.Get("user").(string)
return c.String(http.StatusTeapot, token)
})
mw, err := JWTConfig{
TokenLookup: tc.givenTokenLookup,
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
return tc.whenParseReturn, tc.whenParseError
},
ErrorHandler: tc.givenErrorHandler,
}.ToMiddleware()
assert.NoError(t, err)
e.Use(mw)
req := httptest.NewRequest(http.MethodGet, "/", nil)
for _, a := range tc.whenAuthHeaders {
req.Header.Add(echo.HeaderAuthorization, a)
}
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expect, res.Body.String())
assert.Equal(t, tc.expectCode, res.Code)
})
}
}

View File

@ -3,58 +3,59 @@ package middleware
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/http"
"strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"net/http"
) )
type (
// KeyAuthConfig defines the config for KeyAuth middleware. // KeyAuthConfig defines the config for KeyAuth middleware.
KeyAuthConfig struct { type KeyAuthConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// KeyLookup is a string in the form of "<source>:<name>" that is used // KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract key from the request. // to extract key(s) from the request.
// Optional. Default value "header:Authorization". // Optional. Default value "header:Authorization:Bearer ".
// Possible values: // Possible values:
// - "header:<name>" // - "header:<name>:<value prefix>"
// - "query:<name>" // - "query:<name>"
// - "form:<name>" // - "param:<name>"
// - "cookie:<name>" // - "cookie:<name>"
KeyLookup string `yaml:"key_lookup"` // - "form:<name>"
// Multiple sources example:
// AuthScheme to be used in the Authorization header. // - "header:Authorization:Bearer ,cookie:myowncookie"
// Optional. Default value "Bearer". KeyLookup string
AuthScheme string
// Validator is a function to validate key. // Validator is a function to validate key.
// Required. // Required.
Validator KeyAuthValidator Validator KeyAuthValidator
// ErrorHandler defines a function which is executed for an invalid key. // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
// function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
// It may be used to define a custom error. // It may be used to define a custom error.
//
// Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain.
ErrorHandler KeyAuthErrorHandler ErrorHandler KeyAuthErrorHandler
} }
// KeyAuthValidator defines a function to validate KeyAuth credentials. // KeyAuthValidator defines a function to validate KeyAuth credentials.
KeyAuthValidator func(string, echo.Context) (bool, error) type KeyAuthValidator func(c echo.Context, key string, keyType ExtractorType) (bool, error)
keyExtractor func(echo.Context) (string, error)
// KeyAuthErrorHandler defines a function which is executed for an invalid key. // KeyAuthErrorHandler defines a function which is executed for an invalid key.
KeyAuthErrorHandler func(error, echo.Context) error type KeyAuthErrorHandler func(c echo.Context, err error) error
)
// ErrKeyMissing denotes an error raised when key value could not be extracted from request
var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
// ErrInvalidKey denotes an error raised when key value is invalid by validator
var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
var (
// DefaultKeyAuthConfig is the default KeyAuth middleware config. // DefaultKeyAuthConfig is the default KeyAuth middleware config.
DefaultKeyAuthConfig = KeyAuthConfig{ var DefaultKeyAuthConfig = KeyAuthConfig{
Skipper: DefaultSkipper, Skipper: DefaultSkipper,
KeyLookup: "header:" + echo.HeaderAuthorization, KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
AuthScheme: "Bearer",
} }
)
// KeyAuth returns an KeyAuth middleware. // KeyAuth returns an KeyAuth middleware.
// //
@ -67,34 +68,32 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
return KeyAuthWithConfig(c) return KeyAuthWithConfig(c)
} }
// KeyAuthWithConfig returns an KeyAuth middleware with config. // KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid.
// See `KeyAuth()`. //
// For first valid key it calls the next handler.
// For invalid key, it sends "401 - Unauthorized" response.
// For missing key, it sends "400 - Bad Request" response.
func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration
func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultKeyAuthConfig.Skipper config.Skipper = DefaultKeyAuthConfig.Skipper
} }
// Defaults
if config.AuthScheme == "" {
config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
}
if config.KeyLookup == "" { if config.KeyLookup == "" {
config.KeyLookup = DefaultKeyAuthConfig.KeyLookup config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
} }
if config.Validator == nil { if config.Validator == nil {
panic("echo: key-auth middleware requires a validator function") return nil, errors.New("echo key-auth middleware requires a validator function")
} }
extractors, err := createExtractors(config.KeyLookup)
// Initialize if err != nil {
parts := strings.Split(config.KeyLookup, ":") return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", err)
extractor := keyFromHeader(parts[1], config.AuthScheme) }
switch parts[0] { if len(extractors) == 0 {
case "query": return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string")
extractor = keyFromQuery(parts[1])
case "form":
extractor = keyFromForm(parts[1])
case "cookie":
extractor = keyFromCookie(parts[1])
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -103,79 +102,50 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
// Extract and verify key var lastExtractorErr error
key, err := extractor(c) var lastValidatorErr error
if err != nil { for _, extractor := range extractors {
if config.ErrorHandler != nil { keys, keyType, extrErr := extractor(c)
return config.ErrorHandler(err, c) if extrErr != nil {
lastExtractorErr = extrErr
continue
} }
return echo.NewHTTPError(http.StatusBadRequest, err.Error()) for _, key := range keys {
} valid, err := config.Validator(c, key, keyType)
valid, err := config.Validator(key, c)
if err != nil { if err != nil {
lastValidatorErr = err
continue
}
if !valid {
lastValidatorErr = ErrInvalidKey
continue
}
return next(c)
}
}
// prioritize validator errors over extracting errors
err := lastValidatorErr
if err == nil {
err = lastExtractorErr
}
if config.ErrorHandler != nil { if config.ErrorHandler != nil {
return config.ErrorHandler(err, c) // Allow error handler to swallow the error and continue handler chain execution
// Useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain
if handledErr := config.ErrorHandler(c, err); handledErr != nil {
return handledErr
}
return next(c)
}
if err == ErrExtractionValueMissing {
return ErrKeyMissing // do not wrap extractor errors
} }
return &echo.HTTPError{ return &echo.HTTPError{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
Message: "invalid key", Message: "Unauthorized",
Internal: err, Internal: err,
} }
} else if valid {
return next(c)
}
return echo.ErrUnauthorized
}
}
}
// keyFromHeader returns a `keyExtractor` that extracts key from the request header.
func keyFromHeader(header string, authScheme string) keyExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
if auth == "" {
return "", errors.New("missing key in request header")
}
if header == echo.HeaderAuthorization {
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", errors.New("invalid key in the request header")
}
return auth, nil
}
}
// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
func keyFromQuery(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.QueryParam(param)
if key == "" {
return "", errors.New("missing key in the query string")
}
return key, nil
}
}
// keyFromForm returns a `keyExtractor` that extracts key from the form.
func keyFromForm(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.FormValue(param)
if key == "" {
return "", errors.New("missing key in the form")
}
return key, nil
}
}
// keyFromCookie returns a `keyExtractor` that extracts key from the form.
func keyFromCookie(cookieName string) keyExtractor {
return func(c echo.Context) (string, error) {
key, err := c.Cookie(cookieName)
if err != nil {
return "", fmt.Errorf("missing key in cookies: %w", err)
}
return key.Value, nil
} }
}, nil
} }

View File

@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func testKeyValidator(key string, c echo.Context) (bool, error) { func testKeyValidator(c echo.Context, key string, keyType ExtractorType) (bool, error) {
switch key { switch key {
case "valid-key": case "valid-key":
return true, nil return true, nil
@ -28,7 +28,7 @@ func TestKeyAuth(t *testing.T) {
handlerCalled = true handlerCalled = true
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }
middlewareChain := KeyAuth(testKeyValidator)(handler) middlewareChain := KeyAuthWithConfig(KeyAuthConfig{Validator: testKeyValidator})(handler)
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=401, message=Unauthorized", expectError: "code=401, message=Unauthorized, internal=code=401, message=invalid key",
}, },
{ {
name: "nok, defaults, invalid scheme in header", name: "nok, defaults, invalid scheme in header",
@ -84,13 +84,13 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=invalid key in the request header", expectError: "code=401, message=missing key",
}, },
{ {
name: "nok, defaults, missing header", name: "nok, defaults, missing header",
givenRequest: func(req *http.Request) {}, givenRequest: func(req *http.Request) {},
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header", expectError: "code=401, message=missing key",
}, },
{ {
name: "ok, custom key lookup, header", name: "ok, custom key lookup, header",
@ -110,7 +110,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "header:API-Key" conf.KeyLookup = "header:API-Key"
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header", expectError: "code=401, message=missing key",
}, },
{ {
name: "ok, custom key lookup, query", name: "ok, custom key lookup, query",
@ -130,7 +130,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "query:key" conf.KeyLookup = "query:key"
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in the query string", expectError: "code=401, message=missing key",
}, },
{ {
name: "ok, custom key lookup, form", name: "ok, custom key lookup, form",
@ -155,7 +155,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "form:key" conf.KeyLookup = "form:key"
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in the form", expectError: "code=401, message=missing key",
}, },
{ {
name: "ok, custom key lookup, cookie", name: "ok, custom key lookup, cookie",
@ -179,20 +179,20 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "cookie:key" conf.KeyLookup = "cookie:key"
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in cookies: http: named cookie not present", expectError: "code=401, message=missing key",
}, },
{ {
name: "nok, custom errorHandler, error from extractor", name: "nok, custom errorHandler, error from extractor",
whenConfig: func(conf *KeyAuthConfig) { whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:token" conf.KeyLookup = "header:token"
conf.ErrorHandler = func(err error, context echo.Context) error { conf.ErrorHandler = func(c echo.Context, err error) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err httpError.Internal = err
return httpError return httpError
} }
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=418, message=custom, internal=missing key in request header", expectError: "code=418, message=custom, internal=code=400, message=missing or malformed value",
}, },
{ {
name: "nok, custom errorHandler, error from validator", name: "nok, custom errorHandler, error from validator",
@ -200,7 +200,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
}, },
whenConfig: func(conf *KeyAuthConfig) { whenConfig: func(conf *KeyAuthConfig) {
conf.ErrorHandler = func(err error, context echo.Context) error { conf.ErrorHandler = func(c echo.Context, err error) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err httpError.Internal = err
return httpError return httpError
@ -216,7 +216,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
}, },
whenConfig: func(conf *KeyAuthConfig) {}, whenConfig: func(conf *KeyAuthConfig) {},
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=401, message=invalid key, internal=some user defined error", expectError: "code=401, message=Unauthorized, internal=some user defined error",
}, },
} }
@ -257,3 +257,96 @@ func TestKeyAuthWithConfig(t *testing.T) {
}) })
} }
} }
func TestKeyAuthWithConfig_errors(t *testing.T) {
var testCases = []struct {
name string
whenConfig KeyAuthConfig
expectError string
}{
{
name: "ok, no error",
whenConfig: KeyAuthConfig{
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
return false, nil
},
},
},
{
name: "ok, missing validator func",
whenConfig: KeyAuthConfig{
Validator: nil,
},
expectError: "echo key-auth middleware requires a validator function",
},
{
name: "ok, extractor source can not be split",
whenConfig: KeyAuthConfig{
KeyLookup: "nope",
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
return false, nil
},
},
expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope",
},
{
name: "ok, no extractors",
whenConfig: KeyAuthConfig{
KeyLookup: "nope:nope",
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
return false, nil
},
},
expectError: "echo key-auth middleware could not create extractors from KeyLookup string",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mw, err := tc.whenConfig.ToMiddleware()
if tc.expectError != "" {
assert.Nil(t, mw)
assert.EqualError(t, err, tc.expectError)
} else {
assert.NotNil(t, mw)
assert.NoError(t, err)
}
})
}
}
func TestMustKeyAuthWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
KeyAuthWithConfig(KeyAuthConfig{})
})
}
func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) {
handlerCalled := false
var authValue string
handler := func(c echo.Context) error {
handlerCalled = true
authValue = c.Get("auth").(string)
return c.String(http.StatusOK, "test")
}
middlewareChain := KeyAuthWithConfig(KeyAuthConfig{
Validator: testKeyValidator,
ErrorHandler: func(c echo.Context, err error) error {
// could check error to decide if we can swallow the error
c.Set("auth", "public")
return nil
},
})(handler)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
// no auth header this time
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := middlewareChain(c)
assert.NoError(t, err)
assert.True(t, handlerCalled)
assert.Equal(t, "public", authValue)
}

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"strconv" "strconv"
"strings" "strings"
@ -10,13 +11,11 @@ import (
"time" "time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/color"
"github.com/valyala/fasttemplate" "github.com/valyala/fasttemplate"
) )
type (
// LoggerConfig defines the config for Logger middleware. // LoggerConfig defines the config for Logger middleware.
LoggerConfig struct { type LoggerConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
@ -49,42 +48,41 @@ type (
// Example "${remote_ip} ${status}" // Example "${remote_ip} ${status}"
// //
// Optional. Default value DefaultLoggerConfig.Format. // Optional. Default value DefaultLoggerConfig.Format.
Format string `yaml:"format"` Format string
// Optional. Default value DefaultLoggerConfig.CustomTimeFormat. // Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
CustomTimeFormat string `yaml:"custom_time_format"` CustomTimeFormat string
// Output is a writer where logs in JSON format are written. // Output is a writer where logs in JSON format are written.
// Optional. Default value os.Stdout. // Optional. Default destination `echo.Logger.Infof()`
Output io.Writer Output io.Writer
template *fasttemplate.Template template *fasttemplate.Template
colorer *color.Color
pool *sync.Pool pool *sync.Pool
} }
)
var (
// DefaultLoggerConfig is the default Logger middleware config. // DefaultLoggerConfig is the default Logger middleware config.
DefaultLoggerConfig = LoggerConfig{ var DefaultLoggerConfig = LoggerConfig{
Skipper: DefaultSkipper, Skipper: DefaultSkipper,
Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` +
`"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` +
`"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` +
`,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n",
CustomTimeFormat: "2006-01-02 15:04:05.00000", CustomTimeFormat: "2006-01-02 15:04:05.00000",
colorer: color.New(),
} }
)
// Logger returns a middleware that logs HTTP requests. // Logger returns a middleware that logs HTTP requests.
func Logger() echo.MiddlewareFunc { func Logger() echo.MiddlewareFunc {
return LoggerWithConfig(DefaultLoggerConfig) return LoggerWithConfig(DefaultLoggerConfig)
} }
// LoggerWithConfig returns a Logger middleware with config. // LoggerWithConfig returns a Logger middleware with config or panics on invalid configuration.
// See: `Logger()`.
func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts LoggerConfig to middleware or returns an error for invalid configuration
func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultLoggerConfig.Skipper config.Skipper = DefaultLoggerConfig.Skipper
@ -92,13 +90,8 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
if config.Format == "" { if config.Format == "" {
config.Format = DefaultLoggerConfig.Format config.Format = DefaultLoggerConfig.Format
} }
if config.Output == nil {
config.Output = DefaultLoggerConfig.Output
}
config.template = fasttemplate.New(config.Format, "${", "}") config.template = fasttemplate.New(config.Format, "${", "}")
config.colorer = color.New()
config.colorer.SetOutput(config.Output)
config.pool = &sync.Pool{ config.pool = &sync.Pool{
New: func() interface{} { New: func() interface{} {
return bytes.NewBuffer(make([]byte, 256)) return bytes.NewBuffer(make([]byte, 256))
@ -106,23 +99,23 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) error {
if config.Skipper(c) { if config.Skipper(c) {
return next(c) return next(c)
} }
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
start := time.Now() start := time.Now()
if err = next(c); err != nil { err := next(c)
c.Error(err)
}
stop := time.Now() stop := time.Now()
buf := config.pool.Get().(*bytes.Buffer) buf := config.pool.Get().(*bytes.Buffer)
buf.Reset() buf.Reset()
defer config.pool.Put(buf) defer config.pool.Put(buf)
if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { _, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
switch tag { switch tag {
case "time_unix": case "time_unix":
return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10))
@ -161,17 +154,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
case "user_agent": case "user_agent":
return buf.WriteString(req.UserAgent()) return buf.WriteString(req.UserAgent())
case "status": case "status":
n := res.Status status := res.Status
s := config.colorer.Green(n) if err != nil {
switch { if httpErr, ok := err.(*echo.HTTPError); ok {
case n >= 500: status = httpErr.Code
s = config.colorer.Red(n)
case n >= 400:
s = config.colorer.Yellow(n)
case n >= 300:
s = config.colorer.Cyan(n)
} }
return buf.WriteString(s) }
return buf.WriteString(strconv.Itoa(status))
case "error": case "error":
if err != nil { if err != nil {
// Error may contain invalid JSON e.g. `"` // Error may contain invalid JSON e.g. `"`
@ -201,23 +190,31 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
case strings.HasPrefix(tag, "form:"): case strings.HasPrefix(tag, "form:"):
return buf.Write([]byte(c.FormValue(tag[5:]))) return buf.Write([]byte(c.FormValue(tag[5:])))
case strings.HasPrefix(tag, "cookie:"): case strings.HasPrefix(tag, "cookie:"):
cookie, err := c.Cookie(tag[7:]) cookie, cookieErr := c.Cookie(tag[7:])
if err == nil { if cookieErr == nil {
return buf.Write([]byte(cookie.Value)) return buf.Write([]byte(cookie.Value))
} }
} }
} }
return 0, nil return 0, nil
}); err != nil { })
return if tmplErr != nil {
if err != nil {
return fmt.Errorf("error in middleware chain and also failed to create log from template: %v: %w", tmplErr, err)
}
return fmt.Errorf("failed to create log from template: %w", tmplErr)
} }
if config.Output == nil { if config.Output != nil {
_, err = c.Logger().Output().Write(buf.Bytes()) if _, lErr := config.Output.Write(buf.Bytes()); lErr != nil {
return return lErr
} }
_, err = config.Output.Write(buf.Bytes()) } else {
return if _, lErr := c.Echo().Logger.Write(buf.Bytes()); lErr != nil {
return lErr
} }
} }
return err
}
}, nil
} }

View File

@ -61,7 +61,7 @@ func TestLoggerIPAddress(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
e.Logger.SetOutput(buf) e.Logger = &testLogger{output: buf}
ip := "127.0.0.1" ip := "127.0.0.1"
h := Logger()(func(c echo.Context) error { h := Logger()(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")

View File

@ -6,9 +6,8 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type (
// MethodOverrideConfig defines the config for MethodOverride middleware. // MethodOverrideConfig defines the config for MethodOverride middleware.
MethodOverrideConfig struct { type MethodOverrideConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
@ -18,16 +17,13 @@ type (
} }
// MethodOverrideGetter is a function that gets overridden method from the request // MethodOverrideGetter is a function that gets overridden method from the request
MethodOverrideGetter func(echo.Context) string type MethodOverrideGetter func(echo.Context) string
)
var (
// DefaultMethodOverrideConfig is the default MethodOverride middleware config. // DefaultMethodOverrideConfig is the default MethodOverride middleware config.
DefaultMethodOverrideConfig = MethodOverrideConfig{ var DefaultMethodOverrideConfig = MethodOverrideConfig{
Skipper: DefaultSkipper, Skipper: DefaultSkipper,
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
} }
)
// MethodOverride returns a MethodOverride middleware. // MethodOverride returns a MethodOverride middleware.
// MethodOverride middleware checks for the overridden method from the request and // MethodOverride middleware checks for the overridden method from the request and
@ -38,9 +34,13 @@ func MethodOverride() echo.MiddlewareFunc {
return MethodOverrideWithConfig(DefaultMethodOverrideConfig) return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
} }
// MethodOverrideWithConfig returns a MethodOverride middleware with config. // MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration.
// See: `MethodOverride()`.
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration
func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultMethodOverrideConfig.Skipper config.Skipper = DefaultMethodOverrideConfig.Skipper
@ -64,7 +64,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
} }, nil
} }
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from

View File

@ -22,28 +22,70 @@ func TestMethodOverride(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
m(h)(c)
err := m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method) assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_formParam(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Override with form parameter // Override with form parameter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware()
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) assert.NoError(t, err)
rec = httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
c = e.NewContext(req, rec) c := e.NewContext(req, rec)
m(h)(c)
err = m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method) assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_queryParam(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Override with query parameter // Override with query parameter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware()
req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) assert.NoError(t, err)
rec = httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
c = e.NewContext(req, rec) rec := httptest.NewRecorder()
m(h)(c) c := e.NewContext(req, rec)
err = m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method) assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_ignoreGet(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Ignore `GET` // Ignore `GET`
req = httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodGet, req.Method) assert.Equal(t, http.MethodGet, req.Method)
} }

View File

@ -9,14 +9,11 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // Skipper defines a function to skip middleware. Returning true skips processing the middleware.
// Skipper defines a function to skip middleware. Returning true skips processing type Skipper func(c echo.Context) bool
// the middleware.
Skipper func(echo.Context) bool
// BeforeFunc defines a function which is executed just before the middleware. // BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc func(echo.Context) type BeforeFunc func(c echo.Context)
)
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
groups := pattern.FindAllStringSubmatch(input, -1) groups := pattern.FindAllStringSubmatch(input, -1)
@ -87,3 +84,11 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error
func DefaultSkipper(echo.Context) bool { func DefaultSkipper(echo.Context) bool {
return false return false
} }
func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc {
mw, err := config.ToMiddleware()
if err != nil {
panic(err)
}
return mw
}

View File

@ -2,6 +2,7 @@ package middleware
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
@ -20,9 +21,8 @@ import (
// TODO: Handle TLS proxy // TODO: Handle TLS proxy
type (
// ProxyConfig defines the config for Proxy middleware. // ProxyConfig defines the config for Proxy middleware.
ProxyConfig struct { type ProxyConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
@ -59,46 +59,43 @@ type (
} }
// ProxyTarget defines the upstream target. // ProxyTarget defines the upstream target.
ProxyTarget struct { type ProxyTarget struct {
Name string Name string
URL *url.URL URL *url.URL
Meta echo.Map Meta echo.Map
} }
// ProxyBalancer defines an interface to implement a load balancing technique. // ProxyBalancer defines an interface to implement a load balancing technique.
ProxyBalancer interface { type ProxyBalancer interface {
AddTarget(*ProxyTarget) bool AddTarget(*ProxyTarget) bool
RemoveTarget(string) bool RemoveTarget(string) bool
Next(echo.Context) *ProxyTarget Next(echo.Context) *ProxyTarget
} }
commonBalancer struct { type commonBalancer struct {
targets []*ProxyTarget targets []*ProxyTarget
mutex sync.RWMutex mutex sync.RWMutex
} }
// RandomBalancer implements a random load balancing technique. // RandomBalancer implements a random load balancing technique.
randomBalancer struct { type randomBalancer struct {
*commonBalancer *commonBalancer
random *rand.Rand random *rand.Rand
} }
// RoundRobinBalancer implements a round-robin load balancing technique. // RoundRobinBalancer implements a round-robin load balancing technique.
roundRobinBalancer struct { type roundRobinBalancer struct {
*commonBalancer *commonBalancer
i uint32 i uint32
} }
)
var (
// DefaultProxyConfig is the default Proxy middleware config. // DefaultProxyConfig is the default Proxy middleware config.
DefaultProxyConfig = ProxyConfig{ var DefaultProxyConfig = ProxyConfig{
Skipper: DefaultSkipper, Skipper: DefaultSkipper,
ContextKey: "target", ContextKey: "target",
} }
)
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
in, _, err := c.Response().Hijack() in, _, err := c.Response().Hijack()
if err != nil { if err != nil {
@ -203,15 +200,23 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
return ProxyWithConfig(c) return ProxyWithConfig(c)
} }
// ProxyWithConfig returns a Proxy middleware with config. // ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid.
// See: `Proxy()` //
// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration
func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultProxyConfig.Skipper config.Skipper = DefaultProxyConfig.Skipper
} }
if config.ContextKey == "" {
config.ContextKey = DefaultProxyConfig.ContextKey
}
if config.Balancer == nil { if config.Balancer == nil {
panic("echo: proxy middleware requires balancer") return nil, errors.New("echo proxy middleware requires balancer")
} }
if config.Rewrite != nil { if config.Rewrite != nil {
@ -254,10 +259,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Proxy // Proxy
switch { switch {
case c.IsWebSocket(): case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req) proxyRaw(c, tgt).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream": case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
default: default:
proxyHTTP(tgt, c, config).ServeHTTP(res, req) proxyHTTP(c, tgt, config).ServeHTTP(res, req)
} }
if e, ok := c.Get("_error").(error); ok { if e, ok := c.Get("_error").(error); ok {
err = e err = e
@ -265,7 +270,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
return return
} }
} }, nil
} }
// StatusCodeContextCanceled is a custom HTTP status code for situations // StatusCodeContextCanceled is a custom HTTP status code for situations
@ -275,7 +280,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// 499 too instead of the more problematic 5xx, which does not allow to detect this situation // 499 too instead of the more problematic 5xx, which does not allow to detect this situation
const StatusCodeContextCanceled = 499 const StatusCodeContextCanceled = 499
func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { func proxyHTTP(c echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler {
proxy := httputil.NewSingleHostReverseProxy(tgt.URL) proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
desc := tgt.URL.String() desc := tgt.URL.String()

View File

@ -55,7 +55,7 @@ func TestProxy(t *testing.T) {
// Random // Random
e := echo.New() e := echo.New()
e.Use(Proxy(rb)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
@ -77,7 +77,7 @@ func TestProxy(t *testing.T) {
// Round-robin // Round-robin
rrb := NewRoundRobinBalancer(targets) rrb := NewRoundRobinBalancer(targets)
e = echo.New() e = echo.New()
e.Use(Proxy(rrb)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
@ -113,15 +113,20 @@ func TestProxy(t *testing.T) {
return nil return nil
} }
} }
rrb1 := NewRoundRobinBalancer(targets)
e = echo.New() e = echo.New()
e.Use(contextObserver) e.Use(contextObserver)
e.Use(Proxy(rrb1)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)}))
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
} }
func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) {
assert.Panics(t, func() {
ProxyWithConfig(ProxyConfig{Balancer: nil})
})
}
func TestProxyRealIPHeader(t *testing.T) { func TestProxyRealIPHeader(t *testing.T) {
// Setup // Setup
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
@ -129,7 +134,7 @@ func TestProxyRealIPHeader(t *testing.T) {
url, _ := url.Parse(upstream.URL) url, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
e := echo.New() e := echo.New()
e.Use(Proxy(rrb)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@ -334,7 +339,7 @@ func TestProxyError(t *testing.T) {
// Random // Random
e := echo.New() e := echo.New()
e.Use(Proxy(rb)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@ -362,7 +367,7 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
rb := NewRandomBalancer(nil) rb := NewRandomBalancer(nil)
assert.True(t, rb.AddTarget(target)) assert.True(t, rb.AddTarget(target))
e := echo.New() e := echo.New()
e.Use(Proxy(rb)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(req.Context()) ctx, cancel := context.WithCancel(req.Context())

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@ -9,17 +10,13 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
type (
// RateLimiterStore is the interface to be implemented by custom stores. // RateLimiterStore is the interface to be implemented by custom stores.
RateLimiterStore interface { type RateLimiterStore interface {
// Stores for the rate limiter have to implement the Allow method
Allow(identifier string) (bool, error) Allow(identifier string) (bool, error)
} }
)
type (
// RateLimiterConfig defines the configuration for the rate limiter // RateLimiterConfig defines the configuration for the rate limiter
RateLimiterConfig struct { type RateLimiterConfig struct {
Skipper Skipper Skipper Skipper
BeforeFunc BeforeFunc BeforeFunc BeforeFunc
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor // IdentifierExtractor uses echo.Context to extract the identifier for a visitor
@ -31,17 +28,15 @@ type (
// DenyHandler provides a handler to be called when RateLimiter denies access // DenyHandler provides a handler to be called when RateLimiter denies access
DenyHandler func(context echo.Context, identifier string, err error) error DenyHandler func(context echo.Context, identifier string, err error) error
} }
// Extractor is used to extract data from echo.Context
Extractor func(context echo.Context) (string, error)
)
// errors // Extractor is used to extract data from echo.Context
var ( type Extractor func(context echo.Context) (string, error)
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
// ErrExtractorError denotes an error raised when extractor function is unsuccessful // ErrExtractorError denotes an error raised when extractor function is unsuccessful
ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
)
// DefaultRateLimiterConfig defines default values for RateLimiterConfig // DefaultRateLimiterConfig defines default values for RateLimiterConfig
var DefaultRateLimiterConfig = RateLimiterConfig{ var DefaultRateLimiterConfig = RateLimiterConfig{
@ -111,6 +106,11 @@ RateLimiterWithConfig returns a rate limiting middleware
}, middleware.RateLimiterWithConfig(config)) }, middleware.RateLimiterWithConfig(config))
*/ */
func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration
func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultRateLimiterConfig.Skipper config.Skipper = DefaultRateLimiterConfig.Skipper
} }
@ -124,7 +124,7 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
} }
if config.Store == nil { if config.Store == nil {
panic("Store configuration must be provided") return nil, errors.New("echo rate limiter store configuration must be provided")
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
@ -137,22 +137,19 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
identifier, err := config.IdentifierExtractor(c) identifier, err := config.IdentifierExtractor(c)
if err != nil { if err != nil {
c.Error(config.ErrorHandler(c, err)) return config.ErrorHandler(c, err)
return nil
} }
if allow, err := config.Store.Allow(identifier); !allow { if allow, allowErr := config.Store.Allow(identifier); !allow {
c.Error(config.DenyHandler(c, identifier, err)) return config.DenyHandler(c, identifier, allowErr)
return nil
} }
return next(c) return next(c)
} }
} }, nil
} }
type (
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter // RateLimiterMemoryStore is the built-in store implementation for RateLimiter
RateLimiterMemoryStore struct { type RateLimiterMemoryStore struct {
visitors map[string]*Visitor visitors map[string]*Visitor
mutex sync.Mutex mutex sync.Mutex
rate rate.Limit rate rate.Limit
@ -160,12 +157,12 @@ type (
expiresIn time.Duration expiresIn time.Duration
lastCleanup time.Time lastCleanup time.Time
} }
// Visitor signifies a unique user's limiter details // Visitor signifies a unique user's limiter details
Visitor struct { type Visitor struct {
*rate.Limiter *rate.Limiter
lastSeen time.Time lastSeen time.Time
} }
)
/* /*
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with

View File

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
@ -25,19 +24,19 @@ func TestRateLimiter(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
mw := RateLimiter(inMemoryStore) mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
testCases := []struct { testCases := []struct {
id string id string
code int expectErr string
}{ }{
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -47,20 +46,25 @@ func TestRateLimiter(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
_ = mw(handler)(c) err := mw(handler)(c)
assert.Equal(t, tc.code, rec.Code) if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
} }
} }
func TestRateLimiter_panicBehaviour(t *testing.T) { func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
assert.Panics(t, func() { assert.Panics(t, func() {
RateLimiter(nil) RateLimiterWithConfig(RateLimiterConfig{})
}) })
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
RateLimiter(inMemoryStore) RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
}) })
} }
@ -73,7 +77,7 @@ func TestRateLimiterWithConfig(t *testing.T) {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
IdentifierExtractor: func(c echo.Context) (string, error) { IdentifierExtractor: func(c echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP) id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" { if id == "" {
@ -88,7 +92,8 @@ func TestRateLimiterWithConfig(t *testing.T) {
return ctx.JSON(http.StatusBadRequest, nil) return ctx.JSON(http.StatusBadRequest, nil)
}, },
Store: inMemoryStore, Store: inMemoryStore,
}) }.ToMiddleware()
assert.NoError(t, err)
testCases := []struct { testCases := []struct {
id string id string
@ -111,8 +116,9 @@ func TestRateLimiterWithConfig(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
_ = mw(handler)(c) err := mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, tc.code, rec.Code) assert.Equal(t, tc.code, rec.Code)
} }
} }
@ -126,7 +132,7 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
IdentifierExtractor: func(c echo.Context) (string, error) { IdentifierExtractor: func(c echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP) id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" { if id == "" {
@ -135,19 +141,20 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
return id, nil return id, nil
}, },
Store: inMemoryStore, Store: inMemoryStore,
}) }.ToMiddleware()
assert.NoError(t, err)
testCases := []struct { testCases := []struct {
id string id string
code int expectErr string
}{ }{
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"", http.StatusForbidden}, {expectErr: "code=403, message=error while extracting identifier, internal=invalid identifier"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -158,9 +165,13 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
_ = mw(handler)(c) err := mw(handler)(c)
if tc.expectErr != "" {
assert.Equal(t, tc.code, rec.Code) assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
} }
} }
@ -174,21 +185,22 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
Store: inMemoryStore, Store: inMemoryStore,
}) }.ToMiddleware()
assert.NoError(t, err)
testCases := []struct { testCases := []struct {
id string id string
code int expectErr string
}{ }{
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -199,9 +211,13 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
_ = mw(handler)(c) err := mw(handler)(c)
if tc.expectErr != "" {
assert.Equal(t, tc.code, rec.Code) assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
} }
} }
} }
@ -222,7 +238,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
Skipper: func(c echo.Context) bool { Skipper: func(c echo.Context) bool {
return true return true
}, },
@ -233,10 +249,12 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) { IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil return "127.0.0.1", nil
}, },
}) }.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c) err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, false, beforeFuncRan) assert.Equal(t, false, beforeFuncRan)
} }
@ -256,7 +274,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
Skipper: func(c echo.Context) bool { Skipper: func(c echo.Context) bool {
return false return false
}, },
@ -267,7 +285,8 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) { IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil return "127.0.0.1", nil
}, },
}) }.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c) _ = mw(handler)(c)
@ -291,7 +310,7 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
BeforeFunc: func(c echo.Context) { BeforeFunc: func(c echo.Context) {
beforeRan = true beforeRan = true
}, },
@ -299,10 +318,12 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) { IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil return "127.0.0.1", nil
}, },
}) }.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c) err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, true, beforeRan) assert.Equal(t, true, beforeRan)
} }
@ -413,7 +434,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) {
func generateAddressList(count int) []string { func generateAddressList(count int) []string {
addrs := make([]string, count) addrs := make([]string, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
addrs[i] = random.String(15) addrs[i] = randomString(15)
} }
return addrs return addrs
} }

View File

@ -5,44 +5,34 @@ import (
"runtime" "runtime"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
) )
type (
// RecoverConfig defines the config for Recover middleware. // RecoverConfig defines the config for Recover middleware.
RecoverConfig struct { type RecoverConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// Size of the stack to be printed. // Size of the stack to be printed.
// Optional. Default value 4KB. // Optional. Default value 4KB.
StackSize int `yaml:"stack_size"` StackSize int
// DisableStackAll disables formatting stack traces of all other goroutines // DisableStackAll disables formatting stack traces of all other goroutines
// into buffer after the trace for the current goroutine. // into buffer after the trace for the current goroutine.
// Optional. Default value false. // Optional. Default value false.
DisableStackAll bool `yaml:"disable_stack_all"` DisableStackAll bool
// DisablePrintStack disables printing stack trace. // DisablePrintStack disables printing stack trace.
// Optional. Default value as false. // Optional. Default value as false.
DisablePrintStack bool `yaml:"disable_print_stack"` DisablePrintStack bool
// LogLevel is log level to printing stack trace.
// Optional. Default value 0 (Print).
LogLevel log.Lvl
} }
)
var (
// DefaultRecoverConfig is the default Recover middleware config. // DefaultRecoverConfig is the default Recover middleware config.
DefaultRecoverConfig = RecoverConfig{ var DefaultRecoverConfig = RecoverConfig{
Skipper: DefaultSkipper, Skipper: DefaultSkipper,
StackSize: 4 << 10, // 4 KB StackSize: 4 << 10, // 4 KB
DisableStackAll: false, DisableStackAll: false,
DisablePrintStack: false, DisablePrintStack: false,
LogLevel: 0,
} }
)
// Recover returns a middleware which recovers from panics anywhere in the chain // Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler. // and handles the control to the centralized HTTPErrorHandler.
@ -50,9 +40,13 @@ func Recover() echo.MiddlewareFunc {
return RecoverWithConfig(DefaultRecoverConfig) return RecoverWithConfig(DefaultRecoverConfig)
} }
// RecoverWithConfig returns a Recover middleware with config. // RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration.
// See: `Recover()`.
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration
func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultRecoverConfig.Skipper config.Skipper = DefaultRecoverConfig.Skipper
@ -62,40 +56,26 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) (err error) {
if config.Skipper(c) { if config.Skipper(c) {
return next(c) return next(c)
} }
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
err, ok := r.(error) tmpErr, ok := r.(error)
if !ok { if !ok {
err = fmt.Errorf("%v", r) tmpErr = fmt.Errorf("%v", r)
} }
if !config.DisablePrintStack {
stack := make([]byte, config.StackSize) stack := make([]byte, config.StackSize)
length := runtime.Stack(stack, !config.DisableStackAll) length := runtime.Stack(stack, !config.DisableStackAll)
if !config.DisablePrintStack { tmpErr = fmt.Errorf("[PANIC RECOVER] %w %s", tmpErr, stack[:length])
msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])
switch config.LogLevel {
case log.DEBUG:
c.Logger().Debug(msg)
case log.INFO:
c.Logger().Info(msg)
case log.WARN:
c.Logger().Warn(msg)
case log.ERROR:
c.Logger().Error(msg)
case log.OFF:
// None.
default:
c.Logger().Print(msg)
} }
} err = tmpErr
c.Error(err)
} }
}() }()
return next(c) return next(c)
} }
} }, nil
} }

View File

@ -2,82 +2,109 @@ package middleware
import ( import (
"bytes" "bytes"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestRecover(t *testing.T) { func TestRecover(t *testing.T) {
e := echo.New() e := echo.New()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
e.Logger.SetOutput(buf) e.Logger = &testLogger{output: buf}
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := Recover()(echo.HandlerFunc(func(c echo.Context) error { h := Recover()(func(c echo.Context) error {
panic("test") panic("test")
})) })
h(c) err := h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code) assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine")
assert.Contains(t, buf.String(), "PANIC RECOVER") assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
assert.Contains(t, buf.String(), "") // nothing is logged
} }
func TestRecoverWithConfig_LogLevel(t *testing.T) { func TestRecover_skipper(t *testing.T) {
tests := []struct {
logLevel log.Lvl
levelName string
}{{
logLevel: log.DEBUG,
levelName: "DEBUG",
}, {
logLevel: log.INFO,
levelName: "INFO",
}, {
logLevel: log.WARN,
levelName: "WARN",
}, {
logLevel: log.ERROR,
levelName: "ERROR",
}, {
logLevel: log.OFF,
levelName: "OFF",
}}
for _, tt := range tests {
tt := tt
t.Run(tt.levelName, func(t *testing.T) {
e := echo.New() e := echo.New()
e.Logger.SetLevel(log.DEBUG)
buf := new(bytes.Buffer)
e.Logger.SetOutput(buf)
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
config := DefaultRecoverConfig config := RecoverConfig{
config.LogLevel = tt.logLevel Skipper: func(c echo.Context) bool {
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { return true
panic("test") },
}))
h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
output := buf.String()
if tt.logLevel == log.OFF {
assert.Empty(t, output)
} else {
assert.Contains(t, output, "PANIC RECOVER")
assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName))
} }
h := RecoverWithConfig(config)(func(c echo.Context) error {
panic("testPANIC")
})
var err error
assert.Panics(t, func() {
err = h(c)
})
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
}
func TestRecoverWithConfig(t *testing.T) {
var testCases = []struct {
name string
givenNoPanic bool
whenConfig RecoverConfig
expectErrContain string
expectErr string
}{
{
name: "ok, default config",
whenConfig: DefaultRecoverConfig,
expectErrContain: "[PANIC RECOVER] testPANIC goroutine",
},
{
name: "ok, no panic",
givenNoPanic: true,
whenConfig: DefaultRecoverConfig,
expectErrContain: "",
},
{
name: "ok, DisablePrintStack",
whenConfig: RecoverConfig{
DisablePrintStack: true,
},
expectErr: "testPANIC",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
config := tc.whenConfig
h := RecoverWithConfig(config)(func(c echo.Context) error {
if tc.givenNoPanic {
return nil
}
panic("testPANIC")
})
err := h(c)
if tc.expectErrContain != "" {
assert.Contains(t, err.Error(), tc.expectErrContain)
} else if tc.expectErr != "" {
assert.Contains(t, err.Error(), tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
}) })
} }
} }

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"net/http" "net/http"
"strings" "strings"
@ -14,7 +15,9 @@ type RedirectConfig struct {
// Status code to be used when redirecting the request. // Status code to be used when redirecting the request.
// Optional. Default value http.StatusMovedPermanently. // Optional. Default value http.StatusMovedPermanently.
Code int `yaml:"code"` Code int
redirect redirectLogic
} }
// redirectLogic represents a function that given a scheme, host and uri // redirectLogic represents a function that given a scheme, host and uri
@ -24,29 +27,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string)
const www = "www." const www = "www."
// DefaultRedirectConfig is the default Redirect middleware config. // RedirectHTTPSConfig is the HTTPS Redirect middleware config.
var DefaultRedirectConfig = RedirectConfig{ var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS}
Skipper: DefaultSkipper,
Code: http.StatusMovedPermanently, // RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config.
} var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW}
// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config.
var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW}
// RedirectWWWConfig is the WWW Redirect middleware config.
var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW}
// RedirectNonWWWConfig is the non WWW Redirect middleware config.
var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW}
// HTTPSRedirect redirects http requests to https. // HTTPSRedirect redirects http requests to https.
// For example, http://labstack.com will be redirect to https://labstack.com. // For example, http://labstack.com will be redirect to https://labstack.com.
// //
// Usage `Echo#Pre(HTTPSRedirect())` // Usage `Echo#Pre(HTTPSRedirect())`
func HTTPSRedirect() echo.MiddlewareFunc { func HTTPSRedirect() echo.MiddlewareFunc {
return HTTPSRedirectWithConfig(DefaultRedirectConfig) return HTTPSRedirectWithConfig(RedirectHTTPSConfig)
} }
// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. // HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration.
// See `HTTPSRedirect()`.
func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) { config.redirect = redirectHTTPS
if scheme != "https" { return toMiddlewareOrPanic(config)
return true, "https://" + host + uri
}
return false, ""
})
} }
// HTTPSWWWRedirect redirects http requests to https www. // HTTPSWWWRedirect redirects http requests to https www.
@ -54,18 +61,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
// //
// Usage `Echo#Pre(HTTPSWWWRedirect())` // Usage `Echo#Pre(HTTPSWWWRedirect())`
func HTTPSWWWRedirect() echo.MiddlewareFunc { func HTTPSWWWRedirect() echo.MiddlewareFunc {
return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig) return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig)
} }
// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration.
// See `HTTPSWWWRedirect()`.
func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) { config.redirect = redirectHTTPSWWW
if scheme != "https" && !strings.HasPrefix(host, www) { return toMiddlewareOrPanic(config)
return true, "https://www." + host + uri
}
return false, ""
})
} }
// HTTPSNonWWWRedirect redirects http requests to https non www. // HTTPSNonWWWRedirect redirects http requests to https non www.
@ -73,19 +75,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
// //
// Usage `Echo#Pre(HTTPSNonWWWRedirect())` // Usage `Echo#Pre(HTTPSNonWWWRedirect())`
func HTTPSNonWWWRedirect() echo.MiddlewareFunc { func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig) return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig)
} }
// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration.
// See `HTTPSNonWWWRedirect()`.
func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (ok bool, url string) { config.redirect = redirectNonHTTPSWWW
if scheme != "https" { return toMiddlewareOrPanic(config)
host = strings.TrimPrefix(host, www)
return true, "https://" + host + uri
}
return false, ""
})
} }
// WWWRedirect redirects non www requests to www. // WWWRedirect redirects non www requests to www.
@ -93,18 +89,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
// //
// Usage `Echo#Pre(WWWRedirect())` // Usage `Echo#Pre(WWWRedirect())`
func WWWRedirect() echo.MiddlewareFunc { func WWWRedirect() echo.MiddlewareFunc {
return WWWRedirectWithConfig(DefaultRedirectConfig) return WWWRedirectWithConfig(RedirectWWWConfig)
} }
// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration.
// See `WWWRedirect()`.
func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) { config.redirect = redirectWWW
if !strings.HasPrefix(host, www) { return toMiddlewareOrPanic(config)
return true, scheme + "://www." + host + uri
}
return false, ""
})
} }
// NonWWWRedirect redirects www requests to non www. // NonWWWRedirect redirects www requests to non www.
@ -112,26 +103,25 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
// //
// Usage `Echo#Pre(NonWWWRedirect())` // Usage `Echo#Pre(NonWWWRedirect())`
func NonWWWRedirect() echo.MiddlewareFunc { func NonWWWRedirect() echo.MiddlewareFunc {
return NonWWWRedirectWithConfig(DefaultRedirectConfig) return NonWWWRedirectWithConfig(RedirectNonWWWConfig)
} }
// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration.
// See `NonWWWRedirect()`.
func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) { config.redirect = redirectNonWWW
if strings.HasPrefix(host, www) { return toMiddlewareOrPanic(config)
return true, scheme + "://" + host[4:] + uri
}
return false, ""
})
} }
func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { // ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration
func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultRedirectConfig.Skipper config.Skipper = DefaultSkipper
} }
if config.Code == 0 { if config.Code == 0 {
config.Code = DefaultRedirectConfig.Code config.Code = http.StatusMovedPermanently
}
if config.redirect == nil {
return nil, errors.New("redirectConfig is missing redirect function")
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -142,11 +132,47 @@ func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
req, scheme := c.Request(), c.Scheme() req, scheme := c.Request(), c.Scheme()
host := req.Host host := req.Host
if ok, url := cb(scheme, host, req.RequestURI); ok { if ok, url := config.redirect(scheme, host, req.RequestURI); ok {
return c.Redirect(config.Code, url) return c.Redirect(config.Code, url)
} }
return next(c) return next(c)
} }
}, nil
} }
var redirectHTTPS = func(scheme, host, uri string) (bool, string) {
if scheme != "https" {
return true, "https://" + host + uri
}
return false, ""
}
var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) {
if scheme != "https" && !strings.HasPrefix(host, www) {
return true, "https://www." + host + uri
}
return false, ""
}
var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) {
if scheme != "https" {
host = strings.TrimPrefix(host, www)
return true, "https://" + host + uri
}
return false, ""
}
var redirectWWW = func(scheme, host, uri string) (bool, string) {
if !strings.HasPrefix(host, www) {
return true, scheme + "://www." + host + uri
}
return false, ""
}
var redirectNonWWW = func(scheme, host, uri string) (bool, string) {
if strings.HasPrefix(host, www) {
return true, scheme + "://" + host[4:] + uri
}
return false, ""
} }

View File

@ -2,12 +2,10 @@ package middleware
import ( import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
) )
type (
// RequestIDConfig defines the config for RequestID middleware. // RequestIDConfig defines the config for RequestID middleware.
RequestIDConfig struct { type RequestIDConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
@ -16,31 +14,26 @@ type (
Generator func() string Generator func() string
// RequestIDHandler defines a function which is executed for a request id. // RequestIDHandler defines a function which is executed for a request id.
RequestIDHandler func(echo.Context, string) RequestIDHandler func(c echo.Context, requestID string)
} }
)
var (
// DefaultRequestIDConfig is the default RequestID middleware config.
DefaultRequestIDConfig = RequestIDConfig{
Skipper: DefaultSkipper,
Generator: generator,
}
)
// RequestID returns a X-Request-ID middleware. // RequestID returns a X-Request-ID middleware.
func RequestID() echo.MiddlewareFunc { func RequestID() echo.MiddlewareFunc {
return RequestIDWithConfig(DefaultRequestIDConfig) return RequestIDWithConfig(RequestIDConfig{})
} }
// RequestIDWithConfig returns a X-Request-ID middleware with config. // RequestIDWithConfig returns a X-Request-ID middleware with config or panics on invalid configuration.
func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration
func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultRequestIDConfig.Skipper config.Skipper = DefaultSkipper
} }
if config.Generator == nil { if config.Generator == nil {
config.Generator = generator config.Generator = createRandomStringGenerator(32)
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -62,9 +55,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
} }, nil
}
func generator() string {
return random.String(32)
} }

View File

@ -18,25 +18,104 @@ func TestRequestID(t *testing.T) {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }
rid := RequestIDWithConfig(RequestIDConfig{}) rid := RequestID()
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
}
func TestMustRequestIDWithConfig_skipper(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
generatorCalled := false
e.Use(RequestIDWithConfig(RequestIDConfig{
Skipper: func(c echo.Context) bool {
return true
},
Generator: func() string {
generatorCalled = true
return "customGenerator"
},
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
assert.Equal(t, "test", res.Body.String())
assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "")
assert.False(t, generatorCalled)
}
func TestMustRequestIDWithConfig_customGenerator(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid := RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return "customGenerator" },
})
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
}
func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
called := false
rid := RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return "customGenerator" },
RequestIDHandler: func(c echo.Context, s string) {
called = true
},
})
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
assert.True(t, called)
}
func TestRequestIDWithConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid, err := RequestIDConfig{}.ToMiddleware()
assert.NoError(t, err)
h := rid(handler) h := rid(handler)
h(c) h(c)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
// Custom generator and handler // Custom generator
customID := "customGenerator"
calledHandler := false
rid = RequestIDWithConfig(RequestIDConfig{ rid = RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return customID }, Generator: func() string { return "customGenerator" },
RequestIDHandler: func(_ echo.Context, id string) {
calledHandler = true
assert.Equal(t, customID, id)
},
}) })
h = rid(handler) h = rid(handler)
h(c) h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
assert.True(t, calledHandler)
} }
func TestRequestID_IDNotAltered(t *testing.T) { func TestRequestID_IDNotAltered(t *testing.T) {

View File

@ -24,6 +24,7 @@ import (
// LogStatus: true, // LogStatus: true,
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// logger.Info(). // logger.Info().
// Date("request_start", v.StartTime).
// Str("URI", v.URI). // Str("URI", v.URI).
// Int("status", v.Status). // Int("status", v.Status).
// Msg("request") // Msg("request")
@ -39,6 +40,7 @@ import (
// LogStatus: true, // LogStatus: true,
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// logger.Info("request", // logger.Info("request",
// zap.Time("request_start", v.StartTime),
// zap.String("URI", v.URI), // zap.String("URI", v.URI),
// zap.Int("status", v.Status), // zap.Int("status", v.Status),
// ) // )
@ -54,6 +56,7 @@ import (
// LogStatus: true, // LogStatus: true,
// LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { // LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error {
// log.WithFields(logrus.Fields{ // log.WithFields(logrus.Fields{
// "request_start": values.StartTime,
// "URI": values.URI, // "URI": values.URI,
// "status": values.Status, // "status": values.Status,
// }).Info("request") // }).Info("request")
@ -158,15 +161,15 @@ type RequestLoggerValues struct {
// ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct. // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct.
ResponseSize int64 ResponseSize int64
// Headers are list of headers from request. Note: request can contain more than one header with same value so slice // Headers are list of headers from request. Note: request can contain more than one header with same value so slice
// of values is been logger for each given header. // of values is what will be returned/logged for each given header.
// Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
// names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
Headers map[string][]string Headers map[string][]string
// QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter
// with same name so slice of values is been logger for each given query param name. // with same name so slice of values is what will be returned/logged for each given query param name.
QueryParams map[string][]string QueryParams map[string][]string
// FormValues are list of form values from request body+URI. Note: request can contain more than one form value with // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with
// same name so slice of values is been logger for each given form value name. // same name so slice of values is what will be returned/logged for each given form value name.
FormValues map[string][]string FormValues map[string][]string
} }

View File

@ -289,7 +289,7 @@ func TestRequestLogger_allFields(t *testing.T) {
req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec).(echo.EditableContext)
c.SetPath("/test*") c.SetPath("/test*")

View File

@ -1,14 +1,14 @@
package middleware package middleware
import ( import (
"errors"
"regexp" "regexp"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type (
// RewriteConfig defines the config for Rewrite middleware. // RewriteConfig defines the config for Rewrite middleware.
RewriteConfig struct { type RewriteConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
@ -20,43 +20,39 @@ type (
// "/js/*": "/public/javascripts/$1", // "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2", // "/users/*/orders/*": "/user/$1/order/$2",
// Required. // Required.
Rules map[string]string `yaml:"rules"` Rules map[string]string
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example: // Example:
// "^/old/[0.9]+/": "/new", // "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1", // "^/api/.+?/(.*)": "/v2/$1",
RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` RegexRules map[*regexp.Regexp]string
} }
)
var (
// DefaultRewriteConfig is the default Rewrite middleware config.
DefaultRewriteConfig = RewriteConfig{
Skipper: DefaultSkipper,
}
)
// Rewrite returns a Rewrite middleware. // Rewrite returns a Rewrite middleware.
// //
// Rewrite middleware rewrites the URL path based on the provided rules. // Rewrite middleware rewrites the URL path based on the provided rules.
func Rewrite(rules map[string]string) echo.MiddlewareFunc { func Rewrite(rules map[string]string) echo.MiddlewareFunc {
c := DefaultRewriteConfig c := RewriteConfig{}
c.Rules = rules c.Rules = rules
return RewriteWithConfig(c) return RewriteWithConfig(c)
} }
// RewriteWithConfig returns a Rewrite middleware with config. // RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration.
// See: `Rewrite()`. //
// Rewrite middleware rewrites the URL path based on the provided rules.
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
if config.Rules == nil && config.RegexRules == nil {
panic("echo: rewrite middleware requires url path rewrite rules or regex rules")
} }
// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration
func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper config.Skipper = DefaultSkipper
}
if config.Rules == nil && config.RegexRules == nil {
return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules")
} }
if config.RegexRules == nil { if config.RegexRules == nil {
@ -77,5 +73,5 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
} }, nil
} }

View File

@ -24,10 +24,10 @@ func TestRewriteAfterRouting(t *testing.T) {
}, },
})) }))
e.GET("/public/*", func(c echo.Context) error { e.GET("/public/*", func(c echo.Context) error {
return c.String(http.StatusOK, c.Param("*")) return c.String(http.StatusOK, c.PathParam("*"))
}) })
e.GET("/*", func(c echo.Context) error { e.GET("/*", func(c echo.Context) error {
return c.String(http.StatusOK, c.Param("*")) return c.String(http.StatusOK, c.PathParam("*"))
}) })
var testCases = []struct { var testCases = []struct {
@ -90,20 +90,74 @@ func TestRewriteAfterRouting(t *testing.T) {
} }
} }
func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) {
assert.Panics(t, func() {
RewriteWithConfig(RewriteConfig{})
})
}
func TestMustRewriteWithConfig_skipper(t *testing.T) {
var testCases = []struct {
name string
givenSkipper func(c echo.Context) bool
whenURL string
expectURL string
expectStatus int
}{
{
name: "not skipped",
whenURL: "/old",
expectURL: "/new",
expectStatus: http.StatusOK,
},
{
name: "skipped",
givenSkipper: func(c echo.Context) bool {
return true
},
whenURL: "/old",
expectURL: "/old",
expectStatus: http.StatusNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Pre(RewriteWithConfig(
RewriteConfig{
Skipper: tc.givenSkipper,
Rules: map[string]string{"/old": "/new"}},
))
e.GET("/new", func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectURL, req.URL.EscapedPath())
assert.Equal(t, tc.expectStatus, rec.Code)
})
}
}
// Issue #1086 // Issue #1086
func TestEchoRewritePreMiddleware(t *testing.T) { func TestEchoRewritePreMiddleware(t *testing.T) {
e := echo.New() e := echo.New()
r := e.Router()
// Rewrite old url to new one // Rewrite old url to new one
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(Rewrite(map[string]string{ e.Pre(RewriteWithConfig(RewriteConfig{
"/old": "/new", Rules: map[string]string{"/old": "/new"}}),
}, )
))
// Route // Route
r.Add(http.MethodGet, "/new", func(c echo.Context) error { e.Add(http.MethodGet, "/new", func(c echo.Context) error {
return c.NoContent(http.StatusOK) return c.NoContent(http.StatusOK)
}) })
@ -117,7 +171,6 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
// Issue #1143 // Issue #1143
func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
e := echo.New() e := echo.New()
r := e.Router()
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(RewriteWithConfig(RewriteConfig{ e.Pre(RewriteWithConfig(RewriteConfig{
@ -127,10 +180,10 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
}, },
})) }))
r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
return c.String(http.StatusOK, "hosts") return c.String(http.StatusOK, "hosts")
}) })
r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { e.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
return c.String(http.StatusOK, "eng") return c.String(http.StatusOK, "eng")
}) })

View File

@ -6,21 +6,20 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type (
// SecureConfig defines the config for Secure middleware. // SecureConfig defines the config for Secure middleware.
SecureConfig struct { type SecureConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// XSSProtection provides protection against cross-site scripting attack (XSS) // XSSProtection provides protection against cross-site scripting attack (XSS)
// by setting the `X-XSS-Protection` header. // by setting the `X-XSS-Protection` header.
// Optional. Default value "1; mode=block". // Optional. Default value "1; mode=block".
XSSProtection string `yaml:"xss_protection"` XSSProtection string
// ContentTypeNosniff provides protection against overriding Content-Type // ContentTypeNosniff provides protection against overriding Content-Type
// header by setting the `X-Content-Type-Options` header. // header by setting the `X-Content-Type-Options` header.
// Optional. Default value "nosniff". // Optional. Default value "nosniff".
ContentTypeNosniff string `yaml:"content_type_nosniff"` ContentTypeNosniff string
// XFrameOptions can be used to indicate whether or not a browser should // XFrameOptions can be used to indicate whether or not a browser should
// be allowed to render a page in a <frame>, <iframe> or <object> . // be allowed to render a page in a <frame>, <iframe> or <object> .
@ -32,58 +31,55 @@ type (
// - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself. // - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself.
// - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so. // - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so.
// - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin. // - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin.
XFrameOptions string `yaml:"x_frame_options"` XFrameOptions string
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how // HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how
// long (in seconds) browsers should remember that this site is only to // long (in seconds) browsers should remember that this site is only to
// be accessed using HTTPS. This reduces your exposure to some SSL-stripping // be accessed using HTTPS. This reduces your exposure to some SSL-stripping
// man-in-the-middle (MITM) attacks. // man-in-the-middle (MITM) attacks.
// Optional. Default value 0. // Optional. Default value 0.
HSTSMaxAge int `yaml:"hsts_max_age"` HSTSMaxAge int
// HSTSExcludeSubdomains won't include subdomains tag in the `Strict Transport Security` // HSTSExcludeSubdomains won't include subdomains tag in the `Strict Transport Security`
// header, excluding all subdomains from security policy. It has no effect // header, excluding all subdomains from security policy. It has no effect
// unless HSTSMaxAge is set to a non-zero value. // unless HSTSMaxAge is set to a non-zero value.
// Optional. Default value false. // Optional. Default value false.
HSTSExcludeSubdomains bool `yaml:"hsts_exclude_subdomains"` HSTSExcludeSubdomains bool
// ContentSecurityPolicy sets the `Content-Security-Policy` header providing // ContentSecurityPolicy sets the `Content-Security-Policy` header providing
// security against cross-site scripting (XSS), clickjacking and other code // security against cross-site scripting (XSS), clickjacking and other code
// injection attacks resulting from execution of malicious content in the // injection attacks resulting from execution of malicious content in the
// trusted web page context. // trusted web page context.
// Optional. Default value "". // Optional. Default value "".
ContentSecurityPolicy string `yaml:"content_security_policy"` ContentSecurityPolicy string
// CSPReportOnly would use the `Content-Security-Policy-Report-Only` header instead // CSPReportOnly would use the `Content-Security-Policy-Report-Only` header instead
// of the `Content-Security-Policy` header. This allows iterative updates of the // of the `Content-Security-Policy` header. This allows iterative updates of the
// content security policy by only reporting the violations that would // content security policy by only reporting the violations that would
// have occurred instead of blocking the resource. // have occurred instead of blocking the resource.
// Optional. Default value false. // Optional. Default value false.
CSPReportOnly bool `yaml:"csp_report_only"` CSPReportOnly bool
// HSTSPreloadEnabled will add the preload tag in the `Strict Transport Security` // HSTSPreloadEnabled will add the preload tag in the `Strict Transport Security`
// header, which enables the domain to be included in the HSTS preload list // header, which enables the domain to be included in the HSTS preload list
// maintained by Chrome (and used by Firefox and Safari): https://hstspreload.org/ // maintained by Chrome (and used by Firefox and Safari): https://hstspreload.org/
// Optional. Default value false. // Optional. Default value false.
HSTSPreloadEnabled bool `yaml:"hsts_preload_enabled"` HSTSPreloadEnabled bool
// ReferrerPolicy sets the `Referrer-Policy` header providing security against // ReferrerPolicy sets the `Referrer-Policy` header providing security against
// leaking potentially sensitive request paths to third parties. // leaking potentially sensitive request paths to third parties.
// Optional. Default value "". // Optional. Default value "".
ReferrerPolicy string `yaml:"referrer_policy"` ReferrerPolicy string
} }
)
var (
// DefaultSecureConfig is the default Secure middleware config. // DefaultSecureConfig is the default Secure middleware config.
DefaultSecureConfig = SecureConfig{ var DefaultSecureConfig = SecureConfig{
Skipper: DefaultSkipper, Skipper: DefaultSkipper,
XSSProtection: "1; mode=block", XSSProtection: "1; mode=block",
ContentTypeNosniff: "nosniff", ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN", XFrameOptions: "SAMEORIGIN",
HSTSPreloadEnabled: false, HSTSPreloadEnabled: false,
} }
)
// Secure returns a Secure middleware. // Secure returns a Secure middleware.
// Secure middleware provides protection against cross-site scripting (XSS) attack, // Secure middleware provides protection against cross-site scripting (XSS) attack,
@ -93,9 +89,13 @@ func Secure() echo.MiddlewareFunc {
return SecureWithConfig(DefaultSecureConfig) return SecureWithConfig(DefaultSecureConfig)
} }
// SecureWithConfig returns a Secure middleware with config. // SecureWithConfig returns a Secure middleware with config or panics on invalid configuration.
// See: `Secure()`.
func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc { func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts SecureConfig to middleware or returns an error for invalid configuration
func (config SecureConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultSecureConfig.Skipper config.Skipper = DefaultSecureConfig.Skipper
@ -141,5 +141,5 @@ func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
} }, nil
} }

View File

@ -19,26 +19,40 @@ func TestSecure(t *testing.T) {
} }
// Default // Default
Secure()(h)(c) err := Secure()(h)(c)
assert.NoError(t, err)
assert.Equal(t, "1; mode=block", rec.Header().Get(echo.HeaderXXSSProtection)) assert.Equal(t, "1; mode=block", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "nosniff", rec.Header().Get(echo.HeaderXContentTypeOptions)) assert.Equal(t, "nosniff", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "SAMEORIGIN", rec.Header().Get(echo.HeaderXFrameOptions)) assert.Equal(t, "SAMEORIGIN", rec.Header().Get(echo.HeaderXFrameOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderStrictTransportSecurity)) assert.Equal(t, "", rec.Header().Get(echo.HeaderStrictTransportSecurity))
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy)) assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy))
assert.Equal(t, "", rec.Header().Get(echo.HeaderReferrerPolicy)) assert.Equal(t, "", rec.Header().Get(echo.HeaderReferrerPolicy))
}
// Custom func TestSecureWithConfig(t *testing.T) {
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXForwardedProto, "https") req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder() rec := httptest.NewRecorder()
c = e.NewContext(req, rec) c := e.NewContext(req, rec)
SecureWithConfig(SecureConfig{ mw, err := SecureConfig{
XSSProtection: "", XSSProtection: "",
ContentTypeNosniff: "", ContentTypeNosniff: "",
XFrameOptions: "", XFrameOptions: "",
HSTSMaxAge: 3600, HSTSMaxAge: 3600,
ContentSecurityPolicy: "default-src 'self'", ContentSecurityPolicy: "default-src 'self'",
ReferrerPolicy: "origin", ReferrerPolicy: "origin",
})(h)(c) }.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions))
@ -47,11 +61,21 @@ func TestSecure(t *testing.T) {
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly)) assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly))
assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy)) assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy))
}
func TestSecureWithConfig_CSPReportOnly(t *testing.T) {
// Custom with CSPReportOnly flag // Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXForwardedProto, "https") req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder() rec := httptest.NewRecorder()
c = e.NewContext(req, rec) c := e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
err := SecureWithConfig(SecureConfig{
XSSProtection: "", XSSProtection: "",
ContentTypeNosniff: "", ContentTypeNosniff: "",
XFrameOptions: "", XFrameOptions: "",
@ -60,6 +84,8 @@ func TestSecure(t *testing.T) {
CSPReportOnly: true, CSPReportOnly: true,
ReferrerPolicy: "origin", ReferrerPolicy: "origin",
})(h)(c) })(h)(c)
assert.NoError(t, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions))
@ -67,25 +93,51 @@ func TestSecure(t *testing.T) {
assert.Equal(t, "default-src 'self'", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly)) assert.Equal(t, "default-src 'self'", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly))
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy)) assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy))
assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy)) assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy))
}
func TestSecureWithConfig_HSTSPreloadEnabled(t *testing.T) {
// Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
// Custom, with preload option enabled // Custom, with preload option enabled
req.Header.Set(echo.HeaderXForwardedProto, "https") req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder() rec := httptest.NewRecorder()
c = e.NewContext(req, rec) c := e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
err := SecureWithConfig(SecureConfig{
HSTSMaxAge: 3600, HSTSMaxAge: 3600,
HSTSPreloadEnabled: true, HSTSPreloadEnabled: true,
})(h)(c) })(h)(c)
assert.NoError(t, err)
assert.Equal(t, "max-age=3600; includeSubdomains; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity)) assert.Equal(t, "max-age=3600; includeSubdomains; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity))
}
func TestSecureWithConfig_HSTSExcludeSubdomains(t *testing.T) {
// Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
// Custom, with preload option enabled and subdomains excluded // Custom, with preload option enabled and subdomains excluded
req.Header.Set(echo.HeaderXForwardedProto, "https") req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder() rec := httptest.NewRecorder()
c = e.NewContext(req, rec) c := e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
err := SecureWithConfig(SecureConfig{
HSTSMaxAge: 3600, HSTSMaxAge: 3600,
HSTSPreloadEnabled: true, HSTSPreloadEnabled: true,
HSTSExcludeSubdomains: true, HSTSExcludeSubdomains: true,
})(h)(c) })(h)(c)
assert.NoError(t, err)
assert.Equal(t, "max-age=3600; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity)) assert.Equal(t, "max-age=3600; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity))
} }

View File

@ -1,44 +1,45 @@
package middleware package middleware
import ( import (
"errors"
"net/http"
"strings" "strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // AddTrailingSlashConfig is the middleware config for adding trailing slash to the request.
// TrailingSlashConfig defines the config for TrailingSlash middleware. type AddTrailingSlashConfig struct {
TrailingSlashConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// Status code to be used when redirecting the request. // Status code to be used when redirecting the request.
// Optional, but when provided the request is redirected using this code. // Optional, but when provided the request is redirected using this code.
RedirectCode int `yaml:"redirect_code"` // Valid status codes: [300...308]
RedirectCode int
} }
)
var (
// DefaultTrailingSlashConfig is the default TrailingSlash middleware config.
DefaultTrailingSlashConfig = TrailingSlashConfig{
Skipper: DefaultSkipper,
}
)
// AddTrailingSlash returns a root level (before router) middleware which adds a // AddTrailingSlash returns a root level (before router) middleware which adds a
// trailing slash to the request `URL#Path`. // trailing slash to the request `URL#Path`.
// //
// Usage `Echo#Pre(AddTrailingSlash())` // Usage `Echo#Pre(AddTrailingSlash())`
func AddTrailingSlash() echo.MiddlewareFunc { func AddTrailingSlash() echo.MiddlewareFunc {
return AddTrailingSlashWithConfig(DefaultTrailingSlashConfig) return AddTrailingSlashWithConfig(AddTrailingSlashConfig{})
} }
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config. // AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config or panics on invalid configuration.
// See `AddTrailingSlash()`. func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc {
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config)
// Defaults }
// ToMiddleware converts AddTrailingSlashConfig to middleware or returns an error for invalid configuration
func (config AddTrailingSlashConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper config.Skipper = DefaultSkipper
}
if config.RedirectCode != 0 && (config.RedirectCode < http.StatusMultipleChoices || config.RedirectCode > http.StatusPermanentRedirect) {
// this is same check as `echo.context.Redirect()` does, but we can check this before even serving the request.
return nil, errors.New("invalid redirect code for add trailing slash middleware")
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -69,7 +70,17 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
} }
return next(c) return next(c)
} }
}, nil
} }
// RemoveTrailingSlashConfig is the middleware config for removing trailing slash from the request.
type RemoveTrailingSlashConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Status code to be used when redirecting the request.
// Optional, but when provided the request is redirected using this code.
RedirectCode int
} }
// RemoveTrailingSlash returns a root level (before router) middleware which removes // RemoveTrailingSlash returns a root level (before router) middleware which removes
@ -77,15 +88,22 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
// //
// Usage `Echo#Pre(RemoveTrailingSlash())` // Usage `Echo#Pre(RemoveTrailingSlash())`
func RemoveTrailingSlash() echo.MiddlewareFunc { func RemoveTrailingSlash() echo.MiddlewareFunc {
return RemoveTrailingSlashWithConfig(TrailingSlashConfig{}) return RemoveTrailingSlashWithConfig(RemoveTrailingSlashConfig{})
} }
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config. // RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config or panics on invalid configuration.
// See `RemoveTrailingSlash()`. func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc {
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config)
// Defaults }
// ToMiddleware converts RemoveTrailingSlashConfig to middleware or returns an error for invalid configuration
func (config RemoveTrailingSlashConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper config.Skipper = DefaultSkipper
}
if config.RedirectCode != 0 && (config.RedirectCode < http.StatusMultipleChoices || config.RedirectCode > http.StatusPermanentRedirect) {
// this is same check as `echo.context.Redirect()` does, but we can check this before even serving the request.
return nil, errors.New("invalid redirect code for remove trailing slash middleware")
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -117,7 +135,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
} }
return next(c) return next(c)
} }
} }, nil
} }
func sanitizeURI(uri string) string { func sanitizeURI(uri string) string {

View File

@ -67,7 +67,7 @@ func TestAddTrailingSlashWithConfig(t *testing.T) {
t.Run(tc.whenURL, func(t *testing.T) { t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New() e := echo.New()
mw := AddTrailingSlashWithConfig(TrailingSlashConfig{ mw := AddTrailingSlashWithConfig(AddTrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently, RedirectCode: http.StatusMovedPermanently,
}) })
h := mw(func(c echo.Context) error { h := mw(func(c echo.Context) error {
@ -203,7 +203,7 @@ func TestRemoveTrailingSlashWithConfig(t *testing.T) {
t.Run(tc.whenURL, func(t *testing.T) { t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New() e := echo.New()
mw := RemoveTrailingSlashWithConfig(TrailingSlashConfig{ mw := RemoveTrailingSlashWithConfig(RemoveTrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently, RedirectCode: http.StatusMovedPermanently,
}) })
h := mw(func(c echo.Context) error { h := mw(func(c echo.Context) error {

View File

@ -1,55 +1,65 @@
package middleware package middleware
import ( import (
"errors"
"fmt" "fmt"
"html/template" "html/template"
"io"
"io/fs"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/bytes"
) )
type (
// StaticConfig defines the config for Static middleware. // StaticConfig defines the config for Static middleware.
StaticConfig struct { type StaticConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// Root directory from where the static content is served. // Root directory from where the static content is served (relative to given Filesystem).
// `Root: "."` means root folder from Filesystem.
// Required. // Required.
Root string `yaml:"root"` Root string
// Filesystem provides access to the static content.
// Optional. Defaults to echo.Filesystem (serves files from `.` folder where executable is started)
Filesystem fs.FS
// Index file for serving a directory. // Index file for serving a directory.
// Optional. Default value "index.html". // Optional. Default value "index.html".
Index string `yaml:"index"` Index string
// Enable HTML5 mode by forwarding all not-found requests to root so that // Enable HTML5 mode by forwarding all not-found requests to root so that
// SPA (single-page application) can handle the routing. // SPA (single-page application) can handle the routing.
// Optional. Default value false. // Optional. Default value false.
HTML5 bool `yaml:"html5"` HTML5 bool
// Enable directory browsing. // Enable directory browsing.
// Optional. Default value false. // Optional. Default value false.
Browse bool `yaml:"browse"` Browse bool
// Enable ignoring of the base of the URL path. // Enable ignoring of the base of the URL path.
// Example: when assigning a static middleware to a non root path group, // Example: when assigning a static middleware to a non root path group,
// the filesystem path is not doubled // the filesystem path is not doubled
// Optional. Default value false. // Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"` IgnoreBase bool
// Filesystem provides access to the static content. // DisablePathUnescaping disables path parameter (param: *) unescaping. This is useful when router is set to unescape
// Optional. Defaults to http.Dir(config.Root) // all parameter and doing it again in this middleware would corrupt filename that is requested.
Filesystem http.FileSystem `yaml:"-"` DisablePathUnescaping bool
// DirectoryListTemplate is template to list directory contents
// Optional. Default to `directoryListHTMLTemplate` constant below.
DirectoryListTemplate string
} }
)
const html = ` const directoryListHTMLTemplate = `
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
@ -121,25 +131,26 @@ const html = `
</html> </html>
` `
var (
// DefaultStaticConfig is the default Static middleware config. // DefaultStaticConfig is the default Static middleware config.
DefaultStaticConfig = StaticConfig{ var DefaultStaticConfig = StaticConfig{
Skipper: DefaultSkipper, Skipper: DefaultSkipper,
Index: "index.html", Index: "index.html",
} }
)
// Static returns a Static middleware to serves static content from the provided // Static returns a Static middleware to serves static content from the provided root directory.
// root directory.
func Static(root string) echo.MiddlewareFunc { func Static(root string) echo.MiddlewareFunc {
c := DefaultStaticConfig c := DefaultStaticConfig
c.Root = root c.Root = root
return StaticWithConfig(c) return StaticWithConfig(c)
} }
// StaticWithConfig returns a Static middleware with config. // StaticWithConfig returns a Static middleware to serves static content or panics on invalid configuration.
// See `Static()`.
func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts StaticConfig to middleware or returns an error for invalid configuration
func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Root == "" { if config.Root == "" {
config.Root = "." // For security we want to restrict to CWD. config.Root = "." // For security we want to restrict to CWD.
@ -150,30 +161,32 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
if config.Index == "" { if config.Index == "" {
config.Index = DefaultStaticConfig.Index config.Index = DefaultStaticConfig.Index
} }
if config.Filesystem == nil { if config.DirectoryListTemplate == "" {
config.Filesystem = http.Dir(config.Root) config.DirectoryListTemplate = directoryListHTMLTemplate
config.Root = "."
} }
// Index template dirListTemplate, err := template.New("index").Parse(config.DirectoryListTemplate)
t, err := template.New("index").Parse(html)
if err != nil { if err != nil {
panic(fmt.Sprintf("echo: %v", err)) return nil, fmt.Errorf("echo static middleware directory list template parsing error: %w", err)
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) error {
if config.Skipper(c) { if config.Skipper(c) {
return next(c) return next(c)
} }
p := c.Request().URL.Path p := c.Request().URL.Path
if strings.HasSuffix(c.Path(), "*") { // When serving from a group, e.g. `/static*`. pathUnescape := true
p = c.Param("*") if c.RouteMatchType() == echo.RouteMatchFound && strings.HasSuffix(c.Path(), "*") { // When serving from a group, e.g. `/static*`.
p = c.PathParam("*")
pathUnescape = !config.DisablePathUnescaping // because router could already do PathUnescape
} }
if pathUnescape {
p, err = url.PathUnescape(p) p, err = url.PathUnescape(p)
if err != nil { if err != nil {
return return err
}
} }
name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security
@ -186,22 +199,29 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
} }
} }
file, err := openFile(config.Filesystem, name) currentFS := config.Filesystem
if currentFS == nil {
currentFS = c.Echo().Filesystem
}
file, err := openFile(currentFS, name)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
return err return err
} }
if err = next(c); err == nil { // when file does not exist let handler to handle that request. if it succeeds then we are done
return err err = next(c)
if err == nil {
return nil
} }
he, ok := err.(*echo.HTTPError) he, ok := err.(*echo.HTTPError)
if !(ok && config.HTML5 && he.Code == http.StatusNotFound) { if !(ok && config.HTML5 && he.Code == http.StatusNotFound) {
return err return err
} }
// is case HTML5 mode is enabled + echo 404 we serve index to the client
file, err = openFile(config.Filesystem, filepath.Join(config.Root, config.Index)) file, err = openFile(currentFS, filepath.Join(config.Root, config.Index))
if err != nil { if err != nil {
return err return err
} }
@ -215,10 +235,10 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
} }
if info.IsDir() { if info.IsDir() {
index, err := openFile(config.Filesystem, filepath.Join(name, config.Index)) index, err := openFile(currentFS, filepath.Join(name, config.Index))
if err != nil { if err != nil {
if config.Browse { if config.Browse {
return listDir(t, name, file, c.Response()) return listDir(dirListTemplate, name, currentFS, file, c.Response())
} }
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -238,25 +258,24 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
return serveFile(c, file, info) return serveFile(c, file, info)
} }
} }, nil
} }
func openFile(fs http.FileSystem, name string) (http.File, error) { func openFile(fs fs.FS, name string) (fs.File, error) {
pathWithSlashes := filepath.ToSlash(name) pathWithSlashes := filepath.ToSlash(name)
return fs.Open(pathWithSlashes) return fs.Open(pathWithSlashes)
} }
func serveFile(c echo.Context, file http.File, info os.FileInfo) error { func serveFile(c echo.Context, file fs.File, info os.FileInfo) error {
http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), file) ff, ok := file.(io.ReadSeeker)
if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), ff)
return nil return nil
} }
func listDir(t *template.Template, name string, dir http.File, res *echo.Response) (err error) { func listDir(t *template.Template, name string, filesystem fs.FS, dir fs.File, res *echo.Response) error {
files, err := dir.Readdir(-1)
if err != nil {
return
}
// Create directory index // Create directory index
res.Header().Set(echo.HeaderContentType, echo.MIMETextHTMLCharsetUTF8) res.Header().Set(echo.HeaderContentType, echo.MIMETextHTMLCharsetUTF8)
data := struct { data := struct {
@ -265,12 +284,60 @@ func listDir(t *template.Template, name string, dir http.File, res *echo.Respons
}{ }{
Name: name, Name: name,
} }
for _, f := range files { err := fs.WalkDir(filesystem, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
info, infoErr := d.Info()
if infoErr != nil {
return fmt.Errorf("static middleware list dir error when getting file info: %w", err)
}
data.Files = append(data.Files, struct { data.Files = append(data.Files, struct {
Name string Name string
Dir bool Dir bool
Size string Size string
}{f.Name(), f.IsDir(), bytes.Format(f.Size())}) }{d.Name(), d.IsDir(), format(info.Size())})
return nil
})
if err != nil {
return err
} }
return t.Execute(res, data) return t.Execute(res, data)
} }
// format formats bytes integer to human readable string.
// For example, 31323 bytes will return 30.59KB.
func format(b int64) string {
multiple := ""
value := float64(b)
switch {
case b >= EB:
value /= float64(EB)
multiple = "EB"
case b >= PB:
value /= float64(PB)
multiple = "PB"
case b >= TB:
value /= float64(TB)
multiple = "TB"
case b >= GB:
value /= float64(GB)
multiple = "GB"
case b >= MB:
value /= float64(MB)
multiple = "MB"
case b >= KB:
value /= float64(KB)
multiple = "KB"
case b == 0:
return "0"
default:
return strconv.FormatInt(b, 10) + "B"
}
return fmt.Sprintf("%.2f%s", value, multiple)
}

View File

@ -81,14 +81,16 @@ func TestStatic_CustomFS(t *testing.T) {
config := StaticConfig{ config := StaticConfig{
Root: ".", Root: ".",
Filesystem: http.FS(tc.filesystem), Filesystem: tc.filesystem,
} }
if tc.root != "" { if tc.root != "" {
config.Root = tc.root config.Root = tc.root
} }
middlewareFunc := StaticWithConfig(config) middlewareFunc, err := config.ToMiddleware()
assert.NoError(t, err)
e.Use(middlewareFunc) e.Use(middlewareFunc)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"strings" "strings"
"testing" "testing"
@ -10,6 +11,37 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestStatic_useCaseForApiAndSPAs(t *testing.T) {
e := echo.New()
// serve single page application (SPA) files from server root
e.Use(StaticWithConfig(StaticConfig{
Root: ".",
// by default Echo filesystem is fixed to `./` but this does not allow `../` (moving up in folder structure past filesystem root)
Filesystem: os.DirFS("../_fixture"),
}))
// all requests to `/api/*` will end up in echo handlers (assuming there is not `api` folder and files)
api := e.Group("/api")
users := api.Group("/users")
users.GET("/info", func(c echo.Context) error {
return c.String(http.StatusOK, "users info")
})
req := httptest.NewRequest(http.MethodGet, "/api/users/info", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "users info", rec.Body.String())
req = httptest.NewRequest(http.MethodGet, "/index.html", nil)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, rec.Body.String(), "<title>Echo</title>")
}
func TestStatic(t *testing.T) { func TestStatic(t *testing.T) {
var testCases = []struct { var testCases = []struct {
name string name string
@ -35,7 +67,7 @@ func TestStatic(t *testing.T) {
{ {
name: "ok, when html5 mode serve index for any static file that does not exist", name: "ok, when html5 mode serve index for any static file that does not exist",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "../_fixture", Root: "_fixture",
HTML5: true, HTML5: true,
}, },
whenURL: "/random", whenURL: "/random",
@ -45,7 +77,7 @@ func TestStatic(t *testing.T) {
{ {
name: "ok, serve index as directory index listing files directory", name: "ok, serve index as directory index listing files directory",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "../_fixture/certs", Root: "_fixture/certs",
Browse: true, Browse: true,
}, },
whenURL: "/", whenURL: "/",
@ -55,7 +87,7 @@ func TestStatic(t *testing.T) {
{ {
name: "ok, serve directory index with IgnoreBase and browse", name: "ok, serve directory index with IgnoreBase and browse",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored Root: "_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
IgnoreBase: true, IgnoreBase: true,
Browse: true, Browse: true,
}, },
@ -67,7 +99,7 @@ func TestStatic(t *testing.T) {
{ {
name: "ok, serve file with IgnoreBase", name: "ok, serve file with IgnoreBase",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored Root: "_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
IgnoreBase: true, IgnoreBase: true,
Browse: true, Browse: true,
}, },
@ -95,15 +127,27 @@ func TestStatic(t *testing.T) {
expectContains: "{\"message\":\"Not Found\"}\n", expectContains: "{\"message\":\"Not Found\"}\n",
}, },
{ {
name: "ok, do not serve file, when a handler took care of the request", name: "ok, when no file then a handler will care of the request",
whenURL: "/regular-handler", whenURL: "/regular-handler",
expectCode: http.StatusOK, expectCode: http.StatusOK,
expectContains: "ok", expectContains: "ok",
}, },
{
name: "ok, skip middleware and serve handler",
givenConfig: &StaticConfig{
Root: "_fixture/images/",
Skipper: func(c echo.Context) bool {
return true
},
},
whenURL: "/walle.png",
expectCode: http.StatusTeapot,
expectContains: "walle",
},
{ {
name: "nok, when html5 fail if the index file does not exist", name: "nok, when html5 fail if the index file does not exist",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "../_fixture", Root: "_fixture",
HTML5: true, HTML5: true,
Index: "missing.html", Index: "missing.html",
}, },
@ -114,7 +158,7 @@ func TestStatic(t *testing.T) {
name: "ok, serve from http.FileSystem", name: "ok, serve from http.FileSystem",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "_fixture", Root: "_fixture",
Filesystem: http.Dir(".."), Filesystem: os.DirFS(".."),
}, },
whenURL: "/", whenURL: "/",
expectCode: http.StatusOK, expectCode: http.StatusOK,
@ -125,8 +169,9 @@ func TestStatic(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
e := echo.New() e := echo.New()
e.Filesystem = os.DirFS("../")
config := StaticConfig{Root: "../_fixture"} config := StaticConfig{Root: "_fixture"}
if tc.givenConfig != nil { if tc.givenConfig != nil {
config = *tc.givenConfig config = *tc.givenConfig
} }
@ -136,14 +181,17 @@ func TestStatic(t *testing.T) {
subGroup := e.Group(tc.givenAttachedToGroup, middlewareFunc) subGroup := e.Group(tc.givenAttachedToGroup, middlewareFunc)
// group without http handlers (routes) does not do anything. // group without http handlers (routes) does not do anything.
// Request is matched against http handlers (routes) that have group middleware attached to them // Request is matched against http handlers (routes) that have group middleware attached to them
subGroup.GET("", echo.NotFoundHandler) subGroup.GET("", func(c echo.Context) error { return echo.ErrNotFound })
subGroup.GET("/*", echo.NotFoundHandler) subGroup.GET("/*", func(c echo.Context) error { return echo.ErrNotFound })
} else { } else {
// middleware is on root level // middleware is on root level
e.Use(middlewareFunc) e.Use(middlewareFunc)
e.GET("/regular-handler", func(c echo.Context) error { e.GET("/regular-handler", func(c echo.Context) error {
return c.String(http.StatusOK, "ok") return c.String(http.StatusOK, "ok")
}) })
e.GET("/walle.png", func(c echo.Context) error {
return c.String(http.StatusTeapot, "walle")
})
} }
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
@ -177,7 +225,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "ok", name: "ok",
givenPrefix: "/images", givenPrefix: "/images",
givenRoot: "../_fixture/images", givenRoot: "_fixture/images",
whenURL: "/group/images/walle.png", whenURL: "/group/images/walle.png",
expectStatus: http.StatusOK, expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
@ -185,7 +233,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "No file", name: "No file",
givenPrefix: "/images", givenPrefix: "/images",
givenRoot: "../_fixture/scripts", givenRoot: "_fixture/scripts",
whenURL: "/group/images/bolt.png", whenURL: "/group/images/bolt.png",
expectStatus: http.StatusNotFound, expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -193,7 +241,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Directory not found (no trailing slash)", name: "Directory not found (no trailing slash)",
givenPrefix: "/images", givenPrefix: "/images",
givenRoot: "../_fixture/images", givenRoot: "_fixture/images",
whenURL: "/group/images/", whenURL: "/group/images/",
expectStatus: http.StatusNotFound, expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -201,7 +249,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Directory redirect", name: "Directory redirect",
givenPrefix: "/", givenPrefix: "/",
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/group/folder", whenURL: "/group/folder",
expectStatus: http.StatusMovedPermanently, expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/group/folder/", expectHeaderLocation: "/group/folder/",
@ -211,7 +259,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
name: "Prefixed directory 404 (request URL without slash)", name: "Prefixed directory 404 (request URL without slash)",
givenGroup: "_fixture", givenGroup: "_fixture",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/_fixture/folder", // no trailing slash whenURL: "/_fixture/folder", // no trailing slash
expectStatus: http.StatusNotFound, expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -220,7 +268,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
name: "Prefixed directory redirect (without slash redirect to slash)", name: "Prefixed directory redirect (without slash redirect to slash)",
givenGroup: "_fixture", givenGroup: "_fixture",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/_fixture/folder", // no trailing slash whenURL: "/_fixture/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently, expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/_fixture/folder/", expectHeaderLocation: "/_fixture/folder/",
@ -229,7 +277,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Directory with index.html", name: "Directory with index.html",
givenPrefix: "/", givenPrefix: "/",
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/group/", whenURL: "/group/",
expectStatus: http.StatusOK, expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>", expectBodyStartsWith: "<!doctype html>",
@ -237,7 +285,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Prefixed directory with index.html (prefix ending with slash)", name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/", givenPrefix: "/assets/",
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/group/assets/", whenURL: "/group/assets/",
expectStatus: http.StatusOK, expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>", expectBodyStartsWith: "<!doctype html>",
@ -245,7 +293,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Prefixed directory with index.html (prefix ending without slash)", name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets", givenPrefix: "/assets",
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/group/assets/", whenURL: "/group/assets/",
expectStatus: http.StatusOK, expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>", expectBodyStartsWith: "<!doctype html>",
@ -253,7 +301,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Sub-directory with index.html", name: "Sub-directory with index.html",
givenPrefix: "/", givenPrefix: "/",
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/group/folder/", whenURL: "/group/folder/",
expectStatus: http.StatusOK, expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>", expectBodyStartsWith: "<!doctype html>",
@ -261,7 +309,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "do not allow directory traversal (backslash - windows separator)", name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/", givenPrefix: "/",
givenRoot: "../_fixture/", givenRoot: "_fixture/",
whenURL: `/group/..\\middleware/basic_auth.go`, whenURL: `/group/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound, expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -269,7 +317,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "do not allow directory traversal (slash - unix separator)", name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/", givenPrefix: "/",
givenRoot: "../_fixture/", givenRoot: "_fixture/",
whenURL: `/group/../middleware/basic_auth.go`, whenURL: `/group/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound, expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -279,6 +327,8 @@ func TestStatic_GroupWithStatic(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
e := echo.New() e := echo.New()
e.Filesystem = os.DirFS("../") // so we can access test files
group := "/group" group := "/group"
if tc.givenGroup != "" { if tc.givenGroup != "" {
group = tc.givenGroup group = tc.givenGroup
@ -288,7 +338,9 @@ func TestStatic_GroupWithStatic(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code) assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String() body := rec.Body.String()
if tc.expectBodyStartsWith != "" { if tc.expectBodyStartsWith != "" {
@ -306,3 +358,73 @@ func TestStatic_GroupWithStatic(t *testing.T) {
}) })
} }
} }
func TestMustStaticWithConfig_panicsInvalidDirListTemplate(t *testing.T) {
assert.Panics(t, func() {
StaticWithConfig(StaticConfig{DirectoryListTemplate: `{{}`})
})
}
func TestFormat(t *testing.T) {
var testCases = []struct {
name string
when int64
expect string
}{
{
name: "byte",
when: 0,
expect: "0",
},
{
name: "bytes",
when: 515,
expect: "515B",
},
{
name: "KB",
when: 31323,
expect: "30.59KB",
},
{
name: "MB",
when: 13231323,
expect: "12.62MB",
},
{
name: "GB",
when: 7323232398,
expect: "6.82GB",
},
{
name: "TB",
when: 1_099_511_627_776,
expect: "1.00TB",
},
{
name: "PB",
when: 9923232398434432,
expect: "8.81PB",
},
{
// test with 7EB because of https://github.com/labstack/gommon/pull/38 and https://github.com/labstack/gommon/pull/43
//
// 8 exbi equals 2^64, therefore it cannot be stored in int64. The tests use
// the fact that on x86_64 the following expressions holds true:
// int64(0) - 1 == math.MaxInt64.
//
// However, this is not true for other platforms, specifically aarch64, s390x
// and ppc64le.
name: "EB",
when: 8070450532247929000,
expect: "7.00EB",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := format(tc.when)
assert.Equal(t, tc.expect, result)
})
}
}

View File

@ -2,10 +2,9 @@ package middleware
import ( import (
"context" "context"
"github.com/labstack/echo/v4"
"net/http" "net/http"
"time" "time"
"github.com/labstack/echo/v4"
) )
// --------------------------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------------------------
@ -55,9 +54,8 @@ import (
// }) // })
// //
type (
// TimeoutConfig defines the config for Timeout middleware. // TimeoutConfig defines the config for Timeout middleware.
TimeoutConfig struct { type TimeoutConfig struct {
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
@ -69,7 +67,7 @@ type (
// request timeouted and we already had sent the error code (503) and message response to the client. // request timeouted and we already had sent the error code (503) and message response to the client.
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()` // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
OnTimeoutRouteErrorHandler func(err error, c echo.Context) OnTimeoutRouteErrorHandler func(c echo.Context, err error)
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout // Timeout configures a timeout for the middleware, defaults to 0 for no timeout
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
@ -77,29 +75,22 @@ type (
// difference over 500microseconds (0.5millisecond) response seems to be reliable // difference over 500microseconds (0.5millisecond) response seems to be reliable
Timeout time.Duration Timeout time.Duration
} }
)
var (
// DefaultTimeoutConfig is the default Timeout middleware config.
DefaultTimeoutConfig = TimeoutConfig{
Skipper: DefaultSkipper,
Timeout: 0,
ErrorMessage: "",
}
)
// Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler // Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler
// call runs for longer than its time limit. NB: timeout does not stop handler execution. // call runs for longer than its time limit. NB: timeout does not stop handler execution.
func Timeout() echo.MiddlewareFunc { func Timeout() echo.MiddlewareFunc {
return TimeoutWithConfig(DefaultTimeoutConfig) return TimeoutWithConfig(TimeoutConfig{})
} }
// TimeoutWithConfig returns a Timeout middleware with config. // TimeoutWithConfig returns a Timeout middleware with config or panics on invalid configuration.
// See: `Timeout()`.
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts TimeoutConfig to middleware or returns an error for invalid configuration
func (config TimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultTimeoutConfig.Skipper config.Skipper = DefaultSkipper
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -108,29 +99,30 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
errChan := make(chan error, 1)
handlerWrapper := echoHandlerFuncWrapper{ handlerWrapper := echoHandlerFuncWrapper{
ctx: c, ctx: c,
handler: next, handler: next,
errChan: make(chan error, 1), errChan: errChan,
errHandler: config.OnTimeoutRouteErrorHandler, errHandler: config.OnTimeoutRouteErrorHandler,
} }
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage) handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
handler.ServeHTTP(c.Response().Writer, c.Request()) handler.ServeHTTP(c.Response().Writer, c.Request())
select { select {
case err := <-handlerWrapper.errChan: case err := <-errChan:
return err return err
default: default:
return nil return nil
} }
} }
} }, nil
} }
type echoHandlerFuncWrapper struct { type echoHandlerFuncWrapper struct {
ctx echo.Context ctx echo.Context
handler echo.HandlerFunc handler echo.HandlerFunc
errHandler func(err error, c echo.Context) errHandler func(c echo.Context, err error)
errChan chan error errChan chan error
} }
@ -156,7 +148,7 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques
err := t.handler(t.ctx) err := t.handler(t.ctx)
if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded { if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded {
if err != nil && t.errHandler != nil { if err != nil && t.errHandler != nil {
t.errHandler(err, t.ctx) t.errHandler(t.ctx, err)
} }
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
} }

View File

@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
@ -14,9 +16,6 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
) )
func TestTimeoutSkipper(t *testing.T) { func TestTimeoutSkipper(t *testing.T) {
@ -111,7 +110,7 @@ func TestTimeoutOnTimeoutRouteErrorHandler(t *testing.T) {
actualErrChan := make(chan error, 1) actualErrChan := make(chan error, 1)
m := TimeoutWithConfig(TimeoutConfig{ m := TimeoutWithConfig(TimeoutConfig{
Timeout: 1 * time.Millisecond, Timeout: 1 * time.Millisecond,
OnTimeoutRouteErrorHandler: func(err error, c echo.Context) { OnTimeoutRouteErrorHandler: func(c echo.Context, err error) {
actualErrChan <- err actualErrChan <- err
}, },
}) })
@ -360,7 +359,7 @@ func TestTimeoutWithFullEchoStack(t *testing.T) {
e := echo.New() e := echo.New()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
e.Logger.SetOutput(buf) e.Logger = &testLogger{output: buf}
// NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first
// FIXME: I have no idea how to fix this without adding mutexes. // FIXME: I have no idea how to fix this without adding mutexes.
@ -424,7 +423,7 @@ func startServer(e *echo.Echo) (*http.Server, string, error) {
s := http.Server{ s := http.Server{
Handler: e, Handler: e,
ErrorLog: log.New(e.Logger.Output(), "echo:", 0), ErrorLog: log.New(e.Logger, "echo:", 0),
} }
errCh := make(chan error) errCh := make(chan error)

View File

@ -1,9 +1,27 @@
package middleware package middleware
import ( import (
"crypto/rand"
"fmt"
"strings" "strings"
) )
const (
_ = int64(1 << (10 * iota)) // ignore first value by assigning to blank identifier
// KB is 1 KiloByte = 1024 bytes
KB
// MB is 1 Megabyte = 1_048_576 bytes
MB
// GB is 1 Gigabyte = 1_073_741_824 bytes
GB
// TB is 1 Terabyte = 1_099_511_627_776 bytes
TB
// PB is 1 Petabyte = 1_125_899_906_842_624 bytes
PB
// EB is 1 Exabyte = 1_152_921_504_606_847_000 bytes
EB
)
func matchScheme(domain, pattern string) bool { func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":") didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":") pidx := strings.Index(pattern, ":")
@ -52,3 +70,24 @@ func matchSubdomain(domain, pattern string) bool {
} }
return false return false
} }
func createRandomStringGenerator(length uint8) func() string {
return func() string {
return randomString(length)
}
}
func randomString(length uint8) string {
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
// we are out of random. let the request fail
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
}
for i, b := range bytes {
bytes[i] = charset[b%byte(len(charset))]
}
return string(bytes)
}

View File

@ -1,11 +1,23 @@
package middleware package middleware
import ( import (
"testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"io"
"testing"
) )
type testLogger struct {
output io.Writer
}
func (l *testLogger) Write(p []byte) (n int, err error) {
return l.output.Write(p)
}
func (l *testLogger) Error(err error) {
_, _ = l.output.Write([]byte(err.Error()))
}
func Test_matchScheme(t *testing.T) { func Test_matchScheme(t *testing.T) {
tests := []struct { tests := []struct {
domain, pattern string domain, pattern string
@ -93,3 +105,27 @@ func Test_matchSubdomain(t *testing.T) {
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern)) assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
} }
} }
func TestRandomString(t *testing.T) {
var testCases = []struct {
name string
whenLength uint8
expect string
}{
{
name: "ok, 16",
whenLength: 16,
},
{
name: "ok, 32",
whenLength: 32,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
uid := randomString(tc.whenLength)
assert.Len(t, uid, int(tc.whenLength))
})
}
}

View File

@ -2,15 +2,15 @@ package echo
import ( import (
"bufio" "bufio"
"errors"
"net" "net"
"net/http" "net/http"
) )
type (
// Response wraps an http.ResponseWriter and implements its interface to be used // Response wraps an http.ResponseWriter and implements its interface to be used
// by an HTTP handler to construct an HTTP response. // by an HTTP handler to construct an HTTP response.
// See: https://golang.org/pkg/net/http/#ResponseWriter // See: https://golang.org/pkg/net/http/#ResponseWriter
Response struct { type Response struct {
echo *Echo echo *Echo
beforeFuncs []func() beforeFuncs []func()
afterFuncs []func() afterFuncs []func()
@ -19,7 +19,6 @@ type (
Size int64 Size int64
Committed bool Committed bool
} }
)
// NewResponse creates a new instance of Response. // NewResponse creates a new instance of Response.
func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) { func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) {
@ -47,13 +46,15 @@ func (r *Response) After(fn func()) {
r.afterFuncs = append(r.afterFuncs, fn) r.afterFuncs = append(r.afterFuncs, fn)
} }
var errHeaderAlreadyCommitted = errors.New("response already committed")
// WriteHeader sends an HTTP response header with status code. If WriteHeader is // WriteHeader sends an HTTP response header with status code. If WriteHeader is
// not called explicitly, the first call to Write will trigger an implicit // not called explicitly, the first call to Write will trigger an implicit
// WriteHeader(http.StatusOK). Thus explicit calls to WriteHeader are mainly // WriteHeader(http.StatusOK). Thus explicit calls to WriteHeader are mainly
// used to send error codes. // used to send error codes.
func (r *Response) WriteHeader(code int) { func (r *Response) WriteHeader(code int) {
if r.Committed { if r.Committed {
r.echo.Logger.Warn("response already committed") r.echo.Logger.Error(errHeaderAlreadyCommitted)
return return
} }
r.Status = code r.Status = code

182
route.go Normal file
View File

@ -0,0 +1,182 @@
package echo
import (
"bytes"
"errors"
"fmt"
"reflect"
"runtime"
)
// Route contains information to adding/registering new route with the router.
// Method+Path pair uniquely identifies the Route. It is mandatory to provide Method+Path+Handler fields.
type Route struct {
Method string
Path string
Handler HandlerFunc
Middlewares []MiddlewareFunc
Name string
}
// ToRouteInfo converts Route to RouteInfo
func (r Route) ToRouteInfo(params []string) RouteInfo {
name := r.Name
if name == "" {
name = r.Method + ":" + r.Path
}
return routeInfo{
method: r.Method,
path: r.Path,
params: append([]string(nil), params...),
name: name,
}
}
// ToRoute returns Route which Router uses to register the method handler for path.
func (r Route) ToRoute() Route {
return r
}
// ForGroup recreates Route with added group prefix and group middlewares it is grouped to.
func (r Route) ForGroup(pathPrefix string, middlewares []MiddlewareFunc) Routable {
r.Path = pathPrefix + r.Path
if len(middlewares) > 0 {
m := make([]MiddlewareFunc, 0, len(middlewares)+len(r.Middlewares))
m = append(m, middlewares...)
m = append(m, r.Middlewares...)
r.Middlewares = m
}
return r
}
type routeInfo struct {
method string
path string
params []string
name string
}
func (r routeInfo) Method() string {
return r.method
}
func (r routeInfo) Path() string {
return r.path
}
func (r routeInfo) Params() []string {
return append([]string(nil), r.params...)
}
func (r routeInfo) Name() string {
return r.name
}
// Reverse reverses route to URL string by replacing path parameters with given params values.
func (r routeInfo) Reverse(params ...interface{}) string {
uri := new(bytes.Buffer)
ln := len(params)
n := 0
for i, l := 0, len(r.path); i < l; i++ {
if (r.path[i] == paramLabel || r.path[i] == anyLabel) && n < ln {
for ; i < l && r.path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))
n++
}
if i < l {
uri.WriteByte(r.path[i])
}
}
return uri.String()
}
// HandlerName returns string name for given function.
func HandlerName(h HandlerFunc) string {
t := reflect.ValueOf(h).Type()
if t.Kind() == reflect.Func {
return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
}
return t.String()
}
// Reverse reverses route to URL string by replacing path parameters with given params values.
func (r Routes) Reverse(name string, params ...interface{}) (string, error) {
for _, rr := range r {
if rr.Name() == name {
return rr.Reverse(params...), nil
}
}
return "", errors.New("route not found")
}
// FindByMethodPath searched for matching route info by method and path
func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error) {
if r == nil {
return nil, errors.New("route not found by method and path")
}
for _, rr := range r {
if rr.Method() == method && rr.Path() == path {
return rr, nil
}
}
return nil, errors.New("route not found by method and path")
}
// FilterByMethod searched for matching route info by method
func (r Routes) FilterByMethod(method string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by method")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Method() == method {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by method")
}
return result, nil
}
// FilterByPath searched for matching route info by path
func (r Routes) FilterByPath(path string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by path")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Path() == path {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by path")
}
return result, nil
}
// FilterByName searched for matching route info by name
func (r Routes) FilterByName(name string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by name")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Name() == name {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by name")
}
return result, nil
}

423
route_test.go Normal file
View File

@ -0,0 +1,423 @@
package echo
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)
var myNamedHandler = func(c Context) error {
return nil
}
type NameStruct struct {
}
func (n *NameStruct) getUsers(c Context) error {
return nil
}
func TestHandlerName(t *testing.T) {
myNameFuncVar := func(c Context) error {
return nil
}
tmp := NameStruct{}
var testCases = []struct {
name string
whenHandlerFunc HandlerFunc
expect string
}{
{
name: "ok, func as anonymous func",
whenHandlerFunc: func(c Context) error {
return nil
},
expect: "github.com/labstack/echo/v4.TestHandlerName.func2",
},
{
name: "ok, func as named package variable",
whenHandlerFunc: myNamedHandler,
expect: "github.com/labstack/echo/v4.glob..func3",
},
{
name: "ok, func as named function variable",
whenHandlerFunc: myNameFuncVar,
expect: "github.com/labstack/echo/v4.TestHandlerName.func1",
},
{
name: "ok, func as struct method",
whenHandlerFunc: tmp.getUsers,
expect: "github.com/labstack/echo/v4.(*NameStruct).getUsers-fm",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
name := HandlerName(tc.whenHandlerFunc)
assert.Equal(t, tc.expect, name)
})
}
}
func TestHandlerName_differentFuncSameName(t *testing.T) {
handlerCreator := func(name string) HandlerFunc {
return func(c Context) error {
return c.String(http.StatusTeapot, name)
}
}
h1 := handlerCreator("name1")
assert.Equal(t, "github.com/labstack/echo/v4.TestHandlerName_differentFuncSameName.func2", HandlerName(h1))
h2 := handlerCreator("name2")
assert.Equal(t, "github.com/labstack/echo/v4.TestHandlerName_differentFuncSameName.func3", HandlerName(h2))
}
func TestRoute_ToRouteInfo(t *testing.T) {
var testCases = []struct {
name string
given Route
whenParams []string
expect RouteInfo
}{
{
name: "ok, no params, with name",
given: Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
},
expect: routeInfo{
method: http.MethodGet,
path: "/test",
params: nil,
name: "test route",
},
},
{
name: "ok, params",
given: Route{
Method: http.MethodGet,
Path: "users/:id/:file", // no slash prefix
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "",
},
whenParams: []string{"id", "file"},
expect: routeInfo{
method: http.MethodGet,
path: "users/:id/:file",
params: []string{"id", "file"},
name: "GET:users/:id/:file",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ri := tc.given.ToRouteInfo(tc.whenParams)
assert.Equal(t, tc.expect, ri)
})
}
}
func TestRoute_ToRoute(t *testing.T) {
route := Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
}
r := route.ToRoute()
assert.Equal(t, r.Method, http.MethodGet)
assert.Equal(t, r.Path, "/test")
assert.NotNil(t, r.Handler)
assert.Nil(t, r.Middlewares)
assert.Equal(t, r.Name, "test route")
}
func TestRoute_ForGroup(t *testing.T) {
route := Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
}
mw := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
return next(c)
}
}
gr := route.ForGroup("/users", []MiddlewareFunc{mw})
r := gr.ToRoute()
assert.Equal(t, r.Method, http.MethodGet)
assert.Equal(t, r.Path, "/users/test")
assert.NotNil(t, r.Handler)
assert.Len(t, r.Middlewares, 1)
assert.Equal(t, r.Name, "test route")
}
func exampleRoutes() Routes {
return Routes{
routeInfo{
method: http.MethodGet,
path: "/users",
params: nil,
name: "GET:/users",
},
routeInfo{
method: http.MethodGet,
path: "/users/:id",
params: []string{"id"},
name: "GET:/users/:id",
},
routeInfo{
method: http.MethodPost,
path: "/users/:id",
params: []string{"id"},
name: "POST:/users/:id",
},
routeInfo{
method: http.MethodDelete,
path: "/groups",
params: nil,
name: "non_unique_name",
},
routeInfo{
method: http.MethodPost,
path: "/groups",
params: nil,
name: "non_unique_name",
},
}
}
func TestRoutes_FindByMethodPath(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenMethod string
whenPath string
expectName string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenMethod: http.MethodGet,
whenPath: "/users/:id",
expectName: "GET:/users/:id",
},
{
name: "nok, not found",
given: exampleRoutes(),
whenMethod: http.MethodPut,
whenPath: "/users/:id",
expectName: "",
expectError: "route not found by method and path",
},
{
name: "nok, not found from nil",
given: nil,
whenMethod: http.MethodGet,
whenPath: "/users/:id",
expectName: "",
expectError: "route not found by method and path",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ri, err := tc.given.FindByMethodPath(tc.whenMethod, tc.whenPath)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
assert.Nil(t, ri)
} else {
assert.NoError(t, err)
}
if tc.expectName != "" {
assert.Equal(t, tc.expectName, ri.Name())
}
})
}
}
func TestRoutes_FilterByMethod(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenMethod string
expectNames []string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenMethod: http.MethodGet,
expectNames: []string{"GET:/users", "GET:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenMethod: http.MethodPut,
expectNames: nil,
expectError: "route not found by method",
},
{
name: "nok, not found from nil",
given: nil,
whenMethod: http.MethodGet,
expectNames: nil,
expectError: "route not found by method",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByMethod(tc.whenMethod)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectNames) > 0 {
assert.Len(t, ris, len(tc.expectNames))
for _, ri := range ris {
assert.Contains(t, tc.expectNames, ri.Name())
}
} else {
assert.Nil(t, ris)
}
})
}
}
func TestRoutes_FilterByPath(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenPath string
expectNames []string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenPath: "/users/:id",
expectNames: []string{"GET:/users/:id", "POST:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenPath: "/",
expectNames: nil,
expectError: "route not found by path",
},
{
name: "nok, not found from nil",
given: nil,
whenPath: "/users/:id",
expectNames: nil,
expectError: "route not found by path",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByPath(tc.whenPath)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectNames) > 0 {
assert.Len(t, ris, len(tc.expectNames))
for _, ri := range ris {
assert.Contains(t, tc.expectNames, ri.Name())
}
} else {
assert.Nil(t, ris)
}
})
}
}
func TestRoutes_FilterByName(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenName string
expectMethodPath []string
expectError string
}{
{
name: "ok, found multiple",
given: exampleRoutes(),
whenName: "non_unique_name",
expectMethodPath: []string{"DELETE:/groups", "POST:/groups"},
},
{
name: "ok, found single",
given: exampleRoutes(),
whenName: "GET:/users/:id",
expectMethodPath: []string{"GET:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenName: "/",
expectMethodPath: nil,
expectError: "route not found by name",
},
{
name: "nok, not found from nil",
given: nil,
whenName: "/users/:id",
expectMethodPath: nil,
expectError: "route not found by name",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByName(tc.whenName)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectMethodPath) > 0 {
assert.Len(t, ris, len(tc.expectMethodPath))
for _, ri := range ris {
assert.Contains(t, tc.expectMethodPath, fmt.Sprintf("%v:%v", ri.Method(), ri.Path()))
}
} else {
assert.Nil(t, ris)
}
})
}
}

743
router.go
View File

@ -1,49 +1,133 @@
package echo package echo
import ( import (
"errors"
"net/http" "net/http"
"net/url"
) )
type ( // Router is interface for routing requests to registered routes.
// Router is the registry of all registered routes for an `Echo` instance for type Router interface {
// request matching and URL path parameter parsing. // Add registers Routable with the Router and returns registered RouteInfo
Router struct { Add(routable Routable) (RouteInfo, error)
tree *node // Remove removes route from the Router
routes map[string]*Route Remove(method string, path string) error
echo *Echo // Routes returns information about all registered routes
Routes() Routes
// Match searches Router for matching route and applies it to result fields.
Match(req *http.Request, params *PathParams) RouteMatch
} }
node struct {
// Routable is interface for registering Route with Router. During route registration process the Router will
// convert Routable to RouteInfo with ToRouteInfo method. By creating custom implementation of Routable additional
// information about registered route can be stored in Routes (i.e. privileges used with route etc.)
type Routable interface {
// ToRouteInfo converts Routable to RouteInfo
//
// This method is meant to be used by Router after it parses url for path parameters, to store information about
// route just added.
ToRouteInfo(params []string) RouteInfo
// ToRoute converts Routable to Route which Router uses to register the method handler for path.
//
// This method is meant to be used by Router to get fields (including handler and middleware functions) needed to
// add Route to Router.
ToRoute() Route
// ForGroup recreates routable with added group prefix and group middlewares it is grouped to.
//
// Is necessary for Echo.Group to be able to add/register Routable with Router and having group prefix and group
// middlewares included in actually registered Route.
ForGroup(pathPrefix string, middlewares []MiddlewareFunc) Routable
}
// Routes is collection of RouteInfo instances with various helper methods.
type Routes []RouteInfo
// RouteInfo describes registered route base fields.
// Method+Path pair uniquely identifies the Route. Name can have duplicates.
type RouteInfo interface {
Method() string
Path() string
Name() string
Params() []string
Reverse(params ...interface{}) string
// NOTE: handler and middlewares are not exposed because handler could be already wrapping middlewares and therefore
// it is not always 100% known if handler function already wraps middlewares or not. In Echo handler could be one
// function or several functions wrapping each other.
}
// RouteMatchType describes possible states that request could be in perspective of routing
type RouteMatchType uint8
const (
// RouteMatchUnknown is state before routing is done. Default state for fresh context.
RouteMatchUnknown RouteMatchType = iota
// RouteMatchNotFound is state when router did not find matching route for current request
RouteMatchNotFound
// RouteMatchMethodNotAllowed is state when router did not find route with matching path + method for current request.
// Although router had matching route with that path but different method.
RouteMatchMethodNotAllowed
// RouteMatchFound is state when router found exact match for path + method combination
RouteMatchFound
)
// RouteMatch is result object for Router.Match. Its main purpose is to avoid allocating memory for PathParams inside router.
type RouteMatch struct {
// Type contains result as enumeration of Router.Match and helps to understand did Router actually matched Route or
// what kind of error case (404/405) we have at the end of the handler chain.
Type RouteMatchType
// RoutePath contains original path with what matched route was registered with (including placeholders etc).
RoutePath string
// Handler is function(chain) that was matched by router. In case of no match could result to ErrNotFound or ErrMethodNotAllowed.
Handler HandlerFunc
// RouteInfo is information about route we just matched
RouteInfo RouteInfo
}
// PathParams is collections of PathParam instances with various helper methods
type PathParams []PathParam
// PathParam is tuple pf path parameter name and its value in request path
type PathParam struct {
Name string
Value string
}
// DefaultRouter is the registry of all registered routes for an `Echo` instance for
// request matching and URL path parameter parsing.
// Note: DefaultRouter is not coroutine-safe. Do not Add/Remove routes after HTTP server has been started with Echo.
type DefaultRouter struct {
tree *node
routes Routes
echo *Echo
allowOverwritingRoute bool
unescapePathParamValues bool
useEscapedPathForRouting bool
}
type children []*node
type node struct {
kind kind kind kind
label byte label byte
prefix string prefix string
parent *node parent *node
staticChildren children staticChildren children
ppath string originalPath string
pnames []string methods *routeMethods
methodHandler *methodHandler
paramChild *node paramChild *node
anyChild *node anyChild *node
paramsCount int
// isLeaf indicates that node does not have child routes // isLeaf indicates that node does not have child routes
isLeaf bool isLeaf bool
// isHandler indicates that node has at least one handler registered to it // isHandler indicates that node has at least one handler registered to it
isHandler bool isHandler bool
} }
kind uint8
children []*node type kind uint8
methodHandler struct {
connect HandlerFunc
delete HandlerFunc
get HandlerFunc
head HandlerFunc
options HandlerFunc
patch HandlerFunc
post HandlerFunc
propfind HandlerFunc
put HandlerFunc
trace HandlerFunc
report HandlerFunc
}
)
const ( const (
staticKind kind = iota staticKind kind = iota
@ -54,90 +138,362 @@ const (
anyLabel = byte('*') anyLabel = byte('*')
) )
func (m *methodHandler) isHandler() bool { type routeMethod struct {
return m.connect != nil || *routeInfo
m.delete != nil || handler HandlerFunc
m.get != nil || orgRouteInfo RouteInfo
m.head != nil || }
m.options != nil ||
m.patch != nil || type routeMethods struct {
connect *routeMethod
delete *routeMethod
get *routeMethod
head *routeMethod
options *routeMethod
patch *routeMethod
post *routeMethod
propfind *routeMethod
put *routeMethod
trace *routeMethod
report *routeMethod
anyOther map[string]*routeMethod
}
func (m *routeMethods) set(method string, r *routeMethod) {
switch method {
case http.MethodConnect:
m.connect = r
case http.MethodDelete:
m.delete = r
case http.MethodGet:
m.get = r
case http.MethodHead:
m.head = r
case http.MethodOptions:
m.options = r
case http.MethodPatch:
m.patch = r
case http.MethodPost:
m.post = r
case PROPFIND:
m.propfind = r
case http.MethodPut:
m.put = r
case http.MethodTrace:
m.trace = r
case REPORT:
m.report = r
default:
if m.anyOther == nil {
m.anyOther = make(map[string]*routeMethod)
}
if r.handler == nil {
delete(m.anyOther, method)
} else {
m.anyOther[method] = r
}
}
}
func (m *routeMethods) find(method string) *routeMethod {
switch method {
case http.MethodConnect:
return m.connect
case http.MethodDelete:
return m.delete
case http.MethodGet:
return m.get
case http.MethodHead:
return m.head
case http.MethodOptions:
return m.options
case http.MethodPatch:
return m.patch
case http.MethodPost:
return m.post
case PROPFIND:
return m.propfind
case http.MethodPut:
return m.put
case http.MethodTrace:
return m.trace
case REPORT:
return m.report
default:
return m.anyOther[method]
}
}
func (m *routeMethods) isHandler() bool {
return m.get != nil ||
m.post != nil || m.post != nil ||
m.propfind != nil || m.options != nil ||
m.put != nil || m.put != nil ||
m.delete != nil ||
m.connect != nil ||
m.head != nil ||
m.patch != nil ||
m.propfind != nil ||
m.trace != nil || m.trace != nil ||
m.report != nil m.report != nil ||
len(m.anyOther) != 0
}
// RouterConfig is configuration options for (default) router
type RouterConfig struct {
// AllowOverwritingRoute instructs Router NOT to return error when new route is registered with the same method+path
// and replaces matching route with the new one.
AllowOverwritingRoute bool
// UnescapePathParamValues instructs Router to unescape path parameter value when request if matched to the routes
UnescapePathParamValues bool
// UseEscapedPathForMatching instructs Router to use escaped request URL path (req.URL.Path) for matching the request.
UseEscapedPathForMatching bool
} }
// NewRouter returns a new Router instance. // NewRouter returns a new Router instance.
func NewRouter(e *Echo) *Router { func NewRouter(e *Echo, config RouterConfig) *DefaultRouter {
return &Router{ r := &DefaultRouter{
tree: &node{ tree: &node{
methodHandler: new(methodHandler), methods: new(routeMethods),
isLeaf: true,
isHandler: false,
}, },
routes: map[string]*Route{}, routes: make(Routes, 0),
echo: e, echo: e,
allowOverwritingRoute: config.AllowOverwritingRoute,
unescapePathParamValues: config.UnescapePathParamValues,
useEscapedPathForRouting: config.UseEscapedPathForMatching,
} }
return r
}
// Routes returns all registered routes
func (r *DefaultRouter) Routes() Routes {
return r.routes
}
// Remove unregisters registered route
func (r *DefaultRouter) Remove(method string, path string) error {
currentNode := r.tree
if currentNode == nil || (currentNode.isLeaf && !currentNode.isHandler) {
return errors.New("router has no routes to remove")
} }
// Add registers a new route for method and path with matching handler.
func (r *Router) Add(method, path string, h HandlerFunc) {
// Validate path
if path == "" { if path == "" {
path = "/" path = "/"
} }
if path[0] != '/' { if path[0] != '/' {
path = "/" + path path = "/" + path
} }
pnames := []string{} // Param names
ppath := path // Pristine path
if h == nil && r.echo.Logger != nil { var nodeToRemove *node
// FIXME: in future we should return error prefixLen := 0
r.echo.Logger.Errorf("Adding route without handler function: %v:%v", method, path) for {
if currentNode.originalPath == path && currentNode.isHandler {
nodeToRemove = currentNode
break
}
if currentNode.kind == staticKind {
prefixLen = prefixLen + len(currentNode.prefix)
} else {
prefixLen = len(currentNode.originalPath)
} }
if prefixLen >= len(path) {
break
}
next := path[prefixLen]
switch next {
case paramLabel:
currentNode = currentNode.paramChild
case anyLabel:
currentNode = currentNode.anyChild
default:
currentNode = currentNode.findStaticChild(next)
}
if currentNode == nil {
break
}
}
if nodeToRemove == nil {
return errors.New("could not find route to remove by given path")
}
if !nodeToRemove.isHandler {
return errors.New("could not find route to remove by given path")
}
if mh := nodeToRemove.methods.find(method); mh == nil {
return errors.New("could not find route to remove by given path and method")
}
nodeToRemove.setHandler(method, nil)
var rIndex int
for i, rr := range r.routes {
if rr.Method() == method && rr.Path() == path {
rIndex = i
break
}
}
r.routes = append(r.routes[:rIndex], r.routes[rIndex+1:]...)
if !nodeToRemove.isHandler && nodeToRemove.isLeaf {
// TODO: if !nodeToRemove.isLeaf and has at least 2 children merge paths for remaining nodes?
current := nodeToRemove
for {
parent := current.parent
if parent == nil {
break
}
switch current.kind {
case staticKind:
var index int
for i, c := range parent.staticChildren {
if c == current {
index = i
break
}
}
parent.staticChildren = append(parent.staticChildren[:index], parent.staticChildren[index+1:]...)
case paramKind:
parent.paramChild = nil
case anyKind:
parent.anyChild = nil
}
parent.isLeaf = parent.anyChild == nil && parent.paramChild == nil && len(parent.staticChildren) == 0
if !parent.isLeaf || parent.isHandler {
break
}
current = parent
}
}
return nil
}
// AddRouteError is error returned by Router.Add containing information what actual route adding failed. Useful for
// mass adding (i.e. Any() routes)
type AddRouteError struct {
Method string
Path string
Err error
}
func (e *AddRouteError) Error() string { return e.Method + " " + e.Path + ": " + e.Err.Error() }
func (e *AddRouteError) Unwrap() error { return e.Err }
func newAddRouteError(route Route, err error) *AddRouteError {
return &AddRouteError{
Method: route.Method,
Path: route.Path,
Err: err,
}
}
// Add registers a new route for method and path with matching handler.
func (r *DefaultRouter) Add(routable Routable) (RouteInfo, error) {
route := routable.ToRoute()
if route.Handler == nil {
return nil, newAddRouteError(route, errors.New("adding route without handler function"))
}
method := route.Method
path := route.Path
h := applyMiddleware(route.Handler, route.Middlewares...)
if !r.allowOverwritingRoute {
for _, rr := range r.routes {
if route.Method == rr.Method() && route.Path == rr.Path() {
return nil, newAddRouteError(route, errors.New("adding duplicate route (same method+path) is not allowed"))
}
}
}
if path == "" {
path = "/"
}
if path[0] != '/' {
path = "/" + path
}
paramNames := make([]string, 0)
originalPath := path
wasAdded := false
var ri RouteInfo
for i, lcpIndex := 0, len(path); i < lcpIndex; i++ { for i, lcpIndex := 0, len(path); i < lcpIndex; i++ {
if path[i] == ':' { if path[i] == paramLabel {
if i > 0 && path[i-1] == '\\' { if i > 0 && path[i-1] == '\\' {
continue continue
} }
j := i + 1 j := i + 1
r.insert(method, path[:i], nil, staticKind, "", nil) r.insert(staticKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
for ; i < lcpIndex && path[i] != '/'; i++ { for ; i < lcpIndex && path[i] != '/'; i++ {
} }
pnames = append(pnames, path[j:i]) paramNames = append(paramNames, path[j:i])
path = path[:j] + path[i:] path = path[:j] + path[i:]
i, lcpIndex = j, len(path) i, lcpIndex = j, len(path)
if i == lcpIndex { if i == lcpIndex {
// path node is last fragment of route path. ie. `/users/:id` // path node is last fragment of route path. ie. `/users/:id`
r.insert(method, path[:i], h, paramKind, ppath, pnames) ri = routable.ToRouteInfo(paramNames)
rm := routeMethod{
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
handler: h,
orgRouteInfo: ri,
}
r.insert(paramKind, path[:i], method, rm)
wasAdded = true
break
} else { } else {
r.insert(method, path[:i], nil, paramKind, "", nil) r.insert(paramKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
} }
} else if path[i] == '*' { } else if path[i] == anyLabel {
r.insert(method, path[:i], nil, staticKind, "", nil) r.insert(staticKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
pnames = append(pnames, "*") paramNames = append(paramNames, "*")
r.insert(method, path[:i+1], h, anyKind, ppath, pnames) ri = routable.ToRouteInfo(paramNames)
rm := routeMethod{
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
handler: h,
orgRouteInfo: ri,
}
r.insert(anyKind, path[:i+1], method, rm)
wasAdded = true
break
} }
} }
r.insert(method, path, h, staticKind, ppath, pnames) if !wasAdded {
ri = routable.ToRouteInfo(paramNames)
rm := routeMethod{
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
handler: h,
orgRouteInfo: ri,
}
r.insert(staticKind, path, method, rm)
} }
func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string) { r.storeRouteInfo(ri)
// Adjust max param
paramLen := len(pnames) return ri, nil
if *r.echo.maxParam < paramLen {
*r.echo.maxParam = paramLen
} }
func (r *DefaultRouter) storeRouteInfo(ri RouteInfo) {
for i, rr := range r.routes {
if ri.Method() == rr.Method() && ri.Path() == rr.Path() {
r.routes[i] = ri
return
}
}
r.routes = append(r.routes, ri)
}
func (r *DefaultRouter) insert(t kind, path string, method string, ri routeMethod) {
currentNode := r.tree // Current node as root currentNode := r.tree // Current node as root
if currentNode == nil {
panic("echo: invalid method")
}
search := path search := path
for { for {
@ -157,11 +513,11 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
// At root node // At root node
currentNode.label = search[0] currentNode.label = search[0]
currentNode.prefix = search currentNode.prefix = search
if h != nil { if ri.handler != nil {
currentNode.kind = t currentNode.kind = t
currentNode.addHandler(method, h) currentNode.setHandler(method, &ri)
currentNode.ppath = ppath currentNode.paramsCount = len(ri.params)
currentNode.pnames = pnames currentNode.originalPath = ri.path
} }
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
} else if lcpLen < prefixLen { } else if lcpLen < prefixLen {
@ -171,9 +527,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.prefix[lcpLen:], currentNode.prefix[lcpLen:],
currentNode, currentNode,
currentNode.staticChildren, currentNode.staticChildren,
currentNode.methodHandler, currentNode.methods,
currentNode.ppath, currentNode.paramsCount,
currentNode.pnames, currentNode.originalPath,
currentNode.paramChild, currentNode.paramChild,
currentNode.anyChild, currentNode.anyChild,
) )
@ -193,9 +549,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.label = currentNode.prefix[0] currentNode.label = currentNode.prefix[0]
currentNode.prefix = currentNode.prefix[:lcpLen] currentNode.prefix = currentNode.prefix[:lcpLen]
currentNode.staticChildren = nil currentNode.staticChildren = nil
currentNode.methodHandler = new(methodHandler) currentNode.methods = new(routeMethods)
currentNode.ppath = "" currentNode.originalPath = ""
currentNode.pnames = nil currentNode.paramsCount = 0
currentNode.paramChild = nil currentNode.paramChild = nil
currentNode.anyChild = nil currentNode.anyChild = nil
currentNode.isLeaf = false currentNode.isLeaf = false
@ -207,13 +563,18 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
if lcpLen == searchLen { if lcpLen == searchLen {
// At parent node // At parent node
currentNode.kind = t currentNode.kind = t
currentNode.addHandler(method, h) if ri.handler != nil {
currentNode.ppath = ppath currentNode.setHandler(method, &ri)
currentNode.pnames = pnames currentNode.paramsCount = len(ri.params)
currentNode.originalPath = ri.path
}
} else { } else {
// Create child node // Create child node
n = newNode(t, search[lcpLen:], currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) n = newNode(t, search[lcpLen:], currentNode, nil, new(routeMethods), 0, ri.path, nil, nil)
n.addHandler(method, h) if ri.handler != nil {
n.setHandler(method, &ri)
n.paramsCount = len(ri.params)
}
// Only Static children could reach here // Only Static children could reach here
currentNode.addStaticChild(n) currentNode.addStaticChild(n)
} }
@ -227,8 +588,11 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
continue continue
} }
// Create child node // Create child node
n := newNode(t, search, currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) n := newNode(t, search, currentNode, nil, new(routeMethods), 0, ri.path, nil, nil)
n.addHandler(method, h) if ri.handler != nil {
n.setHandler(method, &ri)
n.paramsCount = len(ri.params)
}
switch t { switch t {
case staticKind: case staticKind:
currentNode.addStaticChild(n) currentNode.addStaticChild(n)
@ -240,28 +604,26 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
} else { } else {
// Node already exists // Node already exists
if h != nil { if ri.handler != nil {
currentNode.addHandler(method, h) currentNode.setHandler(method, &ri)
currentNode.ppath = ppath currentNode.paramsCount = len(ri.params)
if len(currentNode.pnames) == 0 { // Issue #729 currentNode.originalPath = ri.path
currentNode.pnames = pnames
}
} }
} }
return return
} }
} }
func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath string, pnames []string, paramChildren, anyChildren *node) *node { func newNode(t kind, pre string, p *node, sc children, mh *routeMethods, paramsCount int, ppath string, paramChildren, anyChildren *node) *node {
return &node{ return &node{
kind: t, kind: t,
label: pre[0], label: pre[0],
prefix: pre, prefix: pre,
parent: p, parent: p,
staticChildren: sc, staticChildren: sc,
ppath: ppath, originalPath: ppath,
pnames: pnames, paramsCount: paramsCount,
methodHandler: mh, methods: mh,
paramChild: paramChildren, paramChild: paramChildren,
anyChild: anyChildren, anyChild: anyChildren,
isLeaf: sc == nil && paramChildren == nil && anyChildren == nil, isLeaf: sc == nil && paramChildren == nil && anyChildren == nil,
@ -297,99 +659,77 @@ func (n *node) findChildWithLabel(l byte) *node {
return nil return nil
} }
func (n *node) addHandler(method string, h HandlerFunc) { func (n *node) setHandler(method string, r *routeMethod) {
switch method { n.methods.set(method, r)
case http.MethodConnect: if r != nil && r.handler != nil {
n.methodHandler.connect = h
case http.MethodDelete:
n.methodHandler.delete = h
case http.MethodGet:
n.methodHandler.get = h
case http.MethodHead:
n.methodHandler.head = h
case http.MethodOptions:
n.methodHandler.options = h
case http.MethodPatch:
n.methodHandler.patch = h
case http.MethodPost:
n.methodHandler.post = h
case PROPFIND:
n.methodHandler.propfind = h
case http.MethodPut:
n.methodHandler.put = h
case http.MethodTrace:
n.methodHandler.trace = h
case REPORT:
n.methodHandler.report = h
}
if h != nil {
n.isHandler = true n.isHandler = true
} else { } else {
n.isHandler = n.methodHandler.isHandler() n.isHandler = n.methods.isHandler()
} }
} }
func (n *node) findHandler(method string) HandlerFunc { const (
switch method { // NotFoundRouteName is name of RouteInfo returned when router did not find matching route (404: not found).
case http.MethodConnect: NotFoundRouteName = "EchoRouteNotFound"
return n.methodHandler.connect // MethodNotAllowedRouteName is name of RouteInfo returned when router did not find matching method for route (404: method not allowed).
case http.MethodDelete: MethodNotAllowedRouteName = "EchoRouteMethodNotAllowed"
return n.methodHandler.delete )
case http.MethodGet:
return n.methodHandler.get // Note: notFoundRouteInfo exists to avoid allocations when setting 404 RouteInfo to RouteMatch
case http.MethodHead: var notFoundRouteInfo = &routeInfo{
return n.methodHandler.head method: "",
case http.MethodOptions: path: "",
return n.methodHandler.options params: nil,
case http.MethodPatch: name: NotFoundRouteName,
return n.methodHandler.patch
case http.MethodPost:
return n.methodHandler.post
case PROPFIND:
return n.methodHandler.propfind
case http.MethodPut:
return n.methodHandler.put
case http.MethodTrace:
return n.methodHandler.trace
case REPORT:
return n.methodHandler.report
default:
return nil
}
} }
func (n *node) checkMethodNotAllowed() HandlerFunc { // Note: methodNotAllowedRouteInfo exists to avoid allocations when setting 405 RouteInfo to RouteMatch
for _, m := range methods { var methodNotAllowedRouteInfo = &routeInfo{
if h := n.findHandler(m); h != nil { method: "",
return MethodNotAllowedHandler path: "",
} params: nil,
} name: MethodNotAllowedRouteName,
return NotFoundHandler
} }
// Find lookup a handler registered for method and path. It also parses URL for path // notFoundHandler is handler for 404 cases
// parameters and load them into context. // Handle returned ErrNotFound errors in Echo.HTTPErrorHandler
var notFoundHandler = func(c Context) error {
return ErrNotFound
}
// methodNotAllowedHandler is handler for case when route for path+method match was not found (http code 405)
// Handle returned ErrMethodNotAllowed errors in Echo.HTTPErrorHandler
var methodNotAllowedHandler = func(c Context) error {
return ErrMethodNotAllowed
}
// Match looks up a handler registered for method and path. It also parses URL for path parameters and loads them
// into context.
// //
// For performance: // For performance:
// //
// - Get context from `Echo#AcquireContext()` // - Get context from `Echo#AcquireContext()`
// - Reset it `Context#Reset()` // - Reset it `Context#Reset()`
// - Return it `Echo#ReleaseContext()`. // - Return it `Echo#ReleaseContext()`.
func (r *Router) Find(method, path string, c Context) { func (r *DefaultRouter) Match(req *http.Request, pathParams *PathParams) RouteMatch {
ctx := c.(*context) *pathParams = (*pathParams)[0:cap(*pathParams)]
ctx.path = path
currentNode := r.tree // Current node as root
path := req.URL.Path
if !r.useEscapedPathForRouting && req.URL.RawPath != "" {
// Difference between URL.RawPath and URL.Path is:
// * URL.Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/.
// * URL.RawPath is an optional field which only gets set if the default encoding is different from Path.
path = req.URL.RawPath
}
var ( var (
currentNode = r.tree // root as current node
previousBestMatchNode *node previousBestMatchNode *node
matchedHandler HandlerFunc matchedRouteMethod *routeMethod
// search stores the remaining path to check for match. By each iteration we move from start of path to end of the path // search stores the remaining path to check for match. By each iteration we move from start of path to end of the path
// and search value gets shorter and shorter. // and search value gets shorter and shorter.
search = path search = path
searchIndex = 0 searchIndex = 0
paramIndex int // Param counter paramIndex int // Param counter
paramValues = ctx.pvalues // Use the internal slice so the interface can keep the illusion of a dynamic slice
) )
// Backtracking is needed when a dead end (leaf node) is reached in the router tree. // Backtracking is needed when a dead end (leaf node) is reached in the router tree.
@ -421,8 +761,8 @@ func (r *Router) Find(method, path string, c Context) {
paramIndex-- paramIndex--
// for param/any node.prefix value is always `:` so we can not deduce searchIndex from that and must use pValue // for param/any node.prefix value is always `:` so we can not deduce searchIndex from that and must use pValue
// for that index as it would also contain part of path we cut off before moving into node we are backtracking from // for that index as it would also contain part of path we cut off before moving into node we are backtracking from
searchIndex -= len(paramValues[paramIndex]) searchIndex -= len((*pathParams)[paramIndex].Value)
paramValues[paramIndex] = "" (*pathParams)[paramIndex].Value = ""
} }
search = path[searchIndex:] search = path[searchIndex:]
return return
@ -456,7 +796,7 @@ func (r *Router) Find(method, path string, c Context) {
// No matching prefix, let's backtrack to the first possible alternative node of the decision path // No matching prefix, let's backtrack to the first possible alternative node of the decision path
nk, ok := backtrackToNextNodeKind(staticKind) nk, ok := backtrackToNextNodeKind(staticKind)
if !ok { if !ok {
return // No other possibilities on the decision path break // No other possibilities on the decision path
} else if nk == paramKind { } else if nk == paramKind {
goto Param goto Param
// NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently // NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently
@ -479,8 +819,8 @@ func (r *Router) Find(method, path string, c Context) {
if previousBestMatchNode == nil { if previousBestMatchNode == nil {
previousBestMatchNode = currentNode previousBestMatchNode = currentNode
} }
if h := currentNode.findHandler(method); h != nil { if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
matchedHandler = h matchedRouteMethod = rMethod
break break
} }
} }
@ -507,7 +847,7 @@ func (r *Router) Find(method, path string, c Context) {
} }
} }
paramValues[paramIndex] = search[:i] (*pathParams)[paramIndex].Value = search[:i]
paramIndex++ paramIndex++
search = search[i:] search = search[i:]
searchIndex = searchIndex + i searchIndex = searchIndex + i
@ -519,7 +859,7 @@ func (r *Router) Find(method, path string, c Context) {
if child := currentNode.anyChild; child != nil { if child := currentNode.anyChild; child != nil {
// If any node is found, use remaining path for paramValues // If any node is found, use remaining path for paramValues
currentNode = child currentNode = child
paramValues[len(currentNode.pnames)-1] = search (*pathParams)[currentNode.paramsCount-1].Value = search
// update indexes/search in case we need to backtrack when no handler match is found // update indexes/search in case we need to backtrack when no handler match is found
paramIndex++ paramIndex++
searchIndex += +len(search) searchIndex += +len(search)
@ -530,8 +870,8 @@ func (r *Router) Find(method, path string, c Context) {
if previousBestMatchNode == nil { if previousBestMatchNode == nil {
previousBestMatchNode = currentNode previousBestMatchNode = currentNode
} }
if h := currentNode.findHandler(method); h != nil { if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
matchedHandler = h matchedRouteMethod = rMethod
break break
} }
} }
@ -550,20 +890,63 @@ func (r *Router) Find(method, path string, c Context) {
} }
} }
result := RouteMatch{
Type: RouteMatchNotFound,
Handler: notFoundHandler,
RoutePath: "",
RouteInfo: notFoundRouteInfo,
}
if currentNode == nil && previousBestMatchNode == nil { if currentNode == nil && previousBestMatchNode == nil {
return // nothing matched at all *pathParams = (*pathParams)[0:0]
return result // nothing matched at all with given path
} }
if matchedHandler != nil { if matchedRouteMethod != nil {
ctx.handler = matchedHandler result.Type = RouteMatchFound
result.Handler = matchedRouteMethod.handler
result.RoutePath = matchedRouteMethod.routeInfo.path
result.RouteInfo = matchedRouteMethod.routeInfo
} else { } else {
// use previous match as basis. although we have no matching handler we have path match. // use previous match as basis. although we have no matching handler we have path match.
// so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404) // so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404)
currentNode = previousBestMatchNode currentNode = previousBestMatchNode
ctx.handler = currentNode.checkMethodNotAllowed()
}
ctx.path = currentNode.ppath
ctx.pnames = currentNode.pnames
return // this here is only reason why `RouteMatch.RoutePath` exists. We do not want to create new RouteInfo just for path.
result.RoutePath = currentNode.originalPath
if currentNode.isHandler {
// TODO: in case of OPTIONS method we could respond with list of methods that node has. See https://httpwg.org/specs/rfc7231.html#OPTIONS
result.Type = RouteMatchMethodNotAllowed
result.Handler = methodNotAllowedHandler
result.RouteInfo = methodNotAllowedRouteInfo
}
}
*pathParams = (*pathParams)[0:currentNode.paramsCount]
if matchedRouteMethod != nil {
for i, name := range matchedRouteMethod.params {
(*pathParams)[i].Name = name
}
}
if r.unescapePathParamValues && currentNode.kind != staticKind {
// See issue #1531, #1258 - there are cases when path parameter need to be unescaped
for i, p := range *pathParams {
tmpVal, err := url.PathUnescape(p.Value)
if err == nil { // handle problems by ignoring them.
(*pathParams)[i].Value = tmpVal
}
}
}
return result
}
// Get returns path parameter value for given name or default value.
func (p PathParams) Get(name string, defaultValue string) string {
for _, param := range p {
if param.Name == name {
return param.Value
}
}
return defaultValue
} }

File diff suppressed because it is too large Load Diff

220
server.go Normal file
View File

@ -0,0 +1,220 @@
package echo
import (
stdContext "context"
"crypto/tls"
"fmt"
"io/fs"
"log"
"net"
"net/http"
"os"
"time"
)
const (
website = "https://echo.labstack.com"
// http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo
banner = `
____ __
/ __/___/ / ___
/ _// __/ _ \/ _ \
/___/\__/_//_/\___/ %s
High performance, minimalist Go web framework
%s
____________________________________O/_______
O\
`
)
// StartConfig is for creating configured http.Server instance to start serve http(s) requests with given Echo instance
type StartConfig struct {
// Address for the server to listen on (if not using custom listener)
Address string
// ListenerNetwork allows setting listener network (see net.Listen for allowed values)
// Optional: defaults to "tcp"
ListenerNetwork string
// CertFilesystem is file system used to load certificates and keys (if certs/keys are given as paths)
CertFilesystem fs.FS
// DisableHTTP2 disables supports for HTTP2 in TLS server
DisableHTTP2 bool
// HideBanner does not log Echo banner on server startup
HideBanner bool
// HidePort does not log port on server startup
HidePort bool
// GracefulContext is context that completion signals graceful shutdown start
GracefulContext stdContext.Context
// GracefulTimeout is period which server allows listeners to finish serving ongoing requests. If this time is exceeded process is exited
// Defaults to 10 seconds
GracefulTimeout time.Duration
// OnShutdownError allows customization of what happens when (graceful) server Shutdown method returns an error.
// Defaults to calling e.logger.Error(err)
OnShutdownError func(err error)
// TLSConfigFunc allows modifying TLS configuration before listener is created with it.
TLSConfigFunc func(tlsConfig *tls.Config)
// ListenerAddrFunc allows getting listener address before server starts serving requests on listener. Useful when
// address is set as random (`:0`) port.
ListenerAddrFunc func(addr net.Addr)
// BeforeServeFunc allows customizing/accessing server before server starts serving requests on listener.
BeforeServeFunc func(s *http.Server) error
}
// Start starts a HTTP server.
func (sc StartConfig) Start(e *Echo) error {
logger := e.Logger
server := http.Server{
Handler: e,
ErrorLog: log.New(logger, "", 0),
}
var tlsConfig *tls.Config = nil
if sc.TLSConfigFunc != nil {
tlsConfig = &tls.Config{}
configureTLS(&sc, tlsConfig)
sc.TLSConfigFunc(tlsConfig)
}
listener, err := createListener(&sc, tlsConfig)
if err != nil {
return err
}
return serve(&sc, &server, listener, logger)
}
// StartTLS starts a HTTPS server.
// If `certFile` or `keyFile` is `string` the values are treated as file paths.
// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is.
func (sc StartConfig) StartTLS(e *Echo, certFile, keyFile interface{}) error {
logger := e.Logger
s := http.Server{
Handler: e,
ErrorLog: log.New(logger, "", 0),
}
certFs := sc.CertFilesystem
if certFs == nil {
certFs = os.DirFS(".")
}
cert, err := filepathOrContent(certFile, certFs)
if err != nil {
return err
}
key, err := filepathOrContent(keyFile, certFs)
if err != nil {
return err
}
cer, err := tls.X509KeyPair(cert, key)
if err != nil {
return err
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cer}}
configureTLS(&sc, tlsConfig)
if sc.TLSConfigFunc != nil {
sc.TLSConfigFunc(tlsConfig)
}
listener, err := createListener(&sc, tlsConfig)
if err != nil {
return err
}
return serve(&sc, &s, listener, logger)
}
func serve(sc *StartConfig, server *http.Server, listener net.Listener, logger Logger) error {
if sc.BeforeServeFunc != nil {
if err := sc.BeforeServeFunc(server); err != nil {
return err
}
}
startupGreetings(sc, logger, listener)
if sc.GracefulContext != nil {
ctx, cancel := stdContext.WithCancel(sc.GracefulContext)
defer cancel() // make sure this graceful coroutine will end when serve returns by some other means
go gracefulShutdown(ctx, sc, server, logger)
}
return server.Serve(listener)
}
func configureTLS(sc *StartConfig, tlsConfig *tls.Config) {
if !sc.DisableHTTP2 {
tlsConfig.NextProtos = append(tlsConfig.NextProtos, "h2")
}
}
func createListener(sc *StartConfig, tlsConfig *tls.Config) (net.Listener, error) {
listenerNetwork := sc.ListenerNetwork
if listenerNetwork == "" {
listenerNetwork = "tcp"
}
var listener net.Listener
var err error
if tlsConfig != nil {
listener, err = tls.Listen(listenerNetwork, sc.Address, tlsConfig)
} else {
listener, err = net.Listen(listenerNetwork, sc.Address)
}
if err != nil {
return nil, err
}
if sc.ListenerAddrFunc != nil {
sc.ListenerAddrFunc(listener.Addr())
}
return listener, nil
}
func startupGreetings(sc *StartConfig, logger Logger, listener net.Listener) {
if !sc.HideBanner {
bannerText := fmt.Sprintf(banner, "v"+Version, website)
logger.Write([]byte(bannerText))
}
if !sc.HidePort {
logger.Write([]byte(fmt.Sprintf("⇨ http(s) server started on %s\n", listener.Addr())))
}
}
func filepathOrContent(fileOrContent interface{}, certFilesystem fs.FS) (content []byte, err error) {
switch v := fileOrContent.(type) {
case string:
return fs.ReadFile(certFilesystem, v)
case []byte:
return v, nil
default:
return nil, ErrInvalidCertOrKeyType
}
}
func gracefulShutdown(gracefulCtx stdContext.Context, sc *StartConfig, server *http.Server, logger Logger) {
<-gracefulCtx.Done() // wait until shutdown context is closed.
// note: is server if closed by other means this method is still run but is good as no-op
timeout := sc.GracefulTimeout
if timeout == 0 {
timeout = 10 * time.Second
}
shutdownCtx, cancel := stdContext.WithTimeout(stdContext.Background(), timeout)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
// we end up here when listeners are not shut down within given timeout
if sc.OnShutdownError != nil {
sc.OnShutdownError(err)
return
}
logger.Error(fmt.Errorf("failed to shut down server within given timeout: %w", err))
}
}

815
server_test.go Normal file
View File

@ -0,0 +1,815 @@
package echo
import (
"bytes"
stdContext "context"
"crypto/tls"
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"strings"
"testing"
"time"
)
func startOnRandomPort(ctx stdContext.Context, e *Echo) (string, error) {
addrChan := make(chan string)
errCh := make(chan error)
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
return waitForServerStart(addrChan, errCh)
}
func waitForServerStart(addrChan <-chan string, errCh <-chan error) (string, error) {
waitCtx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer cancel()
// wait for addr to arrive
for {
select {
case <-waitCtx.Done():
return "", waitCtx.Err()
case addr := <-addrChan:
return addr, nil
case err := <-errCh:
if err == http.ErrServerClosed { // was closed normally before listener callback was called. should not be possible
return "", nil
}
// failed to start and we did not manage to get even listener part.
return "", err
}
}
}
func doGet(url string) (int, string, error) {
resp, err := http.Get(url)
if err != nil {
return 0, "", err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return resp.StatusCode, "", err
}
return resp.StatusCode, string(body), nil
}
func TestStartConfig_Start(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
// check if server is actually up
code, body, err := doGet(fmt.Sprintf("http://%v/ok", addr))
if err != nil {
assert.NoError(t, err)
return
}
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, "OK", body)
shutdown()
<-errCh // we will be blocking here until server returns from http.Serve
// check if server was stopped
code, body, err = doGet(fmt.Sprintf("http://%v/ok", addr))
assert.Equal(t, 0, code)
assert.Equal(t, "", body)
if err == nil {
t.Errorf("missing error")
return
}
assert.True(t, strings.Contains(err.Error(), "connect: connection refused"))
}
func TestStartConfig_GracefulShutdown(t *testing.T) {
var testCases = []struct {
name string
whenHandlerTakesLonger bool
expectBody string
expectGracefulError string
}{
{
name: "ok, all handlers returns before graceful shutdown deadline",
whenHandlerTakesLonger: false,
expectBody: "OK",
expectGracefulError: "",
},
{
name: "nok, handlers do not returns before graceful shutdown deadline",
whenHandlerTakesLonger: true,
expectBody: "timeout",
expectGracefulError: stdContext.DeadlineExceeded.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
msg := "OK"
if tc.whenHandlerTakesLonger {
time.Sleep(150 * time.Millisecond)
msg = "timeout"
}
return c.String(http.StatusOK, msg)
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 50*time.Millisecond)
defer shutdown()
shutdownErrChan := make(chan error, 1)
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 50 * time.Millisecond,
OnShutdownError: func(err error) {
shutdownErrChan <- err
},
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
code, body, err := doGet(fmt.Sprintf("http://%v/ok", addr))
if err != nil {
assert.NoError(t, err)
return
}
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, tc.expectBody, body)
var shutdownErr error
select {
case shutdownErr = <-shutdownErrChan:
default:
}
if tc.expectGracefulError != "" {
assert.EqualError(t, shutdownErr, tc.expectGracefulError)
} else {
assert.NoError(t, shutdownErr)
}
shutdown()
<-errCh // we will be blocking here until server returns from http.Serve
// check if server was stopped
code, body, err = doGet(fmt.Sprintf("http://%v/ok", addr))
assert.Error(t, err)
if err != nil {
assert.True(t, strings.Contains(err.Error(), "connect: connection refused"))
}
assert.Equal(t, 0, code)
assert.Equal(t, "", body)
})
}
}
func TestStartConfig_Start_withTLSConfigFunc(t *testing.T) {
e := New()
tlsConfigCalled := false
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, errors.New("not_implemented")
}
tlsConfigCalled = true
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.Start(e)
assert.EqualError(t, err, "stop_now")
assert.True(t, tlsConfigCalled)
}
func TestStartConfig_Start_createListenerError(t *testing.T) {
e := New()
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.Start(e)
assert.EqualError(t, err, "tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
}
func TestStartConfig_StartTLS(t *testing.T) {
var testCases = []struct {
name string
addr string
certFile string
keyFile string
expectError string
}{
{
name: "ok",
addr: ":0",
},
{
name: "nok, invalid certFile",
addr: ":0",
certFile: "not existing",
expectError: "open not existing: no such file or directory",
},
{
name: "nok, invalid keyFile",
addr: ":0",
keyFile: "not existing",
expectError: "open not existing: no such file or directory",
},
{
name: "nok, failed to create cert out of certFile and keyFile",
addr: ":0",
keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key
expectError: "tls: found a certificate rather than a key in the PEM for the private key",
},
{
name: "nok, invalid tls address",
addr: "nope",
expectError: "listen tcp: address nope: missing port in address",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
certFile := "_fixture/certs/cert.pem"
if tc.certFile != "" {
certFile = tc.certFile
}
keyFile := "_fixture/certs/key.pem"
if tc.keyFile != "" {
keyFile = tc.keyFile
}
s := &StartConfig{
Address: tc.addr,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, certFile, keyFile)
}()
_, err := waitForServerStart(addrChan, errCh)
if tc.expectError != "" {
if _, ok := err.(*os.PathError); ok {
assert.Error(t, err) // error messages for unix and windows are different. so name only error type here
} else {
assert.EqualError(t, err, tc.expectError)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestStartConfig_StartTLS_withTLSConfigFunc(t *testing.T) {
e := New()
tlsConfigCalled := false
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
assert.Len(t, tlsConfig.Certificates, 1)
tlsConfigCalled = true
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.StartTLS(e, "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
assert.EqualError(t, err, "stop_now")
assert.True(t, tlsConfigCalled)
}
func TestStartConfig_StartTLSAndStart(t *testing.T) {
// We name if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server
e := New()
e.GET("/", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
tlsCtx, tlsShutdown := stdContext.WithTimeout(stdContext.Background(), 100*time.Millisecond)
defer tlsShutdown()
addrTLSChan := make(chan string)
errTLSChan := make(chan error)
go func() {
s := &StartConfig{
Address: ":0",
GracefulContext: tlsCtx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrTLSChan <- addr.String()
},
}
errTLSChan <- s.StartTLS(e, "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
}()
tlsAddr, err := waitForServerStart(addrTLSChan, errTLSChan)
assert.NoError(t, err)
// check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true)
client := &http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}}
res, err := client.Get(fmt.Sprintf("https://%v", tlsAddr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 100*time.Millisecond)
defer shutdown()
addrChan := make(chan string)
errChan := make(chan error)
go func() {
s := &StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errChan <- s.Start(e)
}()
addr, err := waitForServerStart(addrChan, errChan)
assert.NoError(t, err)
// now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS
res, err = client.Get(fmt.Sprintf("http://%v", addr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
// see if HTTPS works after HTTP listener is also added
res, err = client.Get(fmt.Sprintf("https://%v", tlsAddr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestFilepathOrContent(t *testing.T) {
cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
require.NoError(t, err)
key, err := ioutil.ReadFile("_fixture/certs/key.pem")
require.NoError(t, err)
testCases := []struct {
name string
cert interface{}
key interface{}
expectedErr error
}{
{
name: `ValidCertAndKeyFilePath`,
cert: "_fixture/certs/cert.pem",
key: "_fixture/certs/key.pem",
expectedErr: nil,
},
{
name: `ValidCertAndKeyByteString`,
cert: cert,
key: key,
expectedErr: nil,
},
{
name: `InvalidKeyType`,
cert: cert,
key: 1,
expectedErr: ErrInvalidCertOrKeyType,
},
{
name: `InvalidCertType`,
cert: 0,
key: key,
expectedErr: ErrInvalidCertOrKeyType,
},
{
name: `InvalidCertAndKeyTypes`,
cert: 0,
key: 1,
expectedErr: ErrInvalidCertOrKeyType,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
s := &StartConfig{
Address: ":0",
CertFilesystem: os.DirFS("."),
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, tc.cert, tc.key)
}()
_, err := waitForServerStart(addrChan, errCh)
if tc.expectedErr != nil {
assert.EqualError(t, err, tc.expectedErr.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func supportsIPv6() bool {
addrs, _ := net.InterfaceAddrs()
for _, addr := range addrs {
// Check if any interface has local IPv6 assigned
if strings.Contains(addr.String(), "::1") {
return true
}
}
return false
}
func TestStartConfig_WithListenerNetwork(t *testing.T) {
testCases := []struct {
name string
network string
address string
}{
{
name: "tcp ipv4 address",
network: "tcp",
address: "127.0.0.1:1323",
},
{
name: "tcp ipv6 address",
network: "tcp",
address: "[::1]:1323",
},
{
name: "tcp4 ipv4 address",
network: "tcp4",
address: "127.0.0.1:1323",
},
{
name: "tcp6 ipv6 address",
network: "tcp6",
address: "[::1]:1323",
},
}
hasIPv6 := supportsIPv6()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if !hasIPv6 && strings.Contains(tc.address, "::") {
t.Skip("Skipping testing IPv6 for " + tc.address + ", not available")
}
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
s := &StartConfig{
Address: tc.address,
ListenerNetwork: tc.network,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.Start(e)
}()
_, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
code, body, err := doGet(fmt.Sprintf("http://%s/ok", tc.address))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, "OK", body)
})
}
}
func TestStartConfig_WithHideBanner(t *testing.T) {
var testCases = []struct {
name string
hideBanner bool
}{
{
name: "hide banner on startup",
hideBanner: true,
},
{
name: "show banner on startup",
hideBanner: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Logger = &testLogger{output: buf}
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
_, err := waitForServerStart(addrChan, errCh)
errCh <- err
shutdown()
}()
s := &StartConfig{
Address: ":0",
HideBanner: tc.hideBanner,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
if err := s.Start(e); err != http.ErrServerClosed {
assert.NoError(t, err)
}
assert.NoError(t, <-errCh)
contains := strings.Contains(buf.String(), "High performance, minimalist Go web framework")
if tc.hideBanner {
assert.False(t, contains)
} else {
assert.True(t, contains)
}
})
}
}
func TestStartConfig_WithHidePort(t *testing.T) {
var testCases = []struct {
name string
hidePort bool
}{
{
name: "hide port on startup",
hidePort: true,
},
{
name: "show port on startup",
hidePort: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Logger = &testLogger{output: buf}
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error, 1)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
go func() {
_, err := waitForServerStart(addrChan, errCh)
errCh <- err
shutdown()
}()
s := &StartConfig{
Address: ":0",
HidePort: tc.hidePort,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
if err := s.Start(e); err != http.ErrServerClosed {
assert.NoError(t, err)
}
assert.NoError(t, <-errCh)
portMsg := fmt.Sprintf("http(s) server started on")
contains := strings.Contains(buf.String(), portMsg)
if tc.hidePort {
assert.False(t, contains)
} else {
assert.True(t, contains)
}
})
}
}
func TestStartConfig_WithBeforeServeFunc(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
s := &StartConfig{
Address: ":0",
BeforeServeFunc: func(s *http.Server) error {
return errors.New("is called before serve")
},
}
err := s.Start(e)
assert.EqualError(t, err, "is called before serve")
}
func TestWithDisableHTTP2(t *testing.T) {
var testCases = []struct {
name string
disableHTTP2 bool
}{
{
name: "HTTP2 enabled",
disableHTTP2: false,
},
{
name: "HTTP2 disabled",
disableHTTP2: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error, 1)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
certFile := "_fixture/certs/cert.pem"
keyFile := "_fixture/certs/key.pem"
s := &StartConfig{
Address: ":0",
DisableHTTP2: tc.disableHTTP2,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, certFile, keyFile)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
url := fmt.Sprintf("https://%v/ok", addr)
// do ordinary http(s) request
client := &http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}}
res, err := client.Get(url)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
// do HTTP2 request
client.Transport = &http2.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
resp, err := client.Get(url)
if err != nil {
if tc.disableHTTP2 {
assert.True(t, strings.Contains(err.Error(), `http2: unexpected ALPN protocol ""; want "h2"`))
return
}
log.Fatalf("Failed get: %s", err)
}
assert.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatalf("Failed reading response body: %s", err)
}
assert.Equal(t, "OK", string(body))
})
}
}
type testLogger struct {
output io.Writer
}
func (l *testLogger) Write(p []byte) (n int, err error) {
return l.output.Write(p)
}
func (l *testLogger) Printf(format string, args ...interface{}) {
_, _ = l.output.Write([]byte(fmt.Sprintf(format, args...)))
}
func (l *testLogger) Error(err error) {
_, _ = l.output.Write([]byte(err.Error()))
}