diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 26640666..ecc508e6 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -27,7 +27,8 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases - go: [1.14, 1.15, 1.16, 1.17] + # except v5 starts from 1.16 until there is last four major releases after that + go: [1.16, 1.17] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 67d45ad7..00000000 --- a/.travis.yml +++ /dev/null @@ -1,21 +0,0 @@ -arch: - - amd64 - - ppc64le - -language: go -go: - - 1.14.x - - 1.15.x - - tip -env: - - GO111MODULE=on -install: - - go get -v golang.org/x/lint/golint -script: - - golint -set_exit_status ./... - - go test -race -coverprofile=coverage.txt -covermode=atomic ./... -after_success: - - bash <(curl -s https://codecov.io/bash) -matrix: - allow_failures: - - go: tip diff --git a/Makefile b/Makefile index 48061f7e..10f9c8f5 100644 --- a/Makefile +++ b/Makefile @@ -24,11 +24,11 @@ race: ## Run tests with data race detector @go test -race ${PKG_LIST} benchmark: ## Run benchmarks - @go test -run="-" -bench=".*" ${PKG_LIST} + @go test -run="-" -benchmem -bench=".*" ${PKG_LIST} help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.15" -test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.15 +goversion ?= "1.16" +test_version: ## Run tests inside Docker with given version (defaults to 1.16 oldest supported). Example: make test_version goversion=1.16 @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" diff --git a/README.md b/README.md index 364f98ac..eb9369ad 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ ## Supported Go versions +Echo supports last four major releases. `v5` starts from 1.16 until there is last four major releases after that. + As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). Therefore a Go version capable of understanding /vN suffixed imports is required: @@ -67,8 +69,8 @@ package main import ( "net/http" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" ) func main() { @@ -83,7 +85,9 @@ func main() { e.GET("/", hello) // Start server - e.Logger.Fatal(e.Start(":1323")) + if err := e.Start(":1323"); err != http.ErrServerClosed { + log.Fatal(err) + } } // Handler diff --git a/bind.go b/bind.go index fdf0524c..a650ba0b 100644 --- a/bind.go +++ b/bind.go @@ -11,42 +11,38 @@ import ( "strings" ) -type ( - // Binder is the interface that wraps the Bind method. - Binder interface { - Bind(i interface{}, c Context) error - } +// Binder is the interface that wraps the Bind method. +type Binder interface { + Bind(c Context, i interface{}) error +} - // DefaultBinder is the default implementation of the Binder interface. - DefaultBinder struct{} +// DefaultBinder is the default implementation of the Binder interface. +type DefaultBinder struct{} - // BindUnmarshaler is the interface used to wrap the UnmarshalParam method. - // Types that don't implement this, but do implement encoding.TextUnmarshaler - // will use that interface instead. - BindUnmarshaler interface { - // UnmarshalParam decodes and assigns a value from an form or query param. - UnmarshalParam(param string) error - } -) +// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. +// Types that don't implement this, but do implement encoding.TextUnmarshaler +// will use that interface instead. +type BindUnmarshaler interface { + // UnmarshalParam decodes and assigns a value from an form or query param. + UnmarshalParam(param string) error +} // BindPathParams binds path params to bindable object -func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { - names := c.ParamNames() - values := c.ParamValues() +func BindPathParams(c Context, i interface{}) error { params := map[string][]string{} - for i, name := range names { - params[name] = []string{values[i]} + for _, param := range c.PathParams() { + params[param.Name] = []string{param.Value} } - if err := b.bindData(i, params, "param"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err := bindData(i, params, "param"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } // BindQueryParams binds query params to bindable object -func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { - if err := b.bindData(i, c.QueryParams(), "query"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) +func BindQueryParams(c Context, i interface{}) error { + if err := bindData(i, c.QueryParams(), "query"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } @@ -56,7 +52,7 @@ func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm // See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm -func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { +func BindBody(c Context, i interface{}) (err error) { req := c.Request() if req.ContentLength == 0 { return @@ -70,25 +66,25 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { case *HTTPError: return err default: - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } } case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML): if err = xml.NewDecoder(req.Body).Decode(i); err != nil { if ute, ok := err.(*xml.UnsupportedTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())) } else if se, ok := err.(*xml.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())) } - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): params, err := c.FormParams() if err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } - if err = b.bindData(i, params, "form"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err = bindData(i, params, "form"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } default: return ErrUnsupportedMediaType @@ -98,17 +94,17 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { // BindHeaders binds HTTP headers to a bindable object func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error { - if err := b.bindData(i, c.Request().Header, "header"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err := bindData(i, c.Request().Header, "header"); err != nil { + return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error()) } return nil } // Bind implements the `Binder#Bind` function. // Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous -// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. -func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { - if err := b.BindPathParams(c, i); err != nil { +// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. +func (b *DefaultBinder) Bind(c Context, i interface{}) (err error) { + if err := BindPathParams(c, i); err != nil { return err } // Issue #1670 - Query params are binded only for GET/DELETE and NOT for usual request with body (POST/PUT/PATCH) @@ -116,15 +112,15 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { // i.e. is `&id=1&lang=en` from URL same as `{"id":100,"lang":"de"}` request body and which one should have priority when binding. // This HTTP method check restores pre v4.1.11 behavior and avoids different problems when query is mixed with body if c.Request().Method == http.MethodGet || c.Request().Method == http.MethodDelete { - if err = b.BindQueryParams(c, i); err != nil { + if err = BindQueryParams(c, i); err != nil { return err } } - return b.BindBody(c, i) + return BindBody(c, i) } // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag -func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error { +func bindData(destination interface{}, data map[string][]string, tag string) error { if destination == nil || len(data) == 0 { return nil } @@ -170,7 +166,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). // structs that implement BindUnmarshaler are binded only when they have explicit tag if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { - if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { + if err := bindData(structField.Addr().Interface(), data, tag); err != nil { return err } } @@ -297,7 +293,7 @@ func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) { func setIntField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0" + return nil } intVal, err := strconv.ParseInt(value, 10, bitSize) if err == nil { @@ -308,7 +304,7 @@ func setIntField(value string, bitSize int, field reflect.Value) error { func setUintField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0" + return nil } uintVal, err := strconv.ParseUint(value, 10, bitSize) if err == nil { @@ -319,7 +315,7 @@ func setUintField(value string, bitSize int, field reflect.Value) error { func setBoolField(value string, field reflect.Value) error { if value == "" { - value = "false" + return nil } boolVal, err := strconv.ParseBool(value) if err == nil { @@ -330,7 +326,7 @@ func setBoolField(value string, field reflect.Value) error { func setFloatField(value string, bitSize int, field reflect.Value) error { if value == "" { - value = "0.0" + return nil } floatVal, err := strconv.ParseFloat(value, bitSize) if err == nil { diff --git a/bind_test.go b/bind_test.go index 4ed8dbb5..5b555248 100644 --- a/bind_test.go +++ b/bind_test.go @@ -300,6 +300,52 @@ func TestBindHeaderParamBadType(t *testing.T) { } } +func TestBind_CombineQueryWithHeaderParam(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/products/999?length=50&page=10&language=et", nil) + req.Header.Set("language", "de") + req.Header.Set("length", "99") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetPathParams(PathParams{{ + Name: "id", + Value: "999", + }}) + + type SearchOpts struct { + ID int `param:"id"` + Length int `query:"length"` + Page int `query:"page"` + Search string `query:"search"` + Language string `query:"language" header:"language"` + } + + opts := SearchOpts{ + Length: 100, + Page: 0, + Search: "default value", + Language: "en", + } + err := c.Bind(&opts) + assert.NoError(t, err) + + assert.Equal(t, 50, opts.Length) // bind from query + assert.Equal(t, 10, opts.Page) // bind from query + assert.Equal(t, 999, opts.ID) // bind from path param + assert.Equal(t, "et", opts.Language) // bind from query + assert.Equal(t, "default value", opts.Search) // default value stays + + // make sure another bind will not mess already set values unless there are new values + err = (&DefaultBinder{}).BindHeaders(c, &opts) + assert.NoError(t, err) + + assert.Equal(t, 50, opts.Length) // does not have tag in struct although header exists + assert.Equal(t, 10, opts.Page) + assert.Equal(t, 999, opts.ID) + assert.Equal(t, "de", opts.Language) // header overwrites now this value + assert.Equal(t, "default value", opts.Search) +} + func TestBindUnmarshalParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) @@ -330,7 +376,7 @@ func TestBindUnmarshalParam(t *testing.T) { func TestBindUnmarshalText(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { @@ -406,7 +452,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) { func TestBindUnmarshalTextPtr(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { @@ -439,8 +485,7 @@ func TestBindUnsupportedMediaType(t *testing.T) { func TestBindbindData(t *testing.T) { a := assert.New(t) ts := new(bindTestStruct) - b := new(DefaultBinder) - err := b.bindData(ts, values, "form") + err := bindData(ts, values, "form") a.NoError(err) a.Equal(0, ts.I) @@ -462,12 +507,15 @@ func TestBindbindData(t *testing.T) { func TestBindParam(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - c.SetPath("/users/:id/:name") - c.SetParamNames("id", "name") - c.SetParamValues("1", "Jon Snow") + cc := c.(EditableContext) + cc.SetRouteInfo(routeInfo{path: "/users/:id/:name"}) + cc.SetPathParams(PathParams{ + {Name: "id", Value: "1"}, + {Name: "name", Value: "Jon Snow"}, + }) u := new(user) err := c.Bind(u) @@ -478,9 +526,11 @@ func TestBindParam(t *testing.T) { // Second test for the absence of a param c2 := e.NewContext(req, rec) - c2.SetPath("/users/:id") - c2.SetParamNames("id") - c2.SetParamValues("1") + cc2 := c2.(EditableContext) + cc2.SetRouteInfo(routeInfo{path: "/users/:id"}) + cc2.SetPathParams(PathParams{ + {Name: "id", Value: "1"}, + }) u = new(user) err = c2.Bind(u) @@ -492,15 +542,17 @@ func TestBindParam(t *testing.T) { // Bind something with param and post data payload body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) e2 := New() - req2 := httptest.NewRequest(POST, "/", body) + req2 := httptest.NewRequest(http.MethodPost, "/", body) req2.Header.Set(HeaderContentType, MIMEApplicationJSON) rec2 := httptest.NewRecorder() c3 := e2.NewContext(req2, rec2) - c3.SetPath("/users/:id") - c3.SetParamNames("id") - c3.SetParamValues("1") + cc3 := c3.(EditableContext) + cc3.SetRouteInfo(routeInfo{path: "/users/:id"}) + cc3.SetPathParams(PathParams{ + {Name: "id", Value: "1"}, + }) u = new(user) err = c3.Bind(u) @@ -556,47 +608,115 @@ func TestBindSetWithProperType(t *testing.T) { assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) } -func TestBindSetFields(t *testing.T) { - assert := assert.New(t) +func TestSetIntField(t *testing.T) { + ts := new(bindTestStruct) + ts.I = 100 + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setIntField("", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 100, ts.I) + + // second set with value sets the value + err = setIntField("5", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 5, ts.I) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setIntField("", 0, val.FieldByName("I")) + assert.NoError(t, err) + assert.Equal(t, 5, ts.I) +} + +func TestSetUintField(t *testing.T) { + ts := new(bindTestStruct) + ts.UI = 100 + + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setUintField("", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(100), ts.UI) + + // second set with value sets the value + err = setUintField("5", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(5), ts.UI) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setUintField("", 0, val.FieldByName("UI")) + assert.NoError(t, err) + assert.Equal(t, uint(5), ts.UI) +} + +func TestSetFloatField(t *testing.T) { + ts := new(bindTestStruct) + ts.F32 = 100 + + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setFloatField("", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(100), ts.F32) + + // second set with value sets the value + err = setFloatField("15.5", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(15.5), ts.F32) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setFloatField("", 0, val.FieldByName("F32")) + assert.NoError(t, err) + assert.Equal(t, float32(15.5), ts.F32) +} + +func TestSetBoolField(t *testing.T) { + ts := new(bindTestStruct) + ts.B = true + + val := reflect.ValueOf(ts).Elem() + + // empty value does nothing to field + // in that way we can have default values by setting field value before binding + err := setBoolField("", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // second set with value sets the value + err = setBoolField("true", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // third set without value does nothing to the value + // in that way multiple binds (ala query + header) do not reset fields to 0s + err = setBoolField("", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, true, ts.B) + + // fourth set to false + err = setBoolField("false", val.FieldByName("B")) + assert.NoError(t, err) + assert.Equal(t, false, ts.B) +} + +func TestUnmarshalFieldNonPtr(t *testing.T) { ts := new(bindTestStruct) val := reflect.ValueOf(ts).Elem() - // Int - if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) { - assert.Equal(5, ts.I) - } - if assert.NoError(setIntField("", 0, val.FieldByName("I"))) { - assert.Equal(0, ts.I) - } - - // Uint - if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) { - assert.Equal(uint(10), ts.UI) - } - if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) { - assert.Equal(uint(0), ts.UI) - } - - // Float - if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) { - assert.Equal(float32(15.5), ts.F32) - } - if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) { - assert.Equal(float32(0.0), ts.F32) - } - - // Bool - if assert.NoError(setBoolField("true", val.FieldByName("B"))) { - assert.Equal(true, ts.B) - } - if assert.NoError(setBoolField("", val.FieldByName("B"))) { - assert.Equal(false, ts.B) - } ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T")) - if assert.NoError(err) { - assert.Equal(ok, true) - assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) + if assert.NoError(t, err) { + assert.True(t, ok) + assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) } } @@ -604,11 +724,10 @@ func BenchmarkBindbindDataWithTags(b *testing.B) { b.ReportAllocs() assert := assert.New(b) ts := new(bindTestStructWithTags) - binder := new(DefaultBinder) var err error b.ResetTimer() for i := 0; i < b.N; i++ { - err = binder.bindData(ts, values, "form") + err = bindData(ts, values, "form") } assert.NoError(err) assertBindTestStruct(assert, (*bindTestStruct)(ts)) @@ -840,8 +959,10 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { c := e.NewContext(req, rec) if !tc.whenNoPathParams { - c.SetParamNames("node") - c.SetParamValues("node_from_path") + cc := c.(EditableContext) + cc.SetPathParams(PathParams{ + {Name: "node", Value: "node_from_path"}, + }) } var bindTarget interface{} @@ -852,7 +973,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { } b := new(DefaultBinder) - err := b.Bind(bindTarget, c) + err := b.Bind(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -1021,8 +1142,10 @@ func TestDefaultBinder_BindBody(t *testing.T) { c := e.NewContext(req, rec) if !tc.whenNoPathParams { - c.SetParamNames("node") - c.SetParamValues("real_node") + cc := c.(EditableContext) + cc.SetPathParams(PathParams{ + {Name: "node", Value: "real_node"}, + }) } var bindTarget interface{} @@ -1031,9 +1154,8 @@ func TestDefaultBinder_BindBody(t *testing.T) { } else { bindTarget = &Node{} } - b := new(DefaultBinder) - err := b.BindBody(c, bindTarget) + err := BindBody(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { diff --git a/binder.go b/binder.go index 0900ce8d..402b80bc 100644 --- a/binder.go +++ b/binder.go @@ -118,10 +118,10 @@ func QueryParamsBinder(c Context) *ValueBinder { func PathParamsBinder(c Context) *ValueBinder { return &ValueBinder{ failFast: true, - ValueFunc: c.Param, + ValueFunc: c.PathParam, ValuesFunc: func(sourceParam string) []string { // path parameter should not have multiple values so getting values does not make sense but lets not error out here - value := c.Param(sourceParam) + value := c.PathParam(sourceParam) if value == "" { return nil } diff --git a/binder_go1.15_test.go b/binder_go1.15_test.go index 018628c3..4f1587bc 100644 --- a/binder_go1.15_test.go +++ b/binder_go1.15_test.go @@ -30,14 +30,15 @@ func createTestContext15(URL string, body io.Reader, pathParams map[string]strin c := e.NewContext(req, rec) if len(pathParams) > 0 { - names := make([]string, 0) - values := make([]string, 0) + params := make(PathParams, 0) for name, value := range pathParams { - names = append(names, name) - values = append(values, value) + params = append(params, PathParam{ + Name: name, + Value: value, + }) } - c.SetParamNames(names...) - c.SetParamValues(values...) + cc := c.(EditableContext) + cc.SetPathParams(params) } return c diff --git a/binder_test.go b/binder_test.go index 946906a9..00036863 100644 --- a/binder_test.go +++ b/binder_test.go @@ -25,14 +25,15 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string) c := e.NewContext(req, rec) if len(pathParams) > 0 { - names := make([]string, 0) - values := make([]string, 0) + params := make(PathParams, 0) for name, value := range pathParams { - names = append(names, name) - values = append(values, value) + params = append(params, PathParam{ + Name: name, + Value: value, + }) } - c.SetParamNames(names...) - c.SetParamValues(values...) + cc := c.(EditableContext) + cc.SetPathParams(params) } return c @@ -2643,7 +2644,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(&dest, c) + _ = binder.Bind(c, &dest) } } @@ -2710,7 +2711,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(&dest, c) + _ = binder.Bind(c, &dest) if dest.Int64 != 1 { b.Fatalf("int64!=1") } diff --git a/context.go b/context.go index 91ab6e48..65bedcdf 100644 --- a/context.go +++ b/context.go @@ -3,212 +3,233 @@ package echo import ( "bytes" "encoding/xml" + "errors" "fmt" "io" + "io/fs" "mime/multipart" "net" "net/http" "net/url" - "os" "path/filepath" "strings" "sync" ) -type ( - // Context represents the context of the current HTTP request. It holds request and - // response objects, path, path parameters, data and registered handler. - Context interface { - // Request returns `*http.Request`. - Request() *http.Request +// Context represents the context of the current HTTP request. It holds request and +// response objects, path, path parameters, data and registered handler. +type Context interface { + // Request returns `*http.Request`. + Request() *http.Request - // SetRequest sets `*http.Request`. - SetRequest(r *http.Request) + // SetRequest sets `*http.Request`. + SetRequest(r *http.Request) - // SetResponse sets `*Response`. - SetResponse(r *Response) + // SetResponse sets `*Response`. + SetResponse(r *Response) - // Response returns `*Response`. - Response() *Response + // Response returns `*Response`. + Response() *Response - // IsTLS returns true if HTTP connection is TLS otherwise false. - IsTLS() bool + // IsTLS returns true if HTTP connection is TLS otherwise false. + IsTLS() bool - // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. - IsWebSocket() bool + // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. + IsWebSocket() bool - // Scheme returns the HTTP protocol scheme, `http` or `https`. - Scheme() string + // Scheme returns the HTTP protocol scheme, `http` or `https`. + Scheme() string - // RealIP returns the client's network address based on `X-Forwarded-For` - // or `X-Real-IP` request header. - // The behavior can be configured using `Echo#IPExtractor`. - RealIP() string + // RealIP returns the client's network address based on `X-Forwarded-For` + // or `X-Real-IP` request header. + // The behavior can be configured using `Echo#IPExtractor`. + RealIP() string - // Path returns the registered path for the handler. - Path() string + // RouteMatchType returns router match type for current context. This helps middlewares to distinguish which type + // of match router found and how this request context handler chain could end: + // * route match - this path + method had matching route. + // * not found - this path did not match any routes enough to be considered match + // * method not allowed - path had routes registered but for other method types then current request is + // * unknown - initial state for fresh context before router tries to do routing + // + // Note: for pre-middleware (Echo.Pre) this method result is always RouteMatchUnknown as at point router has not tried + // to match request to route. + RouteMatchType() RouteMatchType - // SetPath sets the registered path for the handler. - SetPath(p string) + // RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route. + // In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases. + RouteInfo() RouteInfo - // Param returns path parameter by name. - Param(name string) string + // Path returns the registered path for the handler. + Path() string - // ParamNames returns path parameter names. - ParamNames() []string + // PathParam returns path parameter by name. + PathParam(name string) string - // SetParamNames sets path parameter names. - SetParamNames(names ...string) + // PathParams returns path parameter values. + PathParams() PathParams - // ParamValues returns path parameter values. - ParamValues() []string + // SetPathParams set path parameter for during current request lifecycle. + SetPathParams(params PathParams) - // SetParamValues sets path parameter values. - SetParamValues(values ...string) + // QueryParam returns the query param for the provided name. + QueryParam(name string) string - // QueryParam returns the query param for the provided name. - QueryParam(name string) string + // QueryParams returns the query parameters as `url.Values`. + QueryParams() url.Values - // QueryParams returns the query parameters as `url.Values`. - QueryParams() url.Values + // QueryString returns the URL query string. + QueryString() string - // QueryString returns the URL query string. - QueryString() string + // FormValue returns the form field value for the provided name. + FormValue(name string) string - // FormValue returns the form field value for the provided name. - FormValue(name string) string + // FormParams returns the form parameters as `url.Values`. + FormParams() (url.Values, error) - // FormParams returns the form parameters as `url.Values`. - FormParams() (url.Values, error) + // FormFile returns the multipart form file for the provided name. + FormFile(name string) (*multipart.FileHeader, error) - // FormFile returns the multipart form file for the provided name. - FormFile(name string) (*multipart.FileHeader, error) + // MultipartForm returns the multipart form. + MultipartForm() (*multipart.Form, error) - // MultipartForm returns the multipart form. - MultipartForm() (*multipart.Form, error) + // Cookie returns the named cookie provided in the request. + Cookie(name string) (*http.Cookie, error) - // Cookie returns the named cookie provided in the request. - Cookie(name string) (*http.Cookie, error) + // SetCookie adds a `Set-Cookie` header in HTTP response. + SetCookie(cookie *http.Cookie) - // SetCookie adds a `Set-Cookie` header in HTTP response. - SetCookie(cookie *http.Cookie) + // Cookies returns the HTTP cookies sent with the request. + Cookies() []*http.Cookie - // Cookies returns the HTTP cookies sent with the request. - Cookies() []*http.Cookie + // Get retrieves data from the context. + Get(key string) interface{} - // Get retrieves data from the context. - Get(key string) interface{} + // Set saves data in the context. + Set(key string, val interface{}) - // Set saves data in the context. - Set(key string, val interface{}) + // Bind binds the request body into provided type `i`. The default binder + // does it based on Content-Type header. + Bind(i interface{}) error - // Bind binds the request body into provided type `i`. The default binder - // does it based on Content-Type header. - Bind(i interface{}) error + // Validate validates provided `i`. It is usually called after `Context#Bind()`. + // Validator must be registered using `Echo#Validator`. + Validate(i interface{}) error - // Validate validates provided `i`. It is usually called after `Context#Bind()`. - // Validator must be registered using `Echo#Validator`. - Validate(i interface{}) error + // Render renders a template with data and sends a text/html response with status + // code. Renderer must be registered using `Echo.Renderer`. + Render(code int, name string, data interface{}) error - // Render renders a template with data and sends a text/html response with status - // code. Renderer must be registered using `Echo.Renderer`. - Render(code int, name string, data interface{}) error + // HTML sends an HTTP response with status code. + HTML(code int, html string) error - // HTML sends an HTTP response with status code. - HTML(code int, html string) error + // HTMLBlob sends an HTTP blob response with status code. + HTMLBlob(code int, b []byte) error - // HTMLBlob sends an HTTP blob response with status code. - HTMLBlob(code int, b []byte) error + // String sends a string response with status code. + String(code int, s string) error - // String sends a string response with status code. - String(code int, s string) error + // JSON sends a JSON response with status code. + JSON(code int, i interface{}) error - // JSON sends a JSON response with status code. - JSON(code int, i interface{}) error + // JSONPretty sends a pretty-print JSON with status code. + JSONPretty(code int, i interface{}, indent string) error - // JSONPretty sends a pretty-print JSON with status code. - JSONPretty(code int, i interface{}, indent string) error + // JSONBlob sends a JSON blob response with status code. + JSONBlob(code int, b []byte) error - // JSONBlob sends a JSON blob response with status code. - JSONBlob(code int, b []byte) error + // JSONP sends a JSONP response with status code. It uses `callback` to construct + // the JSONP payload. + JSONP(code int, callback string, i interface{}) error - // JSONP sends a JSONP response with status code. It uses `callback` to construct - // the JSONP payload. - JSONP(code int, callback string, i interface{}) error + // JSONPBlob sends a JSONP blob response with status code. It uses `callback` + // to construct the JSONP payload. + JSONPBlob(code int, callback string, b []byte) error - // JSONPBlob sends a JSONP blob response with status code. It uses `callback` - // to construct the JSONP payload. - JSONPBlob(code int, callback string, b []byte) error + // XML sends an XML response with status code. + XML(code int, i interface{}) error - // XML sends an XML response with status code. - XML(code int, i interface{}) error + // XMLPretty sends a pretty-print XML with status code. + XMLPretty(code int, i interface{}, indent string) error - // XMLPretty sends a pretty-print XML with status code. - XMLPretty(code int, i interface{}, indent string) error + // XMLBlob sends an XML blob response with status code. + XMLBlob(code int, b []byte) error - // XMLBlob sends an XML blob response with status code. - XMLBlob(code int, b []byte) error + // Blob sends a blob response with status code and content type. + Blob(code int, contentType string, b []byte) error - // Blob sends a blob response with status code and content type. - Blob(code int, contentType string, b []byte) error + // Stream sends a streaming response with status code and content type. + Stream(code int, contentType string, r io.Reader) error - // Stream sends a streaming response with status code and content type. - Stream(code int, contentType string, r io.Reader) error + // File sends a response with the content of the file. + File(file string) error - // File sends a response with the content of the file. - File(file string) error + // Attachment sends a response as attachment, prompting client to save the + // file. + Attachment(file string, name string) error - // Attachment sends a response as attachment, prompting client to save the - // file. - Attachment(file string, name string) error + // Inline sends a response as inline, opening the file in the browser. + Inline(file string, name string) error - // Inline sends a response as inline, opening the file in the browser. - Inline(file string, name string) error + // NoContent sends a response with no body and a status code. + NoContent(code int) error - // NoContent sends a response with no body and a status code. - NoContent(code int) error + // Redirect redirects the request to a provided URL with status code. + Redirect(code int, url string) error - // Redirect redirects the request to a provided URL with status code. - Redirect(code int, url string) error + // Error invokes the registered HTTP error handler. + // NB: Avoid using this method. It is better to return errors so middlewares up in chain could act on returned error. + Error(err error) - // Error invokes the registered HTTP error handler. Generally used by middleware. - Error(err error) + // Echo returns the `Echo` instance. + Echo() *Echo +} - // Handler returns the matched handler by router. - Handler() HandlerFunc +// EditableContext is additional interface that structure implementing Context must implement. Methods inside this +// interface are meant for Echo internal usage (for mainly routing) and should not be used in middlewares. +type EditableContext interface { + Context - // SetHandler sets the matched handler by router. - SetHandler(h HandlerFunc) + // RawPathParams returns raw path pathParams value. + RawPathParams() *PathParams - // Logger returns the `Logger` instance. - Logger() Logger + // SetRawPathParams replaces any existing param values with new values for this context lifetime (request). + SetRawPathParams(params *PathParams) - // Set the logger - SetLogger(l Logger) + // SetPath sets the registered path for the handler. + SetPath(p string) - // Echo returns the `Echo` instance. - Echo() *Echo + // SetRouteMatchType sets the RouteMatchType of router match for this request. + SetRouteMatchType(t RouteMatchType) - // Reset resets the context after request completes. It must be called along - // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. - // See `Echo#ServeHTTP()` - Reset(r *http.Request, w http.ResponseWriter) - } + // SetRouteInfo sets the route info of this request to the context. + SetRouteInfo(ri RouteInfo) - context struct { - request *http.Request - response *Response - path string - pnames []string - pvalues []string - query url.Values - handler HandlerFunc - store Map - echo *Echo - logger Logger - lock sync.RWMutex - } -) + // Reset resets the context after request completes. It must be called along + // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. + // See `Echo#ServeHTTP()` + Reset(r *http.Request, w http.ResponseWriter) +} + +type context struct { + request *http.Request + response *Response + + matchType RouteMatchType + route RouteInfo + path string + + // pathParams holds path/uri parameters determined by Router. Lifecycle is handled by Echo to reduce allocations. + pathParams *PathParams + // currentParams hold path parameters set by non-Echo implementation (custom middlewares, handlers) during the lifetime of Request. + // Lifecycle is not handle by Echo and could have excess allocations per served Request + currentParams PathParams + + query url.Values + store Map + echo *Echo + lock sync.RWMutex +} const ( defaultMemory = 32 << 20 // 32 MB @@ -296,52 +317,50 @@ func (c *context) SetPath(p string) { c.path = p } -func (c *context) Param(name string) string { - for i, n := range c.pnames { - if i < len(c.pvalues) { - if n == name { - return c.pvalues[i] - } - } - } - return "" +func (c *context) RouteMatchType() RouteMatchType { + return c.matchType } -func (c *context) ParamNames() []string { - return c.pnames +func (c *context) SetRouteMatchType(t RouteMatchType) { + c.matchType = t } -func (c *context) SetParamNames(names ...string) { - c.pnames = names - - l := len(names) - if *c.echo.maxParam < l { - *c.echo.maxParam = l - } - - if len(c.pvalues) < l { - // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them, - // probably those values will be overriden in a Context#SetParamValues - newPvalues := make([]string, l) - copy(newPvalues, c.pvalues) - c.pvalues = newPvalues - } +func (c *context) RouteInfo() RouteInfo { + return c.route } -func (c *context) ParamValues() []string { - return c.pvalues[:len(c.pnames)] +func (c *context) SetRouteInfo(ri RouteInfo) { + c.route = ri } -func (c *context) SetParamValues(values ...string) { - // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times - // It will brake the Router#Find code - limit := len(values) - if limit > *c.echo.maxParam { - limit = *c.echo.maxParam +func (c *context) RawPathParams() *PathParams { + return c.pathParams +} + +func (c *context) SetRawPathParams(params *PathParams) { + c.pathParams = params +} + +func (c *context) PathParam(name string) string { + if c.currentParams != nil { + return c.currentParams.Get(name, "") } - for i := 0; i < limit; i++ { - c.pvalues[i] = values[i] + + return c.pathParams.Get(name, "") +} + +func (c *context) PathParams() PathParams { + if c.currentParams != nil { + return c.currentParams } + + result := make(PathParams, len(*c.pathParams)) + copy(result, *c.pathParams) + return result +} + +func (c *context) SetPathParams(params PathParams) { + c.currentParams = params } func (c *context) QueryParam(name string) string { @@ -422,7 +441,7 @@ func (c *context) Set(key string, val interface{}) { } func (c *context) Bind(i interface{}) error { - return c.echo.Binder.Bind(i, c) + return c.echo.Binder.Bind(c, i) } func (c *context) Validate(i interface{}) error { @@ -562,27 +581,36 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error) return } -func (c *context) File(file string) (err error) { - f, err := os.Open(file) +func (c *context) File(file string) error { + return c.FsFile(file, c.echo.Filesystem) +} + +func (c *context) FsFile(file string, filesystem fs.FS) error { + // FIXME: should we add this method into echo.Context interface? + f, err := filesystem.Open(file) if err != nil { - return NotFoundHandler(c) + return ErrNotFound } defer f.Close() fi, _ := f.Stat() if fi.IsDir() { file = filepath.Join(file, indexPage) - f, err = os.Open(file) + f, err = filesystem.Open(file) if err != nil { - return NotFoundHandler(c) + return ErrNotFound } defer f.Close() if fi, err = f.Stat(); err != nil { - return + return err } } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) - return + ff, ok := f.(io.ReadSeeker) + if !ok { + return errors.New("file does not implement io.ReadSeeker") + } + http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) + return nil } func (c *context) Attachment(file, name string) error { @@ -613,44 +641,23 @@ func (c *context) Redirect(code int, url string) error { } func (c *context) Error(err error) { - c.echo.HTTPErrorHandler(err, c) + c.echo.HTTPErrorHandler(c, err) } func (c *context) Echo() *Echo { return c.echo } -func (c *context) Handler() HandlerFunc { - return c.handler -} - -func (c *context) SetHandler(h HandlerFunc) { - c.handler = h -} - -func (c *context) Logger() Logger { - res := c.logger - if res != nil { - return res - } - return c.echo.Logger -} - -func (c *context) SetLogger(l Logger) { - c.logger = l -} - func (c *context) Reset(r *http.Request, w http.ResponseWriter) { c.request = r c.response.reset(w) c.query = nil - c.handler = NotFoundHandler c.store = nil + + c.matchType = RouteMatchUnknown + c.route = nil c.path = "" - c.pnames = nil - c.logger = nil - // NOTE: Don't reset because it has to have length c.echo.maxParam at all times - for i := 0; i < *c.echo.maxParam; i++ { - c.pvalues[i] = "" - } + // NOTE: Don't reset because it has to have length c.echo.contextPathParamAllocSize at all times + *c.pathParams = (*c.pathParams)[:0] + c.currentParams = nil } diff --git a/context_test.go b/context_test.go index a8b9a994..0783f1a8 100644 --- a/context_test.go +++ b/context_test.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" "math" "mime/multipart" "net/http" @@ -18,21 +19,19 @@ import ( "text/template" "time" - "github.com/labstack/gommon/log" - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" ) -type ( - Template struct { - templates *template.Template - } -) +type Template struct { + templates *template.Template +} var testUser = user{1, "Jon Snow"} func BenchmarkAllocJSONP(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &jsonLogger{writer: ioutil.Discard} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) @@ -46,7 +45,8 @@ func BenchmarkAllocJSONP(b *testing.B) { func BenchmarkAllocJSON(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &jsonLogger{writer: ioutil.Discard} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) @@ -60,7 +60,8 @@ func BenchmarkAllocJSON(b *testing.B) { func BenchmarkAllocXML(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = &jsonLogger{writer: ioutil.Discard} + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) @@ -106,16 +107,14 @@ func TestContext(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) - assert := testify.New(t) - // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Render @@ -126,23 +125,23 @@ func TestContext(t *testing.T) { } c.echo.Renderer = tmpl err := c.Render(http.StatusOK, "hello", "Jon Snow") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, Jon Snow!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, Jon Snow!", rec.Body.String()) } c.echo.Renderer = nil err = c.Render(http.StatusOK, "hello", "Jon Snow") - assert.Error(err) + assert.Error(t, err) // JSON rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } // JSON with "?pretty" @@ -150,10 +149,10 @@ func TestContext(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } req = httptest.NewRequest(http.MethodGet, "/", nil) // reset @@ -161,37 +160,37 @@ func TestContext(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } // JSON (error) rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.JSON(http.StatusOK, make(chan bool)) - assert.Error(err) + assert.Error(t, err) // JSONP rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) callback := "callback" err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String()) } // XML rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) } // XML with "?pretty" @@ -199,10 +198,10 @@ func TestContext(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) } req = httptest.NewRequest(http.MethodGet, "/", nil) @@ -210,22 +209,22 @@ func TestContext(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.XML(http.StatusOK, make(chan bool)) - assert.Error(err) + assert.Error(t, err) // XML response write error c = e.NewContext(req, rec).(*context) c.response.Writer = responseWriterErr{} err = c.XML(0, 0) - testify.Error(t, err) + assert.Error(t, err) // XMLPretty rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) } t.Run("empty indent", func(t *testing.T) { @@ -237,7 +236,6 @@ func TestContext(t *testing.T) { t.Run("json", func(t *testing.T) { buf.Reset() - assert := testify.New(t) // New JSONBlob with empty indent rec = httptest.NewRecorder() @@ -246,16 +244,15 @@ func TestContext(t *testing.T) { enc.SetIndent(emptyIndent, emptyIndent) err = enc.Encode(u) err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(buf.String(), rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, buf.String(), rec.Body.String()) } }) t.Run("xml", func(t *testing.T) { buf.Reset() - assert := testify.New(t) // New XMLBlob with empty indent rec = httptest.NewRecorder() @@ -264,10 +261,10 @@ func TestContext(t *testing.T) { enc.Indent(emptyIndent, emptyIndent) err = enc.Encode(u) err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+buf.String(), rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+buf.String(), rec.Body.String()) } }) }) @@ -276,12 +273,12 @@ func TestContext(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) data, err := json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.JSONBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON, rec.Body.String()) } // Legacy JSONPBlob @@ -289,44 +286,44 @@ func TestContext(t *testing.T) { c = e.NewContext(req, rec).(*context) callback = "callback" data, err = json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.JSONPBlob(http.StatusOK, callback, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+");", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+");", rec.Body.String()) } // Legacy XMLBlob rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) data, err = xml.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.XMLBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) } // String rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.String(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) } // HTML rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.HTML(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) } // Stream @@ -334,55 +331,55 @@ func TestContext(t *testing.T) { c = e.NewContext(req, rec).(*context) r := strings.NewReader("response from a stream") err = c.Stream(http.StatusOK, "application/octet-stream", r) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType)) - assert.Equal("response from a stream", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType)) + assert.Equal(t, "response from a stream", rec.Body.String()) } // Attachment rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.Attachment("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) + assert.Equal(t, 219885, rec.Body.Len()) } // Inline rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = c.Inline("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) + assert.Equal(t, 219885, rec.Body.Len()) } // NoContent rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) c.NoContent(http.StatusOK) - assert.Equal(http.StatusOK, rec.Code) + assert.Equal(t, http.StatusOK, rec.Code) // Error rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) c.Error(errors.New("error")) - assert.Equal(http.StatusInternalServerError, rec.Code) + assert.Equal(t, http.StatusInternalServerError, rec.Code) // Reset - c.SetParamNames("foo") - c.SetParamValues("bar") + c.pathParams = &PathParams{ + {Name: "foo", Value: "bar"}, + } c.Set("foe", "ban") c.query = url.Values(map[string][]string{"fon": {"baz"}}) c.Reset(req, httptest.NewRecorder()) - assert.Equal(0, len(c.ParamValues())) - assert.Equal(0, len(c.ParamNames())) - assert.Equal(0, len(c.store)) - assert.Equal("", c.Path()) - assert.Equal(0, len(c.QueryParams())) + assert.Equal(t, 0, len(c.PathParams())) + assert.Equal(t, 0, len(c.store)) + assert.Equal(t, nil, c.RouteInfo()) + assert.Equal(t, 0, len(c.QueryParams())) } func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { @@ -392,11 +389,10 @@ func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { c := e.NewContext(req, rec).(*context) err := c.JSON(http.StatusCreated, user{1, "Jon Snow"}) - assert := testify.New(t) - if assert.NoError(err) { - assert.Equal(http.StatusCreated, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } } @@ -407,9 +403,8 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) { c := e.NewContext(req, rec).(*context) err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()}) - assert := testify.New(t) - if assert.Error(err) { - assert.False(c.response.Committed) + if assert.Error(t, err) { + assert.False(t, c.response.Committed) } } @@ -423,22 +418,20 @@ func TestContextCookie(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) - assert := testify.New(t) - // Read single cookie, err := c.Cookie("theme") - if assert.NoError(err) { - assert.Equal("theme", cookie.Name) - assert.Equal("light", cookie.Value) + if assert.NoError(t, err) { + assert.Equal(t, "theme", cookie.Name) + assert.Equal(t, "light", cookie.Value) } // Read multiple for _, cookie := range c.Cookies() { switch cookie.Name { case "theme": - assert.Equal("light", cookie.Value) + assert.Equal(t, "light", cookie.Value) case "user": - assert.Equal("Jon Snow", cookie.Value) + assert.Equal(t, "Jon Snow", cookie.Value) } } @@ -453,104 +446,95 @@ func TestContextCookie(t *testing.T) { HttpOnly: true, } c.SetCookie(cookie) - assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") - assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure") - assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly") -} - -func TestContextPath(t *testing.T) { - e := New() - r := e.Router() - - handler := func(c Context) error { return c.String(http.StatusOK, "OK") } - - r.Add(http.MethodGet, "/users/:id", handler) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1", c) - - assert := testify.New(t) - - assert.Equal("/users/:id", c.Path()) - - r.Add(http.MethodGet, "/users/:uid/files/:fid", handler) - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1/files/1", c) - assert.Equal("/users/:uid/files/:fid", c.Path()) + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } func TestContextPathParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, nil) + c := e.NewContext(req, nil).(*context) + params := &PathParams{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + } // ParamNames - c.SetParamNames("uid", "fid") - testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) - - // ParamValues - c.SetParamValues("101", "501") - testify.EqualValues(t, []string{"101", "501"}, c.ParamValues()) + c.pathParams = params + assert.EqualValues(t, *params, c.PathParams()) // Param - testify.Equal(t, "501", c.Param("fid")) - testify.Equal(t, "", c.Param("undefined")) + assert.Equal(t, "501", c.PathParam("fid")) + assert.Equal(t, "", c.PathParam("undefined")) } func TestContextGetAndSetParam(t *testing.T) { e := New() r := e.Router() - r.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) + _, err := r.Add(Route{ + Method: http.MethodGet, + Path: "/:foo", + Name: "", + Handler: func(Context) error { return nil }, + Middlewares: nil, + }) + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) c := e.NewContext(req, nil) - c.SetParamNames("foo") + + params := &PathParams{{Name: "foo", Value: "101"}} + // ParamNames + c.(*context).pathParams = params // round-trip param values with modification - paramVals := c.ParamValues() - testify.EqualValues(t, []string{""}, c.ParamValues()) - paramVals[0] = "bar" - c.SetParamValues(paramVals...) - testify.EqualValues(t, []string{"bar"}, c.ParamValues()) + paramVals := c.PathParams() + assert.Equal(t, *params, c.PathParams()) + + paramVals[0] = PathParam{Name: "xxx", Value: "yyy"} // PathParams() returns copy and modifying it does nothing to context + assert.Equal(t, PathParams{{Name: "foo", Value: "101"}}, c.PathParams()) + + pathParams := PathParams{ + {Name: "aaa", Value: "bbb"}, + {Name: "ccc", Value: "ddd"}, + } + c.SetPathParams(pathParams) + assert.Equal(t, pathParams, c.PathParams()) // shouldn't explode during Reset() afterwards! - testify.NotPanics(t, func() { - c.Reset(nil, nil) + assert.NotPanics(t, func() { + c.(EditableContext).Reset(nil, nil) }) + assert.Equal(t, PathParams{}, c.PathParams()) + assert.Len(t, *c.(*context).pathParams, 0) + assert.Equal(t, cap(*c.(*context).pathParams), 1) } // Issue #1655 -func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) { - assert := testify.New(t) - +func TestContext_SetParamNamesShouldNotModifyPathParams(t *testing.T) { e := New() - assert.Equal(0, *e.maxParam) + c := e.NewContext(nil, nil).(*context) - expectedOneParam := []string{"one"} - expectedTwoParams := []string{"one", "two"} - expectedThreeParams := []string{"one", "two", ""} - expectedABCParams := []string{"A", "B", "C"} + assert.Equal(t, 0, e.contextPathParamAllocSize) + expectedTwoParams := &PathParams{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + } + c.SetRawPathParams(expectedTwoParams) + assert.Equal(t, 0, e.contextPathParamAllocSize) + assert.Equal(t, *expectedTwoParams, c.PathParams()) - c := e.NewContext(nil, nil) - c.SetParamNames("1", "2") - c.SetParamValues(expectedTwoParams...) - assert.Equal(2, *e.maxParam) - assert.EqualValues(expectedTwoParams, c.ParamValues()) - - c.SetParamNames("1") - assert.Equal(2, *e.maxParam) - // Here for backward compatibility the ParamValues remains as they are - assert.EqualValues(expectedOneParam, c.ParamValues()) - - c.SetParamNames("1", "2", "3") - assert.Equal(3, *e.maxParam) - // Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam - assert.EqualValues(expectedThreeParams, c.ParamValues()) - - c.SetParamValues("A", "B", "C", "D") - assert.Equal(3, *e.maxParam) - // Here D shouldn't be returned - assert.EqualValues(expectedABCParams, c.ParamValues()) + expectedThreeParams := PathParams{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + {Name: "3", Value: "three"}, + } + c.SetPathParams(expectedThreeParams) + assert.Equal(t, 0, e.contextPathParamAllocSize) + assert.Equal(t, expectedThreeParams, c.PathParams()) } func TestContextFormValue(t *testing.T) { @@ -564,13 +548,13 @@ func TestContextFormValue(t *testing.T) { c := e.NewContext(req, nil) // FormValue - testify.Equal(t, "Jon Snow", c.FormValue("name")) - testify.Equal(t, "jon@labstack.com", c.FormValue("email")) + assert.Equal(t, "Jon Snow", c.FormValue("name")) + assert.Equal(t, "jon@labstack.com", c.FormValue("email")) // FormParams params, err := c.FormParams() - if testify.NoError(t, err) { - testify.Equal(t, url.Values{ + if assert.NoError(t, err) { + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, }, params) @@ -581,8 +565,8 @@ func TestContextFormValue(t *testing.T) { req.Header.Add(HeaderContentType, MIMEMultipartForm) c = e.NewContext(req, nil) params, err = c.FormParams() - testify.Nil(t, params) - testify.Error(t, err) + assert.Nil(t, params) + assert.Error(t, err) } func TestContextQueryParam(t *testing.T) { @@ -594,11 +578,11 @@ func TestContextQueryParam(t *testing.T) { c := e.NewContext(req, nil) // QueryParam - testify.Equal(t, "Jon Snow", c.QueryParam("name")) - testify.Equal(t, "jon@labstack.com", c.QueryParam("email")) + assert.Equal(t, "Jon Snow", c.QueryParam("name")) + assert.Equal(t, "jon@labstack.com", c.QueryParam("email")) // QueryParams - testify.Equal(t, url.Values{ + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, }, c.QueryParams()) @@ -609,7 +593,7 @@ func TestContextFormFile(t *testing.T) { buf := new(bytes.Buffer) mr := multipart.NewWriter(buf) w, err := mr.CreateFormFile("file", "test") - if testify.NoError(t, err) { + if assert.NoError(t, err) { w.Write([]byte("test")) } mr.Close() @@ -618,8 +602,8 @@ func TestContextFormFile(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.FormFile("file") - if testify.NoError(t, err) { - testify.Equal(t, "test", f.Filename) + if assert.NoError(t, err) { + assert.Equal(t, "test", f.Filename) } } @@ -634,8 +618,8 @@ func TestContextMultipartForm(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.MultipartForm() - if testify.NoError(t, err) { - testify.NotNil(t, f) + if assert.NoError(t, err) { + assert.NotNil(t, f) } } @@ -644,16 +628,16 @@ func TestContextRedirect(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) - testify.Equal(t, http.StatusMovedPermanently, rec.Code) - testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) - testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) + assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) + assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) } func TestContextStore(t *testing.T) { var c Context = new(context) c.Set("name", "Jon Snow") - testify.Equal(t, "Jon Snow", c.Get("name")) + assert.Equal(t, "Jon Snow", c.Get("name")) } func BenchmarkContext_Store(b *testing.B) { @@ -671,42 +655,6 @@ func BenchmarkContext_Store(b *testing.B) { } } -func TestContextHandler(t *testing.T) { - e := New() - r := e.Router() - b := new(bytes.Buffer) - - r.Add(http.MethodGet, "/handler", func(Context) error { - _, err := b.Write([]byte("handler")) - return err - }) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/handler", c) - err := c.Handler()(c) - testify.Equal(t, "handler", b.String()) - testify.NoError(t, err) -} - -func TestContext_SetHandler(t *testing.T) { - var c Context = new(context) - - testify.Nil(t, c.Handler()) - - c.SetHandler(func(c Context) error { - return nil - }) - testify.NotNil(t, c.Handler()) -} - -func TestContext_Path(t *testing.T) { - path := "/pa/th" - - var c Context = new(context) - - c.SetPath(path) - testify.Equal(t, path, c.Path()) -} - type validator struct{} func (*validator) Validate(i interface{}) error { @@ -717,10 +665,10 @@ func TestContext_Validate(t *testing.T) { e := New() c := e.NewContext(nil, nil) - testify.Error(t, c.Validate(struct{}{})) + assert.Error(t, c.Validate(struct{}{})) e.Validator = &validator{} - testify.NoError(t, c.Validate(struct{}{})) + assert.NoError(t, c.Validate(struct{}{})) } func TestContext_QueryString(t *testing.T) { @@ -728,21 +676,21 @@ func TestContext_QueryString(t *testing.T) { queryString := "query=string&var=val" - req := httptest.NewRequest(GET, "/?"+queryString, nil) + req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil) c := e.NewContext(req, nil) - testify.Equal(t, queryString, c.QueryString()) + assert.Equal(t, queryString, c.QueryString()) } func TestContext_Request(t *testing.T) { var c Context = new(context) - testify.Nil(t, c.Request()) + assert.Nil(t, c.Request()) - req := httptest.NewRequest(GET, "/path", nil) + req := httptest.NewRequest(http.MethodGet, "/path", nil) c.SetRequest(req) - testify.Equal(t, req, c.Request()) + assert.Equal(t, req, c.Request()) } func TestContext_Scheme(t *testing.T) { @@ -799,14 +747,14 @@ func TestContext_Scheme(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.Scheme()) + assert.Equal(t, tt.s, tt.c.Scheme()) } } func TestContext_IsWebSocket(t *testing.T) { tests := []struct { c Context - ws testify.BoolAssertionFunc + ws assert.BoolAssertionFunc }{ { &context{ @@ -814,7 +762,7 @@ func TestContext_IsWebSocket(t *testing.T) { Header: http.Header{HeaderUpgrade: []string{"websocket"}}, }, }, - testify.True, + assert.True, }, { &context{ @@ -822,13 +770,13 @@ func TestContext_IsWebSocket(t *testing.T) { Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, }, }, - testify.True, + assert.True, }, { &context{ request: &http.Request{}, }, - testify.False, + assert.False, }, { &context{ @@ -836,7 +784,7 @@ func TestContext_IsWebSocket(t *testing.T) { Header: http.Header{HeaderUpgrade: []string{"other"}}, }, }, - testify.False, + assert.False, }, } @@ -849,30 +797,14 @@ func TestContext_IsWebSocket(t *testing.T) { func TestContext_Bind(t *testing.T) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) c := e.NewContext(req, nil) u := new(user) req.Header.Add(HeaderContentType, MIMEApplicationJSON) err := c.Bind(u) - testify.NoError(t, err) - testify.Equal(t, &user{1, "Jon Snow"}, u) -} - -func TestContext_Logger(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) - - log1 := c.Logger() - testify.NotNil(t, log1) - - log2 := log.New("echo2") - c.SetLogger(log2) - testify.Equal(t, log2, c.Logger()) - - // Resetting the context returns the initial logger - c.Reset(nil, nil) - testify.Equal(t, log1, c.Logger()) + assert.NoError(t, err) + assert.Equal(t, &user{1, "Jon Snow"}, u) } func TestContext_RealIP(t *testing.T) { @@ -925,6 +857,6 @@ func TestContext_RealIP(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.RealIP()) + assert.Equal(t, tt.s, tt.c.RealIP()) } } diff --git a/echo.go b/echo.go index df5d3584..0f9b3f6f 100644 --- a/echo.go +++ b/echo.go @@ -29,7 +29,9 @@ Example: e.GET("/", hello) // Start server - e.Logger.Fatal(e.Start(":1323")) + if err := e.Start(":8080"); err != http.ErrServerClosed { + log.Fatal(err) + } } Learn more at https://echo.labstack.com @@ -37,127 +39,81 @@ Learn more at https://echo.labstack.com package echo import ( - "bytes" stdContext "context" - "crypto/tls" "errors" "fmt" "io" - "io/ioutil" - stdLog "log" - "net" + "io/fs" "net/http" "net/url" "os" + "os/signal" "path/filepath" - "reflect" - "runtime" "sync" - "time" - - "github.com/labstack/gommon/color" - "github.com/labstack/gommon/log" - "golang.org/x/crypto/acme" - "golang.org/x/crypto/acme/autocert" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) -type ( - // Echo is the top-level framework instance. - Echo struct { - common - // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get - // listener address info (on which interface/port was listener binded) without having data races. - startupMutex sync.RWMutex - StdLogger *stdLog.Logger - colorer *color.Color - premiddleware []MiddlewareFunc - middleware []MiddlewareFunc - maxParam *int - router *Router - routers map[string]*Router - notFoundHandler HandlerFunc - pool sync.Pool - Server *http.Server - TLSServer *http.Server - Listener net.Listener - TLSListener net.Listener - AutoTLSManager autocert.Manager - DisableHTTP2 bool - Debug bool - HideBanner bool - HidePort bool - HTTPErrorHandler HTTPErrorHandler - Binder Binder - JSONSerializer JSONSerializer - Validator Validator - Renderer Renderer - Logger Logger - IPExtractor IPExtractor - ListenerNetwork string - } +// Echo is the top-level framework instance. +// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics. +type Echo struct { + // premiddleware are middlewares that are run for every request before routing is done + premiddleware []MiddlewareFunc + // middleware are middlewares that are run after router found a matching route (not found and method not found are also matches) + middleware []MiddlewareFunc - // Route contains a handler and information for matching against requests. - Route struct { - Method string `json:"method"` - Path string `json:"path"` - Name string `json:"name"` - } + router Router + routers map[string]Router + routerCreator func(e *Echo) Router - // HTTPError represents an error that occurred while handling a request. - HTTPError struct { - Code int `json:"-"` - Message interface{} `json:"message"` - Internal error `json:"-"` // Stores the error returned by an external dependency - } + contextPool sync.Pool + contextPathParamAllocSize int - // MiddlewareFunc defines a function to process middleware. - MiddlewareFunc func(HandlerFunc) HandlerFunc + // NewContextFunc allows using custom context implementations, instead of default *echo.context + NewContextFunc func(pathParamAllocSize int) EditableContext + Debug bool + HTTPErrorHandler HTTPErrorHandler + Binder Binder + JSONSerializer JSONSerializer + Validator Validator + Renderer Renderer + Logger Logger + IPExtractor IPExtractor + // Filesystem is file system used by Static and File handler to access files. + // Defaults to os.DirFS(".") + Filesystem fs.FS +} - // HandlerFunc defines a function to serve HTTP requests. - HandlerFunc func(Context) error +// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. +type JSONSerializer interface { + Serialize(c Context, i interface{}, indent string) error + Deserialize(c Context, i interface{}) error +} - // HTTPErrorHandler is a centralized HTTP error handler. - HTTPErrorHandler func(error, Context) +// HTTPErrorHandler is a centralized HTTP error handler. +type HTTPErrorHandler func(c Context, err error) - // Validator is the interface that wraps the Validate function. - Validator interface { - Validate(i interface{}) error - } +// HandlerFunc defines a function to serve HTTP requests. +type HandlerFunc func(c Context) error - // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. - JSONSerializer interface { - Serialize(c Context, i interface{}, indent string) error - Deserialize(c Context, i interface{}) error - } +// MiddlewareFunc defines a function to process middleware. +type MiddlewareFunc func(next HandlerFunc) HandlerFunc - // Renderer is the interface that wraps the Render function. - Renderer interface { - Render(io.Writer, string, interface{}, Context) error - } +// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking. +type MiddlewareConfigurator interface { + ToMiddleware() (MiddlewareFunc, error) +} - // Map defines a generic map of type `map[string]interface{}`. - Map map[string]interface{} +// Validator is the interface that wraps the Validate function. +type Validator interface { + Validate(i interface{}) error +} - // Common struct for Echo & Group. - common struct{} -) +// Renderer is the interface that wraps the Render function. +type Renderer interface { + Render(io.Writer, string, interface{}, Context) error +} -// HTTP methods -// NOTE: Deprecated, please use the stdlib constants directly instead. -const ( - CONNECT = http.MethodConnect - DELETE = http.MethodDelete - GET = http.MethodGet - HEAD = http.MethodHead - OPTIONS = http.MethodOptions - PATCH = http.MethodPatch - POST = http.MethodPost - // PROPFIND = "PROPFIND" - PUT = http.MethodPut - TRACE = http.MethodTrace -) +// Map defines a generic map of type `map[string]interface{}`. +type Map map[string]interface{} // MIME types const ( @@ -241,274 +197,278 @@ const ( const ( // Version of Echo - Version = "4.6.1" - website = "https://echo.labstack.com" - // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo - banner = ` - ____ __ - / __/___/ / ___ - / _// __/ _ \/ _ \ -/___/\__/_//_/\___/ %s -High performance, minimalist Go web framework -%s -____________________________________O/_______ - O\ -` + Version = "5.0.X" ) -var ( - methods = [...]string{ - http.MethodConnect, - http.MethodDelete, - http.MethodGet, - http.MethodHead, - http.MethodOptions, - http.MethodPatch, - http.MethodPost, - PROPFIND, - http.MethodPut, - http.MethodTrace, - REPORT, - } -) - -// Errors -var ( - ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) - ErrNotFound = NewHTTPError(http.StatusNotFound) - ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) - ErrForbidden = NewHTTPError(http.StatusForbidden) - ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) - ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) - ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) - ErrBadRequest = NewHTTPError(http.StatusBadRequest) - ErrBadGateway = NewHTTPError(http.StatusBadGateway) - ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) - ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) - ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) - ErrValidatorNotRegistered = errors.New("validator not registered") - ErrRendererNotRegistered = errors.New("renderer not registered") - ErrInvalidRedirectCode = errors.New("invalid redirect status code") - ErrCookieNotFound = errors.New("cookie not found") - ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") - ErrInvalidListenerNetwork = errors.New("invalid listener network") -) - -// Error handlers -var ( - NotFoundHandler = func(c Context) error { - return ErrNotFound - } - - MethodNotAllowedHandler = func(c Context) error { - return ErrMethodNotAllowed - } -) +var methods = [...]string{ + http.MethodConnect, + http.MethodDelete, + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + PROPFIND, + http.MethodPut, + http.MethodTrace, + REPORT, +} // New creates an instance of Echo. -func New() (e *Echo) { - e = &Echo{ - Server: new(http.Server), - TLSServer: new(http.Server), - AutoTLSManager: autocert.Manager{ - Prompt: autocert.AcceptTOS, +func New() *Echo { + logger := newJSONLogger(os.Stdout) + e := &Echo{ + Logger: logger, + Filesystem: os.DirFS("."), + Binder: &DefaultBinder{}, + JSONSerializer: &DefaultJSONSerializer{}, + + routers: make(map[string]Router), + routerCreator: func(ec *Echo) Router { + return NewRouter(ec, RouterConfig{}) }, - Logger: log.New("echo"), - colorer: color.New(), - maxParam: new(int), - ListenerNetwork: "tcp", } - e.Server.Handler = e - e.TLSServer.Handler = e - e.HTTPErrorHandler = e.DefaultHTTPErrorHandler - e.Binder = &DefaultBinder{} - e.JSONSerializer = &DefaultJSONSerializer{} - e.Logger.SetLevel(log.ERROR) - e.StdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0) - e.pool.New = func() interface{} { + + e.router = NewRouter(e, RouterConfig{}) + e.HTTPErrorHandler = DefaultHTTPErrorHandler(false) + e.contextPool.New = func() interface{} { + if e.NewContextFunc != nil { + return e.NewContextFunc(e.contextPathParamAllocSize) + } return e.NewContext(nil, nil) } - e.router = NewRouter(e) - e.routers = map[string]*Router{} - return + return e } // NewContext returns a Context instance. func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context { + p := make(PathParams, e.contextPathParamAllocSize) return &context{ - request: r, - response: NewResponse(w, e), - store: make(Map), - echo: e, - pvalues: make([]string, *e.maxParam), - handler: NotFoundHandler, + request: r, + response: NewResponse(w, e), + store: make(Map), + echo: e, + pathParams: &p, + matchType: RouteMatchUnknown, + route: nil, + path: "", } } // Router returns the default router. -func (e *Echo) Router() *Router { +func (e *Echo) Router() Router { return e.router } // Routers returns the map of host => router. -func (e *Echo) Routers() map[string]*Router { +func (e *Echo) Routers() map[string]Router { return e.routers } -// DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response -// with status code. +// RouterFor returns Router for given host. +func (e *Echo) RouterFor(host string) Router { + return e.routers[host] +} + +// ResetRouterCreator resets callback for creating new router instances. +// Note: current (default) router is immediately replaced with router created with creator func and vhost routers are cleared. +func (e *Echo) ResetRouterCreator(creator func(e *Echo) Router) { + e.routerCreator = creator + e.router = creator(e) + e.routers = make(map[string]Router) +} + +// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response +// with status code. `exposeError` parameter decides if returned message will contain also error message or not // -// NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). +// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately) +// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). // When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from // handler. Then the error that global error handler received will be ignored because we have already "commited" the // response and status code header has been sent to the client. -func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { - - if c.Response().Committed { - return - } - - he, ok := err.(*HTTPError) - if ok { - if he.Internal != nil { - if herr, ok := he.Internal.(*HTTPError); ok { - he = herr - } +func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { + return func(c Context, err error) { + if c.Response().Committed { + return } - } else { - he = &HTTPError{ + + he := &HTTPError{ Code: http.StatusInternalServerError, Message: http.StatusText(http.StatusInternalServerError), } - } - - // Issue #1426 - code := he.Code - message := he.Message - if m, ok := he.Message.(string); ok { - if e.Debug { - message = Map{"message": m, "error": err.Error()} - } else { - message = Map{"message": m} + if errors.As(err, &he) { + if he.Internal != nil { // max 2 levels of checks even if internal could have also internal + errors.As(he.Internal, &he) + } } - } - // Send response - if c.Request().Method == http.MethodHead { // Issue #608 - err = c.NoContent(he.Code) - } else { - err = c.JSON(code, message) - } - if err != nil { - e.Logger.Error(err) + // Issue #1426 + code := he.Code + message := he.Message + if m, ok := he.Message.(string); ok { + if exposeError { + message = Map{"message": m, "error": err.Error()} + } else { + message = Map{"message": m} + } + } + + // Send response + var cErr error + if c.Request().Method == http.MethodHead { // Issue #608 + cErr = c.NoContent(he.Code) + } else { + cErr = c.JSON(code, message) + } + if cErr != nil { + c.Echo().Logger.Error(err) // truly rare case. ala client already disconnected + } } } -// Pre adds middleware to the chain which is run before router. +// Pre adds middleware to the chain which is run before router tries to find matching route. +// Meaning middleware is executed even for 404 (not found) cases. func (e *Echo) Pre(middleware ...MiddlewareFunc) { e.premiddleware = append(e.premiddleware, middleware...) } -// Use adds middleware to the chain which is run after router. +// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed. func (e *Echo) Use(middleware ...MiddlewareFunc) { e.middleware = append(e.middleware, middleware...) } // CONNECT registers a new CONNECT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodConnect, path, h, m...) } // DELETE registers a new DELETE route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodDelete, path, h, m...) } // GET registers a new GET route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodGet, path, h, m...) } // HEAD registers a new HEAD route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodHead, path, h, m...) } // OPTIONS registers a new OPTIONS route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodOptions, path, h, m...) } // PATCH registers a new PATCH route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPatch, path, h, m...) } // POST registers a new POST route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPost, path, h, m...) } // PUT registers a new PUT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPut, path, h, m...) } // TRACE registers a new TRACE route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodTrace, path, h, m...) } -// Any registers a new route for all HTTP methods and path with matching handler -// in the router with optional route-level middleware. -func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) +// Any registers a new route for all supported HTTP methods and path with matching handler +// in the router with optional route-level middleware. Panics on error. +func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := e.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris } // Match registers a new route for multiple HTTP methods and path with matching -// handler in the router with optional route-level middleware. -func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) +// handler in the router with optional route-level middleware. Panics on error. +func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := e.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris } -// Static registers a new route with path prefix to serve static files from the -// provided root directory. -func (e *Echo) Static(prefix, root string) *Route { +// Static registers a new route with path prefix to serve static files from the provided root directory. Panics on error. +func (e *Echo) Static(prefix, root string, middleware ...MiddlewareFunc) RouteInfo { + return e.Add( + http.MethodGet, + prefix+"*", + StaticDirectoryHandler(root, false), + middleware..., + ) +} + +// StaticDirectoryHandler creates handler function to serve files from given root path +func StaticDirectoryHandler(root string, disablePathUnescaping bool) HandlerFunc { if root == "" { root = "." // For security we want to restrict to CWD. } - return e.static(prefix, root, e.GET) -} - -func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route { - h := func(c Context) error { - p, err := url.PathUnescape(c.Param("*")) - if err != nil { - return err + return func(c Context) error { + p := c.PathParam("*") + if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice + tmpPath, err := url.PathUnescape(p) + if err != nil { + return fmt.Errorf("failed to unescape path variable: %w", err) + } + p = tmpPath } name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security - fi, err := os.Stat(name) + fi, err := fs.Stat(c.Echo().Filesystem, name) if err != nil { // The access path does not exist - return NotFoundHandler(c) + return ErrNotFound } // If the request is for a directory and does not end with "/" @@ -519,56 +479,57 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl } return c.File(name) } - // Handle added routes based on trailing slash: - // /prefix => exact route "/prefix" + any route "/prefix/*" - // /prefix/ => only any route "/prefix/*" - if prefix != "" { - if prefix[len(prefix)-1] == '/' { - // Only add any route for intentional trailing slash - return get(prefix+"*", h) - } - get(prefix, h) - } - return get(prefix+"/*", h) } -func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route, - m ...MiddlewareFunc) *Route { - return get(path, func(c Context) error { +// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error. +func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c Context) error { return c.File(file) - }, m...) -} - -// File registers a new route with path to serve a static file with optional route-level middleware. -func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { - return e.file(path, file, e.GET, m...) -} - -func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - name := handlerName(handler) - router := e.findRouter(host) - router.Add(method, path, func(c Context) error { - h := applyMiddleware(handler, middleware...) - return h(c) - }) - r := &Route{ - Method: method, - Path: path, - Name: name, } - e.router.routes[method+path] = r - return r + return e.Add(http.MethodGet, path, handler, middleware...) +} + +// AddRoute registers a new Route with default host Router +func (e *Echo) AddRoute(route Routable) (RouteInfo, error) { + return e.add("", route) +} + +func (e *Echo) add(host string, route Routable) (RouteInfo, error) { + router := e.findRouter(host) + ri, err := router.Add(route) + if err != nil { + return nil, err + } + + paramsCount := len(ri.Params()) + if paramsCount > e.contextPathParamAllocSize { + e.contextPathParamAllocSize = paramsCount + } + return ri, nil } // Add registers a new route for an HTTP method and path with matching handler // in the router with optional route-level middleware. -func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - return e.add("", method, path, handler, middleware...) +func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := e.add( + "", + Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + Name: "", + }, + ) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri } // Host creates a new router group for the provided host and optional host-level middleware. func (e *Echo) Host(name string, m ...MiddlewareFunc) (g *Group) { - e.routers[name] = NewRouter(e) + e.routers[name] = e.routerCreator(e) g = &Group{host: name, echo: e} g.Use(m...) return @@ -581,326 +542,95 @@ func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) { return } -// URI generates a URI from handler. -func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string { - name := handlerName(handler) - return e.Reverse(name, params...) -} - -// URL is an alias for `URI` function. -func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { - return e.URI(h, params...) -} - -// Reverse generates an URL from route name and provided parameters. -func (e *Echo) Reverse(name string, params ...interface{}) string { - uri := new(bytes.Buffer) - ln := len(params) - n := 0 - for _, r := range e.router.routes { - if r.Name == name { - for i, l := 0, len(r.Path); i < l; i++ { - if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { - for ; i < l && r.Path[i] != '/'; i++ { - } - uri.WriteString(fmt.Sprintf("%v", params[n])) - n++ - } - if i < l { - uri.WriteByte(r.Path[i]) - } - } - break - } - } - return uri.String() -} - -// Routes returns the registered routes. -func (e *Echo) Routes() []*Route { - routes := make([]*Route, 0, len(e.router.routes)) - for _, v := range e.router.routes { - routes = append(routes, v) - } - return routes -} - // AcquireContext returns an empty `Context` instance from the pool. // You must return the context by calling `ReleaseContext()`. func (e *Echo) AcquireContext() Context { - return e.pool.Get().(Context) + return e.contextPool.Get().(Context) } // ReleaseContext returns the `Context` instance back to the pool. // You must call it after `AcquireContext()`. func (e *Echo) ReleaseContext(c Context) { - e.pool.Put(c) + e.contextPool.Put(c) } // ServeHTTP implements `http.Handler` interface, which serves HTTP requests. func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Acquire context - c := e.pool.Get().(*context) + var c EditableContext + if e.NewContextFunc != nil { + // NOTE: we are not casting always context to EditableContext because casting to interface vs pointer to struct is + // "significantly" slower. Echo Context interface has way to many methods so these checks take time. + // These are benchmarks with 1.16: + // * interface extending another interface = +24% slower (3233 ns/op vs 2605 ns/op) + // * interface (not extending any, just methods)= +14% slower + // + // Quote from https://stackoverflow.com/a/31584377 + // "it's even worse with interface-to-interface assertion, because you also need to ensure that the type implements the interface." + // + // So most of the time we do not need custom context type and simple IF + cast to pointer to struct is fast enough. + c = e.contextPool.Get().(EditableContext) + } else { + c = e.contextPool.Get().(*context) + } c.Reset(r, w) - h := NotFoundHandler + var h func(c Context) error if e.premiddleware == nil { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) - h = c.Handler() - h = applyMiddleware(h, e.middleware...) + params := c.RawPathParams() + match := e.findRouter(r.Host).Match(r, params) + + c.SetRawPathParams(params) + c.SetPath(match.RoutePath) + c.SetRouteInfo(match.RouteInfo) + c.SetRouteMatchType(match.Type) + h = applyMiddleware(match.Handler, e.middleware...) } else { - h = func(c Context) error { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) - h := c.Handler() - h = applyMiddleware(h, e.middleware...) - return h(c) + h = func(cc Context) error { + params := c.RawPathParams() + match := e.findRouter(r.Host).Match(r, params) + // NOTE: router will be executed after pre middlewares have been run. We assume here that context we receive after pre middlewares + // is the same we began with. If not - this is use-case we do not support and is probably abuse from developer. + c.SetRawPathParams(params) + c.SetPath(match.RoutePath) + c.SetRouteInfo(match.RouteInfo) + c.SetRouteMatchType(match.Type) + h1 := applyMiddleware(match.Handler, e.middleware...) + return h1(cc) } h = applyMiddleware(h, e.premiddleware...) } // Execute chain if err := h(c); err != nil { - e.HTTPErrorHandler(err, c) + e.HTTPErrorHandler(c, err) } - // Release context - e.pool.Put(c) + e.contextPool.Put(c) } -// Start starts an HTTP server. +// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by +// sending os.Interrupt signal with `ctrl+c`. +// +// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration +// options. +// +// In need of customization use: +// sc := echo.StartConfig{Address: ":8080"} +// if err := sc.Start(e); err != http.ErrServerClosed { +// log.Fatal(err) +// } +// // or standard library `http.Server` +// s := http.Server{Addr: ":8080", Handler: e} +// if err := s.ListenAndServe(); err != http.ErrServerClosed { +// log.Fatal(err) +// } func (e *Echo) Start(address string) error { - e.startupMutex.Lock() - e.Server.Addr = address - if err := e.configureServer(e.Server); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return e.Server.Serve(e.Listener) -} + sc := StartConfig{Address: address} + ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt) // start shutdown process on ctrl+c + defer cancel() + sc.GracefulContext = ctx -// StartTLS starts an HTTPS server. -// If `certFile` or `keyFile` is `string` the values are treated as file paths. -// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. -func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) { - e.startupMutex.Lock() - var cert []byte - if cert, err = filepathOrContent(certFile); err != nil { - e.startupMutex.Unlock() - return - } - - var key []byte - if key, err = filepathOrContent(keyFile); err != nil { - e.startupMutex.Unlock() - return - } - - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.Certificates = make([]tls.Certificate, 1) - if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { - e.startupMutex.Unlock() - return - } - - e.configureTLS(address) - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) -} - -func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { - switch v := fileOrContent.(type) { - case string: - return ioutil.ReadFile(v) - case []byte: - return v, nil - default: - return nil, ErrInvalidCertOrKeyType - } -} - -// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. -func (e *Echo) StartAutoTLS(address string) error { - e.startupMutex.Lock() - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto) - - e.configureTLS(address) - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) -} - -func (e *Echo) configureTLS(address string) { - s := e.TLSServer - s.Addr = address - if !e.DisableHTTP2 { - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2") - } -} - -// StartServer starts a custom http server. -func (e *Echo) StartServer(s *http.Server) (err error) { - e.startupMutex.Lock() - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - if s.TLSConfig != nil { - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) - } - e.startupMutex.Unlock() - return s.Serve(e.Listener) -} - -func (e *Echo) configureServer(s *http.Server) (err error) { - // Setup - e.colorer.SetOutput(e.Logger.Output()) - s.ErrorLog = e.StdLogger - s.Handler = e - if e.Debug { - e.Logger.SetLevel(log.DEBUG) - } - - if !e.HideBanner { - e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) - } - - if s.TLSConfig == nil { - if e.Listener == nil { - e.Listener, err = newListener(s.Addr, e.ListenerNetwork) - if err != nil { - return err - } - } - if !e.HidePort { - e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) - } - return nil - } - if e.TLSListener == nil { - l, err := newListener(s.Addr, e.ListenerNetwork) - if err != nil { - return err - } - e.TLSListener = tls.NewListener(l, s.TLSConfig) - } - if !e.HidePort { - e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr())) - } - return nil -} - -// ListenerAddr returns net.Addr for Listener -func (e *Echo) ListenerAddr() net.Addr { - e.startupMutex.RLock() - defer e.startupMutex.RUnlock() - if e.Listener == nil { - return nil - } - return e.Listener.Addr() -} - -// TLSListenerAddr returns net.Addr for TLSListener -func (e *Echo) TLSListenerAddr() net.Addr { - e.startupMutex.RLock() - defer e.startupMutex.RUnlock() - if e.TLSListener == nil { - return nil - } - return e.TLSListener.Addr() -} - -// StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). -func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { - e.startupMutex.Lock() - // Setup - s := e.Server - s.Addr = address - e.colorer.SetOutput(e.Logger.Output()) - s.ErrorLog = e.StdLogger - s.Handler = h2c.NewHandler(e, h2s) - if e.Debug { - e.Logger.SetLevel(log.DEBUG) - } - - if !e.HideBanner { - e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) - } - - if e.Listener == nil { - e.Listener, err = newListener(s.Addr, e.ListenerNetwork) - if err != nil { - e.startupMutex.Unlock() - return err - } - } - if !e.HidePort { - e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) - } - e.startupMutex.Unlock() - return s.Serve(e.Listener) -} - -// Close immediately stops the server. -// It internally calls `http.Server#Close()`. -func (e *Echo) Close() error { - e.startupMutex.Lock() - defer e.startupMutex.Unlock() - if err := e.TLSServer.Close(); err != nil { - return err - } - return e.Server.Close() -} - -// Shutdown stops the server gracefully. -// It internally calls `http.Server#Shutdown()`. -func (e *Echo) Shutdown(ctx stdContext.Context) error { - e.startupMutex.Lock() - defer e.startupMutex.Unlock() - if err := e.TLSServer.Shutdown(ctx); err != nil { - return err - } - return e.Server.Shutdown(ctx) -} - -// NewHTTPError creates a new HTTPError instance. -func NewHTTPError(code int, message ...interface{}) *HTTPError { - he := &HTTPError{Code: code, Message: http.StatusText(code)} - if len(message) > 0 { - he.Message = message[0] - } - return he -} - -// Error makes it compatible with `error` interface. -func (he *HTTPError) Error() string { - if he.Internal == nil { - return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) - } - return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) -} - -// SetInternal sets error to HTTPError.Internal -func (he *HTTPError) SetInternal(err error) *HTTPError { - he.Internal = err - return he -} - -// Unwrap satisfies the Go 1.13 error wrapper interface. -func (he *HTTPError) Unwrap() error { - return he.Internal + return sc.Start(e) } // WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. @@ -925,19 +655,7 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { } } -// GetPath returns RawPath, if it's empty returns Path from URL -// Difference between RawPath and Path is: -// * Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/. -// * RawPath is an optional field which only gets set if the default encoding is different from Path. -func GetPath(r *http.Request) string { - path := r.URL.RawPath - if path == "" { - path = r.URL.Path - } - return path -} - -func (e *Echo) findRouter(host string) *Router { +func (e *Echo) findRouter(host string) Router { if len(e.routers) > 0 { if r, ok := e.routers[host]; ok { return r @@ -946,50 +664,6 @@ func (e *Echo) findRouter(host string) *Router { return e.router } -func handlerName(h HandlerFunc) string { - t := reflect.ValueOf(h).Type() - if t.Kind() == reflect.Func { - return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() - } - return t.String() -} - -// // PathUnescape is wraps `url.PathUnescape` -// func PathUnescape(s string) (string, error) { -// return url.PathUnescape(s) -// } - -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - if c, err = ln.AcceptTCP(); err != nil { - return - } else if err = c.(*net.TCPConn).SetKeepAlive(true); err != nil { - return - } - // Ignore error from setting the KeepAlivePeriod as some systems, such as - // OpenBSD, do not support setting TCP_USER_TIMEOUT on IPPROTO_TCP - _ = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute) - return -} - -func newListener(address, network string) (*tcpKeepAliveListener, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, ErrInvalidListenerNetwork - } - l, err := net.Listen(network, address) - if err != nil { - return nil, err - } - return &tcpKeepAliveListener{l.(*net.TCPListener)}, nil -} - func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { for i := len(middleware) - 1; i >= 0; i-- { h = middleware[i](h) diff --git a/echo_test.go b/echo_test.go index f2891586..0fbd1b55 100644 --- a/echo_test.go +++ b/echo_test.go @@ -3,31 +3,23 @@ package echo import ( "bytes" stdContext "context" - "crypto/tls" "errors" "fmt" - "io/ioutil" "net" "net/http" "net/http/httptest" "net/url" - "os" - "reflect" "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/net/http2" ) -type ( - user struct { - ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` - Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` - } -) +type user struct { + ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` + Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` +} const ( userJSON = `{"id":1,"name":"Jon Snow"}` @@ -61,8 +53,8 @@ func TestEcho(t *testing.T) { // Router assert.NotNil(t, e.Router()) - // DefaultHTTPErrorHandler - e.DefaultHTTPErrorHandler(errors.New("error"), c) + e.HTTPErrorHandler(c, errors.New("error")) + assert.Equal(t, http.StatusInternalServerError, rec.Code) } @@ -84,6 +76,22 @@ func TestEchoStatic(t *testing.T) { expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, + { + name: "without prefix", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/walle.png", // `` + `*` creates route `/test*` witch matches `walle.png` + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "without prefix does not serve dir index or redirect", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/", // `/` + `*` creates route `/*` + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, { name: "No file", givenPrefix: "/images", @@ -104,7 +112,7 @@ func TestEchoStatic(t *testing.T) { name: "Directory Redirect", givenPrefix: "/", givenRoot: "_fixture", - whenURL: "/folder", + whenURL: "/folder", // `/folder` is subdirectory inside `/_fixture` expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", expectBodyStartsWith: "", @@ -211,42 +219,36 @@ func TestEchoStatic(t *testing.T) { } func TestEchoStaticRedirectIndex(t *testing.T) { - assert := assert.New(t) e := New() // HandlerFunc - e.Static("/static", "_fixture") + ri := e.Static("/static", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/static*", ri.Path()) + assert.Equal(t, "GET:/static*", ri.Name()) + assert.Equal(t, []string{"*"}, ri.Params()) - errCh := make(chan error) - - go func() { - errCh <- e.Start("127.0.0.1:1323") - }() - - time.Sleep(200 * time.Millisecond) - - if resp, err := http.Get("http://127.0.0.1:1323/static"); err == nil { - defer resp.Body.Close() - assert.Equal(http.StatusOK, resp.StatusCode) - - if body, err := ioutil.ReadAll(resp.Body); err == nil { - assert.Equal(true, strings.HasPrefix(string(body), "")) - } else { - assert.Fail(err.Error()) - } - - } else { - assert.Fail(err.Error()) + ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) + defer cancel() + addr, err := startOnRandomPort(ctx, e) + if err != nil { + assert.Fail(t, err.Error()) } - if err := e.Close(); err != nil { - t.Fatal(err) - } + code, body, err := doGet(fmt.Sprintf("http://%v/static", addr)) + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(body, "")) + assert.Equal(t, http.StatusOK, code) } func TestEchoFile(t *testing.T) { e := New() - e.File("/walle", "_fixture/images/walle.png") + ri := e.File("/walle", "_fixture/images/walle.png") + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/walle", ri.Path()) + assert.Equal(t, "GET:/walle", ri.Name()) + assert.Nil(t, ri.Params()) + c, b := request(http.MethodGet, "/walle", e) assert.Equal(t, http.StatusOK, c) assert.NotEmpty(t, b) @@ -258,7 +260,8 @@ func TestEchoMiddleware(t *testing.T) { e.Pre(func(next HandlerFunc) HandlerFunc { return func(c Context) error { - assert.Empty(t, c.Path()) + // before route match is found RouteInfo does not exist + assert.Equal(t, nil, c.RouteInfo()) buf.WriteString("-1") return next(c) } @@ -303,7 +306,7 @@ func TestEchoMiddlewareError(t *testing.T) { return errors.New("error") } }) - e.GET("/", NotFoundHandler) + e.GET("/", notFoundHandler) c, _ := request(http.MethodGet, "/", e) assert.Equal(t, http.StatusInternalServerError, c) } @@ -358,128 +361,202 @@ func TestEchoWrapMiddleware(t *testing.T) { } } +func TestEchoGet_routeInfoIsImmutable(t *testing.T) { + e := New() + ri := e.GET("/test", handlerFunc) + assert.Equal(t, "GET:/test", ri.Name()) + + riFromRouter, err := e.Router().Routes().FindByMethodPath(http.MethodGet, "/test") + assert.NoError(t, err) + assert.Equal(t, "GET:/test", riFromRouter.Name()) + + rInfo := ri.(routeInfo) + rInfo.name = "changed" // this change should not change other returned values + + assert.Equal(t, "GET:/test", ri.Name()) + + riFromRouter, err = e.Router().Routes().FindByMethodPath(http.MethodGet, "/test") + assert.NoError(t, err) + assert.Equal(t, "GET:/test", riFromRouter.Name()) +} + func TestEchoConnect(t *testing.T) { e := New() - testMethod(t, http.MethodConnect, "/", e) + + ri := e.CONNECT("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodConnect+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodConnect, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoDelete(t *testing.T) { e := New() - testMethod(t, http.MethodDelete, "/", e) + + ri := e.DELETE("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodDelete+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodDelete, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoGet(t *testing.T) { e := New() - testMethod(t, http.MethodGet, "/", e) + + ri := e.GET("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodGet+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodGet, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoHead(t *testing.T) { e := New() - testMethod(t, http.MethodHead, "/", e) + + ri := e.HEAD("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodHead+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodHead, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoOptions(t *testing.T) { e := New() - testMethod(t, http.MethodOptions, "/", e) + + ri := e.OPTIONS("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodOptions+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodOptions, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPatch(t *testing.T) { e := New() - testMethod(t, http.MethodPatch, "/", e) + + ri := e.PATCH("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPatch+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPatch, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPost(t *testing.T) { e := New() - testMethod(t, http.MethodPost, "/", e) + + ri := e.POST("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPost+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPost, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPut(t *testing.T) { e := New() - testMethod(t, http.MethodPut, "/", e) + + ri := e.PUT("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodPut+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPut, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoTrace(t *testing.T) { e := New() - testMethod(t, http.MethodTrace, "/", e) + + ri := e.TRACE("/", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodTrace, ri.Method()) + assert.Equal(t, "/", ri.Path()) + assert.Equal(t, http.MethodTrace+":/", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodTrace, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoAny(t *testing.T) { // JFC e := New() - e.Any("/", func(c Context) error { + ris := e.Any("/", func(c Context) error { return c.String(http.StatusOK, "Any") }) + assert.Len(t, ris, 11) } func TestEchoMatch(t *testing.T) { // JFC e := New() - e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { + ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { return c.String(http.StatusOK, "Match") }) + assert.Len(t, ris, 2) } -func TestEchoURL(t *testing.T) { - e := New() - static := func(Context) error { return nil } - getUser := func(Context) error { return nil } - getAny := func(Context) error { return nil } - getFile := func(Context) error { return nil } - - e.GET("/static/file", static) - e.GET("/users/:id", getUser) - e.GET("/documents/*", getAny) - g := e.Group("/group") - g.GET("/users/:uid/files/:fid", getFile) - - assert := assert.New(t) - - assert.Equal("/static/file", e.URL(static)) - assert.Equal("/users/:id", e.URL(getUser)) - assert.Equal("/users/1", e.URL(getUser, "1")) - assert.Equal("/users/1", e.URL(getUser, "1")) - assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt")) - assert.Equal("/documents/*", e.URL(getAny)) - assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) - assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) -} - -func TestEchoRoutes(t *testing.T) { - e := New() - routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, - } - for _, r := range routes { - e.Add(r.Method, r.Path, func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - } - - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { - found := false - for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { - found = true - break - } - } - if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) - } - } - } -} - -func TestEchoRoutesHandleHostsProperly(t *testing.T) { +func TestEcho_Routers_HandleHostsProperly(t *testing.T) { e := New() h := e.Host("route.com") routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, + {Method: http.MethodGet, Path: "/users/:user/events"}, + {Method: http.MethodGet, Path: "/users/:user/events/public"}, + {Method: http.MethodPost, Path: "/repos/:owner/:repo/git/refs"}, + {Method: http.MethodPost, Path: "/repos/:owner/:repo/git/tags"}, } for _, r := range routes { h.Add(r.Method, r.Path, func(c Context) error { @@ -487,17 +564,22 @@ func TestEchoRoutesHandleHostsProperly(t *testing.T) { }) } - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { + routers := e.Routers() + + routeCom, ok := routers["route.com"] + assert.True(t, ok) + + if assert.Equal(t, len(routes), len(routeCom.Routes())) { + for _, r := range routeCom.Routes() { found := false for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { + if r.Method() == rr.Method && r.Path() == rr.Path { found = true break } } if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) + t.Errorf("Route %s %s not found", r.Method(), r.Path()) } } } @@ -509,7 +591,7 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { return c.String(http.StatusOK, "/with/slash") }) e.GET("/:id", func(c Context) error { - return c.String(http.StatusOK, c.Param("id")) + return c.String(http.StatusOK, c.PathParam("id")) }) var testCases = []struct { @@ -546,8 +628,6 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { } func TestEchoHost(t *testing.T) { - assert := assert.New(t) - okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) } teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) } acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) } @@ -642,8 +722,8 @@ func TestEchoHost(t *testing.T) { e.ServeHTTP(rec, req) - assert.Equal(tc.expectStatus, rec.Code) - assert.Equal(tc.expectBody, rec.Body.String()) + assert.Equal(t, tc.expectStatus, rec.Code) + assert.Equal(t, tc.expectBody, rec.Body.String()) }) } } @@ -732,339 +812,27 @@ func TestEchoContext(t *testing.T) { e.ReleaseContext(c) } -func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error { - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) - defer cancel() - - ticker := time.NewTicker(5 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - var addr net.Addr - if isTLS { - addr = e.TLSListenerAddr() - } else { - addr = e.ListenerAddr() - } - if addr != nil && strings.Contains(addr.String(), ":") { - return nil // was started - } - case err := <-errChan: - if err == http.ErrServerClosed { - return nil - } - return err - } - } -} - -func TestEchoStart(t *testing.T) { - e := New() - errChan := make(chan error) - - go func() { - err := e.Start(":0") - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, false) - assert.NoError(t, err) - - assert.NoError(t, e.Close()) -} - -func TestEcho_StartTLS(t *testing.T) { - var testCases = []struct { - name string - addr string - certFile string - keyFile string - expectError string - }{ - { - name: "ok", - addr: ":0", - }, - { - name: "nok, invalid certFile", - addr: ":0", - certFile: "not existing", - expectError: "open not existing: no such file or directory", - }, - { - name: "nok, invalid keyFile", - addr: ":0", - keyFile: "not existing", - expectError: "open not existing: no such file or directory", - }, - { - name: "nok, failed to create cert out of certFile and keyFile", - addr: ":0", - keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key - expectError: "tls: found a certificate rather than a key in the PEM for the private key", - }, - { - name: "nok, invalid tls address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - errChan := make(chan error) - - go func() { - certFile := "_fixture/certs/cert.pem" - if tc.certFile != "" { - certFile = tc.certFile - } - keyFile := "_fixture/certs/key.pem" - if tc.keyFile != "" { - keyFile = tc.keyFile - } - - err := e.StartTLS(tc.addr, certFile, keyFile) - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, true) - if tc.expectError != "" { - if _, ok := err.(*os.PathError); ok { - assert.Error(t, err) // error messages for unix and windows are different. so test only error type here - } else { - assert.EqualError(t, err, tc.expectError) - } - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func TestEchoStartTLSAndStart(t *testing.T) { - // We test if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server +func TestEcho_Start(t *testing.T) { e := New() e.GET("/", func(c Context) error { - return c.String(http.StatusOK, "OK") + return c.String(http.StatusTeapot, "OK") }) - - errTLSChan := make(chan error) + rndPort, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + defer rndPort.Close() + errChan := make(chan error, 1) go func() { - certFile := "_fixture/certs/cert.pem" - keyFile := "_fixture/certs/key.pem" - err := e.StartTLS("localhost:", certFile, keyFile) - if err != nil { - errTLSChan <- err - } + errChan <- e.Start(rndPort.Addr().String()) }() - err := waitForServerStart(e, errTLSChan, true) - assert.NoError(t, err) - defer func() { - if err := e.Shutdown(stdContext.Background()); err != nil { - t.Error(err) - } - }() - - // check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true) - client := &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }} - res, err := client.Get("https://" + e.TLSListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - errChan := make(chan error) - go func() { - err := e.Start("localhost:") - if err != nil { - errChan <- err - } - }() - err = waitForServerStart(e, errChan, false) - assert.NoError(t, err) - - // now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS - res, err = http.Get("http://" + e.ListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - // see if HTTPS works after HTTP listener is also added - res, err = client.Get("https://" + e.TLSListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) -} - -func TestEchoStartTLSByteString(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - - testCases := []struct { - cert interface{} - key interface{} - expectedErr error - name string - }{ - { - cert: "_fixture/certs/cert.pem", - key: "_fixture/certs/key.pem", - expectedErr: nil, - name: `ValidCertAndKeyFilePath`, - }, - { - cert: cert, - key: key, - expectedErr: nil, - name: `ValidCertAndKeyByteString`, - }, - { - cert: cert, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidKeyType`, - }, - { - cert: 0, - key: key, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertType`, - }, - { - cert: 0, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertAndKeyTypes`, - }, + select { + case <-time.After(250 * time.Millisecond): + t.Fatal("start did not error out") + case err := <-errChan: + assert.Contains(t, err.Error(), "bind: address already in use") } - - for _, test := range testCases { - test := test - t.Run(test.name, func(t *testing.T) { - e := New() - e.HideBanner = true - - errChan := make(chan error, 0) - - go func() { - errChan <- e.StartTLS(":0", test.cert, test.key) - }() - - err := waitForServerStart(e, errChan, true) - if test.expectedErr != nil { - assert.EqualError(t, err, test.expectedErr.Error()) - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func TestEcho_StartAutoTLS(t *testing.T) { - var testCases = []struct { - name string - addr string - expectError string - }{ - { - name: "ok", - addr: ":0", - }, - { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - errChan := make(chan error, 0) - - go func() { - errChan <- e.StartAutoTLS(tc.addr) - }() - - err := waitForServerStart(e, errChan, true) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func TestEcho_StartH2CServer(t *testing.T) { - var testCases = []struct { - name string - addr string - expectError string - }{ - { - name: "ok", - addr: ":0", - }, - { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Debug = true - h2s := &http2.Server{} - - errChan := make(chan error) - go func() { - err := e.StartH2CServer(tc.addr, h2s) - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, false) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) - } -} - -func testMethod(t *testing.T, method, path string, e *Echo) { - p := reflect.ValueOf(path) - h := reflect.ValueOf(func(c Context) error { - return c.String(http.StatusOK, method) - }) - i := interface{}(e) - reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h}) - _, body := request(method, path, e) - assert.Equal(t, method, body) } func request(method, path string, e *Echo) (int, string) { @@ -1074,364 +842,131 @@ func request(method, path string, e *Echo) (int, string) { return rec.Code, rec.Body.String() } -func TestHTTPError(t *testing.T) { - t.Run("non-internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - - assert.Equal(t, "code=400, message=map[code:12]", err.Error()) - }) - t.Run("internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err.SetInternal(errors.New("internal error")) - assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) - }) -} - -func TestHTTPError_Unwrap(t *testing.T) { - t.Run("non-internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - - assert.Nil(t, errors.Unwrap(err)) - }) - t.Run("internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err.SetInternal(errors.New("internal error")) - assert.Equal(t, "internal error", errors.Unwrap(err).Error()) - }) -} - func TestDefaultHTTPErrorHandler(t *testing.T) { - e := New() - e.Debug = true - e.Any("/plain", func(c Context) error { - return errors.New("An error occurred") - }) - e.Any("/badrequest", func(c Context) error { - return NewHTTPError(http.StatusBadRequest, "Invalid request") - }) - e.Any("/servererror", func(c Context) error { - return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{ - "code": 33, - "message": "Something bad happened", - "error": "stackinfo", - }) - }) - e.Any("/early-return", func(c Context) error { - c.String(http.StatusOK, "OK") - return errors.New("ERROR") - }) - e.GET("/internal-error", func(c Context) error { - err := errors.New("internal error message body") - return NewHTTPError(http.StatusBadRequest).SetInternal(err) - }) - - // With Debug=true plain response contains error message - c, b := request(http.MethodGet, "/plain", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b) - // and special handling for HTTPError - c, b = request(http.MethodGet, "/badrequest", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b) - // complex errors are serialized to pretty JSON - c, b = request(http.MethodGet, "/servererror", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b) - // if the body is already set HTTPErrorHandler should not add anything to response body - c, b = request(http.MethodGet, "/early-return", e) - assert.Equal(t, http.StatusOK, c) - assert.Equal(t, "OK", b) - // internal error should be reflected in the message - c, b = request(http.MethodGet, "/internal-error", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n", b) - - e.Debug = false - // With Debug=false the error response is shortened - c, b = request(http.MethodGet, "/plain", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b) - c, b = request(http.MethodGet, "/badrequest", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b) - // No difference for error response with non plain string errors - c, b = request(http.MethodGet, "/servererror", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b) -} - -func TestEchoClose(t *testing.T) { - e := New() - errCh := make(chan error) - - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if err := e.Close(); err != nil { - t.Fatal(err) - } - - assert.NoError(t, e.Close()) - - err = <-errCh - assert.Equal(t, err.Error(), "http: Server closed") -} - -func TestEchoShutdown(t *testing.T) { - e := New() - errCh := make(chan error) - - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if err := e.Close(); err != nil { - t.Fatal(err) - } - - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 10*time.Second) - defer cancel() - assert.NoError(t, e.Shutdown(ctx)) - - err = <-errCh - assert.Equal(t, err.Error(), "http: Server closed") -} - -var listenerNetworkTests = []struct { - test string - network string - address string -}{ - {"tcp ipv4 address", "tcp", "127.0.0.1:1323"}, - {"tcp ipv6 address", "tcp", "[::1]:1323"}, - {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"}, - {"tcp6 ipv6 address", "tcp6", "[::1]:1323"}, -} - -func supportsIPv6() bool { - addrs, _ := net.InterfaceAddrs() - for _, addr := range addrs { - // Check if any interface has local IPv6 assigned - if strings.Contains(addr.String(), "::1") { - return true - } - } - return false -} - -func TestEchoListenerNetwork(t *testing.T) { - hasIPv6 := supportsIPv6() - for _, tt := range listenerNetworkTests { - if !hasIPv6 && strings.Contains(tt.address, "::") { - t.Skip("Skipping testing IPv6 for " + tt.address + ", not available") - continue - } - t.Run(tt.test, func(t *testing.T) { - e := New() - e.ListenerNetwork = tt.network - - // HandlerFunc - e.GET("/ok", func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - - errCh := make(chan error) - - go func() { - errCh <- e.Start(tt.address) - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - - if body, err := ioutil.ReadAll(resp.Body); err == nil { - assert.Equal(t, "OK", string(body)) - } else { - assert.Fail(t, err.Error()) - } - - } else { - assert.Fail(t, err.Error()) - } - - if err := e.Close(); err != nil { - t.Fatal(err) - } - }) - } -} - -func TestEchoListenerNetworkInvalid(t *testing.T) { - e := New() - e.ListenerNetwork = "unix" - - // HandlerFunc - e.GET("/ok", func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - - assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) -} - -func TestEchoReverse(t *testing.T) { - assert := assert.New(t) - - e := New() - dummyHandler := func(Context) error { return nil } - - e.GET("/static", dummyHandler).Name = "/static" - e.GET("/static/*", dummyHandler).Name = "/static/*" - e.GET("/params/:foo", dummyHandler).Name = "/params/:foo" - e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" - e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" - - assert.Equal("/static", e.Reverse("/static")) - assert.Equal("/static", e.Reverse("/static", "missing param")) - assert.Equal("/static/*", e.Reverse("/static/*")) - assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) - - assert.Equal("/params/:foo", e.Reverse("/params/:foo")) - assert.Equal("/params/one", e.Reverse("/params/:foo", "one")) - assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux")) - assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one")) - assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) - assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) -} - -func TestEchoReverseHandleHostProperly(t *testing.T) { - assert := assert.New(t) - - dummyHandler := func(Context) error { return nil } - - e := New() - h := e.Host("the_host") - h.GET("/static", dummyHandler).Name = "/static" - h.GET("/static/*", dummyHandler).Name = "/static/*" - - assert.Equal("/static", e.Reverse("/static")) - assert.Equal("/static", e.Reverse("/static", "missing param")) - assert.Equal("/static/*", e.Reverse("/static/*")) - assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) -} - -func TestEcho_ListenerAddr(t *testing.T) { - e := New() - - addr := e.ListenerAddr() - assert.Nil(t, addr) - - errCh := make(chan error) - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) -} - -func TestEcho_TLSListenerAddr(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - - e := New() - - addr := e.TLSListenerAddr() - assert.Nil(t, addr) - - errCh := make(chan error) - go func() { - errCh <- e.StartTLS(":0", cert, key) - }() - - err = waitForServerStart(e, errCh, true) - assert.NoError(t, err) -} - -func TestEcho_StartServer(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - certs, err := tls.X509KeyPair(cert, key) - require.NoError(t, err) - var testCases = []struct { - name string - addr string - TLSConfig *tls.Config - expectError string + name string + givenExposeError bool + givenLoggerFunc bool + whenMethod string + whenError error + expectBody string + expectStatus int + expectLogged string }{ { - name: "ok", - addr: ":0", + name: "ok, expose error = true, HTTPError", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error"), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"code=418, message=my_error","message":"my_error"}` + "\n", }, { - name: "ok, start with TLS", - addr: ":0", - TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}, + name: "ok, expose error = true, HTTPError + internal error", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(errors.New("internal_error")), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"code=418, message=my_error, internal=internal_error","message":"my_error"}` + "\n", }, { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", + name: "ok, expose error = true, HTTPError + internal HTTPError", + givenExposeError: true, + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(NewHTTPError(http.StatusTooEarly, "early_error")), + expectStatus: http.StatusTooEarly, + expectBody: `{"error":"code=418, message=my_error, internal=code=425, message=early_error","message":"early_error"}` + "\n", }, { - name: "nok, invalid tls address", - addr: "nope", - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - expectError: "listen tcp: address nope: missing port in address", + name: "ok, expose error = false, HTTPError", + whenError: NewHTTPError(http.StatusTeapot, "my_error"), + expectStatus: http.StatusTeapot, + expectBody: `{"message":"my_error"}` + "\n", + }, + { + name: "ok, expose error = false, HTTPError + internal HTTPError", + whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(NewHTTPError(http.StatusTooEarly, "early_error")), + expectStatus: http.StatusTooEarly, + expectBody: `{"message":"early_error"}` + "\n", + }, + { + name: "ok, expose error = true, Error", + givenExposeError: true, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n", + }, + { + name: "ok, expose error = false, Error", + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"message":"Internal Server Error"}` + "\n", + }, + { + name: "ok, http.HEAD, expose error = true, Error", + givenExposeError: true, + whenMethod: http.MethodHead, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: ``, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) e := New() - e.Debug = true + e.Logger = &jsonLogger{writer: buf} + e.Any("/path", func(c Context) error { + return tc.whenError + }) - server := new(http.Server) - server.Addr = tc.addr - if tc.TLSConfig != nil { - server.TLSConfig = tc.TLSConfig + e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError) + + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod } + c, b := request(method, "/path", e) - errCh := make(chan error) - go func() { - errCh <- e.StartServer(server) - }() - - err := waitForServerStart(e, errCh, tc.TLSConfig != nil) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - assert.NoError(t, e.Close()) + assert.Equal(t, tc.expectStatus, c) + assert.Equal(t, tc.expectBody, b) + assert.Equal(t, tc.expectLogged, buf.String()) }) } } -func benchmarkEchoRoutes(b *testing.B, routes []*Route) { +type myCustomContext struct { + context +} + +func (c *myCustomContext) QueryParam(name string) string { + return "prefix_" + name +} + +func TestEcho_customContext(t *testing.T) { + e := New() + e.NewContextFunc = func(pathParamAllocSize int) EditableContext { + p := make(PathParams, pathParamAllocSize) + return &myCustomContext{ + context{ + request: nil, + response: NewResponse(nil, e), + store: make(Map), + echo: e, + pathParams: &p, + route: nil, + }, + } + } + + e.GET("/info/:id/:file", func(c Context) error { + return c.String(http.StatusTeapot, c.QueryParam("param")) + }) + + status, body := request(http.MethodGet, "/info/1/a.csv", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "prefix_param", body) +} + +func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { e := New() req := httptest.NewRequest("GET", "/", nil) u := req.URL diff --git a/go.mod b/go.mod index 60a64317..115d56a4 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,13 @@ module github.com/labstack/echo/v4 -go 1.15 +go 1.16 require ( - github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/labstack/gommon v0.3.0 - github.com/mattn/go-colorable v0.1.8 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect - github.com/stretchr/testify v1.4.0 + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang-jwt/jwt/v4 v4.0.0 + github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 - golang.org/x/net v0.0.0-20210913180222-943fd674d43e - golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect - golang.org/x/text v0.3.7 // indirect + golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index 9dcac7c5..54f4aa9b 100644 --- a/go.sum +++ b/go.sum @@ -1,51 +1,29 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= -github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= -github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= -github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o= +github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210913180222-943fd674d43e h1:+b/22bPvDYt4NPDcy4xAGCmON713ONAWFeY3Z7I3tR8= -golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 h1:xrCZDmdtoloIiooiA9q0OQb9r8HejIHYoHGhGCe1pGg= -golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/group.go b/group.go index 426bef9e..d4416497 100644 --- a/group.go +++ b/group.go @@ -4,95 +4,117 @@ import ( "net/http" ) -type ( - // Group is a set of sub-routes for a specified route. It can be used for inner - // routes that share a common middleware or functionality that should be separate - // from the parent echo instance while still inheriting from it. - Group struct { - common - host string - prefix string - middleware []MiddlewareFunc - echo *Echo - } -) - -// Use implements `Echo#Use()` for sub-routes within the Group. -func (g *Group) Use(middleware ...MiddlewareFunc) { - g.middleware = append(g.middleware, middleware...) - if len(g.middleware) == 0 { - return - } - // Allow all requests to reach the group as they might get dropped if router - // doesn't find a match, making none of the group middleware process. - g.Any("", NotFoundHandler) - g.Any("/*", NotFoundHandler) +// Group is a set of sub-routes for a specified route. It can be used for inner +// routes that share a common middleware or functionality that should be separate +// from the parent echo instance while still inheriting from it. +type Group struct { + host string + prefix string + middleware []MiddlewareFunc + echo *Echo } -// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. -func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// Use implements `Echo#Use()` for sub-routes within the Group. +// Group middlewares are not executed on request when there is no matching route found. +func (g *Group) Use(middleware ...MiddlewareFunc) { + g.middleware = append(g.middleware, middleware...) +} + +// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error. +func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodConnect, path, h, m...) } -// DELETE implements `Echo#DELETE()` for sub-routes within the Group. -func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error. +func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodDelete, path, h, m...) } -// GET implements `Echo#GET()` for sub-routes within the Group. -func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error. +func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodGet, path, h, m...) } -// HEAD implements `Echo#HEAD()` for sub-routes within the Group. -func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error. +func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodHead, path, h, m...) } -// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. -func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error. +func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodOptions, path, h, m...) } -// PATCH implements `Echo#PATCH()` for sub-routes within the Group. -func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error. +func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPatch, path, h, m...) } -// POST implements `Echo#POST()` for sub-routes within the Group. -func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error. +func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPost, path, h, m...) } -// PUT implements `Echo#PUT()` for sub-routes within the Group. -func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error. +func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPut, path, h, m...) } -// TRACE implements `Echo#TRACE()` for sub-routes within the Group. -func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error. +func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodTrace, path, h, m...) } -// Any implements `Echo#Any()` for sub-routes within the Group. -func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) +// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error. +func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := g.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris } -// Match implements `Echo#Match()` for sub-routes within the Group. -func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) +// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error. +func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := g.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris } // Group creates a new sub-group with prefix and optional sub-group-level middleware. +// Important! Group middlewares are only executed in case there was exact route match and not +// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add +// a catch-all route `/*` for the group which handler returns always 404 func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) @@ -102,23 +124,43 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { return } -// Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(prefix, root string) { - g.static(prefix, root, g.GET) +// Static implements `Echo#Static()` for sub-routes within the Group. Panics on error. +func (g *Group) Static(prefix, root string, middleware ...MiddlewareFunc) RouteInfo { + return g.Add( + http.MethodGet, + prefix+"*", + StaticDirectoryHandler(root, false), + middleware..., + ) } -// File implements `Echo#File()` for sub-routes within the Group. -func (g *Group) File(path, file string) { - g.file(path, file, g.GET) +// File implements `Echo#File()` for sub-routes within the Group. Panics on error. +func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c Context) error { + return c.File(file) + } + return g.Add(http.MethodGet, path, handler, middleware...) } -// Add implements `Echo#Add()` for sub-routes within the Group. -func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - // Combine into a new slice to avoid accidentally passing the same slice for +// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error. +func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := g.AddRoute(Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri +} + +// AddRoute registers a new Routable with Router +func (g *Group) AddRoute(route Routable) (RouteInfo, error) { + // Combine middleware into a new slice to avoid accidentally passing the same slice for // multiple routes, which would lead to later add() calls overwriting the // middleware from earlier calls. - m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) - m = append(m, g.middleware...) - m = append(m, middleware...) - return g.echo.add(g.host, method, g.prefix+path, handler, m...) + groupRoute := route.ForGroup(g.prefix, append([]MiddlewareFunc{}, g.middleware...)) + return g.echo.add(g.host, groupRoute) } diff --git a/group_test.go b/group_test.go index c51fd91e..606105fd 100644 --- a/group_test.go +++ b/group_test.go @@ -1,31 +1,68 @@ package echo import ( + "github.com/stretchr/testify/assert" "io/ioutil" "net/http" "net/http/httptest" + "strings" "testing" - - "github.com/stretchr/testify/assert" ) -// TODO: Fix me -func TestGroup(t *testing.T) { - g := New().Group("/group") - h := func(Context) error { return nil } - g.CONNECT("/", h) - g.DELETE("/", h) - g.GET("/", h) - g.HEAD("/", h) - g.OPTIONS("/", h) - g.PATCH("/", h) - g.POST("/", h) - g.PUT("/", h) - g.TRACE("/", h) - g.Any("/", h) - g.Match([]string{http.MethodGet, http.MethodPost}, "/", h) - g.Static("/static", "/tmp") - g.File("/walle", "_fixture/images//walle.png") +func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware it will not be executed when there are no routes under that group + _ = e.Group("/group", mw) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware and routes when we have no match on some route the middlewares for that + // group will not be executed + g := e.Group("/group", mw) + g.GET("/yes", handlerFunc) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_multiLevelGroup(t *testing.T) { + e := New() + + api := e.Group("/api") + users := api.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + status, body := request(http.MethodGet, "/api/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) } func TestGroupFile(t *testing.T) { @@ -92,11 +129,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { } m2 := func(next HandlerFunc) HandlerFunc { return func(c Context) error { - return c.String(http.StatusOK, c.Path()) + return c.String(http.StatusOK, c.RouteInfo().Path()) } } h := func(c Context) error { - return c.String(http.StatusOK, c.Path()) + return c.String(http.StatusOK, c.RouteInfo().Path()) } g.Use(m1) g.GET("/help", h, m2) @@ -119,3 +156,442 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { assert.Equal(t, "/*", m) } + +func TestGroup_CONNECT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.CONNECT("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodConnect, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_DELETE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.DELETE("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodDelete, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_HEAD(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.HEAD("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodHead+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodHead, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_OPTIONS(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.OPTIONS("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodOptions, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PATCH(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PATCH("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPatch, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_POST(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.POST("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPost+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPost, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PUT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PUT("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodPut+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodPut, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_TRACE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.TRACE("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodTrace, ri.Method()) + assert.Equal(t, "/users/activate", ri.Path()) + assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name()) + assert.Nil(t, ri.Params()) + + status, body := request(http.MethodTrace, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_Any(t *testing.T) { + e := New() + + users := e.Group("/users") + ris := users.Any("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + assert.Len(t, ris, 11) + + for _, m := range methods { + status, body := request(m, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_AnyWithErrors(t *testing.T) { + e := New() + + users := e.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + errs := func() (errs []error) { + defer func() { + if r := recover(); r != nil { + if tmpErr, ok := r.([]error); ok { + errs = tmpErr + return + } + panic(r) + } + }() + + users.Any("/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + return nil + }() + assert.Len(t, errs, 1) + assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") + + for _, m := range methods { + status, body := request(m, "/users/activate", e) + + expect := http.StatusTeapot + if m == http.MethodGet { + expect = http.StatusOK + } + assert.Equal(t, expect, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_Match(t *testing.T) { + e := New() + + myMethods := []string{http.MethodGet, http.MethodPost} + users := e.Group("/users") + ris := users.Match(myMethods, "/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + assert.Len(t, ris, 2) + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_MatchWithErrors(t *testing.T) { + e := New() + + users := e.Group("/users") + users.GET("/activate", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + myMethods := []string{http.MethodGet, http.MethodPost} + + errs := func() (errs []error) { + defer func() { + if r := recover(); r != nil { + if tmpErr, ok := r.([]error); ok { + errs = tmpErr + return + } + panic(r) + } + }() + + users.Match(myMethods, "/activate", func(c Context) error { + return c.String(http.StatusTeapot, "OK") + }) + return nil + }() + assert.Len(t, errs, 1) + assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + + expect := http.StatusTeapot + if m == http.MethodGet { + expect = http.StatusOK + } + assert.Equal(t, expect, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_Static(t *testing.T) { + e := New() + + g := e.Group("/books") + ri := g.Static("/download", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method()) + assert.Equal(t, "/books/download*", ri.Path()) + assert.Equal(t, "GET:/books/download*", ri.Name()) + assert.Equal(t, []string{"*"}, ri.Params()) + + req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + assert.True(t, strings.HasPrefix(body, "")) +} + +func TestGroup_StaticMultiTest(t *testing.T) { + var testCases = []struct { + name string + givenPrefix string + givenRoot string + whenURL string + expectStatus int + expectHeaderLocation string + expectBodyStartsWith string + }{ + { + name: "ok", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "ok, without prefix", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png` + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "nok, without prefix does not serve dir index", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/test/", // `/test` + `*` creates route `/test*` + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "No file", + givenPrefix: "/images", + givenRoot: "_fixture/scripts", + whenURL: "/test/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory Redirect", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory Redirect with non-root path", + givenPrefix: "/static", + givenRoot: "_fixture", + whenURL: "/test/static", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/static/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + g := e.Group("/test") + g.Static(tc.givenPrefix, tc.givenRoot) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) + } else { + _, ok := rec.Result().Header["Location"] + assert.False(t, ok) + } + }) + } +} diff --git a/httperror.go b/httperror.go new file mode 100644 index 00000000..5c217dac --- /dev/null +++ b/httperror.go @@ -0,0 +1,74 @@ +package echo + +import ( + "errors" + "fmt" + "net/http" +) + +// Errors +var ( + ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) + ErrNotFound = NewHTTPError(http.StatusNotFound) + ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) + ErrForbidden = NewHTTPError(http.StatusForbidden) + ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) + ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) + ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) + ErrBadRequest = NewHTTPError(http.StatusBadRequest) + ErrBadGateway = NewHTTPError(http.StatusBadGateway) + ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) + ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) + ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) + ErrValidatorNotRegistered = errors.New("validator not registered") + ErrRendererNotRegistered = errors.New("renderer not registered") + ErrInvalidRedirectCode = errors.New("invalid redirect status code") + ErrCookieNotFound = errors.New("cookie not found") + ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") + ErrInvalidListenerNetwork = errors.New("invalid listener network") +) + +// HTTPError represents an error that occurred while handling a request. +type HTTPError struct { + Code int `json:"-"` + Message interface{} `json:"message"` + Internal error `json:"-"` // Stores the error returned by an external dependency +} + +// NewHTTPError creates a new HTTPError instance. +func NewHTTPError(code int, message ...interface{}) *HTTPError { // FIXME: this need cleanup - why vararg if [0] is only used? + he := &HTTPError{Code: code, Message: http.StatusText(code)} + if len(message) > 0 { + he.Message = message[0] + } + return he +} + +// NewHTTPErrorWithInternal creates a new HTTPError instance with internal error set. +func NewHTTPErrorWithInternal(code int, internalError error, message ...interface{}) *HTTPError { + he := NewHTTPError(code, message...) + he.Internal = internalError + return he +} + +// Error makes it compatible with `error` interface. +func (he *HTTPError) Error() string { + if he.Internal == nil { + return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) + } + return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) +} + +// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field +func (he *HTTPError) WithInternal(err error) *HTTPError { + return &HTTPError{ + Code: he.Code, + Message: he.Message, + Internal: err, + } +} + +// Unwrap satisfies the Go 1.13 error wrapper interface. +func (he *HTTPError) Unwrap() error { + return he.Internal +} diff --git a/httperror_test.go b/httperror_test.go new file mode 100644 index 00000000..f9d340f1 --- /dev/null +++ b/httperror_test.go @@ -0,0 +1,52 @@ +package echo + +import ( + "errors" + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestHTTPError(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Equal(t, "code=400, message=map[code:12]", err.Error()) + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) + }) +} + +func TestNewHTTPErrorWithInternal(t *testing.T) { + he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"), "test message") + assert.Equal(t, "code=400, message=test message, internal=test", he.Error()) +} + +func TestNewHTTPErrorWithInternal_noCustomMessage(t *testing.T) { + he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test")) + assert.Equal(t, "code=400, message=Bad Request, internal=test", he.Error()) +} + +func TestHTTPError_Unwrap(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Nil(t, errors.Unwrap(err)) + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "internal error", errors.Unwrap(err).Error()) + }) +} diff --git a/json.go b/json.go index 16b2d057..16074fa2 100644 --- a/json.go +++ b/json.go @@ -23,9 +23,16 @@ func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error { err := json.NewDecoder(c.Request().Body).Decode(i) if ute, ok := err.(*json.UnmarshalTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) + return NewHTTPErrorWithInternal( + http.StatusBadRequest, + err, + fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset), + ) } else if se, ok := err.(*json.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) + return NewHTTPErrorWithInternal(http.StatusBadRequest, + err, + fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error()), + ) } return err } diff --git a/log.go b/log.go index 3f8de590..49442535 100644 --- a/log.go +++ b/log.go @@ -1,41 +1,141 @@ package echo import ( + "bytes" "io" - - "github.com/labstack/gommon/log" + "strconv" + "sync" + "time" ) -type ( - // Logger defines the logging interface. - Logger interface { - Output() io.Writer - SetOutput(w io.Writer) - Prefix() string - SetPrefix(p string) - Level() log.Lvl - SetLevel(v log.Lvl) - SetHeader(h string) - Print(i ...interface{}) - Printf(format string, args ...interface{}) - Printj(j log.JSON) - Debug(i ...interface{}) - Debugf(format string, args ...interface{}) - Debugj(j log.JSON) - Info(i ...interface{}) - Infof(format string, args ...interface{}) - Infoj(j log.JSON) - Warn(i ...interface{}) - Warnf(format string, args ...interface{}) - Warnj(j log.JSON) - Error(i ...interface{}) - Errorf(format string, args ...interface{}) - Errorj(j log.JSON) - Fatal(i ...interface{}) - Fatalj(j log.JSON) - Fatalf(format string, args ...interface{}) - Panic(i ...interface{}) - Panicj(j log.JSON) - Panicf(format string, args ...interface{}) +//----------------------------------------------------------------------------- +// Example for Zap (https://github.com/uber-go/zap) +//func main() { +// e := echo.New() +// logger, _ := zap.NewProduction() +// e.Logger = &ZapLogger{logger: logger} +//} +//type ZapLogger struct { +// logger *zap.Logger +//} +// +//func (l *ZapLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.Info(string(p), zap.String("subsystem", "echo")) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *ZapLogger) Error(err error) { +// l.logger.Error(err.Error(), zap.Error(err), zap.String("subsystem", "echo")) +//} + +//----------------------------------------------------------------------------- +// Example for Zerolog (https://github.com/rs/zerolog) +//func main() { +// e := echo.New() +// logger := zerolog.New(os.Stdout) +// e.Logger = &ZeroLogger{logger: &logger} +//} +// +//type ZeroLogger struct { +// logger *zerolog.Logger +//} +// +//func (l *ZeroLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.Info().Str("subsystem", "echo").Msg(string(p)) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *ZeroLogger) Error(err error) { +// l.logger.Error().Str("subsystem", "echo").Err(err).Msg(err.Error()) +//} + +//----------------------------------------------------------------------------- +// Example for Logrus (https://github.com/sirupsen/logrus) +//func main() { +// e := echo.New() +// e.Logger = &LogrusLogger{logger: logrus.New()} +//} +// +//type LogrusLogger struct { +// logger *logrus.Logger +//} +// +//func (l *LogrusLogger) Write(p []byte) (n int, err error) { +// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all. +// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Info(string(p)) // naively log everything as string message. +// return len(p), nil +//} +// +//func (l *LogrusLogger) Error(err error) { +// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Error(err) +//} + +// Logger defines the logging interface that Echo uses internally in few places. +// For logging in handlers use your own logger instance (dependency injected or package/public variable) from logging framework of your choice. +type Logger interface { + // Write provides writer interface for http.Server `ErrorLog` and for logging startup messages. + // `http.Server.ErrorLog` logs errors from accepting connections, unexpected behavior from handlers, + // and underlying FileSystem errors. + // `logger` middleware will use this method to write its JSON payload. + Write(p []byte) (n int, err error) + // Error logs the error + Error(err error) +} + +// jsonLogger is similar logger formatting implementation as `v4` had. It is not particularly fast or efficient. Only +// goal it to exist is to have somewhat backwards compatibility with `v4` for Echo internals logging formatting. +// It is not meant for logging in handlers/middlewares. Use some real logging library for those cases. +type jsonLogger struct { + writer io.Writer + bufferPool sync.Pool + + timeNow func() time.Time +} + +func newJSONLogger(writer io.Writer) *jsonLogger { + return &jsonLogger{ + writer: writer, + bufferPool: sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 256)) + }, + }, + timeNow: time.Now, } -) +} + +func (l *jsonLogger) Write(p []byte) (n int, err error) { + pLen := len(p) + if pLen >= 2 && // naively try to avoid JSON values to be wrapped into message + (p[0] == '{' && p[pLen-2] == '}' && p[pLen-1] == '\n') || + (p[0] == '{' && p[pLen-1] == '}') { + return l.writer.Write(p) + } + // we log with WARN level as we have no idea what that message level should be. From Echo perspective this method is + // called when we pass Echo logger to http.Server.ErrorLog and there are problems inside http.Server - which probably + // deserves at least WARN level. + return l.printf("WARN", string(p)) +} + +func (l *jsonLogger) Error(err error) { + _, _ = l.printf("ERROR", err.Error()) +} + +func (l *jsonLogger) printf(level string, message string) (n int, err error) { + buf := l.bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer l.bufferPool.Put(buf) + + buf.WriteString(`{"time":"`) + buf.WriteString(l.timeNow().Format(time.RFC3339Nano)) + buf.WriteString(`","level":"`) + buf.WriteString(level) + buf.WriteString(`","prefix":"echo","message":`) + + buf.WriteString(strconv.Quote(message)) + buf.WriteString("}\n") + + return l.writer.Write(buf.Bytes()) +} diff --git a/log_test.go b/log_test.go new file mode 100644 index 00000000..ed635290 --- /dev/null +++ b/log_test.go @@ -0,0 +1,77 @@ +package echo + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestJsonLogger_Write(t *testing.T) { + var testCases = []struct { + name string + when []byte + expect string + }{ + { + name: "ok, write non JSONlike message", + when: []byte("version: %v, build: %v"), + expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: %v, build: %v"}` + "\n", + }, + { + name: "ok, write quoted message", + when: []byte(`version: "%v"`), + expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: \"%v\""}` + "\n", + }, + { + name: "ok, write JSON", + when: []byte(`{"version": 123}` + "\n"), + expect: `{"version": 123}` + "\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + logger := newJSONLogger(buf) + logger.timeNow = func() time.Time { + return time.Unix(1631045377, 0) + } + + _, err := logger.Write(tc.when) + + result := buf.String() + assert.Equal(t, tc.expect, result) + assert.NoError(t, err) + }) + } +} + +func TestJsonLogger_Error(t *testing.T) { + var testCases = []struct { + name string + whenError error + expect string + }{ + { + name: "ok", + whenError: ErrForbidden, + expect: `{"time":"2021-09-07T23:09:37+03:00","level":"ERROR","prefix":"echo","message":"code=403, message=Forbidden"}` + "\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + logger := newJSONLogger(buf) + logger.timeNow = func() time.Time { + return time.Unix(1631045377, 0) + } + + logger.Error(tc.whenError) + + result := buf.String() + assert.Equal(t, tc.expect, result) + }) + } +} diff --git a/middleware/DEVELOPMENT.md b/middleware/DEVELOPMENT.md new file mode 100644 index 00000000..68002cad --- /dev/null +++ b/middleware/DEVELOPMENT.md @@ -0,0 +1,13 @@ +# Development Guidelines for middlewares + +// FIXME: add info about `MiddlewareConfigurator` interface + +## Best practices: + +* Do not use `panic` in middleware creator functions in case of invalid configuration. +* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead + because previous middlewares up in call chain could have logic for dealing with returned errors. +* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they + want to create middleware with panics or with returning errors on configuration errors. +* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same. + diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 8cf1ed9f..73caeaf9 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -1,64 +1,59 @@ package middleware import ( + "bytes" "encoding/base64" + "errors" + "fmt" "strconv" "strings" "github.com/labstack/echo/v4" ) -type ( - // BasicAuthConfig defines the config for BasicAuth middleware. - BasicAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BasicAuthConfig defines the config for BasicAuthWithConfig middleware. +type BasicAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Validator is a function to validate BasicAuth credentials. - // Required. - Validator BasicAuthValidator + // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers + // this function would be called once for each header until first valid result is returned + // Required. + Validator BasicAuthValidator - // Realm is a string to define realm attribute of BasicAuth. - // Default value "Restricted". - Realm string - } + // Realm is a string to define realm attribute of BasicAuthWithConfig. + // Default value "Restricted". + Realm string +} - // BasicAuthValidator defines a function to validate BasicAuth credentials. - BasicAuthValidator func(string, string, echo.Context) (bool, error) -) +// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials. +type BasicAuthValidator func(c echo.Context, user string, password string) (bool, error) const ( basic = "basic" defaultRealm = "Restricted" ) -var ( - // DefaultBasicAuthConfig is the default BasicAuth middleware config. - DefaultBasicAuthConfig = BasicAuthConfig{ - Skipper: DefaultSkipper, - Realm: defaultRealm, - } -) - // BasicAuth returns an BasicAuth middleware. // // For valid credentials it calls the next handler. // For missing or invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { - c := DefaultBasicAuthConfig - c.Validator = fn - return BasicAuthWithConfig(c) + return BasicAuthWithConfig(BasicAuthConfig{Validator: fn}) } -// BasicAuthWithConfig returns an BasicAuth middleware with config. -// See `BasicAuth()`. +// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config. func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration +func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Validator == nil { - panic("echo: basic-auth middleware requires a validator function") + return nil, errors.New("echo basic-auth middleware requires a validator function") } if config.Skipper == nil { - config.Skipper = DefaultBasicAuthConfig.Skipper + config.Skipper = DefaultSkipper } if config.Realm == "" { config.Realm = defaultRealm @@ -70,29 +65,33 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { return next(c) } - auth := c.Request().Header.Get(echo.HeaderAuthorization) + var lastError error l := len(basic) - - if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { - b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err != nil { - return err + for _, auth := range c.Request().Header[echo.HeaderAuthorization] { + if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { + continue } - cred := string(b) - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Verify credentials - valid, err := config.Validator(cred[:i], cred[i+1:], c) - if err != nil { - return err - } else if valid { - return next(c) - } - break + + b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) + if errDecode != nil { + lastError = fmt.Errorf("invalid basic auth value: %w", errDecode) + continue + } + idx := bytes.IndexByte(b, ':') + if idx >= 0 { + valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:])) + if errValidate != nil { + lastError = errValidate + } else if valid { + return next(c) } } } + if lastError != nil { + return lastError + } + realm := defaultRealm if config.Realm != defaultRealm { realm = strconv.Quote(config.Realm) @@ -102,5 +101,5 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) return echo.ErrUnauthorized } - } + }, nil } diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 76039db0..41532312 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -2,6 +2,7 @@ package middleware import ( "encoding/base64" + "errors" "net/http" "net/http/httptest" "strings" @@ -12,60 +13,146 @@ import ( ) func TestBasicAuth(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - c := e.NewContext(req, res) - f := func(u, p string, c echo.Context) (bool, error) { + validatorFunc := func(c echo.Context, u, p string) (bool, error) { if u == "joe" && p == "secret" { return true, nil } + if u == "error" { + return false, errors.New(p) + } return false, nil } - h := BasicAuth(f)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + defaultConfig := BasicAuthConfig{Validator: validatorFunc} - assert := assert.New(t) + var testCases = []struct { + name string + givenConfig BasicAuthConfig + whenAuth []string + expectHeader string + expectErr string + }{ + { + name: "ok", + givenConfig: defaultConfig, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, multiple", + givenConfig: defaultConfig, + whenAuth: []string{ + "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), + basic + " NOT_BASE64", + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + }, + }, + { + name: "nok, invalid Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, not base64 Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, + expectErr: "invalid basic auth value: illegal base64 data at input byte 3", + }, + { + name: "nok, missing Authorization header", + givenConfig: defaultConfig, + expectHeader: basic + ` realm=Restricted`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "ok, realm", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, realm, case-insensitive header scheme", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "nok, realm, invalid Authorization header", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm="someRealm"`, + expectErr: "code=401, message=Unauthorized", + }, + { + name: "nok, validator func returns an error", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))}, + expectErr: "my_error", + }, + { + name: "ok, skipped", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool { + return true + }}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + }, + } - // Valid credentials - auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) - h = BasicAuthWithConfig(BasicAuthConfig{ - Skipper: nil, - Validator: f, - Realm: "someRealm", - })(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + config := tc.givenConfig - // Valid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + mw, err := config.ToMiddleware() + assert.NoError(t, err) - // Case-insensitive header scheme - auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + h := mw(func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) - // Invalid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) - req.Header.Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) - assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) + if len(tc.whenAuth) != 0 { + for _, a := range tc.whenAuth { + req.Header.Add(echo.HeaderAuthorization, a) + } + } + err = h(c) - // Missing Authorization header - req.Header.Del(echo.HeaderAuthorization) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) - - // Invalid Authorization header - auth = base64.StdEncoding.EncodeToString([]byte("invalid")) - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) + if tc.expectErr != "" { + assert.Equal(t, http.StatusOK, res.Code) + assert.EqualError(t, err, tc.expectErr) + } else { + assert.Equal(t, http.StatusTeapot, res.Code) + assert.NoError(t, err) + } + if tc.expectHeader != "" { + assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate)) + } + }) + } +} + +func TestBasicAuth_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuth(nil) + assert.NotNil(t, mw) + }) + + mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) { + return true, nil + }) + assert.NotNil(t, mw) +} + +func TestBasicAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil}) + assert.NotNil(t, mw) + }) + + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) { + return true, nil + }}) + assert.NotNil(t, mw) } diff --git a/middleware/body_dump.go b/middleware/body_dump.go index ebd0d0ab..d04822df 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -3,6 +3,7 @@ package middleware import ( "bufio" "bytes" + "errors" "io" "io/ioutil" "net" @@ -11,63 +12,56 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // BodyDumpConfig defines the config for BodyDump middleware. - BodyDumpConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BodyDumpConfig defines the config for BodyDump middleware. +type BodyDumpConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Handler receives request and response payload. - // Required. - Handler BodyDumpHandler - } + // Handler receives request and response payload. + // Required. + Handler BodyDumpHandler +} - // BodyDumpHandler receives the request and response payload. - BodyDumpHandler func(echo.Context, []byte, []byte) +// BodyDumpHandler receives the request and response payload. +type BodyDumpHandler func(c echo.Context, reqBody []byte, resBody []byte) - bodyDumpResponseWriter struct { - io.Writer - http.ResponseWriter - } -) - -var ( - // DefaultBodyDumpConfig is the default BodyDump middleware config. - DefaultBodyDumpConfig = BodyDumpConfig{ - Skipper: DefaultSkipper, - } -) +type bodyDumpResponseWriter struct { + io.Writer + http.ResponseWriter +} // BodyDump returns a BodyDump middleware. // // BodyDump middleware captures the request and response payload and calls the // registered handler. func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc { - c := DefaultBodyDumpConfig - c.Handler = handler - return BodyDumpWithConfig(c) + return BodyDumpWithConfig(BodyDumpConfig{Handler: handler}) } // BodyDumpWithConfig returns a BodyDump middleware with config. // See: `BodyDump()`. func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration +func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Handler == nil { - panic("echo: body-dump middleware requires a handler function") + return nil, errors.New("echo body-dump middleware requires a handler function") } if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } // Request reqBody := []byte{} - if c.Request().Body != nil { // Read + if c.Request().Body != nil { reqBody, _ = ioutil.ReadAll(c.Request().Body) } c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset @@ -78,16 +72,14 @@ func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer} c.Response().Writer = writer - if err = next(c); err != nil { - c.Error(err) - } + err := next(c) // Callback config.Handler(c, reqBody, resBody.Bytes()) - return + return err } - } + }, nil } func (w *bodyDumpResponseWriter) WriteHeader(code int) { diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index e6e00f72..21ca4cac 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -28,31 +28,48 @@ func TestBodyDump(t *testing.T) { requestBody := "" responseBody := "" - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { + mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) { requestBody = string(reqBody) responseBody = string(resBody) - }) + }}.ToMiddleware() + assert.NoError(t, err) - assert := assert.New(t) - - if assert.NoError(mw(h)(c)) { - assert.Equal(requestBody, hw) - assert.Equal(responseBody, hw) - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.String()) + if assert.NoError(t, mw(h)(c)) { + assert.Equal(t, requestBody, hw) + assert.Equal(t, responseBody, hw) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.String()) } - // Must set default skipper - BodyDumpWithConfig(BodyDumpConfig{ - Skipper: nil, - Handler: func(c echo.Context, reqBody, resBody []byte) { - requestBody = string(reqBody) - responseBody = string(resBody) - }, - }) } -func TestBodyDumpFails(t *testing.T) { +func TestBodyDump_skipper(t *testing.T) { + e := echo.New() + + isCalled := false + mw, err := BodyDumpConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + Handler: func(c echo.Context, reqBody, resBody []byte) { + isCalled = true + }, + }.ToMiddleware() + assert.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + return errors.New("some error") + } + + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.False(t, isCalled) +} + +func TestBodyDump_fails(t *testing.T) { e := echo.New() hw := "Hello, World!" req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) @@ -62,30 +79,37 @@ func TestBodyDumpFails(t *testing.T) { return errors.New("some error") } - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) + mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}.ToMiddleware() + assert.NoError(t, err) - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestBodyDumpWithConfig_panic(t *testing.T) { assert.Panics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ + mw := BodyDumpWithConfig(BodyDumpConfig{ Skipper: nil, Handler: nil, }) + assert.NotNil(t, mw) }) assert.NotPanics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ - Skipper: func(c echo.Context) bool { - return true - }, - Handler: func(c echo.Context, reqBody, resBody []byte) { - }, - }) - - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } + mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}) + assert.NotNil(t, mw) + }) +} + +func TestBodyDump_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BodyDump(nil) + assert.NotNil(t, mw) + }) + + assert.NotPanics(t, func() { + BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) }) } diff --git a/middleware/body_limit.go b/middleware/body_limit.go index b436bd59..b00bbae6 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -1,98 +1,83 @@ package middleware import ( - "fmt" "io" "sync" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/bytes" ) -type ( - // BodyLimitConfig defines the config for BodyLimit middleware. - BodyLimitConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BodyLimitConfig defines the config for BodyLimitWithConfig middleware. +type BodyLimitConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Maximum allowed size for a request body, it can be specified - // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. - Limit string `yaml:"limit"` - limit int64 - } + // LimitBytes is maximum allowed size in bytes for a request body + LimitBytes int64 +} - limitedReader struct { - BodyLimitConfig - reader io.ReadCloser - read int64 - context echo.Context - } -) - -var ( - // DefaultBodyLimitConfig is the default BodyLimit middleware config. - DefaultBodyLimitConfig = BodyLimitConfig{ - Skipper: DefaultSkipper, - } -) +type limitedReader struct { + BodyLimitConfig + reader io.ReadCloser + read int64 + context echo.Context +} // BodyLimit returns a BodyLimit middleware. // -// BodyLimit middleware sets the maximum allowed size for a request body, if the -// size exceeds the configured limit, it sends "413 - Request Entity Too Large" -// response. The BodyLimit is determined based on both `Content-Length` request +// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it +// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request // header and actual content read, which makes it super secure. -// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, -// G, T or P. -func BodyLimit(limit string) echo.MiddlewareFunc { - c := DefaultBodyLimitConfig - c.Limit = limit - return BodyLimitWithConfig(c) +func BodyLimit(limitBytes int64) echo.MiddlewareFunc { + return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes}) } -// BodyLimitWithConfig returns a BodyLimit middleware with config. -// See: `BodyLimit()`. +// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for +// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response. +// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which +// makes it super secure. func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { - // Defaults - if config.Skipper == nil { - config.Skipper = DefaultBodyLimitConfig.Skipper - } + return toMiddlewareOrPanic(config) +} - limit, err := bytes.Parse(config.Limit) - if err != nil { - panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit)) +// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration +func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + if config.Skipper == nil { + config.Skipper = DefaultSkipper + } + pool := sync.Pool{ + New: func() interface{} { + return &limitedReader{BodyLimitConfig: config} + }, } - config.limit = limit - pool := limitedReaderPool(config) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { return next(c) } - req := c.Request() // Based on content length - if req.ContentLength > config.limit { + if req.ContentLength > config.LimitBytes { return echo.ErrStatusRequestEntityTooLarge } // Based on content read r := pool.Get().(*limitedReader) - r.Reset(req.Body, c) + r.Reset(c, req.Body) defer pool.Put(r) req.Body = r return next(c) } - } + }, nil } func (r *limitedReader) Read(b []byte) (n int, err error) { n, err = r.reader.Read(b) r.read += int64(n) - if r.read > r.limit { + if r.read > r.LimitBytes { return n, echo.ErrStatusRequestEntityTooLarge } return @@ -102,16 +87,8 @@ func (r *limitedReader) Close() error { return r.reader.Close() } -func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) { +func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) { r.reader = reader r.context = context r.read = 0 } - -func limitedReaderPool(c BodyLimitConfig) sync.Pool { - return sync.Pool{ - New: func() interface{} { - return &limitedReader{BodyLimitConfig: c} - }, - } -} diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 0e8642a0..c03c6fd8 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -11,6 +11,137 @@ import ( "github.com/stretchr/testify/assert" ) +func TestBodyLimitConfig_ToMiddleware(t *testing.T) { + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + body, err := ioutil.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + + // Based on content length (within limit) + mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) + } + + // Based on content read (overlimit) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he := mw(h)(c).(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) + + // Based on content read (within limit) + req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + + mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, World!", rec.Body.String()) + + // Based on content read (overlimit) + req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he = mw(h)(c).(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) +} + +func TestBodyLimitReader(t *testing.T) { + hw := []byte("Hello, World!") + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + + config := BodyLimitConfig{ + Skipper: DefaultSkipper, + LimitBytes: 2, + } + reader := &limitedReader{ + BodyLimitConfig: config, + reader: ioutil.NopCloser(bytes.NewReader(hw)), + context: e.NewContext(req, rec), + } + + // read all should return ErrStatusRequestEntityTooLarge + _, err := ioutil.ReadAll(reader) + he := err.(*echo.HTTPError) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) + + // reset reader and read two bytes must succeed + bt := make([]byte, 2) + reader.Reset(e.NewContext(req, rec), ioutil.NopCloser(bytes.NewReader(hw))) + n, err := reader.Read(bt) + assert.Equal(t, 2, n) + assert.Equal(t, nil, err) +} + +func TestBodyLimit_skipper(t *testing.T) { + e := echo.New() + h := func(c echo.Context) error { + body, err := ioutil.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + mw, err := BodyLimitConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + LimitBytes: 2, + }.ToMiddleware() + assert.NoError(t, err) + + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} + +func TestBodyLimitWithConfig(t *testing.T) { + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c echo.Context) error { + body, err := ioutil.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + + mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB}) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} + func TestBodyLimit(t *testing.T) { e := echo.New() hw := []byte("Hello, World!") @@ -25,61 +156,10 @@ func TestBodyLimit(t *testing.T) { return c.String(http.StatusOK, string(body)) } - assert := assert.New(t) + mw := BodyLimit(2 * MB) - // Based on content length (within limit) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.Bytes()) - } - - // Based on content read (overlimit) - he := BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) - - // Based on content read (within limit) - req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, World!", rec.Body.String()) - } - - // Based on content read (overlimit) - req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - he = BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) -} - -func TestBodyLimitReader(t *testing.T) { - hw := []byte("Hello, World!") - e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) - rec := httptest.NewRecorder() - - config := BodyLimitConfig{ - Skipper: DefaultSkipper, - Limit: "2B", - limit: 2, - } - reader := &limitedReader{ - BodyLimitConfig: config, - reader: ioutil.NopCloser(bytes.NewReader(hw)), - context: e.NewContext(req, rec), - } - - // read all should return ErrStatusRequestEntityTooLarge - _, err := ioutil.ReadAll(reader) - he := err.(*echo.HTTPError) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) - - // reset reader and read two bytes must succeed - bt := make([]byte, 2) - reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec)) - n, err := reader.Read(bt) - assert.Equal(t, 2, n) - assert.Equal(t, nil, err) + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) } diff --git a/middleware/compress.go b/middleware/compress.go index 6ae19745..8cd66f8d 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -3,6 +3,7 @@ package middleware import ( "bufio" "compress/gzip" + "errors" "io" "io/ioutil" "net" @@ -13,50 +14,45 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // GzipConfig defines the config for Gzip middleware. - GzipConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Gzip compression level. - // Optional. Default value -1. - Level int `yaml:"level"` - } - - gzipResponseWriter struct { - io.Writer - http.ResponseWriter - } -) - const ( gzipScheme = "gzip" ) -var ( - // DefaultGzipConfig is the default Gzip middleware config. - DefaultGzipConfig = GzipConfig{ - Skipper: DefaultSkipper, - Level: -1, - } -) +// GzipConfig defines the config for Gzip middleware. +type GzipConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper -// Gzip returns a middleware which compresses HTTP response using gzip compression -// scheme. -func Gzip() echo.MiddlewareFunc { - return GzipWithConfig(DefaultGzipConfig) + // Gzip compression level. + // Optional. Default value -1. + Level int } -// GzipWithConfig return Gzip middleware with config. -// See: `Gzip()`. +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter +} + +// Gzip returns a middleware which compresses HTTP response using gzip compression scheme. +func Gzip() echo.MiddlewareFunc { + return GzipWithConfig(GzipConfig{}) +} + +// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme. func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration +func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression + return nil, errors.New("invalid gzip level") } if config.Level == 0 { - config.Level = DefaultGzipConfig.Level + config.Level = -1 } pool := gzipCompressPool(config) @@ -97,7 +93,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } func (w *gzipResponseWriter) WriteHeader(code int) { diff --git a/middleware/compress_test.go b/middleware/compress_test.go index d16ffca4..1b4ebc8c 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -3,94 +3,128 @@ package middleware import ( "bytes" "compress/gzip" - "io" "io/ioutil" "net/http" "net/http/httptest" + "os" "testing" + "time" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) -func TestGzip(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - +func TestGzip_NoAcceptEncodingHeader(t *testing.T) { // Skip if no Accept-Encoding header h := Gzip()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - assert := assert.New(t) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - assert.Equal("test", rec.Body.String()) + err := h(c) + assert.NoError(t, err) - // Gzip - req = httptest.NewRequest(http.MethodGet, "/", nil) + assert.Equal(t, "test", rec.Body.String()) +} + +func TestMustGzipWithConfig_panics(t *testing.T) { + assert.Panics(t, func() { + GzipWithConfig(GzipConfig{Level: 999}) + }) +} + +func TestGzip_AcceptEncodingHeader(t *testing.T) { + h := Gzip()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) + r, err := gzip.NewReader(rec.Body) - if assert.NoError(err) { - buf := new(bytes.Buffer) - defer r.Close() - buf.ReadFrom(r) - assert.Equal("test", buf.String()) - } + assert.NoError(t, err) + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal(t, "test", buf.String()) +} - chunkBuf := make([]byte, 5) - - // Gzip chunked - req = httptest.NewRequest(http.MethodGet, "/", nil) +func TestGzip_chunked(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - c = e.NewContext(req, rec) - Gzip()(func(c echo.Context) error { + chunkChan := make(chan struct{}) + waitChan := make(chan struct{}) + h := Gzip()(func(c echo.Context) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") // Write and flush the first part of the data - c.Response().Write([]byte("test\n")) + c.Response().Write([]byte("first\n")) c.Response().Flush() - // Read the first part of the data - assert.True(rec.Flushed) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - r.Reset(rec.Body) - - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write and flush the second part of the data - c.Response().Write([]byte("test\n")) + c.Response().Write([]byte("second\n")) c.Response().Flush() - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write the final part of the data and return - c.Response().Write([]byte("test")) - return nil - })(c) + c.Response().Write([]byte("third")) + chunkChan <- struct{}{} + return nil + }) + + go func() { + err := h(c) + chunkChan <- struct{}{} + assert.NoError(t, err) + }() + + <-chunkChan // wait for first write + waitChan <- struct{}{} + + <-chunkChan // wait for second write + waitChan <- struct{}{} + + <-chunkChan // wait for final write in handler + <-chunkChan // wait for return from handler + time.Sleep(5 * time.Millisecond) // to have time for flushing + + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + + r, err := gzip.NewReader(rec.Body) + assert.NoError(t, err) buf := new(bytes.Buffer) - defer r.Close() buf.ReadFrom(r) - assert.Equal("test", buf.String()) + assert.Equal(t, "first\nsecond\nthird", buf.String()) } -func TestGzipNoContent(t *testing.T) { +func TestGzip_NoContent(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) @@ -106,7 +140,7 @@ func TestGzipNoContent(t *testing.T) { } } -func TestGzipErrorReturned(t *testing.T) { +func TestGzip_ErrorReturned(t *testing.T) { e := echo.New() e.Use(Gzip()) e.GET("/", func(c echo.Context) error { @@ -120,31 +154,25 @@ func TestGzipErrorReturned(t *testing.T) { assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } -func TestGzipErrorReturnedInvalidConfig(t *testing.T) { - e := echo.New() - // Invalid level - e.Use(GzipWithConfig(GzipConfig{Level: 12})) - e.GET("/", func(c echo.Context) error { - c.Response().Write([]byte("test")) - return nil - }) - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, rec.Body.String(), "gzip") +func TestGzipWithConfig_invalidLevel(t *testing.T) { + mw, err := GzipConfig{Level: 12}.ToMiddleware() + assert.EqualError(t, err, "invalid gzip level") + assert.Nil(t, mw) } // Issue #806 func TestGzipWithStatic(t *testing.T) { e := echo.New() + e.Filesystem = os.DirFS("../") + e.Use(Gzip()) - e.Static("/test", "../_fixture/images") + e.Static("/test", "_fixture/images") req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) // Data is written out in chunks when Content-Length == "", so only // validate the content length if it's not set. diff --git a/middleware/cors.go b/middleware/cors.go index d6ef8964..c3db8749 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -9,60 +9,56 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // CORSConfig defines the config for CORS middleware. - CORSConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// CORSConfig defines the config for CORS middleware. +type CORSConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // AllowOrigin defines a list of origins that may access the resource. - // Optional. Default value []string{"*"}. - AllowOrigins []string `yaml:"allow_origins"` + // AllowOrigin defines a list of origins that may access the resource. + // Optional. Default value []string{"*"}. + AllowOrigins []string - // AllowOriginFunc is a custom function to validate the origin. It takes the - // origin as an argument and returns true if allowed or false otherwise. If - // an error is returned, it is returned by the handler. If this option is - // set, AllowOrigins is ignored. - // Optional. - AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // Optional. + AllowOriginFunc func(origin string) (bool, error) - // AllowMethods defines a list methods allowed when accessing the resource. - // This is used in response to a preflight request. - // Optional. Default value DefaultCORSConfig.AllowMethods. - AllowMethods []string `yaml:"allow_methods"` + // AllowMethods defines a list methods allowed when accessing the resource. + // This is used in response to a preflight request. + // Optional. Default value DefaultCORSConfig.AllowMethods. + AllowMethods []string - // AllowHeaders defines a list of request headers that can be used when - // making the actual request. This is in response to a preflight request. - // Optional. Default value []string{}. - AllowHeaders []string `yaml:"allow_headers"` + // AllowHeaders defines a list of request headers that can be used when + // making the actual request. This is in response to a preflight request. + // Optional. Default value []string{}. + AllowHeaders []string - // AllowCredentials indicates whether or not the response to the request - // can be exposed when the credentials flag is true. When used as part of - // a response to a preflight request, this indicates whether or not the - // actual request can be made using credentials. - // Optional. Default value false. - AllowCredentials bool `yaml:"allow_credentials"` + // AllowCredentials indicates whether or not the response to the request + // can be exposed when the credentials flag is true. When used as part of + // a response to a preflight request, this indicates whether or not the + // actual request can be made using credentials. + // Optional. Default value false. + AllowCredentials bool - // ExposeHeaders defines a whitelist headers that clients are allowed to - // access. - // Optional. Default value []string{}. - ExposeHeaders []string `yaml:"expose_headers"` + // ExposeHeaders defines a whitelist headers that clients are allowed to + // access. + // Optional. Default value []string{}. + ExposeHeaders []string - // MaxAge indicates how long (in seconds) the results of a preflight request - // can be cached. - // Optional. Default value 0. - MaxAge int `yaml:"max_age"` - } -) + // MaxAge indicates how long (in seconds) the results of a preflight request + // can be cached. + // Optional. Default value 0. + MaxAge int +} -var ( - // DefaultCORSConfig is the default CORS middleware config. - DefaultCORSConfig = CORSConfig{ - Skipper: DefaultSkipper, - AllowOrigins: []string{"*"}, - AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, - } -) +// DefaultCORSConfig is the default CORS middleware config. +var DefaultCORSConfig = CORSConfig{ + Skipper: DefaultSkipper, + AllowOrigins: []string{"*"}, + AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, +} // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS @@ -70,9 +66,14 @@ func CORS() echo.MiddlewareFunc { return CORSWithConfig(DefaultCORSConfig) } -// CORSWithConfig returns a CORS middleware with config. +// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration. // See: `CORS()`. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration +func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultCORSConfig.Skipper @@ -207,5 +208,5 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } return c.NoContent(http.StatusNoContent) } - } + }, nil } diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 717abe49..8a654223 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -17,7 +17,7 @@ func TestCORS(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := CORS()(echo.NotFoundHandler) + h := CORS()(func(c echo.Context) error { return echo.ErrNotFound }) req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) @@ -26,7 +26,7 @@ func TestCORS(t *testing.T) { req = httptest.NewRequest(http.MethodGet, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec) - h = CORS()(echo.NotFoundHandler) + h = CORS()(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) @@ -38,7 +38,7 @@ func TestCORS(t *testing.T) { AllowOrigins: []string{"localhost"}, AllowCredentials: true, MaxAge: 3600, - })(echo.NotFoundHandler) + })(func(c echo.Context) error { return echo.ErrNotFound }) req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) @@ -55,7 +55,7 @@ func TestCORS(t *testing.T) { AllowCredentials: true, MaxAge: 3600, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) @@ -73,7 +73,7 @@ func TestCORS(t *testing.T) { AllowCredentials: true, MaxAge: 3600, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) @@ -90,7 +90,7 @@ func TestCORS(t *testing.T) { cors = CORSWithConfig(CORSConfig{ AllowOrigins: []string{"*"}, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) @@ -104,7 +104,7 @@ func TestCORS(t *testing.T) { cors = CORSWithConfig(CORSConfig{ AllowOrigins: []string{"http://*.example.com"}, }) - h = cors(echo.NotFoundHandler) + h = cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) @@ -149,7 +149,7 @@ func Test_allowOriginScheme(t *testing.T) { cors := CORSWithConfig(CORSConfig{ AllowOrigins: []string{tt.pattern}, }) - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) if tt.expected { @@ -240,7 +240,7 @@ func Test_allowOriginSubdomain(t *testing.T) { cors := CORSWithConfig(CORSConfig{ AllowOrigins: []string{tt.pattern}, }) - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) if tt.expected { @@ -331,7 +331,7 @@ func TestCorsHeaders(t *testing.T) { //AllowCredentials: true, //MaxAge: 3600, }) - h := cors(echo.NotFoundHandler) + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) h(c) assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) @@ -387,11 +387,11 @@ func Test_allowOriginFunc(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, origin) - cors := CORSWithConfig(CORSConfig{ - AllowOriginFunc: allowOriginFunc, - }) - h := cors(echo.NotFoundHandler) - err := h(c) + cors, err := CORSConfig{AllowOriginFunc: allowOriginFunc}.ToMiddleware() + assert.NoError(t, err) + + h := cors(func(c echo.Context) error { return echo.ErrNotFound }) + err = h(c) expected, expectedErr := allowOriginFunc(origin) if expectedErr != nil { diff --git a/middleware/csrf.go b/middleware/csrf.go index 7804997d..5859bd0f 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -8,89 +8,90 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" ) -type ( - // CSRFConfig defines the config for CSRF middleware. - CSRFConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// CSRFConfig defines the config for CSRF middleware. +type CSRFConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // TokenLength is the length of the generated token. - TokenLength uint8 `yaml:"token_length"` - // Optional. Default value 32. + // TokenLength is the length of the generated token. + TokenLength uint8 + // Optional. Default value 32. - // TokenLookup is a string in the form of ":" that is used - // to extract token from the request. - // Optional. Default value "header:X-CSRF-Token". - // Possible values: - // - "header:" - // - "form:" - // - "query:" - TokenLookup string `yaml:"token_lookup"` + // Generator defines a function to generate token. + // Optional. Defaults tp randomString(TokenLength). + Generator func() string - // Context key to store generated CSRF token into context. - // Optional. Default value "csrf". - ContextKey string `yaml:"context_key"` + // TokenLookup is a string in the form of ":" that is used + // to extract token from the request. + // Optional. Default value "header:X-CSRF-Token". + // Possible values: + // - "header:" + // - "form:" + // - "query:" + TokenLookup string - // Name of the CSRF cookie. This cookie will store CSRF token. - // Optional. Default value "csrf". - CookieName string `yaml:"cookie_name"` + // Context key to store generated CSRF token into context. + // Optional. Default value "csrf". + ContextKey string - // Domain of the CSRF cookie. - // Optional. Default value none. - CookieDomain string `yaml:"cookie_domain"` + // Name of the CSRF cookie. This cookie will store CSRF token. + // Optional. Default value "csrf". + CookieName string - // Path of the CSRF cookie. - // Optional. Default value none. - CookiePath string `yaml:"cookie_path"` + // Domain of the CSRF cookie. + // Optional. Default value none. + CookieDomain string - // Max age (in seconds) of the CSRF cookie. - // Optional. Default value 86400 (24hr). - CookieMaxAge int `yaml:"cookie_max_age"` + // Path of the CSRF cookie. + // Optional. Default value none. + CookiePath string - // Indicates if CSRF cookie is secure. - // Optional. Default value false. - CookieSecure bool `yaml:"cookie_secure"` + // Max age (in seconds) of the CSRF cookie. + // Optional. Default value 86400 (24hr). + CookieMaxAge int - // Indicates if CSRF cookie is HTTP only. - // Optional. Default value false. - CookieHTTPOnly bool `yaml:"cookie_http_only"` + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool - // Indicates SameSite mode of the CSRF cookie. - // Optional. Default value SameSiteDefaultMode. - CookieSameSite http.SameSite `yaml:"cookie_same_site"` - } + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool - // csrfTokenExtractor defines a function that takes `echo.Context` and returns - // either a token or an error. - csrfTokenExtractor func(echo.Context) (string, error) -) + // Indicates SameSite mode of the CSRF cookie. + // Optional. Default value SameSiteDefaultMode. + CookieSameSite http.SameSite +} -var ( - // DefaultCSRFConfig is the default CSRF middleware config. - DefaultCSRFConfig = CSRFConfig{ - Skipper: DefaultSkipper, - TokenLength: 32, - TokenLookup: "header:" + echo.HeaderXCSRFToken, - ContextKey: "csrf", - CookieName: "_csrf", - CookieMaxAge: 86400, - CookieSameSite: http.SameSiteDefaultMode, - } -) +// csrfTokenExtractor defines a function that takes `echo.Context` and returns either a token or an error. +type csrfTokenExtractor func(echo.Context) (string, error) + +// DefaultCSRFConfig is the default CSRF middleware config. +var DefaultCSRFConfig = CSRFConfig{ + Skipper: DefaultSkipper, + TokenLength: 32, + TokenLookup: "header:" + echo.HeaderXCSRFToken, + ContextKey: "csrf", + CookieName: "_csrf", + CookieMaxAge: 86400, + CookieSameSite: http.SameSiteDefaultMode, +} // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery func CSRF() echo.MiddlewareFunc { - c := DefaultCSRFConfig - return CSRFWithConfig(c) + return CSRFWithConfig(DefaultCSRFConfig) } -// CSRFWithConfig returns a CSRF middleware with config. -// See `CSRF()`. +// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration +func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper @@ -98,6 +99,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } + if config.Generator == nil { + config.Generator = createRandomStringGenerator(config.TokenLength) + } if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } @@ -136,7 +140,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { // Generate token if err != nil { - token = random.String(config.TokenLength) + token = config.Generator() } else { // Reuse token token = k.Value @@ -181,7 +185,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return next(c) } - } + }, nil } // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index af1d2639..427e593b 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -9,11 +9,26 @@ import ( "testing" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" "github.com/stretchr/testify/assert" ) func TestCSRF(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + csrf := CSRF() + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Generate CSRF token + h(c) + assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") + +} + +func TestMustCSRFWithConfig(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -43,7 +58,7 @@ func TestCSRF(t *testing.T) { assert.Error(t, h(c)) // Valid CSRF token - token := random.String(16) + token := randomString(16) req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { @@ -145,9 +160,10 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ + csrf, err := CSRFConfig{ CookieSameSite: http.SameSiteNoneMode, - }) + }.ToMiddleware() + assert.NoError(t, err) h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") diff --git a/middleware/decompress.go b/middleware/decompress.go index c046359a..546c4b00 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -11,18 +11,16 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // DecompressConfig defines the config for Decompress middleware. - DecompressConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// DecompressConfig defines the config for Decompress middleware. +type DecompressConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers - GzipDecompressPool Decompressor - } -) + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor +} -//GZIPEncoding content-encoding header if set to "gzip", decompress body contents. +// GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" // Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers @@ -30,14 +28,6 @@ type Decompressor interface { gzipDecompressPool() sync.Pool } -var ( - //DefaultDecompressConfig defines the config for decompress middleware - DefaultDecompressConfig = DecompressConfig{ - Skipper: DefaultSkipper, - GzipDecompressPool: &DefaultGzipDecompressPool{}, - } -) - // DefaultGzipDecompressPool is the default implementation of Decompressor interface type DefaultGzipDecompressPool struct { } @@ -65,19 +55,23 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { } } -//Decompress decompresses request body based if content encoding type is set to "gzip" with default config +// Decompress decompresses request body based if content encoding type is set to "gzip" with default config func Decompress() echo.MiddlewareFunc { - return DecompressWithConfig(DefaultDecompressConfig) + return DecompressWithConfig(DecompressConfig{}) } -//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config +// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration. func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration +func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = DefaultSkipper } if config.GzipDecompressPool == nil { - config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool + config.GzipDecompressPool = &DefaultGzipDecompressPool{} } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -116,5 +110,5 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 51fa6b0f..d3a4008a 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -17,6 +17,31 @@ import ( func TestDecompress(t *testing.T) { e := echo.New() + + h := Decompress()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + // Decompress request body + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) +} + +func TestDecompress_skippedIfNoHeader(t *testing.T) { + e := echo.New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -26,39 +51,42 @@ func TestDecompress(t *testing.T) { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - assert := assert.New(t) - assert.Equal("test", rec.Body.String()) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, "test", rec.Body.String()) - // Decompress - body := `{"name": "echo"}` - gz, _ := gzipString(body) - req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) - req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) - assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - b, err := ioutil.ReadAll(req.Body) - assert.NoError(err) - assert.Equal(body, string(b)) } -func TestDecompressDefaultConfig(t *testing.T) { +func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { + h, err := DecompressConfig{}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + })(c) + assert.NoError(t, err) + + assert.Equal(t, "test", rec.Body.String()) + +} + +func TestDecompressWithConfig_DefaultConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := Decompress()(func(c echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - - assert := assert.New(t) - assert.Equal("test", rec.Body.String()) // Decompress body := `{"name": "echo"}` @@ -67,11 +95,14 @@ func TestDecompressDefaultConfig(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec = httptest.NewRecorder() c = e.NewContext(req, rec) - h(c) - assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := ioutil.ReadAll(req.Body) - assert.NoError(err) - assert.Equal(body, string(b)) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) } func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { @@ -82,7 +113,9 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := ioutil.ReadAll(req.Body) assert.NoError(t, err) @@ -99,7 +132,10 @@ func TestDecompressNoContent(t *testing.T) { h := Decompress()(func(c echo.Context) error { return c.NoContent(http.StatusNoContent) }) - if assert.NoError(t, h(c)) { + + err := h(c) + + if assert.NoError(t, err) { assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) assert.Equal(t, 0, len(rec.Body.Bytes())) @@ -115,7 +151,9 @@ func TestDecompressErrorReturned(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } @@ -132,7 +170,9 @@ func TestDecompressSkipper(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) reqBody, err := ioutil.ReadAll(c.Request().Body) assert.NoError(t, err) @@ -161,7 +201,9 @@ func TestDecompressPoolError(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) reqBody, err := ioutil.ReadAll(c.Request().Body) assert.NoError(t, err) diff --git a/middleware/extractor.go b/middleware/extractor.go new file mode 100644 index 00000000..2cb51afc --- /dev/null +++ b/middleware/extractor.go @@ -0,0 +1,148 @@ +package middleware + +import ( + "fmt" + "github.com/labstack/echo/v4" + "net/http" + "net/textproto" + "strings" +) + +// ErrExtractionValueMissing denotes an error raised when value could not be extracted from request +var ErrExtractionValueMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed value") + +// ExtractorType is enum type for where extractor will take its data +type ExtractorType string + +const ( + // HeaderExtractor tells extractor to take values from request header + HeaderExtractor ExtractorType = "header" + // QueryExtractor tells extractor to take values from request query parameters + QueryExtractor ExtractorType = "query" + // ParamExtractor tells extractor to take values from request route parameters + ParamExtractor ExtractorType = "param" + // CookieExtractor tells extractor to take values from request cookie + CookieExtractor ExtractorType = "cookie" + // FormExtractor tells extractor to take values from request form fields + FormExtractor ExtractorType = "form" +) + +func createExtractors(lookups string) ([]valuesExtractor, error) { + sources := strings.Split(lookups, ",") + var extractors []valuesExtractor + for _, source := range sources { + parts := strings.Split(source, ":") + if len(parts) < 2 { + return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source) + } + + switch ExtractorType(parts[0]) { + case QueryExtractor: + extractors = append(extractors, valuesFromQuery(parts[1])) + case ParamExtractor: + extractors = append(extractors, valuesFromParam(parts[1])) + case CookieExtractor: + extractors = append(extractors, valuesFromCookie(parts[1])) + case FormExtractor: + extractors = append(extractors, valuesFromForm(parts[1])) + case HeaderExtractor: + prefix := "" + if len(parts) > 2 { + prefix = parts[2] + } + extractors = append(extractors, valuesFromHeader(parts[1], prefix)) + } + } + return extractors, nil +} + +// valuesFromHeader returns a functions that extracts values from the request header. +func valuesFromHeader(header string, valuePrefix string) valuesExtractor { + prefixLen := len(valuePrefix) + return func(c echo.Context) ([]string, ExtractorType, error) { + values := textproto.MIMEHeader(c.Request().Header).Values(header) + if len(values) == 0 { + return nil, HeaderExtractor, ErrExtractionValueMissing + } + + result := make([]string, 0) + for _, value := range values { + if prefixLen == 0 { + result = append(result, value) + continue + } + if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { + result = append(result, value[prefixLen:]) + } + } + if len(result) == 0 { + return nil, HeaderExtractor, ErrExtractionValueMissing + } + return result, HeaderExtractor, nil + } +} + +// valuesFromQuery returns a function that extracts values from the query string. +func valuesFromQuery(param string) valuesExtractor { + return func(c echo.Context) ([]string, ExtractorType, error) { + result := c.QueryParams()[param] + if len(result) == 0 { + return nil, QueryExtractor, ErrExtractionValueMissing + } + return result, QueryExtractor, nil + + } +} + +// valuesFromParam returns a function that extracts values from the url param string. +func valuesFromParam(param string) valuesExtractor { + return func(c echo.Context) ([]string, ExtractorType, error) { + result := make([]string, 0) + for _, p := range c.PathParams() { + if param == p.Name { + result = append(result, p.Value) + } + } + if len(result) == 0 { + return nil, ParamExtractor, ErrExtractionValueMissing + } + return result, ParamExtractor, nil + } +} + +// valuesFromCookie returns a function that extracts values from the named cookie. +func valuesFromCookie(name string) valuesExtractor { + return func(c echo.Context) ([]string, ExtractorType, error) { + cookies := c.Cookies() + if len(cookies) == 0 { + return nil, CookieExtractor, ErrExtractionValueMissing + } + + result := make([]string, 0) + for _, cookie := range cookies { + if name == cookie.Name { + result = append(result, cookie.Value) + } + } + if len(result) == 0 { + return nil, CookieExtractor, ErrExtractionValueMissing + } + return result, CookieExtractor, nil + } +} + +// valuesFromForm returns a function that extracts values from the form field. +func valuesFromForm(name string) valuesExtractor { + return func(c echo.Context) ([]string, ExtractorType, error) { + if err := c.Request().ParseForm(); err != nil { + return nil, FormExtractor, fmt.Errorf("valuesFromForm parse form failed: %w", err) + } + values := c.Request().Form[name] + if len(values) == 0 { + return nil, FormExtractor, ErrExtractionValueMissing + } + + result := append([]string{}, values...) + return result, FormExtractor, nil + } +} diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go new file mode 100644 index 00000000..3be9cab9 --- /dev/null +++ b/middleware/extractor_test.go @@ -0,0 +1,498 @@ +package middleware + +import ( + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func TestCreateExtractors(t *testing.T) { + var testCases = []struct { + name string + givenRequest func() *http.Request + givenPathParams echo.PathParams + whenLoopups string + expectValues []string + expectExtractorType ExtractorType + expectCreateError string + expectError string + }{ + { + name: "ok, header", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "Bearer token") + return req + }, + whenLoopups: "header:Authorization:Bearer ", + expectValues: []string{"token"}, + expectExtractorType: HeaderExtractor, + }, + { + name: "ok, form", + givenRequest: func() *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + return req + }, + whenLoopups: "form:name", + expectValues: []string{"Jon Snow"}, + expectExtractorType: FormExtractor, + }, + { + name: "ok, cookie", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderCookie, "_csrf=token") + return req + }, + whenLoopups: "cookie:_csrf", + expectValues: []string{"token"}, + expectExtractorType: CookieExtractor, + }, + { + name: "ok, param", + givenPathParams: echo.PathParams{ + {Name: "id", Value: "123"}, + }, + whenLoopups: "param:id", + expectValues: []string{"123"}, + expectExtractorType: ParamExtractor, + }, + { + name: "ok, query", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/?id=999", nil) + return req + }, + whenLoopups: "query:id", + expectValues: []string{"999"}, + expectExtractorType: QueryExtractor, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + req = tc.givenRequest() + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(echo.EditableContext) + if tc.givenPathParams != nil { + c.SetRawPathParams(&tc.givenPathParams) + } + + extractors, err := createExtractors(tc.whenLoopups) + if tc.expectCreateError != "" { + assert.EqualError(t, err, tc.expectCreateError) + return + } + assert.NoError(t, err) + + for _, e := range extractors { + values, eType, eErr := e(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, tc.expectExtractorType, eType) + if tc.expectError != "" { + assert.EqualError(t, eErr, tc.expectError) + return + } + assert.NoError(t, eErr) + } + }) + } +} + +func TestValuesFromHeader(t *testing.T) { + exampleRequest := func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + } + + var testCases = []struct { + name string + givenRequest func(req *http.Request) + whenName string + whenValuePrefix string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "ok, single value, case insensitive", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "Basic ", + expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "ok, multiple value", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"}, + }, + { + name: "ok, empty prefix", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "", + expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "nok, no matching due different prefix", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "Bearer ", + expectError: ErrExtractionValueMissing.Error(), + }, + { + name: "nok, no matching due different prefix", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderWWWAuthenticate, + whenValuePrefix: "", + expectError: ErrExtractionValueMissing.Error(), + }, + { + name: "nok, no headers", + givenRequest: nil, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectError: ErrExtractionValueMissing.Error(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + tc.givenRequest(req) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix) + + values, eType, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, HeaderExtractor, eType) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromQuery(t *testing.T) { + var testCases = []struct { + name string + givenQueryPart string + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenQueryPart: "?id=123&name=test", + whenName: "id", + expectValues: []string{"123"}, + }, + { + name: "ok, multiple value", + givenQueryPart: "?id=123&id=456&name=test", + whenName: "id", + expectValues: []string{"123", "456"}, + }, + { + name: "nok, missing value", + givenQueryPart: "?id=123&name=test", + whenName: "nope", + expectError: ErrExtractionValueMissing.Error(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromQuery(tc.whenName) + + values, eType, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, QueryExtractor, eType) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromParam(t *testing.T) { + examplePathParams := echo.PathParams{ + {Name: "id", Value: "123"}, + {Name: "gid", Value: "456"}, + {Name: "gid", Value: "789"}, + } + + var testCases = []struct { + name string + givenPathParams echo.PathParams + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenPathParams: examplePathParams, + whenName: "id", + expectValues: []string{"123"}, + }, + { + name: "ok, multiple value", + givenPathParams: examplePathParams, + whenName: "gid", + expectValues: []string{"456", "789"}, + }, + { + name: "nok, no values", + givenPathParams: nil, + whenName: "nope", + expectValues: nil, + expectError: ErrExtractionValueMissing.Error(), + }, + { + name: "nok, no matching value", + givenPathParams: examplePathParams, + whenName: "nope", + expectValues: nil, + expectError: ErrExtractionValueMissing.Error(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(echo.EditableContext) + if tc.givenPathParams != nil { + c.SetRawPathParams(&tc.givenPathParams) + } + + extractor := valuesFromParam(tc.whenName) + + values, eType, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ParamExtractor, eType) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromCookie(t *testing.T) { + exampleRequest := func(req *http.Request) { + req.Header.Set(echo.HeaderCookie, "_csrf=token") + } + + var testCases = []struct { + name string + givenRequest func(req *http.Request) + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenRequest: exampleRequest, + whenName: "_csrf", + expectValues: []string{"token"}, + }, + { + name: "ok, multiple value", + givenRequest: func(req *http.Request) { + req.Header.Add(echo.HeaderCookie, "_csrf=token") + req.Header.Add(echo.HeaderCookie, "_csrf=token2") + }, + whenName: "_csrf", + expectValues: []string{"token", "token2"}, + }, + { + name: "nok, no matching cookie", + givenRequest: exampleRequest, + whenName: "xxx", + expectValues: nil, + expectError: ErrExtractionValueMissing.Error(), + }, + { + name: "nok, no cookies at all", + givenRequest: nil, + whenName: "xxx", + expectValues: nil, + expectError: ErrExtractionValueMissing.Error(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + tc.givenRequest(req) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromCookie(tc.whenName) + + values, eType, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, CookieExtractor, eType) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromForm(t *testing.T) { + examplePostFormRequest := func(mod func(v *url.Values)) *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + f.Set("emails[]", "jon@labstack.com") + if mod != nil { + mod(&f) + } + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + + return req + } + exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + f.Set("emails[]", "jon@labstack.com") + if mod != nil { + mod(&f) + } + + req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil) + return req + } + + var testCases = []struct { + name string + givenRequest *http.Request + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, POST form, single value", + givenRequest: examplePostFormRequest(nil), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com"}, + }, + { + name: "ok, POST form, multiple value", + givenRequest: examplePostFormRequest(func(v *url.Values) { + v.Add("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, + { + name: "ok, GET form, single value", + givenRequest: exampleGetFormRequest(nil), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com"}, + }, + { + name: "ok, GET form, multiple value", + givenRequest: examplePostFormRequest(func(v *url.Values) { + v.Add("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, + { + name: "nok, POST form, value missing", + givenRequest: examplePostFormRequest(nil), + whenName: "nope", + expectError: ErrExtractionValueMissing.Error(), + }, + { + name: "nok, POST form, form parsing error", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Body = nil + return req + }(), + whenName: "name", + expectError: "valuesFromForm parse form failed: missing form body", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := tc.givenRequest + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromForm(tc.whenName) + + values, eType, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, FormExtractor, eType) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/middleware/jwt.go b/middleware/jwt.go index 21e33ab8..815e18f0 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -1,134 +1,84 @@ -// +build go1.15 - package middleware import ( "errors" "fmt" - "net/http" - "reflect" - "strings" - - "github.com/golang-jwt/jwt" "github.com/labstack/echo/v4" + "net/http" ) -type ( - // JWTConfig defines the config for JWT middleware. - JWTConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// JWTConfig defines the config for JWT middleware. +type JWTConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc BeforeFunc + // BeforeFunc defines a function which is executed just before the middleware. + BeforeFunc BeforeFunc - // SuccessHandler defines a function which is executed for a valid token. - SuccessHandler JWTSuccessHandler + // SuccessHandler defines a function which is executed for a valid token. + SuccessHandler JWTSuccessHandler - // ErrorHandler defines a function which is executed for an invalid token. - // It may be used to define a custom JWT error. - ErrorHandler JWTErrorHandler + // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator + // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. + // It may be used to define a custom JWT error. + // + // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. + // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public JWT token value to request and continue with handler chain. + ErrorHandler JWTErrorHandlerWithContext - // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. - ErrorHandlerWithContext JWTErrorHandlerWithContext + // Context key to store user information from the token into context. + // Optional. Default value "user". + ContextKey string - // Signing key to validate token. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKeys is provided. - SigningKey interface{} + // TokenLookup is a string in the form of ":" or ":,:" that is used + // to extract token(s) from the request. + // Optional. Default value "header:Authorization:Bearer ". + // Possible values: + // - "header:" + // - "query:" + // - "param:" + // - "cookie:" + // - "form:" + // Multiple sources example: + // - "header:Authorization,cookie:myowncookie" + TokenLookup string - // Map of signing keys to validate token with kid field usage. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKey is provided. - SigningKeys map[string]interface{} + // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token + // parsing fails or parsed token is invalid. + // NB: could be called multiple times per request when token lookup is able to extract multiple token values (i.e. multiple Authorization headers) + // See `jwt_external_test.go` for example implementation using `github.com/golang-jwt/jwt` as JWT implementation library + ParseTokenFunc func(c echo.Context, auth string) (interface{}, error) +} - // Signing method used to check the token's signing algorithm. - // Optional. Default value HS256. - SigningMethod string +// JWTSuccessHandler defines a function which is executed for a valid token. +type JWTSuccessHandler func(c echo.Context) - // Context key to store user information from the token into context. - // Optional. Default value "user". - ContextKey string +// JWTErrorHandler defines a function which is executed for an invalid token. +type JWTErrorHandler func(err error) error - // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. - // Not used if custom ParseTokenFunc is set. - // Optional. Default value jwt.MapClaims - Claims jwt.Claims +// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. +type JWTErrorHandlerWithContext func(c echo.Context, err error) error - // TokenLookup is a string in the form of ":" or ":,:" that is used - // to extract token from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" - // - "query:" - // - "param:" - // - "cookie:" - // - "form:" - // Multiply sources example: - // - "header: Authorization,cookie: myowncookie" +type valuesExtractor func(c echo.Context) ([]string, ExtractorType, error) - TokenLookup string - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // KeyFunc defines a user-defined function that supplies the public key for a token validation. - // The function shall take care of verifying the signing algorithm and selecting the proper key. - // A user-defined KeyFunc can be useful if tokens are issued by an external party. - // Used by default ParseTokenFunc implementation. - // - // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither SigningKeys nor SigningKey is provided. - // Not used if custom ParseTokenFunc is set. - // Default to an internal implementation verifying the signing algorithm and selecting the proper key. - KeyFunc jwt.Keyfunc - - // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token - // parsing fails or parsed token is invalid. - // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library - ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) - } - - // JWTSuccessHandler defines a function which is executed for a valid token. - JWTSuccessHandler func(echo.Context) - - // JWTErrorHandler defines a function which is executed for an invalid token. - JWTErrorHandler func(error) error - - // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. - JWTErrorHandlerWithContext func(error, echo.Context) error - - jwtExtractor func(echo.Context) (string, error) -) - -// Algorithms const ( + // AlgorithmHS256 is token signing algorithm AlgorithmHS256 = "HS256" ) -// Errors -var ( - ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") - ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") -) +// ErrJWTMissing denotes an error raised when JWT token value could not be extracted from request +var ErrJWTMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing or malformed jwt") -var ( - // DefaultJWTConfig is the default JWT auth middleware config. - DefaultJWTConfig = JWTConfig{ - Skipper: DefaultSkipper, - SigningMethod: AlgorithmHS256, - ContextKey: "user", - TokenLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - Claims: jwt.MapClaims{}, - KeyFunc: nil, - } -) +// ErrJWTInvalid denotes an error raised when JWT token value is invalid or expired +var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") + +// DefaultJWTConfig is the default JWT auth middleware config. +var DefaultJWTConfig = JWTConfig{ + Skipper: DefaultSkipper, + ContextKey: "user", + TokenLookup: "header:" + echo.HeaderAuthorization + ":Bearer ", +} // JWT returns a JSON Web Token (JWT) auth middleware. // @@ -137,64 +87,43 @@ var ( // For missing token, it returns "400 - Bad Request" error. // // See: https://jwt.io/introduction -// See `JWTConfig.TokenLookup` -func JWT(key interface{}) echo.MiddlewareFunc { +func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc { c := DefaultJWTConfig - c.SigningKey = key + c.ParseTokenFunc = parseTokenFunc return JWTWithConfig(c) } -// JWTWithConfig returns a JWT auth middleware with config. -// See: `JWT()`. +// JWTWithConfig returns a JSON Web Token (JWT) auth middleware or panics if configuration is invalid. +// +// For valid token, it sets the user in context and calls next handler. +// For invalid token, it returns "401 - Unauthorized" error. +// For missing token, it returns "400 - Bad Request" error. +// +// See: https://jwt.io/introduction func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts JWTConfig to middleware or returns an error for invalid configuration +func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultJWTConfig.Skipper } - if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil { - panic("echo: jwt middleware requires signing key") - } - if config.SigningMethod == "" { - config.SigningMethod = DefaultJWTConfig.SigningMethod + if config.ParseTokenFunc == nil { + return nil, errors.New("echo jwt middleware requires parse token function") } if config.ContextKey == "" { config.ContextKey = DefaultJWTConfig.ContextKey } - if config.Claims == nil { - config.Claims = DefaultJWTConfig.Claims - } if config.TokenLookup == "" { config.TokenLookup = DefaultJWTConfig.TokenLookup } - if config.AuthScheme == "" { - config.AuthScheme = DefaultJWTConfig.AuthScheme + extractors, err := createExtractors(config.TokenLookup) + if err != nil { + return nil, fmt.Errorf("echo jwt middleware could not create token extractor: %w", err) } - if config.KeyFunc == nil { - config.KeyFunc = config.defaultKeyFunc - } - if config.ParseTokenFunc == nil { - config.ParseTokenFunc = config.defaultParseToken - } - - // Initialize - // Split sources - sources := strings.Split(config.TokenLookup, ",") - var extractors []jwtExtractor - for _, source := range sources { - parts := strings.Split(source, ":") - - switch parts[0] { - case "query": - extractors = append(extractors, jwtFromQuery(parts[1])) - case "param": - extractors = append(extractors, jwtFromParam(parts[1])) - case "cookie": - extractors = append(extractors, jwtFromCookie(parts[1])) - case "form": - extractors = append(extractors, jwtFromForm(parts[1])) - case "header": - extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme)) - } + if len(extractors) == 0 { + return nil, errors.New("echo jwt middleware could not create extractors from TokenLookup string") } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -206,142 +135,55 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.BeforeFunc != nil { config.BeforeFunc(c) } - var auth string - var err error + var lastExtractorErr error + var lastTokenErr error for _, extractor := range extractors { - // Extract token from extractor, if it's not fail break the loop and - // set auth - auth, err = extractor(c) - if err == nil { - break + auths, _, extrErr := extractor(c) + if extrErr != nil { + lastExtractorErr = extrErr + continue + } + for _, auth := range auths { + token, err := config.ParseTokenFunc(c, auth) + if err != nil { + lastTokenErr = err + continue + } + // Store user information from token into context. + c.Set(config.ContextKey, token) + if config.SuccessHandler != nil { + config.SuccessHandler(c) + } + return next(c) } } - // If none of extractor has a token, handle error - if err != nil { - if config.ErrorHandler != nil { - return config.ErrorHandler(err) - } - if config.ErrorHandlerWithContext != nil { - return config.ErrorHandlerWithContext(err, c) - } - return err - } - - token, err := config.ParseTokenFunc(auth, c) + // prioritize token errors over extracting errors + err := lastTokenErr if err == nil { - // Store user information from token into context. - c.Set(config.ContextKey, token) - if config.SuccessHandler != nil { - config.SuccessHandler(c) + err = lastExtractorErr + } + if config.ErrorHandler != nil { + if err == ErrExtractionValueMissing { + err = ErrJWTMissing + } + // Allow error handler to swallow the error and continue handler chain execution + // Useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public token to request and continue with handler chain + if handledErr := config.ErrorHandler(c, err); handledErr != nil { + return handledErr } return next(c) } - if config.ErrorHandler != nil { - return config.ErrorHandler(err) - } - if config.ErrorHandlerWithContext != nil { - return config.ErrorHandlerWithContext(err, c) + if err == ErrExtractionValueMissing { + return ErrJWTMissing } + // everything else goes under http.StatusUnauthorized to avoid exposing JWT internals with generic error return &echo.HTTPError{ Code: ErrJWTInvalid.Code, Message: ErrJWTInvalid.Message, Internal: err, } } - } -} - -func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) { - token := new(jwt.Token) - var err error - // Issue #647, #656 - if _, ok := config.Claims.(jwt.MapClaims); ok { - token, err = jwt.Parse(auth, config.KeyFunc) - } else { - t := reflect.ValueOf(config.Claims).Type().Elem() - claims := reflect.New(t).Interface().(jwt.Claims) - token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc) - } - if err != nil { - return nil, err - } - if !token.Valid { - return nil, errors.New("invalid token") - } - return token, nil -} - -// defaultKeyFunc returns a signing key of the given token. -func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { - // Check the signing method - if t.Method.Alg() != config.SigningMethod { - return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) - } - if len(config.SigningKeys) > 0 { - if kid, ok := t.Header["kid"].(string); ok { - if key, ok := config.SigningKeys[kid]; ok { - return key, nil - } - } - return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) - } - - return config.SigningKey, nil -} - -// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header. -func jwtFromHeader(header string, authScheme string) jwtExtractor { - return func(c echo.Context) (string, error) { - auth := c.Request().Header.Get(header) - l := len(authScheme) - if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) { - return auth[l+1:], nil - } - return "", ErrJWTMissing - } -} - -// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string. -func jwtFromQuery(param string) jwtExtractor { - return func(c echo.Context) (string, error) { - token := c.QueryParam(param) - if token == "" { - return "", ErrJWTMissing - } - return token, nil - } -} - -// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string. -func jwtFromParam(param string) jwtExtractor { - return func(c echo.Context) (string, error) { - token := c.Param(param) - if token == "" { - return "", ErrJWTMissing - } - return token, nil - } -} - -// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie. -func jwtFromCookie(name string) jwtExtractor { - return func(c echo.Context) (string, error) { - cookie, err := c.Cookie(name) - if err != nil { - return "", ErrJWTMissing - } - return cookie.Value, nil - } -} - -// jwtFromForm returns a `jwtExtractor` that extracts token from the form field. -func jwtFromForm(name string) jwtExtractor { - return func(c echo.Context) (string, error) { - field := c.FormValue(name) - if field == "" { - return "", ErrJWTMissing - } - return field, nil - } + }, nil } diff --git a/middleware/jwt_external_test.go b/middleware/jwt_external_test.go new file mode 100644 index 00000000..172958e3 --- /dev/null +++ b/middleware/jwt_external_test.go @@ -0,0 +1,76 @@ +package middleware_test + +import ( + "errors" + "fmt" + "github.com/golang-jwt/jwt/v4" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "net/http" + "net/http/httptest" +) + +// CreateJWTGoParseTokenFunc creates JWTGo implementation for ParseTokenFunc +// +// signingKey is signing key to validate token. +// This is one of the options to provide a token validation key. +// The order of precedence is a user-defined SigningKeys and SigningKey. +// Required if signingKeys is not provided. +// +// signingKeys is Map of signing keys to validate token with kid field usage. +// This is one of the options to provide a token validation key. +// The order of precedence is a user-defined SigningKeys and SigningKey. +// Required if signingKey is not provided +func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) { + // keyFunc defines a user-defined function that supplies the public key for a token validation. + // The function shall take care of verifying the signing algorithm and selecting the proper key. + // A user-defined KeyFunc can be useful if tokens are issued by an external party. + keyFunc := func(t *jwt.Token) (interface{}, error) { + if t.Method.Alg() != middleware.AlgorithmHS256 { + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) + } + if len(signingKeys) == 0 { + return signingKey, nil + } + + if kid, ok := t.Header["kid"].(string); ok { + if key, ok := signingKeys[kid]; ok { + return key, nil + } + } + return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) + } + + return func(c echo.Context, auth string) (interface{}, error) { + token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here + if err != nil { + return nil, err + } + if !token.Valid { + return nil, errors.New("invalid token") + } + return token, nil + } +} + +func ExampleJWTConfig_withJWTGoAsTokenParser() { + mw := middleware.JWTWithConfig(middleware.JWTConfig{ + ParseTokenFunc: CreateJWTGoParseTokenFunc([]byte("secret"), nil), + }) + + e := echo.New() + e.Use(mw) + + e.GET("/", func(c echo.Context) error { + user := c.Get("user").(*jwt.Token) + return c.JSON(http.StatusTeapot, user.Claims) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + fmt.Printf("status: %v, body: %v", res.Code, res.Body.String()) + // Output: status: 418, body: {"admin":true,"name":"John Doe","sub":"1234567890"} +} diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 5f36ce0a..64b19ddc 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -1,5 +1,3 @@ -// +build go1.15 - package middleware import ( @@ -11,11 +9,32 @@ import ( "strings" "testing" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v4" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) +func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) { + // This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running + keyFunc := func(t *jwt.Token) (interface{}, error) { + if t.Method.Alg() != signingMethod { + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) + } + return signingKey, nil + } + + return func(c echo.Context, auth string) (interface{}, error) { + token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) + if err != nil { + return nil, err + } + if !token.Valid { + return nil, errors.New("invalid token") + } + return token, nil + } +} + // jwtCustomInfo defines some custom types we're going to use within our tokens. type jwtCustomInfo struct { Name string `json:"name"` @@ -28,43 +47,7 @@ type jwtCustomClaims struct { jwtCustomInfo } -func TestJWTRace(t *testing.T) { - e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" - raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss" - validKey := []byte("secret") - - h := JWTWithConfig(JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: validKey, - })(handler) - - makeReq := func(token string) echo.Context { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token) - c := e.NewContext(req, res) - assert.NoError(t, h(c)) - return c - } - - c := makeReq(initialToken) - user := c.Get("user").(*jwt.Token) - claims := user.Claims.(*jwtCustomClaims) - assert.Equal(t, claims.Name, "John Doe") - - makeReq(raceToken) - user = c.Get("user").(*jwt.Token) - claims = user.Claims.(*jwtCustomClaims) - // Initial context should still be "John Doe", not "Race Condition" - assert.Equal(t, claims.Name, "John Doe") - assert.Equal(t, claims.Admin, true) -} - -func TestJWT(t *testing.T) { +func TestJWT_combinations(t *testing.T) { e := echo.New() handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") @@ -72,344 +55,236 @@ func TestJWT(t *testing.T) { token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" validKey := []byte("secret") invalidKey := []byte("invalid-key") - validAuth := DefaultJWTConfig.AuthScheme + " " + token + validAuth := "Bearer " + token - for _, tc := range []struct { - expPanic bool - expErrCode int // 0 for Success + var testCases = []struct { + name string config JWTConfig reqURL string // "/" if empty hdrAuth string hdrCookie string // test.Request doesn't provide SetCookie(); use name=val formValues map[string]string - info string + expPanic bool + expErrCode int // 0 for Success }{ { expPanic: true, - info: "No signing key provided", - }, - { - expErrCode: http.StatusBadRequest, - config: JWTConfig{ - SigningKey: validKey, - SigningMethod: "RS256", - }, - info: "Unexpected signing method", + name: "No signing key provided", }, { expErrCode: http.StatusUnauthorized, hdrAuth: validAuth, - config: JWTConfig{SigningKey: invalidKey}, - info: "Invalid key", + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey), + }, + name: "Unexpected signing method", + }, + { + expErrCode: http.StatusUnauthorized, + hdrAuth: validAuth, + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, invalidKey), + }, + name: "Invalid key", }, { hdrAuth: validAuth, - config: JWTConfig{SigningKey: validKey}, - info: "Valid JWT", + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, + name: "Valid JWT", }, { hdrAuth: "Token" + " " + token, - config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, - info: "Valid JWT with custom AuthScheme", + config: JWTConfig{ + TokenLookup: "header:" + echo.HeaderAuthorization + ":Token ", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, + name: "Valid JWT with custom AuthScheme", }, { hdrAuth: validAuth, config: JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: []byte("secret"), + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), }, - info: "Valid JWT with custom claims", + name: "Valid JWT with custom claims", }, { hdrAuth: "invalid-auth", - expErrCode: http.StatusBadRequest, - config: JWTConfig{SigningKey: validKey}, - info: "Invalid Authorization header", - }, - { - config: JWTConfig{SigningKey: validKey}, - expErrCode: http.StatusBadRequest, - info: "Empty header auth field", + expErrCode: http.StatusUnauthorized, + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, + name: "Invalid Authorization header", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + }, + expErrCode: http.StatusUnauthorized, + name: "Empty header auth field", + }, + { + config: JWTConfig{ + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwt=" + token, - info: "Valid query method", + name: "Valid query method", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwtxyz=" + token, - expErrCode: http.StatusBadRequest, - info: "Invalid query param name", + expErrCode: http.StatusUnauthorized, + name: "Invalid query param name", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwt=invalid-token", expErrCode: http.StatusUnauthorized, - info: "Invalid query param value", + name: "Invalid query param value", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt", }, reqURL: "/?a=b", - expErrCode: http.StatusBadRequest, - info: "Empty query", + expErrCode: http.StatusUnauthorized, + name: "Empty query", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "param:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "param:jwt", }, reqURL: "/" + token, - info: "Valid param method", + name: "Valid param method", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "cookie:jwt", }, hdrCookie: "jwt=" + token, - info: "Valid cookie method", + name: "Valid cookie method", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt,cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "query:jwt,cookie:jwt", }, hdrCookie: "jwt=" + token, - info: "Multiple jwt lookuop", + name: "Multiple jwt lookuop", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "cookie:jwt", }, expErrCode: http.StatusUnauthorized, hdrCookie: "jwt=invalid", - info: "Invalid token with cookie method", + name: "Invalid token with cookie method", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "cookie:jwt", }, - expErrCode: http.StatusBadRequest, - info: "Empty cookie", + expErrCode: http.StatusUnauthorized, + name: "Empty cookie", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "form:jwt", }, formValues: map[string]string{"jwt": token}, - info: "Valid form method", + name: "Valid form method", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "form:jwt", }, expErrCode: http.StatusUnauthorized, formValues: map[string]string{"jwt": "invalid"}, - info: "Invalid token with form method", + name: "Invalid token with form method", }, { config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "form:jwt", - }, - expErrCode: http.StatusBadRequest, - info: "Empty form field", - }, - { - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return validKey, nil - }, - }, - info: "Valid JWT with a valid key using a user-defined KeyFunc", - }, - { - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return invalidKey, nil - }, + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey), + TokenLookup: "form:jwt", }, expErrCode: http.StatusUnauthorized, - info: "Valid JWT with an invalid key using a user-defined KeyFunc", + name: "Empty form field", }, - { - hdrAuth: validAuth, - config: JWTConfig{ - KeyFunc: func(*jwt.Token) (interface{}, error) { - return nil, errors.New("faulty KeyFunc") - }, - }, - expErrCode: http.StatusUnauthorized, - info: "Token verification does not pass using a user-defined KeyFunc", - }, - { - hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token, - config: JWTConfig{SigningKey: validKey}, - info: "Valid JWT with lower case AuthScheme", - }, - } { - if tc.reqURL == "" { - tc.reqURL = "/" - } - - var req *http.Request - if len(tc.formValues) > 0 { - form := url.Values{} - for k, v := range tc.formValues { - form.Set(k, v) - } - req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) - req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") - req.ParseForm() - } else { - req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) - } - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - req.Header.Set(echo.HeaderCookie, tc.hdrCookie) - c := e.NewContext(req, res) - - if tc.reqURL == "/"+token { - c.SetParamNames("jwt") - c.SetParamValues(token) - } - - if tc.expPanic { - assert.Panics(t, func() { - JWTWithConfig(tc.config) - }, tc.info) - continue - } - - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - assert.Equal(t, tc.expErrCode, he.Code, tc.info) - continue - } - - h := JWTWithConfig(tc.config)(handler) - if assert.NoError(t, h(c), tc.info) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - assert.Equal(t, claims["name"], "John Doe", tc.info) - case *jwtCustomClaims: - assert.Equal(t, claims.Name, "John Doe", tc.info) - assert.Equal(t, claims.Admin, true, tc.info) - default: - panic("unexpected type of claims") - } - } } -} -func TestJWTwithKID(t *testing.T) { - test := assert.New(t) - - e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk" - secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM" - wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90" - staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o" - validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")} - invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")} - staticSecret := []byte("static_secret") - invalidStaticSecret := []byte("invalid_secret") - - for _, tc := range []struct { - expErrCode int // 0 for Success - config JWTConfig - hdrAuth string - info string - }{ - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "First token valid", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Second token valid", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Wrong key id token", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: staticSecret}, - info: "Valid static secret token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: invalidStaticSecret}, - info: "Invalid static secret", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys first token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys second token", - }, - } { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - c := e.NewContext(req, res) - - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - test.Equal(tc.expErrCode, he.Code, tc.info) - continue - } - - h := JWTWithConfig(tc.config)(handler) - if test.NoError(h(c), tc.info) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - test.Equal(claims["name"], "John Doe", tc.info) - case *jwtCustomClaims: - test.Equal(claims.Name, "John Doe", tc.info) - test.Equal(claims.Admin, true, tc.info) - default: - panic("unexpected type of claims") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.reqURL == "" { + tc.reqURL = "/" } - } + + var req *http.Request + if len(tc.formValues) > 0 { + form := url.Values{} + for k, v := range tc.formValues { + form.Set(k, v) + } + req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) + req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") + req.ParseForm() + } else { + req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) + } + res := httptest.NewRecorder() + req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) + req.Header.Set(echo.HeaderCookie, tc.hdrCookie) + c := e.NewContext(req, res) + + if tc.reqURL == "/"+token { + cc := c.(echo.EditableContext) + cc.SetPathParams(echo.PathParams{ + {Name: "jwt", Value: token}, + }) + } + + if tc.expPanic { + assert.Panics(t, func() { + JWTWithConfig(tc.config) + }, tc.name) + return + } + + if tc.expErrCode != 0 { + h := JWTWithConfig(tc.config)(handler) + he := h(c).(*echo.HTTPError) + assert.Equal(t, tc.expErrCode, he.Code) + return + } + + h := JWTWithConfig(tc.config)(handler) + if assert.NoError(t, h(c), tc.name) { + user := c.Get("user").(*jwt.Token) + switch claims := user.Claims.(type) { + case jwt.MapClaims: + assert.Equal(t, claims["name"], "John Doe") + case *jwtCustomClaims: + assert.Equal(t, claims.Name, "John Doe") + assert.Equal(t, claims.Admin, true) + default: + panic("unexpected type of claims") + } + } + }) } } @@ -420,7 +295,7 @@ func TestJWTConfig_skipper(t *testing.T) { Skipper: func(context echo.Context) bool { return true // skip everything }, - SigningKey: []byte("secret"), + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), })) isCalled := false @@ -448,11 +323,11 @@ func TestJWTConfig_BeforeFunc(t *testing.T) { BeforeFunc: func(context echo.Context) { isCalled = true }, - SigningKey: []byte("secret"), + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), })) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) @@ -469,18 +344,8 @@ func TestJWTConfig_extractorErrorHandling(t *testing.T) { { name: "ok, ErrorHandler is executed", given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandler: func(err error) error { - return echo.NewHTTPError(http.StatusTeapot, "custom_error") - }, - }, - expectStatusCode: http.StatusTeapot, - }, - { - name: "ok, ErrorHandlerWithContext is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, context echo.Context) error { + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), + ErrorHandler: func(c echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "custom_error") }, }, @@ -515,23 +380,13 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { { name: "ok, ErrorHandler is executed", given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandler: func(err error) error { + ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")), + ErrorHandler: func(c echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) }, }, expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n", }, - { - name: "ok, ErrorHandlerWithContext is executed", - given: JWTConfig{ - SigningKey: []byte("secret"), - ErrorHandlerWithContext: func(err error, context echo.Context) error { - return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error()) - }, - }, - expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n", - }, } for _, tc := range testCases { @@ -544,14 +399,14 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { config := tc.given parseTokenCalled := false - config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) { + config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) { parseTokenCalled = true return nil, errors.New("parsing failed") } e.Use(JWTWithConfig(config)) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) @@ -574,7 +429,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { signingKey := []byte("secret") config := JWTConfig{ - ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) { + ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { keyFunc := func(t *jwt.Token) (interface{}, error) { if t.Method.Alg() != "HS256" { return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) @@ -597,9 +452,161 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { e.Use(JWTWithConfig(config)) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, http.StatusTeapot, res.Code) } + +func TestMustJWTWithConfig_SuccessHandler(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + success := c.Get("success").(string) + user := c.Get("user").(string) + return c.String(http.StatusTeapot, fmt.Sprintf("%v:%v", success, user)) + }) + + mw, err := JWTConfig{ + ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + return auth, nil + }, + SuccessHandler: func(c echo.Context) { + c.Set("success", "yes") + }, + }.ToMiddleware() + assert.NoError(t, err) + e.Use(mw) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderAuthorization, "Bearer valid_token_base64") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, "yes:valid_token_base64", res.Body.String()) + assert.Equal(t, http.StatusTeapot, res.Code) +} + +func TestJWTWithConfig_CallNextOnNilErrorHandlerResult(t *testing.T) { + var testCases = []struct { + name string + givenCallNext bool + givenErrorHandler JWTErrorHandlerWithContext + givenTokenLookup string + whenAuthHeaders []string + whenCookies []string + whenParseReturn string + whenParseError error + expectHandlerCalled bool + expect string + expectCode int + }{ + { + name: "ok, with valid JWT from auth header", + givenCallNext: true, + givenErrorHandler: func(c echo.Context, err error) error { + return nil + }, + whenAuthHeaders: []string{"Bearer valid_token_base64"}, + whenParseReturn: "valid_token", + expectCode: http.StatusTeapot, + expect: "valid_token", + }, + { + name: "ok, missing header, callNext and set public_token from error handler", + givenCallNext: true, + givenErrorHandler: func(c echo.Context, err error) error { + if err != ErrJWTMissing { + panic("must get ErrJWTMissing") + } + c.Set("user", "public_token") + return nil + }, + whenAuthHeaders: []string{}, // no JWT header + expectCode: http.StatusTeapot, + expect: "public_token", + }, + { + name: "ok, invalid token, callNext and set public_token from error handler", + givenCallNext: true, + givenErrorHandler: func(c echo.Context, err error) error { + // this is probably not realistic usecase. on parse error you probably want to return error + if err.Error() != "parser_error" { + panic("must get parser_error") + } + c.Set("user", "public_token") + return nil + }, + whenAuthHeaders: []string{"Bearer invalid_header"}, + whenParseError: errors.New("parser_error"), + expectCode: http.StatusTeapot, + expect: "public_token", + }, + { + name: "nok, invalid token, return error from error handler", + givenCallNext: true, + givenErrorHandler: func(c echo.Context, err error) error { + if err.Error() != "parser_error" { + panic("must get parser_error") + } + return err + }, + whenAuthHeaders: []string{"Bearer invalid_header"}, + whenParseError: errors.New("parser_error"), + expectCode: http.StatusInternalServerError, + expect: "{\"message\":\"Internal Server Error\"}\n", + }, + { + name: "nok, callNext but return error from error handler", + givenCallNext: true, + givenErrorHandler: func(c echo.Context, err error) error { + return err + }, + whenAuthHeaders: []string{}, // no JWT header + expectCode: http.StatusUnauthorized, + expect: "{\"message\":\"missing or malformed jwt\"}\n", + }, + { + name: "nok, callNext=false", + givenCallNext: false, + givenErrorHandler: func(c echo.Context, err error) error { + return err + }, + whenAuthHeaders: []string{}, // no JWT header + expectCode: http.StatusUnauthorized, + expect: "{\"message\":\"missing or malformed jwt\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + token := c.Get("user").(string) + return c.String(http.StatusTeapot, token) + }) + + mw, err := JWTConfig{ + TokenLookup: tc.givenTokenLookup, + ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + return tc.whenParseReturn, tc.whenParseError + }, + ErrorHandler: tc.givenErrorHandler, + }.ToMiddleware() + assert.NoError(t, err) + e.Use(mw) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + for _, a := range tc.whenAuthHeaders { + req.Header.Add(echo.HeaderAuthorization, a) + } + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, tc.expect, res.Body.String()) + assert.Equal(t, tc.expectCode, res.Code) + }) + } +} diff --git a/middleware/key_auth.go b/middleware/key_auth.go index 54f3b47f..82c3d6b9 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -3,58 +3,59 @@ package middleware import ( "errors" "fmt" - "net/http" - "strings" - "github.com/labstack/echo/v4" + "net/http" ) -type ( - // KeyAuthConfig defines the config for KeyAuth middleware. - KeyAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// KeyAuthConfig defines the config for KeyAuth middleware. +type KeyAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // KeyLookup is a string in the form of ":" that is used - // to extract key from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" - // - "query:" - // - "form:" - // - "cookie:" - KeyLookup string `yaml:"key_lookup"` + // KeyLookup is a string in the form of ":" or ":,:" that is used + // to extract key(s) from the request. + // Optional. Default value "header:Authorization:Bearer ". + // Possible values: + // - "header::" + // - "query:" + // - "param:" + // - "cookie:" + // - "form:" + // Multiple sources example: + // - "header:Authorization:Bearer ,cookie:myowncookie" + KeyLookup string - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string + // Validator is a function to validate key. + // Required. + Validator KeyAuthValidator - // Validator is a function to validate key. - // Required. - Validator KeyAuthValidator + // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator + // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. + // It may be used to define a custom error. + // + // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. + // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain. + ErrorHandler KeyAuthErrorHandler +} - // ErrorHandler defines a function which is executed for an invalid key. - // It may be used to define a custom error. - ErrorHandler KeyAuthErrorHandler - } +// KeyAuthValidator defines a function to validate KeyAuth credentials. +type KeyAuthValidator func(c echo.Context, key string, keyType ExtractorType) (bool, error) - // KeyAuthValidator defines a function to validate KeyAuth credentials. - KeyAuthValidator func(string, echo.Context) (bool, error) +// KeyAuthErrorHandler defines a function which is executed for an invalid key. +type KeyAuthErrorHandler func(c echo.Context, err error) error - keyExtractor func(echo.Context) (string, error) +// ErrKeyMissing denotes an error raised when key value could not be extracted from request +var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") - // KeyAuthErrorHandler defines a function which is executed for an invalid key. - KeyAuthErrorHandler func(error, echo.Context) error -) +// ErrInvalidKey denotes an error raised when key value is invalid by validator +var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") -var ( - // DefaultKeyAuthConfig is the default KeyAuth middleware config. - DefaultKeyAuthConfig = KeyAuthConfig{ - Skipper: DefaultSkipper, - KeyLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - } -) +// DefaultKeyAuthConfig is the default KeyAuth middleware config. +var DefaultKeyAuthConfig = KeyAuthConfig{ + Skipper: DefaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ", +} // KeyAuth returns an KeyAuth middleware. // @@ -67,34 +68,32 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc { return KeyAuthWithConfig(c) } -// KeyAuthWithConfig returns an KeyAuth middleware with config. -// See `KeyAuth()`. +// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid. +// +// For first valid key it calls the next handler. +// For invalid key, it sends "401 - Unauthorized" response. +// For missing key, it sends "400 - Bad Request" response. func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration +func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultKeyAuthConfig.Skipper } - // Defaults - if config.AuthScheme == "" { - config.AuthScheme = DefaultKeyAuthConfig.AuthScheme - } if config.KeyLookup == "" { config.KeyLookup = DefaultKeyAuthConfig.KeyLookup } if config.Validator == nil { - panic("echo: key-auth middleware requires a validator function") + return nil, errors.New("echo key-auth middleware requires a validator function") } - - // Initialize - parts := strings.Split(config.KeyLookup, ":") - extractor := keyFromHeader(parts[1], config.AuthScheme) - switch parts[0] { - case "query": - extractor = keyFromQuery(parts[1]) - case "form": - extractor = keyFromForm(parts[1]) - case "cookie": - extractor = keyFromCookie(parts[1]) + extractors, err := createExtractors(config.KeyLookup) + if err != nil { + return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", err) + } + if len(extractors) == 0 { + return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string") } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -103,79 +102,50 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { return next(c) } - // Extract and verify key - key, err := extractor(c) - if err != nil { - if config.ErrorHandler != nil { - return config.ErrorHandler(err, c) + var lastExtractorErr error + var lastValidatorErr error + for _, extractor := range extractors { + keys, keyType, extrErr := extractor(c) + if extrErr != nil { + lastExtractorErr = extrErr + continue + } + for _, key := range keys { + valid, err := config.Validator(c, key, keyType) + if err != nil { + lastValidatorErr = err + continue + } + if !valid { + lastValidatorErr = ErrInvalidKey + continue + } + return next(c) } - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - valid, err := config.Validator(key, c) - if err != nil { - if config.ErrorHandler != nil { - return config.ErrorHandler(err, c) + + // prioritize validator errors over extracting errors + err := lastValidatorErr + if err == nil { + err = lastExtractorErr + } + if config.ErrorHandler != nil { + // Allow error handler to swallow the error and continue handler chain execution + // Useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain + if handledErr := config.ErrorHandler(c, err); handledErr != nil { + return handledErr } - return &echo.HTTPError{ - Code: http.StatusUnauthorized, - Message: "invalid key", - Internal: err, - } - } else if valid { return next(c) } - return echo.ErrUnauthorized - } - } -} - -// keyFromHeader returns a `keyExtractor` that extracts key from the request header. -func keyFromHeader(header string, authScheme string) keyExtractor { - return func(c echo.Context) (string, error) { - auth := c.Request().Header.Get(header) - if auth == "" { - return "", errors.New("missing key in request header") - } - if header == echo.HeaderAuthorization { - l := len(authScheme) - if len(auth) > l+1 && auth[:l] == authScheme { - return auth[l+1:], nil + if err == ErrExtractionValueMissing { + return ErrKeyMissing // do not wrap extractor errors + } + return &echo.HTTPError{ + Code: http.StatusUnauthorized, + Message: "Unauthorized", + Internal: err, } - return "", errors.New("invalid key in the request header") } - return auth, nil - } -} - -// keyFromQuery returns a `keyExtractor` that extracts key from the query string. -func keyFromQuery(param string) keyExtractor { - return func(c echo.Context) (string, error) { - key := c.QueryParam(param) - if key == "" { - return "", errors.New("missing key in the query string") - } - return key, nil - } -} - -// keyFromForm returns a `keyExtractor` that extracts key from the form. -func keyFromForm(param string) keyExtractor { - return func(c echo.Context) (string, error) { - key := c.FormValue(param) - if key == "" { - return "", errors.New("missing key in the form") - } - return key, nil - } -} - -// keyFromCookie returns a `keyExtractor` that extracts key from the form. -func keyFromCookie(cookieName string) keyExtractor { - return func(c echo.Context) (string, error) { - key, err := c.Cookie(cookieName) - if err != nil { - return "", fmt.Errorf("missing key in cookies: %w", err) - } - return key.Value, nil - } + }, nil } diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index 0cc513ab..e81d5baa 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/assert" ) -func testKeyValidator(key string, c echo.Context) (bool, error) { +func testKeyValidator(c echo.Context, key string, keyType ExtractorType) (bool, error) { switch key { case "valid-key": return true, nil @@ -28,7 +28,7 @@ func TestKeyAuth(t *testing.T) { handlerCalled = true return c.String(http.StatusOK, "test") } - middlewareChain := KeyAuth(testKeyValidator)(handler) + middlewareChain := KeyAuthWithConfig(KeyAuthConfig{Validator: testKeyValidator})(handler) e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") }, expectHandlerCalled: false, - expectError: "code=401, message=Unauthorized", + expectError: "code=401, message=Unauthorized, internal=code=401, message=invalid key", }, { name: "nok, defaults, invalid scheme in header", @@ -84,13 +84,13 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") }, expectHandlerCalled: false, - expectError: "code=400, message=invalid key in the request header", + expectError: "code=401, message=missing key", }, { name: "nok, defaults, missing header", givenRequest: func(req *http.Request) {}, expectHandlerCalled: false, - expectError: "code=400, message=missing key in request header", + expectError: "code=401, message=missing key", }, { name: "ok, custom key lookup, header", @@ -110,7 +110,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "header:API-Key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in request header", + expectError: "code=401, message=missing key", }, { name: "ok, custom key lookup, query", @@ -130,7 +130,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "query:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in the query string", + expectError: "code=401, message=missing key", }, { name: "ok, custom key lookup, form", @@ -155,7 +155,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "form:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in the form", + expectError: "code=401, message=missing key", }, { name: "ok, custom key lookup, cookie", @@ -179,20 +179,20 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in cookies: http: named cookie not present", + expectError: "code=401, message=missing key", }, { name: "nok, custom errorHandler, error from extractor", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:token" - conf.ErrorHandler = func(err error, context echo.Context) error { + conf.ErrorHandler = func(c echo.Context, err error) error { httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError.Internal = err return httpError } }, expectHandlerCalled: false, - expectError: "code=418, message=custom, internal=missing key in request header", + expectError: "code=418, message=custom, internal=code=400, message=missing or malformed value", }, { name: "nok, custom errorHandler, error from validator", @@ -200,7 +200,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { - conf.ErrorHandler = func(err error, context echo.Context) error { + conf.ErrorHandler = func(c echo.Context, err error) error { httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError.Internal = err return httpError @@ -216,7 +216,7 @@ func TestKeyAuthWithConfig(t *testing.T) { }, whenConfig: func(conf *KeyAuthConfig) {}, expectHandlerCalled: false, - expectError: "code=401, message=invalid key, internal=some user defined error", + expectError: "code=401, message=Unauthorized, internal=some user defined error", }, } @@ -257,3 +257,96 @@ func TestKeyAuthWithConfig(t *testing.T) { }) } } + +func TestKeyAuthWithConfig_errors(t *testing.T) { + var testCases = []struct { + name string + whenConfig KeyAuthConfig + expectError string + }{ + { + name: "ok, no error", + whenConfig: KeyAuthConfig{ + Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) { + return false, nil + }, + }, + }, + { + name: "ok, missing validator func", + whenConfig: KeyAuthConfig{ + Validator: nil, + }, + expectError: "echo key-auth middleware requires a validator function", + }, + { + name: "ok, extractor source can not be split", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope", + Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope", + }, + { + name: "ok, no extractors", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope:nope", + Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create extractors from KeyLookup string", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mw, err := tc.whenConfig.ToMiddleware() + if tc.expectError != "" { + assert.Nil(t, mw) + assert.EqualError(t, err, tc.expectError) + } else { + assert.NotNil(t, mw) + assert.NoError(t, err) + } + }) + } +} + +func TestMustKeyAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + KeyAuthWithConfig(KeyAuthConfig{}) + }) +} + +func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) { + handlerCalled := false + var authValue string + handler := func(c echo.Context) error { + handlerCalled = true + authValue = c.Get("auth").(string) + return c.String(http.StatusOK, "test") + } + middlewareChain := KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + ErrorHandler: func(c echo.Context, err error) error { + // could check error to decide if we can swallow the error + c.Set("auth", "public") + return nil + }, + })(handler) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // no auth header this time + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middlewareChain(c) + + assert.NoError(t, err) + assert.True(t, handlerCalled) + assert.Equal(t, "public", authValue) +} diff --git a/middleware/logger.go b/middleware/logger.go index 9baac476..636be52e 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -3,6 +3,7 @@ package middleware import ( "bytes" "encoding/json" + "fmt" "io" "strconv" "strings" @@ -10,81 +11,78 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/color" "github.com/valyala/fasttemplate" ) -type ( - // LoggerConfig defines the config for Logger middleware. - LoggerConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// LoggerConfig defines the config for Logger middleware. +type LoggerConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Tags to construct the logger format. - // - // - time_unix - // - time_unix_nano - // - time_rfc3339 - // - time_rfc3339_nano - // - time_custom - // - id (Request ID) - // - remote_ip - // - uri - // - host - // - method - // - path - // - protocol - // - referer - // - user_agent - // - status - // - error - // - latency (In nanoseconds) - // - latency_human (Human readable) - // - bytes_in (Bytes received) - // - bytes_out (Bytes sent) - // - header: - // - query: - // - form: - // - // Example "${remote_ip} ${status}" - // - // Optional. Default value DefaultLoggerConfig.Format. - Format string `yaml:"format"` + // Tags to construct the logger format. + // + // - time_unix + // - time_unix_nano + // - time_rfc3339 + // - time_rfc3339_nano + // - time_custom + // - id (Request ID) + // - remote_ip + // - uri + // - host + // - method + // - path + // - protocol + // - referer + // - user_agent + // - status + // - error + // - latency (In nanoseconds) + // - latency_human (Human readable) + // - bytes_in (Bytes received) + // - bytes_out (Bytes sent) + // - header: + // - query: + // - form: + // + // Example "${remote_ip} ${status}" + // + // Optional. Default value DefaultLoggerConfig.Format. + Format string - // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. - CustomTimeFormat string `yaml:"custom_time_format"` + // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. + CustomTimeFormat string - // Output is a writer where logs in JSON format are written. - // Optional. Default value os.Stdout. - Output io.Writer + // Output is a writer where logs in JSON format are written. + // Optional. Default destination `echo.Logger.Infof()` + Output io.Writer - template *fasttemplate.Template - colorer *color.Color - pool *sync.Pool - } -) + template *fasttemplate.Template + pool *sync.Pool +} -var ( - // DefaultLoggerConfig is the default Logger middleware config. - DefaultLoggerConfig = LoggerConfig{ - Skipper: DefaultSkipper, - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + - `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + - `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + - `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", - CustomTimeFormat: "2006-01-02 15:04:05.00000", - colorer: color.New(), - } -) +// DefaultLoggerConfig is the default Logger middleware config. +var DefaultLoggerConfig = LoggerConfig{ + Skipper: DefaultSkipper, + Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + + `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + + `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + + `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", + CustomTimeFormat: "2006-01-02 15:04:05.00000", +} // Logger returns a middleware that logs HTTP requests. func Logger() echo.MiddlewareFunc { return LoggerWithConfig(DefaultLoggerConfig) } -// LoggerWithConfig returns a Logger middleware with config. -// See: `Logger()`. +// LoggerWithConfig returns a Logger middleware with config or panics on invalid configuration. func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts LoggerConfig to middleware or returns an error for invalid configuration +func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultLoggerConfig.Skipper @@ -92,13 +90,8 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { if config.Format == "" { config.Format = DefaultLoggerConfig.Format } - if config.Output == nil { - config.Output = DefaultLoggerConfig.Output - } config.template = fasttemplate.New(config.Format, "${", "}") - config.colorer = color.New() - config.colorer.SetOutput(config.Output) config.pool = &sync.Pool{ New: func() interface{} { return bytes.NewBuffer(make([]byte, 256)) @@ -106,23 +99,23 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() + start := time.Now() - if err = next(c); err != nil { - c.Error(err) - } + err := next(c) stop := time.Now() + buf := config.pool.Get().(*bytes.Buffer) buf.Reset() defer config.pool.Put(buf) - if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { + _, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { switch tag { case "time_unix": return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) @@ -161,17 +154,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { case "user_agent": return buf.WriteString(req.UserAgent()) case "status": - n := res.Status - s := config.colorer.Green(n) - switch { - case n >= 500: - s = config.colorer.Red(n) - case n >= 400: - s = config.colorer.Yellow(n) - case n >= 300: - s = config.colorer.Cyan(n) + status := res.Status + if err != nil { + if httpErr, ok := err.(*echo.HTTPError); ok { + status = httpErr.Code + } } - return buf.WriteString(s) + return buf.WriteString(strconv.Itoa(status)) case "error": if err != nil { // Error may contain invalid JSON e.g. `"` @@ -201,23 +190,31 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { case strings.HasPrefix(tag, "form:"): return buf.Write([]byte(c.FormValue(tag[5:]))) case strings.HasPrefix(tag, "cookie:"): - cookie, err := c.Cookie(tag[7:]) - if err == nil { + cookie, cookieErr := c.Cookie(tag[7:]) + if cookieErr == nil { return buf.Write([]byte(cookie.Value)) } } } return 0, nil - }); err != nil { - return + }) + if tmplErr != nil { + if err != nil { + return fmt.Errorf("error in middleware chain and also failed to create log from template: %v: %w", tmplErr, err) + } + return fmt.Errorf("failed to create log from template: %w", tmplErr) } - if config.Output == nil { - _, err = c.Logger().Output().Write(buf.Bytes()) - return + if config.Output != nil { + if _, lErr := config.Output.Write(buf.Bytes()); lErr != nil { + return lErr + } + } else { + if _, lErr := c.Echo().Logger.Write(buf.Bytes()); lErr != nil { + return lErr + } } - _, err = config.Output.Write(buf.Bytes()) - return + return err } - } + }, nil } diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 394f6271..44fbd8dd 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -61,7 +61,7 @@ func TestLoggerIPAddress(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + e.Logger = &testLogger{output: buf} ip := "127.0.0.1" h := Logger()(func(c echo.Context) error { return c.String(http.StatusOK, "test") diff --git a/middleware/method_override.go b/middleware/method_override.go index 92b14d2e..7e46feb9 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -6,28 +6,24 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // MethodOverrideConfig defines the config for MethodOverride middleware. - MethodOverrideConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// MethodOverrideConfig defines the config for MethodOverride middleware. +type MethodOverrideConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Getter is a function that gets overridden method from the request. - // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). - Getter MethodOverrideGetter - } + // Getter is a function that gets overridden method from the request. + // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). + Getter MethodOverrideGetter +} - // MethodOverrideGetter is a function that gets overridden method from the request - MethodOverrideGetter func(echo.Context) string -) +// MethodOverrideGetter is a function that gets overridden method from the request +type MethodOverrideGetter func(echo.Context) string -var ( - // DefaultMethodOverrideConfig is the default MethodOverride middleware config. - DefaultMethodOverrideConfig = MethodOverrideConfig{ - Skipper: DefaultSkipper, - Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), - } -) +// DefaultMethodOverrideConfig is the default MethodOverride middleware config. +var DefaultMethodOverrideConfig = MethodOverrideConfig{ + Skipper: DefaultSkipper, + Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), +} // MethodOverride returns a MethodOverride middleware. // MethodOverride middleware checks for the overridden method from the request and @@ -38,9 +34,13 @@ func MethodOverride() echo.MiddlewareFunc { return MethodOverrideWithConfig(DefaultMethodOverrideConfig) } -// MethodOverrideWithConfig returns a MethodOverride middleware with config. -// See: `MethodOverride()`. +// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration +func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultMethodOverrideConfig.Skipper @@ -64,7 +64,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 5760b158..58116168 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -22,28 +22,70 @@ func TestMethodOverride(t *testing.T) { rec := httptest.NewRecorder() req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) c := e.NewContext(req, rec) - m(h)(c) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_formParam(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + // Override with form parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) - req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) - rec = httptest.NewRecorder() + m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) + rec := httptest.NewRecorder() req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) - c = e.NewContext(req, rec) - m(h)(c) + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_queryParam(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } // Override with query parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) - req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - m(h)(c) + m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_ignoreGet(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } // Ignore `GET` - req = httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodGet, req.Method) } diff --git a/middleware/middleware.go b/middleware/middleware.go index a7ad73a5..9d0a0df2 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -9,14 +9,11 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // Skipper defines a function to skip middleware. Returning true skips processing - // the middleware. - Skipper func(echo.Context) bool +// Skipper defines a function to skip middleware. Returning true skips processing the middleware. +type Skipper func(c echo.Context) bool - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc func(echo.Context) -) +// BeforeFunc defines a function which is executed just before the middleware. +type BeforeFunc func(c echo.Context) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { groups := pattern.FindAllStringSubmatch(input, -1) @@ -87,3 +84,11 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error func DefaultSkipper(echo.Context) bool { return false } + +func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc { + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} diff --git a/middleware/proxy.go b/middleware/proxy.go index 6cfd6731..f677861f 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "errors" "fmt" "io" "math/rand" @@ -20,85 +21,81 @@ import ( // TODO: Handle TLS proxy -type ( - // ProxyConfig defines the config for Proxy middleware. - ProxyConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// ProxyConfig defines the config for Proxy middleware. +type ProxyConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Balancer defines a load balancing technique. - // Required. - Balancer ProxyBalancer + // Balancer defines a load balancing technique. + // Required. + Balancer ProxyBalancer - // Rewrite defines URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Examples: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - Rewrite map[string]string + // Rewrite defines URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Examples: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + Rewrite map[string]string - // RegexRewrite defines rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRewrite map[*regexp.Regexp]string + // RegexRewrite defines rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRewrite map[*regexp.Regexp]string - // Context key to store selected ProxyTarget into context. - // Optional. Default value "target". - ContextKey string + // Context key to store selected ProxyTarget into context. + // Optional. Default value "target". + ContextKey string - // To customize the transport to remote. - // Examples: If custom TLS certificates are required. - Transport http.RoundTripper + // To customize the transport to remote. + // Examples: If custom TLS certificates are required. + Transport http.RoundTripper - // ModifyResponse defines function to modify response from ProxyTarget. - ModifyResponse func(*http.Response) error - } + // ModifyResponse defines function to modify response from ProxyTarget. + ModifyResponse func(*http.Response) error +} - // ProxyTarget defines the upstream target. - ProxyTarget struct { - Name string - URL *url.URL - Meta echo.Map - } +// ProxyTarget defines the upstream target. +type ProxyTarget struct { + Name string + URL *url.URL + Meta echo.Map +} - // ProxyBalancer defines an interface to implement a load balancing technique. - ProxyBalancer interface { - AddTarget(*ProxyTarget) bool - RemoveTarget(string) bool - Next(echo.Context) *ProxyTarget - } +// ProxyBalancer defines an interface to implement a load balancing technique. +type ProxyBalancer interface { + AddTarget(*ProxyTarget) bool + RemoveTarget(string) bool + Next(echo.Context) *ProxyTarget +} - commonBalancer struct { - targets []*ProxyTarget - mutex sync.RWMutex - } +type commonBalancer struct { + targets []*ProxyTarget + mutex sync.RWMutex +} - // RandomBalancer implements a random load balancing technique. - randomBalancer struct { - *commonBalancer - random *rand.Rand - } +// RandomBalancer implements a random load balancing technique. +type randomBalancer struct { + *commonBalancer + random *rand.Rand +} - // RoundRobinBalancer implements a round-robin load balancing technique. - roundRobinBalancer struct { - *commonBalancer - i uint32 - } -) +// RoundRobinBalancer implements a round-robin load balancing technique. +type roundRobinBalancer struct { + *commonBalancer + i uint32 +} -var ( - // DefaultProxyConfig is the default Proxy middleware config. - DefaultProxyConfig = ProxyConfig{ - Skipper: DefaultSkipper, - ContextKey: "target", - } -) +// DefaultProxyConfig is the default Proxy middleware config. +var DefaultProxyConfig = ProxyConfig{ + Skipper: DefaultSkipper, + ContextKey: "target", +} -func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { +func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { in, _, err := c.Response().Hijack() if err != nil { @@ -203,15 +200,23 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc { return ProxyWithConfig(c) } -// ProxyWithConfig returns a Proxy middleware with config. -// See: `Proxy()` +// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid. +// +// Proxy middleware forwards the request to upstream server using a configured load balancing technique. func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration +func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultProxyConfig.Skipper } + if config.ContextKey == "" { + config.ContextKey = DefaultProxyConfig.ContextKey + } if config.Balancer == nil { - panic("echo: proxy middleware requires balancer") + return nil, errors.New("echo proxy middleware requires balancer") } if config.Rewrite != nil { @@ -254,10 +259,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // Proxy switch { case c.IsWebSocket(): - proxyRaw(tgt, c).ServeHTTP(res, req) + proxyRaw(c, tgt).ServeHTTP(res, req) case req.Header.Get(echo.HeaderAccept) == "text/event-stream": default: - proxyHTTP(tgt, c, config).ServeHTTP(res, req) + proxyHTTP(c, tgt, config).ServeHTTP(res, req) } if e, ok := c.Get("_error").(error); ok { err = e @@ -265,7 +270,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { return } - } + }, nil } // StatusCodeContextCanceled is a custom HTTP status code for situations @@ -275,7 +280,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // 499 too instead of the more problematic 5xx, which does not allow to detect this situation const StatusCodeContextCanceled = 499 -func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { +func proxyHTTP(c echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler { proxy := httputil.NewSingleHostReverseProxy(tgt.URL) proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { desc := tgt.URL.String() diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 7939fc5c..4dd1b3ad 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -55,7 +55,7 @@ func TestProxy(t *testing.T) { // Random e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -77,7 +77,7 @@ func TestProxy(t *testing.T) { // Round-robin rrb := NewRoundRobinBalancer(targets) e = echo.New() - e.Use(Proxy(rrb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -113,15 +113,20 @@ func TestProxy(t *testing.T) { return nil } } - rrb1 := NewRoundRobinBalancer(targets) e = echo.New() e.Use(contextObserver) - e.Use(Proxy(rrb1)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) } +func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) { + assert.Panics(t, func() { + ProxyWithConfig(ProxyConfig{Balancer: nil}) + }) +} + func TestProxyRealIPHeader(t *testing.T) { // Setup upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) @@ -129,7 +134,7 @@ func TestProxyRealIPHeader(t *testing.T) { url, _ := url.Parse(upstream.URL) rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) e := echo.New() - e.Use(Proxy(rrb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -334,7 +339,7 @@ func TestProxyError(t *testing.T) { // Random e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -362,7 +367,7 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { rb := NewRandomBalancer(nil) assert.True(t, rb.AddTarget(target)) e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) ctx, cancel := context.WithCancel(req.Context()) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index 0291eb45..81e885ad 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "net/http" "sync" "time" @@ -9,39 +10,33 @@ import ( "golang.org/x/time/rate" ) -type ( - // RateLimiterStore is the interface to be implemented by custom stores. - RateLimiterStore interface { - // Stores for the rate limiter have to implement the Allow method - Allow(identifier string) (bool, error) - } -) +// RateLimiterStore is the interface to be implemented by custom stores. +type RateLimiterStore interface { + Allow(identifier string) (bool, error) +} -type ( - // RateLimiterConfig defines the configuration for the rate limiter - RateLimiterConfig struct { - Skipper Skipper - BeforeFunc BeforeFunc - // IdentifierExtractor uses echo.Context to extract the identifier for a visitor - IdentifierExtractor Extractor - // Store defines a store for the rate limiter - Store RateLimiterStore - // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error - ErrorHandler func(context echo.Context, err error) error - // DenyHandler provides a handler to be called when RateLimiter denies access - DenyHandler func(context echo.Context, identifier string, err error) error - } - // Extractor is used to extract data from echo.Context - Extractor func(context echo.Context) (string, error) -) +// RateLimiterConfig defines the configuration for the rate limiter +type RateLimiterConfig struct { + Skipper Skipper + BeforeFunc BeforeFunc + // IdentifierExtractor uses echo.Context to extract the identifier for a visitor + IdentifierExtractor Extractor + // Store defines a store for the rate limiter + Store RateLimiterStore + // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error + ErrorHandler func(context echo.Context, err error) error + // DenyHandler provides a handler to be called when RateLimiter denies access + DenyHandler func(context echo.Context, identifier string, err error) error +} -// errors -var ( - // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded - ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") - // ErrExtractorError denotes an error raised when extractor function is unsuccessful - ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") -) +// Extractor is used to extract data from echo.Context +type Extractor func(context echo.Context) (string, error) + +// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded +var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + +// ErrExtractorError denotes an error raised when extractor function is unsuccessful +var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") // DefaultRateLimiterConfig defines default values for RateLimiterConfig var DefaultRateLimiterConfig = RateLimiterConfig{ @@ -111,6 +106,11 @@ RateLimiterWithConfig returns a rate limiting middleware }, middleware.RateLimiterWithConfig(config)) */ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration +func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultRateLimiterConfig.Skipper } @@ -124,7 +124,7 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { config.DenyHandler = DefaultRateLimiterConfig.DenyHandler } if config.Store == nil { - panic("Store configuration must be provided") + return nil, errors.New("echo rate limiter store configuration must be provided") } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -137,35 +137,32 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { identifier, err := config.IdentifierExtractor(c) if err != nil { - c.Error(config.ErrorHandler(c, err)) - return nil + return config.ErrorHandler(c, err) } - if allow, err := config.Store.Allow(identifier); !allow { - c.Error(config.DenyHandler(c, identifier, err)) - return nil + if allow, allowErr := config.Store.Allow(identifier); !allow { + return config.DenyHandler(c, identifier, allowErr) } return next(c) } - } + }, nil } -type ( - // RateLimiterMemoryStore is the built-in store implementation for RateLimiter - RateLimiterMemoryStore struct { - visitors map[string]*Visitor - mutex sync.Mutex - rate rate.Limit - burst int - expiresIn time.Duration - lastCleanup time.Time - } - // Visitor signifies a unique user's limiter details - Visitor struct { - *rate.Limiter - lastSeen time.Time - } -) +// RateLimiterMemoryStore is the built-in store implementation for RateLimiter +type RateLimiterMemoryStore struct { + visitors map[string]*Visitor + mutex sync.Mutex + rate rate.Limit + burst int + expiresIn time.Duration + lastCleanup time.Time +} + +// Visitor signifies a unique user's limiter details +type Visitor struct { + *rate.Limiter + lastSeen time.Time +} /* NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 89d9a6ed..99cea87c 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -11,7 +11,6 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" "github.com/stretchr/testify/assert" "golang.org/x/time/rate" ) @@ -25,19 +24,19 @@ func TestRateLimiter(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) - mw := RateLimiter(inMemoryStore) + mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -47,20 +46,25 @@ func TestRateLimiter(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - _ = mw(handler)(c) - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } -func TestRateLimiter_panicBehaviour(t *testing.T) { +func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) assert.Panics(t, func() { - RateLimiter(nil) + RateLimiterWithConfig(RateLimiterConfig{}) }) assert.NotPanics(t, func() { - RateLimiter(inMemoryStore) + RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) }) } @@ -73,7 +77,7 @@ func TestRateLimiterWithConfig(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ IdentifierExtractor: func(c echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { @@ -88,7 +92,8 @@ func TestRateLimiterWithConfig(t *testing.T) { return ctx.JSON(http.StatusBadRequest, nil) }, Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { id string @@ -111,8 +116,9 @@ func TestRateLimiterWithConfig(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) + err := mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, tc.code, rec.Code) } } @@ -126,7 +132,7 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ IdentifierExtractor: func(c echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { @@ -135,19 +141,20 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { return id, nil }, Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"", http.StatusForbidden}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {expectErr: "code=403, message=error while extracting identifier, internal=invalid identifier"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -158,9 +165,13 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) - - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } @@ -174,21 +185,22 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -199,9 +211,13 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) - - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } } @@ -222,7 +238,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Skipper: func(c echo.Context) bool { return true }, @@ -233,10 +249,12 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) - _ = mw(handler)(c) + err = mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, false, beforeFuncRan) } @@ -256,7 +274,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Skipper: func(c echo.Context) bool { return false }, @@ -267,7 +285,8 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) _ = mw(handler)(c) @@ -291,7 +310,7 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ BeforeFunc: func(c echo.Context) { beforeRan = true }, @@ -299,10 +318,12 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { IdentifierExtractor: func(ctx echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) - _ = mw(handler)(c) + err = mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, true, beforeRan) } @@ -413,7 +434,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) { func generateAddressList(count int) []string { addrs := make([]string, count) for i := 0; i < count; i++ { - addrs[i] = random.String(15) + addrs[i] = randomString(15) } return addrs } diff --git a/middleware/recover.go b/middleware/recover.go index 0dbe740d..6d8f16df 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -5,44 +5,34 @@ import ( "runtime" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/log" ) -type ( - // RecoverConfig defines the config for Recover middleware. - RecoverConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RecoverConfig defines the config for Recover middleware. +type RecoverConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Size of the stack to be printed. - // Optional. Default value 4KB. - StackSize int `yaml:"stack_size"` + // Size of the stack to be printed. + // Optional. Default value 4KB. + StackSize int - // DisableStackAll disables formatting stack traces of all other goroutines - // into buffer after the trace for the current goroutine. - // Optional. Default value false. - DisableStackAll bool `yaml:"disable_stack_all"` + // DisableStackAll disables formatting stack traces of all other goroutines + // into buffer after the trace for the current goroutine. + // Optional. Default value false. + DisableStackAll bool - // DisablePrintStack disables printing stack trace. - // Optional. Default value as false. - DisablePrintStack bool `yaml:"disable_print_stack"` + // DisablePrintStack disables printing stack trace. + // Optional. Default value as false. + DisablePrintStack bool +} - // LogLevel is log level to printing stack trace. - // Optional. Default value 0 (Print). - LogLevel log.Lvl - } -) - -var ( - // DefaultRecoverConfig is the default Recover middleware config. - DefaultRecoverConfig = RecoverConfig{ - Skipper: DefaultSkipper, - StackSize: 4 << 10, // 4 KB - DisableStackAll: false, - DisablePrintStack: false, - LogLevel: 0, - } -) +// DefaultRecoverConfig is the default Recover middleware config. +var DefaultRecoverConfig = RecoverConfig{ + Skipper: DefaultSkipper, + StackSize: 4 << 10, // 4 KB + DisableStackAll: false, + DisablePrintStack: false, +} // Recover returns a middleware which recovers from panics anywhere in the chain // and handles the control to the centralized HTTPErrorHandler. @@ -50,9 +40,13 @@ func Recover() echo.MiddlewareFunc { return RecoverWithConfig(DefaultRecoverConfig) } -// RecoverWithConfig returns a Recover middleware with config. -// See: `Recover()`. +// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration. func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration +func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultRecoverConfig.Skipper @@ -62,40 +56,26 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c echo.Context) (err error) { if config.Skipper(c) { return next(c) } defer func() { if r := recover(); r != nil { - err, ok := r.(error) + tmpErr, ok := r.(error) if !ok { - err = fmt.Errorf("%v", r) + tmpErr = fmt.Errorf("%v", r) } - stack := make([]byte, config.StackSize) - length := runtime.Stack(stack, !config.DisableStackAll) if !config.DisablePrintStack { - msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) - switch config.LogLevel { - case log.DEBUG: - c.Logger().Debug(msg) - case log.INFO: - c.Logger().Info(msg) - case log.WARN: - c.Logger().Warn(msg) - case log.ERROR: - c.Logger().Error(msg) - case log.OFF: - // None. - default: - c.Logger().Print(msg) - } + stack := make([]byte, config.StackSize) + length := runtime.Stack(stack, !config.DisableStackAll) + tmpErr = fmt.Errorf("[PANIC RECOVER] %w %s", tmpErr, stack[:length]) } - c.Error(err) + err = tmpErr } }() return next(c) } - } + }, nil } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 64433297..e38fe36b 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -2,82 +2,109 @@ package middleware import ( "bytes" - "fmt" "net/http" "net/http/httptest" "testing" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/log" "github.com/stretchr/testify/assert" ) func TestRecover(t *testing.T) { e := echo.New() buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + e.Logger = &testLogger{output: buf} req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Recover()(echo.HandlerFunc(func(c echo.Context) error { + h := Recover()(func(c echo.Context) error { panic("test") - })) - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, buf.String(), "PANIC RECOVER") + }) + err := h(c) + assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine") + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain + assert.Contains(t, buf.String(), "") // nothing is logged } -func TestRecoverWithConfig_LogLevel(t *testing.T) { - tests := []struct { - logLevel log.Lvl - levelName string - }{{ - logLevel: log.DEBUG, - levelName: "DEBUG", - }, { - logLevel: log.INFO, - levelName: "INFO", - }, { - logLevel: log.WARN, - levelName: "WARN", - }, { - logLevel: log.ERROR, - levelName: "ERROR", - }, { - logLevel: log.OFF, - levelName: "OFF", - }} +func TestRecover_skipper(t *testing.T) { + e := echo.New() - for _, tt := range tests { - tt := tt - t.Run(tt.levelName, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := RecoverConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + } + h := RecoverWithConfig(config)(func(c echo.Context) error { + panic("testPANIC") + }) + + var err error + assert.Panics(t, func() { + err = h(c) + }) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain +} + +func TestRecoverWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenNoPanic bool + whenConfig RecoverConfig + expectErrContain string + expectErr string + }{ + { + name: "ok, default config", + whenConfig: DefaultRecoverConfig, + expectErrContain: "[PANIC RECOVER] testPANIC goroutine", + }, + { + name: "ok, no panic", + givenNoPanic: true, + whenConfig: DefaultRecoverConfig, + expectErrContain: "", + }, + { + name: "ok, DisablePrintStack", + whenConfig: RecoverConfig{ + DisablePrintStack: true, + }, + expectErr: "testPANIC", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { e := echo.New() - e.Logger.SetLevel(log.DEBUG) - - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - config := DefaultRecoverConfig - config.LogLevel = tt.logLevel - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic("test") - })) + config := tc.whenConfig + h := RecoverWithConfig(config)(func(c echo.Context) error { + if tc.givenNoPanic { + return nil + } + panic("testPANIC") + }) - h(c) + err := h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - - output := buf.String() - if tt.logLevel == log.OFF { - assert.Empty(t, output) + if tc.expectErrContain != "" { + assert.Contains(t, err.Error(), tc.expectErrContain) + } else if tc.expectErr != "" { + assert.Contains(t, err.Error(), tc.expectErr) } else { - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName)) + assert.NoError(t, err) } + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain }) } } diff --git a/middleware/redirect.go b/middleware/redirect.go index 13877db3..beca1349 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "net/http" "strings" @@ -14,7 +15,9 @@ type RedirectConfig struct { // Status code to be used when redirecting the request. // Optional. Default value http.StatusMovedPermanently. - Code int `yaml:"code"` + Code int + + redirect redirectLogic } // redirectLogic represents a function that given a scheme, host and uri @@ -24,29 +27,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string) const www = "www." -// DefaultRedirectConfig is the default Redirect middleware config. -var DefaultRedirectConfig = RedirectConfig{ - Skipper: DefaultSkipper, - Code: http.StatusMovedPermanently, -} +// RedirectHTTPSConfig is the HTTPS Redirect middleware config. +var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS} + +// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config. +var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW} + +// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config. +var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW} + +// RedirectWWWConfig is the WWW Redirect middleware config. +var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW} + +// RedirectNonWWWConfig is the non WWW Redirect middleware config. +var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW} // HTTPSRedirect redirects http requests to https. // For example, http://labstack.com will be redirect to https://labstack.com. // // Usage `Echo#Pre(HTTPSRedirect())` func HTTPSRedirect() echo.MiddlewareFunc { - return HTTPSRedirectWithConfig(DefaultRedirectConfig) + return HTTPSRedirectWithConfig(RedirectHTTPSConfig) } -// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSRedirect()`. +// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration. func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if scheme != "https" { - return true, "https://" + host + uri - } - return false, "" - }) + config.redirect = redirectHTTPS + return toMiddlewareOrPanic(config) } // HTTPSWWWRedirect redirects http requests to https www. @@ -54,18 +61,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSWWWRedirect())` func HTTPSWWWRedirect() echo.MiddlewareFunc { - return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig) } -// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSWWWRedirect()`. +// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration. func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if scheme != "https" && !strings.HasPrefix(host, www) { - return true, "https://www." + host + uri - } - return false, "" - }) + config.redirect = redirectHTTPSWWW + return toMiddlewareOrPanic(config) } // HTTPSNonWWWRedirect redirects http requests to https non www. @@ -73,19 +75,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSNonWWWRedirect())` func HTTPSNonWWWRedirect() echo.MiddlewareFunc { - return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig) } -// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSNonWWWRedirect()`. +// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration. func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if scheme != "https" { - host = strings.TrimPrefix(host, www) - return true, "https://" + host + uri - } - return false, "" - }) + config.redirect = redirectNonHTTPSWWW + return toMiddlewareOrPanic(config) } // WWWRedirect redirects non www requests to www. @@ -93,18 +89,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(WWWRedirect())` func WWWRedirect() echo.MiddlewareFunc { - return WWWRedirectWithConfig(DefaultRedirectConfig) + return WWWRedirectWithConfig(RedirectWWWConfig) } -// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `WWWRedirect()`. +// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration. func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if !strings.HasPrefix(host, www) { - return true, scheme + "://www." + host + uri - } - return false, "" - }) + config.redirect = redirectWWW + return toMiddlewareOrPanic(config) } // NonWWWRedirect redirects www requests to non www. @@ -112,26 +103,25 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(NonWWWRedirect())` func NonWWWRedirect() echo.MiddlewareFunc { - return NonWWWRedirectWithConfig(DefaultRedirectConfig) + return NonWWWRedirectWithConfig(RedirectNonWWWConfig) } -// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `NonWWWRedirect()`. +// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration. func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if strings.HasPrefix(host, www) { - return true, scheme + "://" + host[4:] + uri - } - return false, "" - }) + config.redirect = redirectNonWWW + return toMiddlewareOrPanic(config) } -func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { +// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration +func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultRedirectConfig.Skipper + config.Skipper = DefaultSkipper } if config.Code == 0 { - config.Code = DefaultRedirectConfig.Code + config.Code = http.StatusMovedPermanently + } + if config.redirect == nil { + return nil, errors.New("redirectConfig is missing redirect function") } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -142,11 +132,47 @@ func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { req, scheme := c.Request(), c.Scheme() host := req.Host - if ok, url := cb(scheme, host, req.RequestURI); ok { + if ok, url := config.redirect(scheme, host, req.RequestURI); ok { return c.Redirect(config.Code, url) } return next(c) } - } + }, nil +} + +var redirectHTTPS = func(scheme, host, uri string) (bool, string) { + if scheme != "https" { + return true, "https://" + host + uri + } + return false, "" +} + +var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) { + if scheme != "https" && !strings.HasPrefix(host, www) { + return true, "https://www." + host + uri + } + return false, "" +} + +var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) { + if scheme != "https" { + host = strings.TrimPrefix(host, www) + return true, "https://" + host + uri + } + return false, "" +} + +var redirectWWW = func(scheme, host, uri string) (bool, string) { + if !strings.HasPrefix(host, www) { + return true, scheme + "://www." + host + uri + } + return false, "" +} + +var redirectNonWWW = func(scheme, host, uri string) (bool, string) { + if strings.HasPrefix(host, www) { + return true, scheme + "://" + host[4:] + uri + } + return false, "" } diff --git a/middleware/request_id.go b/middleware/request_id.go index b0baeeb2..9ebeb727 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -2,45 +2,38 @@ package middleware import ( "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" ) -type ( - // RequestIDConfig defines the config for RequestID middleware. - RequestIDConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RequestIDConfig defines the config for RequestID middleware. +type RequestIDConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Generator defines a function to generate an ID. - // Optional. Default value random.String(32). - Generator func() string + // Generator defines a function to generate an ID. + // Optional. Default value random.String(32). + Generator func() string - // RequestIDHandler defines a function which is executed for a request id. - RequestIDHandler func(echo.Context, string) - } -) - -var ( - // DefaultRequestIDConfig is the default RequestID middleware config. - DefaultRequestIDConfig = RequestIDConfig{ - Skipper: DefaultSkipper, - Generator: generator, - } -) + // RequestIDHandler defines a function which is executed for a request id. + RequestIDHandler func(c echo.Context, requestID string) +} // RequestID returns a X-Request-ID middleware. func RequestID() echo.MiddlewareFunc { - return RequestIDWithConfig(DefaultRequestIDConfig) + return RequestIDWithConfig(RequestIDConfig{}) } -// RequestIDWithConfig returns a X-Request-ID middleware with config. +// RequestIDWithConfig returns a X-Request-ID middleware with config or panics on invalid configuration. func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration +func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultRequestIDConfig.Skipper + config.Skipper = DefaultSkipper } if config.Generator == nil { - config.Generator = generator + config.Generator = createRandomStringGenerator(32) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -62,9 +55,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { return next(c) } - } -} - -func generator() string { - return random.String(32) + }, nil } diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 944b3b49..7bb320a3 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -18,25 +18,104 @@ func TestRequestID(t *testing.T) { return c.String(http.StatusOK, "test") } - rid := RequestIDWithConfig(RequestIDConfig{}) + rid := RequestID() + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) +} + +func TestMustRequestIDWithConfig_skipper(t *testing.T) { + e := echo.New() + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + generatorCalled := false + e.Use(RequestIDWithConfig(RequestIDConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + Generator: func() string { + generatorCalled = true + return "customGenerator" + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "test", res.Body.String()) + + assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "") + assert.False(t, generatorCalled) +} + +func TestMustRequestIDWithConfig_customGenerator(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") +} + +func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + called := false + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + RequestIDHandler: func(c echo.Context, s string) { + called = true + }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") + assert.True(t, called) +} + +func TestRequestIDWithConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid, err := RequestIDConfig{}.ToMiddleware() + assert.NoError(t, err) h := rid(handler) h(c) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) - // Custom generator and handler - customID := "customGenerator" - calledHandler := false + // Custom generator rid = RequestIDWithConfig(RequestIDConfig{ - Generator: func() string { return customID }, - RequestIDHandler: func(_ echo.Context, id string) { - calledHandler = true - assert.Equal(t, customID, id) - }, + Generator: func() string { return "customGenerator" }, }) h = rid(handler) h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") - assert.True(t, calledHandler) } func TestRequestID_IDNotAltered(t *testing.T) { diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 1b3e3eaa..de411875 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -24,6 +24,7 @@ import ( // LogStatus: true, // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // logger.Info(). +// Date("request_start", v.StartTime). // Str("URI", v.URI). // Int("status", v.Status). // Msg("request") @@ -39,6 +40,7 @@ import ( // LogStatus: true, // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // logger.Info("request", +// zap.Time("request_start", v.StartTime), // zap.String("URI", v.URI), // zap.Int("status", v.Status), // ) @@ -54,8 +56,9 @@ import ( // LogStatus: true, // LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { // log.WithFields(logrus.Fields{ -// "URI": values.URI, -// "status": values.Status, +// "request_start": values.StartTime, +// "URI": values.URI, +// "status": values.Status, // }).Info("request") // // return nil @@ -158,15 +161,15 @@ type RequestLoggerValues struct { // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct. ResponseSize int64 // Headers are list of headers from request. Note: request can contain more than one header with same value so slice - // of values is been logger for each given header. + // of values is what will be returned/logged for each given header. // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". Headers map[string][]string // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter - // with same name so slice of values is been logger for each given query param name. + // with same name so slice of values is what will be returned/logged for each given query param name. QueryParams map[string][]string // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with - // same name so slice of values is been logger for each given form value name. + // same name so slice of values is what will be returned/logged for each given form value name. FormValues map[string][]string } diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index 5118b121..9811d152 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -289,7 +289,7 @@ func TestRequestLogger_allFields(t *testing.T) { req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + c := e.NewContext(req, rec).(echo.EditableContext) c.SetPath("/test*") diff --git a/middleware/rewrite.go b/middleware/rewrite.go index e5b0a6b5..4bd952ea 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,62 +1,58 @@ package middleware import ( + "errors" "regexp" "github.com/labstack/echo/v4" ) -type ( - // RewriteConfig defines the config for Rewrite middleware. - RewriteConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RewriteConfig defines the config for Rewrite middleware. +type RewriteConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Rules defines the URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Example: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - // Required. - Rules map[string]string `yaml:"rules"` + // Rules defines the URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Example: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + // Required. + Rules map[string]string - // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` - } -) - -var ( - // DefaultRewriteConfig is the default Rewrite middleware config. - DefaultRewriteConfig = RewriteConfig{ - Skipper: DefaultSkipper, - } -) + // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRules map[*regexp.Regexp]string +} // Rewrite returns a Rewrite middleware. // // Rewrite middleware rewrites the URL path based on the provided rules. func Rewrite(rules map[string]string) echo.MiddlewareFunc { - c := DefaultRewriteConfig + c := RewriteConfig{} c.Rules = rules return RewriteWithConfig(c) } -// RewriteWithConfig returns a Rewrite middleware with config. -// See: `Rewrite()`. +// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration. +// +// Rewrite middleware rewrites the URL path based on the provided rules. func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { - // Defaults - if config.Rules == nil && config.RegexRules == nil { - panic("echo: rewrite middleware requires url path rewrite rules or regex rules") - } + return toMiddlewareOrPanic(config) +} +// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration +func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Rules == nil && config.RegexRules == nil { + return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules") } if config.RegexRules == nil { @@ -77,5 +73,5 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 0ac04bb2..eea155ad 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -24,10 +24,10 @@ func TestRewriteAfterRouting(t *testing.T) { }, })) e.GET("/public/*", func(c echo.Context) error { - return c.String(http.StatusOK, c.Param("*")) + return c.String(http.StatusOK, c.PathParam("*")) }) e.GET("/*", func(c echo.Context) error { - return c.String(http.StatusOK, c.Param("*")) + return c.String(http.StatusOK, c.PathParam("*")) }) var testCases = []struct { @@ -90,20 +90,74 @@ func TestRewriteAfterRouting(t *testing.T) { } } +func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) { + assert.Panics(t, func() { + RewriteWithConfig(RewriteConfig{}) + }) +} + +func TestMustRewriteWithConfig_skipper(t *testing.T) { + var testCases = []struct { + name string + givenSkipper func(c echo.Context) bool + whenURL string + expectURL string + expectStatus int + }{ + { + name: "not skipped", + whenURL: "/old", + expectURL: "/new", + expectStatus: http.StatusOK, + }, + { + name: "skipped", + givenSkipper: func(c echo.Context) bool { + return true + }, + whenURL: "/old", + expectURL: "/old", + expectStatus: http.StatusNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig( + RewriteConfig{ + Skipper: tc.givenSkipper, + Rules: map[string]string{"/old": "/new"}}, + )) + + e.GET("/new", func(c echo.Context) error { + return c.NoContent(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectURL, req.URL.EscapedPath()) + assert.Equal(t, tc.expectStatus, rec.Code) + }) + } +} + // Issue #1086 func TestEchoRewritePreMiddleware(t *testing.T) { e := echo.New() - r := e.Router() // Rewrite old url to new one // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches - e.Pre(Rewrite(map[string]string{ - "/old": "/new", - }, - )) + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{"/old": "/new"}}), + ) // Route - r.Add(http.MethodGet, "/new", func(c echo.Context) error { + e.Add(http.MethodGet, "/new", func(c echo.Context) error { return c.NoContent(http.StatusOK) }) @@ -117,7 +171,6 @@ func TestEchoRewritePreMiddleware(t *testing.T) { // Issue #1143 func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { e := echo.New() - r := e.Router() // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(RewriteWithConfig(RewriteConfig{ @@ -127,10 +180,10 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { }, })) - r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { return c.String(http.StatusOK, "hosts") }) - r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { return c.String(http.StatusOK, "eng") }) diff --git a/middleware/secure.go b/middleware/secure.go index 6c405172..9218aa5d 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -6,84 +6,80 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // SecureConfig defines the config for Secure middleware. - SecureConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// SecureConfig defines the config for Secure middleware. +type SecureConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // XSSProtection provides protection against cross-site scripting attack (XSS) - // by setting the `X-XSS-Protection` header. - // Optional. Default value "1; mode=block". - XSSProtection string `yaml:"xss_protection"` + // XSSProtection provides protection against cross-site scripting attack (XSS) + // by setting the `X-XSS-Protection` header. + // Optional. Default value "1; mode=block". + XSSProtection string - // ContentTypeNosniff provides protection against overriding Content-Type - // header by setting the `X-Content-Type-Options` header. - // Optional. Default value "nosniff". - ContentTypeNosniff string `yaml:"content_type_nosniff"` + // ContentTypeNosniff provides protection against overriding Content-Type + // header by setting the `X-Content-Type-Options` header. + // Optional. Default value "nosniff". + ContentTypeNosniff string - // XFrameOptions can be used to indicate whether or not a browser should - // be allowed to render a page in a ,