diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml
index 26640666..ecc508e6 100644
--- a/.github/workflows/echo.yml
+++ b/.github/workflows/echo.yml
@@ -27,7 +27,8 @@ jobs:
os: [ubuntu-latest, macos-latest, windows-latest]
# Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
# Echo tests with last four major releases
- go: [1.14, 1.15, 1.16, 1.17]
+ # except v5 starts from 1.16 until there is last four major releases after that
+ go: [1.16, 1.17]
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index 67d45ad7..00000000
--- a/.travis.yml
+++ /dev/null
@@ -1,21 +0,0 @@
-arch:
- - amd64
- - ppc64le
-
-language: go
-go:
- - 1.14.x
- - 1.15.x
- - tip
-env:
- - GO111MODULE=on
-install:
- - go get -v golang.org/x/lint/golint
-script:
- - golint -set_exit_status ./...
- - go test -race -coverprofile=coverage.txt -covermode=atomic ./...
-after_success:
- - bash <(curl -s https://codecov.io/bash)
-matrix:
- allow_failures:
- - go: tip
diff --git a/Makefile b/Makefile
index 48061f7e..10f9c8f5 100644
--- a/Makefile
+++ b/Makefile
@@ -24,11 +24,11 @@ race: ## Run tests with data race detector
@go test -race ${PKG_LIST}
benchmark: ## Run benchmarks
- @go test -run="-" -bench=".*" ${PKG_LIST}
+ @go test -run="-" -benchmem -bench=".*" ${PKG_LIST}
help: ## Display this help screen
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
-goversion ?= "1.15"
-test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.15
+goversion ?= "1.16"
+test_version: ## Run tests inside Docker with given version (defaults to 1.16 oldest supported). Example: make test_version goversion=1.16
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"
diff --git a/README.md b/README.md
index 364f98ac..eb9369ad 100644
--- a/README.md
+++ b/README.md
@@ -12,6 +12,8 @@
## Supported Go versions
+Echo supports last four major releases. `v5` starts from 1.16 until there is last four major releases after that.
+
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
Therefore a Go version capable of understanding /vN suffixed imports is required:
@@ -67,8 +69,8 @@ package main
import (
"net/http"
- "github.com/labstack/echo/v4"
- "github.com/labstack/echo/v4/middleware"
+ "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v5/middleware"
)
func main() {
@@ -83,7 +85,9 @@ func main() {
e.GET("/", hello)
// Start server
- e.Logger.Fatal(e.Start(":1323"))
+ if err := e.Start(":1323"); err != http.ErrServerClosed {
+ log.Fatal(err)
+ }
}
// Handler
diff --git a/bind.go b/bind.go
index fdf0524c..a650ba0b 100644
--- a/bind.go
+++ b/bind.go
@@ -11,42 +11,38 @@ import (
"strings"
)
-type (
- // Binder is the interface that wraps the Bind method.
- Binder interface {
- Bind(i interface{}, c Context) error
- }
+// Binder is the interface that wraps the Bind method.
+type Binder interface {
+ Bind(c Context, i interface{}) error
+}
- // DefaultBinder is the default implementation of the Binder interface.
- DefaultBinder struct{}
+// DefaultBinder is the default implementation of the Binder interface.
+type DefaultBinder struct{}
- // BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
- // Types that don't implement this, but do implement encoding.TextUnmarshaler
- // will use that interface instead.
- BindUnmarshaler interface {
- // UnmarshalParam decodes and assigns a value from an form or query param.
- UnmarshalParam(param string) error
- }
-)
+// BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
+// Types that don't implement this, but do implement encoding.TextUnmarshaler
+// will use that interface instead.
+type BindUnmarshaler interface {
+ // UnmarshalParam decodes and assigns a value from an form or query param.
+ UnmarshalParam(param string) error
+}
// BindPathParams binds path params to bindable object
-func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error {
- names := c.ParamNames()
- values := c.ParamValues()
+func BindPathParams(c Context, i interface{}) error {
params := map[string][]string{}
- for i, name := range names {
- params[name] = []string{values[i]}
+ for _, param := range c.PathParams() {
+ params[param.Name] = []string{param.Value}
}
- if err := b.bindData(i, params, "param"); err != nil {
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ if err := bindData(i, params, "param"); err != nil {
+ return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
return nil
}
// BindQueryParams binds query params to bindable object
-func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error {
- if err := b.bindData(i, c.QueryParams(), "query"); err != nil {
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+func BindQueryParams(c Context, i interface{}) error {
+ if err := bindData(i, c.QueryParams(), "query"); err != nil {
+ return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
return nil
}
@@ -56,7 +52,7 @@ func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error {
// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm
// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm
-func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
+func BindBody(c Context, i interface{}) (err error) {
req := c.Request()
if req.ContentLength == 0 {
return
@@ -70,25 +66,25 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
case *HTTPError:
return err
default:
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
}
case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML):
if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
if ute, ok := err.(*xml.UnsupportedTypeError); ok {
- return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err)
+ return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error()))
} else if se, ok := err.(*xml.SyntaxError); ok {
- return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err)
+ return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error()))
}
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
params, err := c.FormParams()
if err != nil {
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
- if err = b.bindData(i, params, "form"); err != nil {
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ if err = bindData(i, params, "form"); err != nil {
+ return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
default:
return ErrUnsupportedMediaType
@@ -98,17 +94,17 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
// BindHeaders binds HTTP headers to a bindable object
func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error {
- if err := b.bindData(i, c.Request().Header, "header"); err != nil {
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ if err := bindData(i, c.Request().Header, "header"); err != nil {
+ return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
return nil
}
// Bind implements the `Binder#Bind` function.
// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous
-// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
-func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
- if err := b.BindPathParams(c, i); err != nil {
+// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
+func (b *DefaultBinder) Bind(c Context, i interface{}) (err error) {
+ if err := BindPathParams(c, i); err != nil {
return err
}
// Issue #1670 - Query params are binded only for GET/DELETE and NOT for usual request with body (POST/PUT/PATCH)
@@ -116,15 +112,15 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
// i.e. is `&id=1&lang=en` from URL same as `{"id":100,"lang":"de"}` request body and which one should have priority when binding.
// This HTTP method check restores pre v4.1.11 behavior and avoids different problems when query is mixed with body
if c.Request().Method == http.MethodGet || c.Request().Method == http.MethodDelete {
- if err = b.BindQueryParams(c, i); err != nil {
+ if err = BindQueryParams(c, i); err != nil {
return err
}
}
- return b.BindBody(c, i)
+ return BindBody(c, i)
}
// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag
-func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error {
+func bindData(destination interface{}, data map[string][]string, tag string) error {
if destination == nil || len(data) == 0 {
return nil
}
@@ -170,7 +166,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
// If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags).
// structs that implement BindUnmarshaler are binded only when they have explicit tag
if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
- if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil {
+ if err := bindData(structField.Addr().Interface(), data, tag); err != nil {
return err
}
}
@@ -297,7 +293,7 @@ func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
func setIntField(value string, bitSize int, field reflect.Value) error {
if value == "" {
- value = "0"
+ return nil
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err == nil {
@@ -308,7 +304,7 @@ func setIntField(value string, bitSize int, field reflect.Value) error {
func setUintField(value string, bitSize int, field reflect.Value) error {
if value == "" {
- value = "0"
+ return nil
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err == nil {
@@ -319,7 +315,7 @@ func setUintField(value string, bitSize int, field reflect.Value) error {
func setBoolField(value string, field reflect.Value) error {
if value == "" {
- value = "false"
+ return nil
}
boolVal, err := strconv.ParseBool(value)
if err == nil {
@@ -330,7 +326,7 @@ func setBoolField(value string, field reflect.Value) error {
func setFloatField(value string, bitSize int, field reflect.Value) error {
if value == "" {
- value = "0.0"
+ return nil
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err == nil {
diff --git a/bind_test.go b/bind_test.go
index 4ed8dbb5..5b555248 100644
--- a/bind_test.go
+++ b/bind_test.go
@@ -300,6 +300,52 @@ func TestBindHeaderParamBadType(t *testing.T) {
}
}
+func TestBind_CombineQueryWithHeaderParam(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/products/999?length=50&page=10&language=et", nil)
+ req.Header.Set("language", "de")
+ req.Header.Set("length", "99")
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ c.SetPathParams(PathParams{{
+ Name: "id",
+ Value: "999",
+ }})
+
+ type SearchOpts struct {
+ ID int `param:"id"`
+ Length int `query:"length"`
+ Page int `query:"page"`
+ Search string `query:"search"`
+ Language string `query:"language" header:"language"`
+ }
+
+ opts := SearchOpts{
+ Length: 100,
+ Page: 0,
+ Search: "default value",
+ Language: "en",
+ }
+ err := c.Bind(&opts)
+ assert.NoError(t, err)
+
+ assert.Equal(t, 50, opts.Length) // bind from query
+ assert.Equal(t, 10, opts.Page) // bind from query
+ assert.Equal(t, 999, opts.ID) // bind from path param
+ assert.Equal(t, "et", opts.Language) // bind from query
+ assert.Equal(t, "default value", opts.Search) // default value stays
+
+ // make sure another bind will not mess already set values unless there are new values
+ err = (&DefaultBinder{}).BindHeaders(c, &opts)
+ assert.NoError(t, err)
+
+ assert.Equal(t, 50, opts.Length) // does not have tag in struct although header exists
+ assert.Equal(t, 10, opts.Page)
+ assert.Equal(t, 999, opts.ID)
+ assert.Equal(t, "de", opts.Language) // header overwrites now this value
+ assert.Equal(t, "default value", opts.Search)
+}
+
func TestBindUnmarshalParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
@@ -330,7 +376,7 @@ func TestBindUnmarshalParam(t *testing.T) {
func TestBindUnmarshalText(t *testing.T) {
e := New()
- req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
+ req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
@@ -406,7 +452,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) {
func TestBindUnmarshalTextPtr(t *testing.T) {
e := New()
- req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil)
+ req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
@@ -439,8 +485,7 @@ func TestBindUnsupportedMediaType(t *testing.T) {
func TestBindbindData(t *testing.T) {
a := assert.New(t)
ts := new(bindTestStruct)
- b := new(DefaultBinder)
- err := b.bindData(ts, values, "form")
+ err := bindData(ts, values, "form")
a.NoError(err)
a.Equal(0, ts.I)
@@ -462,12 +507,15 @@ func TestBindbindData(t *testing.T) {
func TestBindParam(t *testing.T) {
e := New()
- req := httptest.NewRequest(GET, "/", nil)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- c.SetPath("/users/:id/:name")
- c.SetParamNames("id", "name")
- c.SetParamValues("1", "Jon Snow")
+ cc := c.(EditableContext)
+ cc.SetRouteInfo(routeInfo{path: "/users/:id/:name"})
+ cc.SetPathParams(PathParams{
+ {Name: "id", Value: "1"},
+ {Name: "name", Value: "Jon Snow"},
+ })
u := new(user)
err := c.Bind(u)
@@ -478,9 +526,11 @@ func TestBindParam(t *testing.T) {
// Second test for the absence of a param
c2 := e.NewContext(req, rec)
- c2.SetPath("/users/:id")
- c2.SetParamNames("id")
- c2.SetParamValues("1")
+ cc2 := c2.(EditableContext)
+ cc2.SetRouteInfo(routeInfo{path: "/users/:id"})
+ cc2.SetPathParams(PathParams{
+ {Name: "id", Value: "1"},
+ })
u = new(user)
err = c2.Bind(u)
@@ -492,15 +542,17 @@ func TestBindParam(t *testing.T) {
// Bind something with param and post data payload
body := bytes.NewBufferString(`{ "name": "Jon Snow" }`)
e2 := New()
- req2 := httptest.NewRequest(POST, "/", body)
+ req2 := httptest.NewRequest(http.MethodPost, "/", body)
req2.Header.Set(HeaderContentType, MIMEApplicationJSON)
rec2 := httptest.NewRecorder()
c3 := e2.NewContext(req2, rec2)
- c3.SetPath("/users/:id")
- c3.SetParamNames("id")
- c3.SetParamValues("1")
+ cc3 := c3.(EditableContext)
+ cc3.SetRouteInfo(routeInfo{path: "/users/:id"})
+ cc3.SetPathParams(PathParams{
+ {Name: "id", Value: "1"},
+ })
u = new(user)
err = c3.Bind(u)
@@ -556,47 +608,115 @@ func TestBindSetWithProperType(t *testing.T) {
assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
}
-func TestBindSetFields(t *testing.T) {
- assert := assert.New(t)
+func TestSetIntField(t *testing.T) {
+ ts := new(bindTestStruct)
+ ts.I = 100
+ val := reflect.ValueOf(ts).Elem()
+
+ // empty value does nothing to field
+ // in that way we can have default values by setting field value before binding
+ err := setIntField("", 0, val.FieldByName("I"))
+ assert.NoError(t, err)
+ assert.Equal(t, 100, ts.I)
+
+ // second set with value sets the value
+ err = setIntField("5", 0, val.FieldByName("I"))
+ assert.NoError(t, err)
+ assert.Equal(t, 5, ts.I)
+
+ // third set without value does nothing to the value
+ // in that way multiple binds (ala query + header) do not reset fields to 0s
+ err = setIntField("", 0, val.FieldByName("I"))
+ assert.NoError(t, err)
+ assert.Equal(t, 5, ts.I)
+}
+
+func TestSetUintField(t *testing.T) {
+ ts := new(bindTestStruct)
+ ts.UI = 100
+
+ val := reflect.ValueOf(ts).Elem()
+
+ // empty value does nothing to field
+ // in that way we can have default values by setting field value before binding
+ err := setUintField("", 0, val.FieldByName("UI"))
+ assert.NoError(t, err)
+ assert.Equal(t, uint(100), ts.UI)
+
+ // second set with value sets the value
+ err = setUintField("5", 0, val.FieldByName("UI"))
+ assert.NoError(t, err)
+ assert.Equal(t, uint(5), ts.UI)
+
+ // third set without value does nothing to the value
+ // in that way multiple binds (ala query + header) do not reset fields to 0s
+ err = setUintField("", 0, val.FieldByName("UI"))
+ assert.NoError(t, err)
+ assert.Equal(t, uint(5), ts.UI)
+}
+
+func TestSetFloatField(t *testing.T) {
+ ts := new(bindTestStruct)
+ ts.F32 = 100
+
+ val := reflect.ValueOf(ts).Elem()
+
+ // empty value does nothing to field
+ // in that way we can have default values by setting field value before binding
+ err := setFloatField("", 0, val.FieldByName("F32"))
+ assert.NoError(t, err)
+ assert.Equal(t, float32(100), ts.F32)
+
+ // second set with value sets the value
+ err = setFloatField("15.5", 0, val.FieldByName("F32"))
+ assert.NoError(t, err)
+ assert.Equal(t, float32(15.5), ts.F32)
+
+ // third set without value does nothing to the value
+ // in that way multiple binds (ala query + header) do not reset fields to 0s
+ err = setFloatField("", 0, val.FieldByName("F32"))
+ assert.NoError(t, err)
+ assert.Equal(t, float32(15.5), ts.F32)
+}
+
+func TestSetBoolField(t *testing.T) {
+ ts := new(bindTestStruct)
+ ts.B = true
+
+ val := reflect.ValueOf(ts).Elem()
+
+ // empty value does nothing to field
+ // in that way we can have default values by setting field value before binding
+ err := setBoolField("", val.FieldByName("B"))
+ assert.NoError(t, err)
+ assert.Equal(t, true, ts.B)
+
+ // second set with value sets the value
+ err = setBoolField("true", val.FieldByName("B"))
+ assert.NoError(t, err)
+ assert.Equal(t, true, ts.B)
+
+ // third set without value does nothing to the value
+ // in that way multiple binds (ala query + header) do not reset fields to 0s
+ err = setBoolField("", val.FieldByName("B"))
+ assert.NoError(t, err)
+ assert.Equal(t, true, ts.B)
+
+ // fourth set to false
+ err = setBoolField("false", val.FieldByName("B"))
+ assert.NoError(t, err)
+ assert.Equal(t, false, ts.B)
+}
+
+func TestUnmarshalFieldNonPtr(t *testing.T) {
ts := new(bindTestStruct)
val := reflect.ValueOf(ts).Elem()
- // Int
- if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) {
- assert.Equal(5, ts.I)
- }
- if assert.NoError(setIntField("", 0, val.FieldByName("I"))) {
- assert.Equal(0, ts.I)
- }
-
- // Uint
- if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) {
- assert.Equal(uint(10), ts.UI)
- }
- if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) {
- assert.Equal(uint(0), ts.UI)
- }
-
- // Float
- if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) {
- assert.Equal(float32(15.5), ts.F32)
- }
- if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) {
- assert.Equal(float32(0.0), ts.F32)
- }
-
- // Bool
- if assert.NoError(setBoolField("true", val.FieldByName("B"))) {
- assert.Equal(true, ts.B)
- }
- if assert.NoError(setBoolField("", val.FieldByName("B"))) {
- assert.Equal(false, ts.B)
- }
ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T"))
- if assert.NoError(err) {
- assert.Equal(ok, true)
- assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
+ if assert.NoError(t, err) {
+ assert.True(t, ok)
+ assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
}
}
@@ -604,11 +724,10 @@ func BenchmarkBindbindDataWithTags(b *testing.B) {
b.ReportAllocs()
assert := assert.New(b)
ts := new(bindTestStructWithTags)
- binder := new(DefaultBinder)
var err error
b.ResetTimer()
for i := 0; i < b.N; i++ {
- err = binder.bindData(ts, values, "form")
+ err = bindData(ts, values, "form")
}
assert.NoError(err)
assertBindTestStruct(assert, (*bindTestStruct)(ts))
@@ -840,8 +959,10 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
c := e.NewContext(req, rec)
if !tc.whenNoPathParams {
- c.SetParamNames("node")
- c.SetParamValues("node_from_path")
+ cc := c.(EditableContext)
+ cc.SetPathParams(PathParams{
+ {Name: "node", Value: "node_from_path"},
+ })
}
var bindTarget interface{}
@@ -852,7 +973,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
}
b := new(DefaultBinder)
- err := b.Bind(bindTarget, c)
+ err := b.Bind(c, bindTarget)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@@ -1021,8 +1142,10 @@ func TestDefaultBinder_BindBody(t *testing.T) {
c := e.NewContext(req, rec)
if !tc.whenNoPathParams {
- c.SetParamNames("node")
- c.SetParamValues("real_node")
+ cc := c.(EditableContext)
+ cc.SetPathParams(PathParams{
+ {Name: "node", Value: "real_node"},
+ })
}
var bindTarget interface{}
@@ -1031,9 +1154,8 @@ func TestDefaultBinder_BindBody(t *testing.T) {
} else {
bindTarget = &Node{}
}
- b := new(DefaultBinder)
- err := b.BindBody(c, bindTarget)
+ err := BindBody(c, bindTarget)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
diff --git a/binder.go b/binder.go
index 0900ce8d..402b80bc 100644
--- a/binder.go
+++ b/binder.go
@@ -118,10 +118,10 @@ func QueryParamsBinder(c Context) *ValueBinder {
func PathParamsBinder(c Context) *ValueBinder {
return &ValueBinder{
failFast: true,
- ValueFunc: c.Param,
+ ValueFunc: c.PathParam,
ValuesFunc: func(sourceParam string) []string {
// path parameter should not have multiple values so getting values does not make sense but lets not error out here
- value := c.Param(sourceParam)
+ value := c.PathParam(sourceParam)
if value == "" {
return nil
}
diff --git a/binder_go1.15_test.go b/binder_go1.15_test.go
index 018628c3..4f1587bc 100644
--- a/binder_go1.15_test.go
+++ b/binder_go1.15_test.go
@@ -30,14 +30,15 @@ func createTestContext15(URL string, body io.Reader, pathParams map[string]strin
c := e.NewContext(req, rec)
if len(pathParams) > 0 {
- names := make([]string, 0)
- values := make([]string, 0)
+ params := make(PathParams, 0)
for name, value := range pathParams {
- names = append(names, name)
- values = append(values, value)
+ params = append(params, PathParam{
+ Name: name,
+ Value: value,
+ })
}
- c.SetParamNames(names...)
- c.SetParamValues(values...)
+ cc := c.(EditableContext)
+ cc.SetPathParams(params)
}
return c
diff --git a/binder_test.go b/binder_test.go
index 946906a9..00036863 100644
--- a/binder_test.go
+++ b/binder_test.go
@@ -25,14 +25,15 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string)
c := e.NewContext(req, rec)
if len(pathParams) > 0 {
- names := make([]string, 0)
- values := make([]string, 0)
+ params := make(PathParams, 0)
for name, value := range pathParams {
- names = append(names, name)
- values = append(values, value)
+ params = append(params, PathParam{
+ Name: name,
+ Value: value,
+ })
}
- c.SetParamNames(names...)
- c.SetParamValues(values...)
+ cc := c.(EditableContext)
+ cc.SetPathParams(params)
}
return c
@@ -2643,7 +2644,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) {
binder := new(DefaultBinder)
for i := 0; i < b.N; i++ {
var dest Opts
- _ = binder.Bind(&dest, c)
+ _ = binder.Bind(c, &dest)
}
}
@@ -2710,7 +2711,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
binder := new(DefaultBinder)
for i := 0; i < b.N; i++ {
var dest Opts
- _ = binder.Bind(&dest, c)
+ _ = binder.Bind(c, &dest)
if dest.Int64 != 1 {
b.Fatalf("int64!=1")
}
diff --git a/context.go b/context.go
index 91ab6e48..65bedcdf 100644
--- a/context.go
+++ b/context.go
@@ -3,212 +3,233 @@ package echo
import (
"bytes"
"encoding/xml"
+ "errors"
"fmt"
"io"
+ "io/fs"
"mime/multipart"
"net"
"net/http"
"net/url"
- "os"
"path/filepath"
"strings"
"sync"
)
-type (
- // Context represents the context of the current HTTP request. It holds request and
- // response objects, path, path parameters, data and registered handler.
- Context interface {
- // Request returns `*http.Request`.
- Request() *http.Request
+// Context represents the context of the current HTTP request. It holds request and
+// response objects, path, path parameters, data and registered handler.
+type Context interface {
+ // Request returns `*http.Request`.
+ Request() *http.Request
- // SetRequest sets `*http.Request`.
- SetRequest(r *http.Request)
+ // SetRequest sets `*http.Request`.
+ SetRequest(r *http.Request)
- // SetResponse sets `*Response`.
- SetResponse(r *Response)
+ // SetResponse sets `*Response`.
+ SetResponse(r *Response)
- // Response returns `*Response`.
- Response() *Response
+ // Response returns `*Response`.
+ Response() *Response
- // IsTLS returns true if HTTP connection is TLS otherwise false.
- IsTLS() bool
+ // IsTLS returns true if HTTP connection is TLS otherwise false.
+ IsTLS() bool
- // IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
- IsWebSocket() bool
+ // IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
+ IsWebSocket() bool
- // Scheme returns the HTTP protocol scheme, `http` or `https`.
- Scheme() string
+ // Scheme returns the HTTP protocol scheme, `http` or `https`.
+ Scheme() string
- // RealIP returns the client's network address based on `X-Forwarded-For`
- // or `X-Real-IP` request header.
- // The behavior can be configured using `Echo#IPExtractor`.
- RealIP() string
+ // RealIP returns the client's network address based on `X-Forwarded-For`
+ // or `X-Real-IP` request header.
+ // The behavior can be configured using `Echo#IPExtractor`.
+ RealIP() string
- // Path returns the registered path for the handler.
- Path() string
+ // RouteMatchType returns router match type for current context. This helps middlewares to distinguish which type
+ // of match router found and how this request context handler chain could end:
+ // * route match - this path + method had matching route.
+ // * not found - this path did not match any routes enough to be considered match
+ // * method not allowed - path had routes registered but for other method types then current request is
+ // * unknown - initial state for fresh context before router tries to do routing
+ //
+ // Note: for pre-middleware (Echo.Pre) this method result is always RouteMatchUnknown as at point router has not tried
+ // to match request to route.
+ RouteMatchType() RouteMatchType
- // SetPath sets the registered path for the handler.
- SetPath(p string)
+ // RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route.
+ // In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases.
+ RouteInfo() RouteInfo
- // Param returns path parameter by name.
- Param(name string) string
+ // Path returns the registered path for the handler.
+ Path() string
- // ParamNames returns path parameter names.
- ParamNames() []string
+ // PathParam returns path parameter by name.
+ PathParam(name string) string
- // SetParamNames sets path parameter names.
- SetParamNames(names ...string)
+ // PathParams returns path parameter values.
+ PathParams() PathParams
- // ParamValues returns path parameter values.
- ParamValues() []string
+ // SetPathParams set path parameter for during current request lifecycle.
+ SetPathParams(params PathParams)
- // SetParamValues sets path parameter values.
- SetParamValues(values ...string)
+ // QueryParam returns the query param for the provided name.
+ QueryParam(name string) string
- // QueryParam returns the query param for the provided name.
- QueryParam(name string) string
+ // QueryParams returns the query parameters as `url.Values`.
+ QueryParams() url.Values
- // QueryParams returns the query parameters as `url.Values`.
- QueryParams() url.Values
+ // QueryString returns the URL query string.
+ QueryString() string
- // QueryString returns the URL query string.
- QueryString() string
+ // FormValue returns the form field value for the provided name.
+ FormValue(name string) string
- // FormValue returns the form field value for the provided name.
- FormValue(name string) string
+ // FormParams returns the form parameters as `url.Values`.
+ FormParams() (url.Values, error)
- // FormParams returns the form parameters as `url.Values`.
- FormParams() (url.Values, error)
+ // FormFile returns the multipart form file for the provided name.
+ FormFile(name string) (*multipart.FileHeader, error)
- // FormFile returns the multipart form file for the provided name.
- FormFile(name string) (*multipart.FileHeader, error)
+ // MultipartForm returns the multipart form.
+ MultipartForm() (*multipart.Form, error)
- // MultipartForm returns the multipart form.
- MultipartForm() (*multipart.Form, error)
+ // Cookie returns the named cookie provided in the request.
+ Cookie(name string) (*http.Cookie, error)
- // Cookie returns the named cookie provided in the request.
- Cookie(name string) (*http.Cookie, error)
+ // SetCookie adds a `Set-Cookie` header in HTTP response.
+ SetCookie(cookie *http.Cookie)
- // SetCookie adds a `Set-Cookie` header in HTTP response.
- SetCookie(cookie *http.Cookie)
+ // Cookies returns the HTTP cookies sent with the request.
+ Cookies() []*http.Cookie
- // Cookies returns the HTTP cookies sent with the request.
- Cookies() []*http.Cookie
+ // Get retrieves data from the context.
+ Get(key string) interface{}
- // Get retrieves data from the context.
- Get(key string) interface{}
+ // Set saves data in the context.
+ Set(key string, val interface{})
- // Set saves data in the context.
- Set(key string, val interface{})
+ // Bind binds the request body into provided type `i`. The default binder
+ // does it based on Content-Type header.
+ Bind(i interface{}) error
- // Bind binds the request body into provided type `i`. The default binder
- // does it based on Content-Type header.
- Bind(i interface{}) error
+ // Validate validates provided `i`. It is usually called after `Context#Bind()`.
+ // Validator must be registered using `Echo#Validator`.
+ Validate(i interface{}) error
- // Validate validates provided `i`. It is usually called after `Context#Bind()`.
- // Validator must be registered using `Echo#Validator`.
- Validate(i interface{}) error
+ // Render renders a template with data and sends a text/html response with status
+ // code. Renderer must be registered using `Echo.Renderer`.
+ Render(code int, name string, data interface{}) error
- // Render renders a template with data and sends a text/html response with status
- // code. Renderer must be registered using `Echo.Renderer`.
- Render(code int, name string, data interface{}) error
+ // HTML sends an HTTP response with status code.
+ HTML(code int, html string) error
- // HTML sends an HTTP response with status code.
- HTML(code int, html string) error
+ // HTMLBlob sends an HTTP blob response with status code.
+ HTMLBlob(code int, b []byte) error
- // HTMLBlob sends an HTTP blob response with status code.
- HTMLBlob(code int, b []byte) error
+ // String sends a string response with status code.
+ String(code int, s string) error
- // String sends a string response with status code.
- String(code int, s string) error
+ // JSON sends a JSON response with status code.
+ JSON(code int, i interface{}) error
- // JSON sends a JSON response with status code.
- JSON(code int, i interface{}) error
+ // JSONPretty sends a pretty-print JSON with status code.
+ JSONPretty(code int, i interface{}, indent string) error
- // JSONPretty sends a pretty-print JSON with status code.
- JSONPretty(code int, i interface{}, indent string) error
+ // JSONBlob sends a JSON blob response with status code.
+ JSONBlob(code int, b []byte) error
- // JSONBlob sends a JSON blob response with status code.
- JSONBlob(code int, b []byte) error
+ // JSONP sends a JSONP response with status code. It uses `callback` to construct
+ // the JSONP payload.
+ JSONP(code int, callback string, i interface{}) error
- // JSONP sends a JSONP response with status code. It uses `callback` to construct
- // the JSONP payload.
- JSONP(code int, callback string, i interface{}) error
+ // JSONPBlob sends a JSONP blob response with status code. It uses `callback`
+ // to construct the JSONP payload.
+ JSONPBlob(code int, callback string, b []byte) error
- // JSONPBlob sends a JSONP blob response with status code. It uses `callback`
- // to construct the JSONP payload.
- JSONPBlob(code int, callback string, b []byte) error
+ // XML sends an XML response with status code.
+ XML(code int, i interface{}) error
- // XML sends an XML response with status code.
- XML(code int, i interface{}) error
+ // XMLPretty sends a pretty-print XML with status code.
+ XMLPretty(code int, i interface{}, indent string) error
- // XMLPretty sends a pretty-print XML with status code.
- XMLPretty(code int, i interface{}, indent string) error
+ // XMLBlob sends an XML blob response with status code.
+ XMLBlob(code int, b []byte) error
- // XMLBlob sends an XML blob response with status code.
- XMLBlob(code int, b []byte) error
+ // Blob sends a blob response with status code and content type.
+ Blob(code int, contentType string, b []byte) error
- // Blob sends a blob response with status code and content type.
- Blob(code int, contentType string, b []byte) error
+ // Stream sends a streaming response with status code and content type.
+ Stream(code int, contentType string, r io.Reader) error
- // Stream sends a streaming response with status code and content type.
- Stream(code int, contentType string, r io.Reader) error
+ // File sends a response with the content of the file.
+ File(file string) error
- // File sends a response with the content of the file.
- File(file string) error
+ // Attachment sends a response as attachment, prompting client to save the
+ // file.
+ Attachment(file string, name string) error
- // Attachment sends a response as attachment, prompting client to save the
- // file.
- Attachment(file string, name string) error
+ // Inline sends a response as inline, opening the file in the browser.
+ Inline(file string, name string) error
- // Inline sends a response as inline, opening the file in the browser.
- Inline(file string, name string) error
+ // NoContent sends a response with no body and a status code.
+ NoContent(code int) error
- // NoContent sends a response with no body and a status code.
- NoContent(code int) error
+ // Redirect redirects the request to a provided URL with status code.
+ Redirect(code int, url string) error
- // Redirect redirects the request to a provided URL with status code.
- Redirect(code int, url string) error
+ // Error invokes the registered HTTP error handler.
+ // NB: Avoid using this method. It is better to return errors so middlewares up in chain could act on returned error.
+ Error(err error)
- // Error invokes the registered HTTP error handler. Generally used by middleware.
- Error(err error)
+ // Echo returns the `Echo` instance.
+ Echo() *Echo
+}
- // Handler returns the matched handler by router.
- Handler() HandlerFunc
+// EditableContext is additional interface that structure implementing Context must implement. Methods inside this
+// interface are meant for Echo internal usage (for mainly routing) and should not be used in middlewares.
+type EditableContext interface {
+ Context
- // SetHandler sets the matched handler by router.
- SetHandler(h HandlerFunc)
+ // RawPathParams returns raw path pathParams value.
+ RawPathParams() *PathParams
- // Logger returns the `Logger` instance.
- Logger() Logger
+ // SetRawPathParams replaces any existing param values with new values for this context lifetime (request).
+ SetRawPathParams(params *PathParams)
- // Set the logger
- SetLogger(l Logger)
+ // SetPath sets the registered path for the handler.
+ SetPath(p string)
- // Echo returns the `Echo` instance.
- Echo() *Echo
+ // SetRouteMatchType sets the RouteMatchType of router match for this request.
+ SetRouteMatchType(t RouteMatchType)
- // Reset resets the context after request completes. It must be called along
- // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
- // See `Echo#ServeHTTP()`
- Reset(r *http.Request, w http.ResponseWriter)
- }
+ // SetRouteInfo sets the route info of this request to the context.
+ SetRouteInfo(ri RouteInfo)
- context struct {
- request *http.Request
- response *Response
- path string
- pnames []string
- pvalues []string
- query url.Values
- handler HandlerFunc
- store Map
- echo *Echo
- logger Logger
- lock sync.RWMutex
- }
-)
+ // Reset resets the context after request completes. It must be called along
+ // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
+ // See `Echo#ServeHTTP()`
+ Reset(r *http.Request, w http.ResponseWriter)
+}
+
+type context struct {
+ request *http.Request
+ response *Response
+
+ matchType RouteMatchType
+ route RouteInfo
+ path string
+
+ // pathParams holds path/uri parameters determined by Router. Lifecycle is handled by Echo to reduce allocations.
+ pathParams *PathParams
+ // currentParams hold path parameters set by non-Echo implementation (custom middlewares, handlers) during the lifetime of Request.
+ // Lifecycle is not handle by Echo and could have excess allocations per served Request
+ currentParams PathParams
+
+ query url.Values
+ store Map
+ echo *Echo
+ lock sync.RWMutex
+}
const (
defaultMemory = 32 << 20 // 32 MB
@@ -296,52 +317,50 @@ func (c *context) SetPath(p string) {
c.path = p
}
-func (c *context) Param(name string) string {
- for i, n := range c.pnames {
- if i < len(c.pvalues) {
- if n == name {
- return c.pvalues[i]
- }
- }
- }
- return ""
+func (c *context) RouteMatchType() RouteMatchType {
+ return c.matchType
}
-func (c *context) ParamNames() []string {
- return c.pnames
+func (c *context) SetRouteMatchType(t RouteMatchType) {
+ c.matchType = t
}
-func (c *context) SetParamNames(names ...string) {
- c.pnames = names
-
- l := len(names)
- if *c.echo.maxParam < l {
- *c.echo.maxParam = l
- }
-
- if len(c.pvalues) < l {
- // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them,
- // probably those values will be overriden in a Context#SetParamValues
- newPvalues := make([]string, l)
- copy(newPvalues, c.pvalues)
- c.pvalues = newPvalues
- }
+func (c *context) RouteInfo() RouteInfo {
+ return c.route
}
-func (c *context) ParamValues() []string {
- return c.pvalues[:len(c.pnames)]
+func (c *context) SetRouteInfo(ri RouteInfo) {
+ c.route = ri
}
-func (c *context) SetParamValues(values ...string) {
- // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times
- // It will brake the Router#Find code
- limit := len(values)
- if limit > *c.echo.maxParam {
- limit = *c.echo.maxParam
+func (c *context) RawPathParams() *PathParams {
+ return c.pathParams
+}
+
+func (c *context) SetRawPathParams(params *PathParams) {
+ c.pathParams = params
+}
+
+func (c *context) PathParam(name string) string {
+ if c.currentParams != nil {
+ return c.currentParams.Get(name, "")
}
- for i := 0; i < limit; i++ {
- c.pvalues[i] = values[i]
+
+ return c.pathParams.Get(name, "")
+}
+
+func (c *context) PathParams() PathParams {
+ if c.currentParams != nil {
+ return c.currentParams
}
+
+ result := make(PathParams, len(*c.pathParams))
+ copy(result, *c.pathParams)
+ return result
+}
+
+func (c *context) SetPathParams(params PathParams) {
+ c.currentParams = params
}
func (c *context) QueryParam(name string) string {
@@ -422,7 +441,7 @@ func (c *context) Set(key string, val interface{}) {
}
func (c *context) Bind(i interface{}) error {
- return c.echo.Binder.Bind(i, c)
+ return c.echo.Binder.Bind(c, i)
}
func (c *context) Validate(i interface{}) error {
@@ -562,27 +581,36 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error)
return
}
-func (c *context) File(file string) (err error) {
- f, err := os.Open(file)
+func (c *context) File(file string) error {
+ return c.FsFile(file, c.echo.Filesystem)
+}
+
+func (c *context) FsFile(file string, filesystem fs.FS) error {
+ // FIXME: should we add this method into echo.Context interface?
+ f, err := filesystem.Open(file)
if err != nil {
- return NotFoundHandler(c)
+ return ErrNotFound
}
defer f.Close()
fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.Join(file, indexPage)
- f, err = os.Open(file)
+ f, err = filesystem.Open(file)
if err != nil {
- return NotFoundHandler(c)
+ return ErrNotFound
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
- return
+ return err
}
}
- http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f)
- return
+ ff, ok := f.(io.ReadSeeker)
+ if !ok {
+ return errors.New("file does not implement io.ReadSeeker")
+ }
+ http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
+ return nil
}
func (c *context) Attachment(file, name string) error {
@@ -613,44 +641,23 @@ func (c *context) Redirect(code int, url string) error {
}
func (c *context) Error(err error) {
- c.echo.HTTPErrorHandler(err, c)
+ c.echo.HTTPErrorHandler(c, err)
}
func (c *context) Echo() *Echo {
return c.echo
}
-func (c *context) Handler() HandlerFunc {
- return c.handler
-}
-
-func (c *context) SetHandler(h HandlerFunc) {
- c.handler = h
-}
-
-func (c *context) Logger() Logger {
- res := c.logger
- if res != nil {
- return res
- }
- return c.echo.Logger
-}
-
-func (c *context) SetLogger(l Logger) {
- c.logger = l
-}
-
func (c *context) Reset(r *http.Request, w http.ResponseWriter) {
c.request = r
c.response.reset(w)
c.query = nil
- c.handler = NotFoundHandler
c.store = nil
+
+ c.matchType = RouteMatchUnknown
+ c.route = nil
c.path = ""
- c.pnames = nil
- c.logger = nil
- // NOTE: Don't reset because it has to have length c.echo.maxParam at all times
- for i := 0; i < *c.echo.maxParam; i++ {
- c.pvalues[i] = ""
- }
+ // NOTE: Don't reset because it has to have length c.echo.contextPathParamAllocSize at all times
+ *c.pathParams = (*c.pathParams)[:0]
+ c.currentParams = nil
}
diff --git a/context_test.go b/context_test.go
index a8b9a994..0783f1a8 100644
--- a/context_test.go
+++ b/context_test.go
@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
+ "io/ioutil"
"math"
"mime/multipart"
"net/http"
@@ -18,21 +19,19 @@ import (
"text/template"
"time"
- "github.com/labstack/gommon/log"
- testify "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/assert"
)
-type (
- Template struct {
- templates *template.Template
- }
-)
+type Template struct {
+ templates *template.Template
+}
var testUser = user{1, "Jon Snow"}
func BenchmarkAllocJSONP(b *testing.B) {
e := New()
- req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
+ e.Logger = &jsonLogger{writer: ioutil.Discard}
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
@@ -46,7 +45,8 @@ func BenchmarkAllocJSONP(b *testing.B) {
func BenchmarkAllocJSON(b *testing.B) {
e := New()
- req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
+ e.Logger = &jsonLogger{writer: ioutil.Discard}
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
@@ -60,7 +60,8 @@ func BenchmarkAllocJSON(b *testing.B) {
func BenchmarkAllocXML(b *testing.B) {
e := New()
- req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
+ e.Logger = &jsonLogger{writer: ioutil.Discard}
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
@@ -106,16 +107,14 @@ func TestContext(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
- assert := testify.New(t)
-
// Echo
- assert.Equal(e, c.Echo())
+ assert.Equal(t, e, c.Echo())
// Request
- assert.NotNil(c.Request())
+ assert.NotNil(t, c.Request())
// Response
- assert.NotNil(c.Response())
+ assert.NotNil(t, c.Response())
//--------
// Render
@@ -126,23 +125,23 @@ func TestContext(t *testing.T) {
}
c.echo.Renderer = tmpl
err := c.Render(http.StatusOK, "hello", "Jon Snow")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("Hello, Jon Snow!", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "Hello, Jon Snow!", rec.Body.String())
}
c.echo.Renderer = nil
err = c.Render(http.StatusOK, "hello", "Jon Snow")
- assert.Error(err)
+ assert.Error(t, err)
// JSON
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSON+"\n", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
}
// JSON with "?pretty"
@@ -150,10 +149,10 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSONPretty+"\n", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
}
req = httptest.NewRequest(http.MethodGet, "/", nil) // reset
@@ -161,37 +160,37 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSONPretty+"\n", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
}
// JSON (error)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, make(chan bool))
- assert.Error(err)
+ assert.Error(t, err)
// JSONP
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
callback := "callback"
err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String())
}
// XML
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXML, rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXML, rec.Body.String())
}
// XML with "?pretty"
@@ -199,10 +198,10 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXMLPretty, rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
}
req = httptest.NewRequest(http.MethodGet, "/", nil)
@@ -210,22 +209,22 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, make(chan bool))
- assert.Error(err)
+ assert.Error(t, err)
// XML response write error
c = e.NewContext(req, rec).(*context)
c.response.Writer = responseWriterErr{}
err = c.XML(0, 0)
- testify.Error(t, err)
+ assert.Error(t, err)
// XMLPretty
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXMLPretty, rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
}
t.Run("empty indent", func(t *testing.T) {
@@ -237,7 +236,6 @@ func TestContext(t *testing.T) {
t.Run("json", func(t *testing.T) {
buf.Reset()
- assert := testify.New(t)
// New JSONBlob with empty indent
rec = httptest.NewRecorder()
@@ -246,16 +244,15 @@ func TestContext(t *testing.T) {
enc.SetIndent(emptyIndent, emptyIndent)
err = enc.Encode(u)
err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(buf.String(), rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, buf.String(), rec.Body.String())
}
})
t.Run("xml", func(t *testing.T) {
buf.Reset()
- assert := testify.New(t)
// New XMLBlob with empty indent
rec = httptest.NewRecorder()
@@ -264,10 +261,10 @@ func TestContext(t *testing.T) {
enc.Indent(emptyIndent, emptyIndent)
err = enc.Encode(u)
err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+buf.String(), rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+buf.String(), rec.Body.String())
}
})
})
@@ -276,12 +273,12 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
data, err := json.Marshal(user{1, "Jon Snow"})
- assert.NoError(err)
+ assert.NoError(t, err)
err = c.JSONBlob(http.StatusOK, data)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSON, rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON, rec.Body.String())
}
// Legacy JSONPBlob
@@ -289,44 +286,44 @@ func TestContext(t *testing.T) {
c = e.NewContext(req, rec).(*context)
callback = "callback"
data, err = json.Marshal(user{1, "Jon Snow"})
- assert.NoError(err)
+ assert.NoError(t, err)
err = c.JSONPBlob(http.StatusOK, callback, data)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(callback+"("+userJSON+");", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, callback+"("+userJSON+");", rec.Body.String())
}
// Legacy XMLBlob
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
data, err = xml.Marshal(user{1, "Jon Snow"})
- assert.NoError(err)
+ assert.NoError(t, err)
err = c.XMLBlob(http.StatusOK, data)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXML, rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXML, rec.Body.String())
}
// String
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.String(http.StatusOK, "Hello, World!")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal("Hello, World!", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "Hello, World!", rec.Body.String())
}
// HTML
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.HTML(http.StatusOK, "Hello, World!")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal("Hello, World!", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "Hello, World!", rec.Body.String())
}
// Stream
@@ -334,55 +331,55 @@ func TestContext(t *testing.T) {
c = e.NewContext(req, rec).(*context)
r := strings.NewReader("response from a stream")
err = c.Stream(http.StatusOK, "application/octet-stream", r)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType))
- assert.Equal("response from a stream", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "response from a stream", rec.Body.String())
}
// Attachment
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.Attachment("_fixture/images/walle.png", "walle.png")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
- assert.Equal(219885, rec.Body.Len())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
+ assert.Equal(t, 219885, rec.Body.Len())
}
// Inline
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.Inline("_fixture/images/walle.png", "walle.png")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
- assert.Equal(219885, rec.Body.Len())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
+ assert.Equal(t, 219885, rec.Body.Len())
}
// NoContent
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
c.NoContent(http.StatusOK)
- assert.Equal(http.StatusOK, rec.Code)
+ assert.Equal(t, http.StatusOK, rec.Code)
// Error
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
c.Error(errors.New("error"))
- assert.Equal(http.StatusInternalServerError, rec.Code)
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
// Reset
- c.SetParamNames("foo")
- c.SetParamValues("bar")
+ c.pathParams = &PathParams{
+ {Name: "foo", Value: "bar"},
+ }
c.Set("foe", "ban")
c.query = url.Values(map[string][]string{"fon": {"baz"}})
c.Reset(req, httptest.NewRecorder())
- assert.Equal(0, len(c.ParamValues()))
- assert.Equal(0, len(c.ParamNames()))
- assert.Equal(0, len(c.store))
- assert.Equal("", c.Path())
- assert.Equal(0, len(c.QueryParams()))
+ assert.Equal(t, 0, len(c.PathParams()))
+ assert.Equal(t, 0, len(c.store))
+ assert.Equal(t, nil, c.RouteInfo())
+ assert.Equal(t, 0, len(c.QueryParams()))
}
func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
@@ -392,11 +389,10 @@ func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
c := e.NewContext(req, rec).(*context)
err := c.JSON(http.StatusCreated, user{1, "Jon Snow"})
- assert := testify.New(t)
- if assert.NoError(err) {
- assert.Equal(http.StatusCreated, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSON+"\n", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusCreated, rec.Code)
+ assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
}
}
@@ -407,9 +403,8 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) {
c := e.NewContext(req, rec).(*context)
err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()})
- assert := testify.New(t)
- if assert.Error(err) {
- assert.False(c.response.Committed)
+ if assert.Error(t, err) {
+ assert.False(t, c.response.Committed)
}
}
@@ -423,22 +418,20 @@ func TestContextCookie(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
- assert := testify.New(t)
-
// Read single
cookie, err := c.Cookie("theme")
- if assert.NoError(err) {
- assert.Equal("theme", cookie.Name)
- assert.Equal("light", cookie.Value)
+ if assert.NoError(t, err) {
+ assert.Equal(t, "theme", cookie.Name)
+ assert.Equal(t, "light", cookie.Value)
}
// Read multiple
for _, cookie := range c.Cookies() {
switch cookie.Name {
case "theme":
- assert.Equal("light", cookie.Value)
+ assert.Equal(t, "light", cookie.Value)
case "user":
- assert.Equal("Jon Snow", cookie.Value)
+ assert.Equal(t, "Jon Snow", cookie.Value)
}
}
@@ -453,104 +446,95 @@ func TestContextCookie(t *testing.T) {
HttpOnly: true,
}
c.SetCookie(cookie)
- assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly")
-}
-
-func TestContextPath(t *testing.T) {
- e := New()
- r := e.Router()
-
- handler := func(c Context) error { return c.String(http.StatusOK, "OK") }
-
- r.Add(http.MethodGet, "/users/:id", handler)
- c := e.NewContext(nil, nil)
- r.Find(http.MethodGet, "/users/1", c)
-
- assert := testify.New(t)
-
- assert.Equal("/users/:id", c.Path())
-
- r.Add(http.MethodGet, "/users/:uid/files/:fid", handler)
- c = e.NewContext(nil, nil)
- r.Find(http.MethodGet, "/users/1/files/1", c)
- assert.Equal("/users/:uid/files/:fid", c.Path())
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
}
func TestContextPathParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, nil)
+ c := e.NewContext(req, nil).(*context)
+ params := &PathParams{
+ {Name: "uid", Value: "101"},
+ {Name: "fid", Value: "501"},
+ }
// ParamNames
- c.SetParamNames("uid", "fid")
- testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames())
-
- // ParamValues
- c.SetParamValues("101", "501")
- testify.EqualValues(t, []string{"101", "501"}, c.ParamValues())
+ c.pathParams = params
+ assert.EqualValues(t, *params, c.PathParams())
// Param
- testify.Equal(t, "501", c.Param("fid"))
- testify.Equal(t, "", c.Param("undefined"))
+ assert.Equal(t, "501", c.PathParam("fid"))
+ assert.Equal(t, "", c.PathParam("undefined"))
}
func TestContextGetAndSetParam(t *testing.T) {
e := New()
r := e.Router()
- r.Add(http.MethodGet, "/:foo", func(Context) error { return nil })
+ _, err := r.Add(Route{
+ Method: http.MethodGet,
+ Path: "/:foo",
+ Name: "",
+ Handler: func(Context) error { return nil },
+ Middlewares: nil,
+ })
+ assert.NoError(t, err)
+
req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
c := e.NewContext(req, nil)
- c.SetParamNames("foo")
+
+ params := &PathParams{{Name: "foo", Value: "101"}}
+ // ParamNames
+ c.(*context).pathParams = params
// round-trip param values with modification
- paramVals := c.ParamValues()
- testify.EqualValues(t, []string{""}, c.ParamValues())
- paramVals[0] = "bar"
- c.SetParamValues(paramVals...)
- testify.EqualValues(t, []string{"bar"}, c.ParamValues())
+ paramVals := c.PathParams()
+ assert.Equal(t, *params, c.PathParams())
+
+ paramVals[0] = PathParam{Name: "xxx", Value: "yyy"} // PathParams() returns copy and modifying it does nothing to context
+ assert.Equal(t, PathParams{{Name: "foo", Value: "101"}}, c.PathParams())
+
+ pathParams := PathParams{
+ {Name: "aaa", Value: "bbb"},
+ {Name: "ccc", Value: "ddd"},
+ }
+ c.SetPathParams(pathParams)
+ assert.Equal(t, pathParams, c.PathParams())
// shouldn't explode during Reset() afterwards!
- testify.NotPanics(t, func() {
- c.Reset(nil, nil)
+ assert.NotPanics(t, func() {
+ c.(EditableContext).Reset(nil, nil)
})
+ assert.Equal(t, PathParams{}, c.PathParams())
+ assert.Len(t, *c.(*context).pathParams, 0)
+ assert.Equal(t, cap(*c.(*context).pathParams), 1)
}
// Issue #1655
-func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) {
- assert := testify.New(t)
-
+func TestContext_SetParamNamesShouldNotModifyPathParams(t *testing.T) {
e := New()
- assert.Equal(0, *e.maxParam)
+ c := e.NewContext(nil, nil).(*context)
- expectedOneParam := []string{"one"}
- expectedTwoParams := []string{"one", "two"}
- expectedThreeParams := []string{"one", "two", ""}
- expectedABCParams := []string{"A", "B", "C"}
+ assert.Equal(t, 0, e.contextPathParamAllocSize)
+ expectedTwoParams := &PathParams{
+ {Name: "1", Value: "one"},
+ {Name: "2", Value: "two"},
+ }
+ c.SetRawPathParams(expectedTwoParams)
+ assert.Equal(t, 0, e.contextPathParamAllocSize)
+ assert.Equal(t, *expectedTwoParams, c.PathParams())
- c := e.NewContext(nil, nil)
- c.SetParamNames("1", "2")
- c.SetParamValues(expectedTwoParams...)
- assert.Equal(2, *e.maxParam)
- assert.EqualValues(expectedTwoParams, c.ParamValues())
-
- c.SetParamNames("1")
- assert.Equal(2, *e.maxParam)
- // Here for backward compatibility the ParamValues remains as they are
- assert.EqualValues(expectedOneParam, c.ParamValues())
-
- c.SetParamNames("1", "2", "3")
- assert.Equal(3, *e.maxParam)
- // Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam
- assert.EqualValues(expectedThreeParams, c.ParamValues())
-
- c.SetParamValues("A", "B", "C", "D")
- assert.Equal(3, *e.maxParam)
- // Here D shouldn't be returned
- assert.EqualValues(expectedABCParams, c.ParamValues())
+ expectedThreeParams := PathParams{
+ {Name: "1", Value: "one"},
+ {Name: "2", Value: "two"},
+ {Name: "3", Value: "three"},
+ }
+ c.SetPathParams(expectedThreeParams)
+ assert.Equal(t, 0, e.contextPathParamAllocSize)
+ assert.Equal(t, expectedThreeParams, c.PathParams())
}
func TestContextFormValue(t *testing.T) {
@@ -564,13 +548,13 @@ func TestContextFormValue(t *testing.T) {
c := e.NewContext(req, nil)
// FormValue
- testify.Equal(t, "Jon Snow", c.FormValue("name"))
- testify.Equal(t, "jon@labstack.com", c.FormValue("email"))
+ assert.Equal(t, "Jon Snow", c.FormValue("name"))
+ assert.Equal(t, "jon@labstack.com", c.FormValue("email"))
// FormParams
params, err := c.FormParams()
- if testify.NoError(t, err) {
- testify.Equal(t, url.Values{
+ if assert.NoError(t, err) {
+ assert.Equal(t, url.Values{
"name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"},
}, params)
@@ -581,8 +565,8 @@ func TestContextFormValue(t *testing.T) {
req.Header.Add(HeaderContentType, MIMEMultipartForm)
c = e.NewContext(req, nil)
params, err = c.FormParams()
- testify.Nil(t, params)
- testify.Error(t, err)
+ assert.Nil(t, params)
+ assert.Error(t, err)
}
func TestContextQueryParam(t *testing.T) {
@@ -594,11 +578,11 @@ func TestContextQueryParam(t *testing.T) {
c := e.NewContext(req, nil)
// QueryParam
- testify.Equal(t, "Jon Snow", c.QueryParam("name"))
- testify.Equal(t, "jon@labstack.com", c.QueryParam("email"))
+ assert.Equal(t, "Jon Snow", c.QueryParam("name"))
+ assert.Equal(t, "jon@labstack.com", c.QueryParam("email"))
// QueryParams
- testify.Equal(t, url.Values{
+ assert.Equal(t, url.Values{
"name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"},
}, c.QueryParams())
@@ -609,7 +593,7 @@ func TestContextFormFile(t *testing.T) {
buf := new(bytes.Buffer)
mr := multipart.NewWriter(buf)
w, err := mr.CreateFormFile("file", "test")
- if testify.NoError(t, err) {
+ if assert.NoError(t, err) {
w.Write([]byte("test"))
}
mr.Close()
@@ -618,8 +602,8 @@ func TestContextFormFile(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
f, err := c.FormFile("file")
- if testify.NoError(t, err) {
- testify.Equal(t, "test", f.Filename)
+ if assert.NoError(t, err) {
+ assert.Equal(t, "test", f.Filename)
}
}
@@ -634,8 +618,8 @@ func TestContextMultipartForm(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
f, err := c.MultipartForm()
- if testify.NoError(t, err) {
- testify.NotNil(t, f)
+ if assert.NoError(t, err) {
+ assert.NotNil(t, f)
}
}
@@ -644,16 +628,16 @@ func TestContextRedirect(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
- testify.Equal(t, http.StatusMovedPermanently, rec.Code)
- testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
- testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
+ assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
+ assert.Equal(t, http.StatusMovedPermanently, rec.Code)
+ assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
+ assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
}
func TestContextStore(t *testing.T) {
var c Context = new(context)
c.Set("name", "Jon Snow")
- testify.Equal(t, "Jon Snow", c.Get("name"))
+ assert.Equal(t, "Jon Snow", c.Get("name"))
}
func BenchmarkContext_Store(b *testing.B) {
@@ -671,42 +655,6 @@ func BenchmarkContext_Store(b *testing.B) {
}
}
-func TestContextHandler(t *testing.T) {
- e := New()
- r := e.Router()
- b := new(bytes.Buffer)
-
- r.Add(http.MethodGet, "/handler", func(Context) error {
- _, err := b.Write([]byte("handler"))
- return err
- })
- c := e.NewContext(nil, nil)
- r.Find(http.MethodGet, "/handler", c)
- err := c.Handler()(c)
- testify.Equal(t, "handler", b.String())
- testify.NoError(t, err)
-}
-
-func TestContext_SetHandler(t *testing.T) {
- var c Context = new(context)
-
- testify.Nil(t, c.Handler())
-
- c.SetHandler(func(c Context) error {
- return nil
- })
- testify.NotNil(t, c.Handler())
-}
-
-func TestContext_Path(t *testing.T) {
- path := "/pa/th"
-
- var c Context = new(context)
-
- c.SetPath(path)
- testify.Equal(t, path, c.Path())
-}
-
type validator struct{}
func (*validator) Validate(i interface{}) error {
@@ -717,10 +665,10 @@ func TestContext_Validate(t *testing.T) {
e := New()
c := e.NewContext(nil, nil)
- testify.Error(t, c.Validate(struct{}{}))
+ assert.Error(t, c.Validate(struct{}{}))
e.Validator = &validator{}
- testify.NoError(t, c.Validate(struct{}{}))
+ assert.NoError(t, c.Validate(struct{}{}))
}
func TestContext_QueryString(t *testing.T) {
@@ -728,21 +676,21 @@ func TestContext_QueryString(t *testing.T) {
queryString := "query=string&var=val"
- req := httptest.NewRequest(GET, "/?"+queryString, nil)
+ req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil)
c := e.NewContext(req, nil)
- testify.Equal(t, queryString, c.QueryString())
+ assert.Equal(t, queryString, c.QueryString())
}
func TestContext_Request(t *testing.T) {
var c Context = new(context)
- testify.Nil(t, c.Request())
+ assert.Nil(t, c.Request())
- req := httptest.NewRequest(GET, "/path", nil)
+ req := httptest.NewRequest(http.MethodGet, "/path", nil)
c.SetRequest(req)
- testify.Equal(t, req, c.Request())
+ assert.Equal(t, req, c.Request())
}
func TestContext_Scheme(t *testing.T) {
@@ -799,14 +747,14 @@ func TestContext_Scheme(t *testing.T) {
}
for _, tt := range tests {
- testify.Equal(t, tt.s, tt.c.Scheme())
+ assert.Equal(t, tt.s, tt.c.Scheme())
}
}
func TestContext_IsWebSocket(t *testing.T) {
tests := []struct {
c Context
- ws testify.BoolAssertionFunc
+ ws assert.BoolAssertionFunc
}{
{
&context{
@@ -814,7 +762,7 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"websocket"}},
},
},
- testify.True,
+ assert.True,
},
{
&context{
@@ -822,13 +770,13 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"Websocket"}},
},
},
- testify.True,
+ assert.True,
},
{
&context{
request: &http.Request{},
},
- testify.False,
+ assert.False,
},
{
&context{
@@ -836,7 +784,7 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"other"}},
},
},
- testify.False,
+ assert.False,
},
}
@@ -849,30 +797,14 @@ func TestContext_IsWebSocket(t *testing.T) {
func TestContext_Bind(t *testing.T) {
e := New()
- req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
c := e.NewContext(req, nil)
u := new(user)
req.Header.Add(HeaderContentType, MIMEApplicationJSON)
err := c.Bind(u)
- testify.NoError(t, err)
- testify.Equal(t, &user{1, "Jon Snow"}, u)
-}
-
-func TestContext_Logger(t *testing.T) {
- e := New()
- c := e.NewContext(nil, nil)
-
- log1 := c.Logger()
- testify.NotNil(t, log1)
-
- log2 := log.New("echo2")
- c.SetLogger(log2)
- testify.Equal(t, log2, c.Logger())
-
- // Resetting the context returns the initial logger
- c.Reset(nil, nil)
- testify.Equal(t, log1, c.Logger())
+ assert.NoError(t, err)
+ assert.Equal(t, &user{1, "Jon Snow"}, u)
}
func TestContext_RealIP(t *testing.T) {
@@ -925,6 +857,6 @@ func TestContext_RealIP(t *testing.T) {
}
for _, tt := range tests {
- testify.Equal(t, tt.s, tt.c.RealIP())
+ assert.Equal(t, tt.s, tt.c.RealIP())
}
}
diff --git a/echo.go b/echo.go
index df5d3584..0f9b3f6f 100644
--- a/echo.go
+++ b/echo.go
@@ -29,7 +29,9 @@ Example:
e.GET("/", hello)
// Start server
- e.Logger.Fatal(e.Start(":1323"))
+ if err := e.Start(":8080"); err != http.ErrServerClosed {
+ log.Fatal(err)
+ }
}
Learn more at https://echo.labstack.com
@@ -37,127 +39,81 @@ Learn more at https://echo.labstack.com
package echo
import (
- "bytes"
stdContext "context"
- "crypto/tls"
"errors"
"fmt"
"io"
- "io/ioutil"
- stdLog "log"
- "net"
+ "io/fs"
"net/http"
"net/url"
"os"
+ "os/signal"
"path/filepath"
- "reflect"
- "runtime"
"sync"
- "time"
-
- "github.com/labstack/gommon/color"
- "github.com/labstack/gommon/log"
- "golang.org/x/crypto/acme"
- "golang.org/x/crypto/acme/autocert"
- "golang.org/x/net/http2"
- "golang.org/x/net/http2/h2c"
)
-type (
- // Echo is the top-level framework instance.
- Echo struct {
- common
- // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get
- // listener address info (on which interface/port was listener binded) without having data races.
- startupMutex sync.RWMutex
- StdLogger *stdLog.Logger
- colorer *color.Color
- premiddleware []MiddlewareFunc
- middleware []MiddlewareFunc
- maxParam *int
- router *Router
- routers map[string]*Router
- notFoundHandler HandlerFunc
- pool sync.Pool
- Server *http.Server
- TLSServer *http.Server
- Listener net.Listener
- TLSListener net.Listener
- AutoTLSManager autocert.Manager
- DisableHTTP2 bool
- Debug bool
- HideBanner bool
- HidePort bool
- HTTPErrorHandler HTTPErrorHandler
- Binder Binder
- JSONSerializer JSONSerializer
- Validator Validator
- Renderer Renderer
- Logger Logger
- IPExtractor IPExtractor
- ListenerNetwork string
- }
+// Echo is the top-level framework instance.
+// Note: replacing/nilling public fields is not coroutine/thread-safe and can cause data-races/panics.
+type Echo struct {
+ // premiddleware are middlewares that are run for every request before routing is done
+ premiddleware []MiddlewareFunc
+ // middleware are middlewares that are run after router found a matching route (not found and method not found are also matches)
+ middleware []MiddlewareFunc
- // Route contains a handler and information for matching against requests.
- Route struct {
- Method string `json:"method"`
- Path string `json:"path"`
- Name string `json:"name"`
- }
+ router Router
+ routers map[string]Router
+ routerCreator func(e *Echo) Router
- // HTTPError represents an error that occurred while handling a request.
- HTTPError struct {
- Code int `json:"-"`
- Message interface{} `json:"message"`
- Internal error `json:"-"` // Stores the error returned by an external dependency
- }
+ contextPool sync.Pool
+ contextPathParamAllocSize int
- // MiddlewareFunc defines a function to process middleware.
- MiddlewareFunc func(HandlerFunc) HandlerFunc
+ // NewContextFunc allows using custom context implementations, instead of default *echo.context
+ NewContextFunc func(pathParamAllocSize int) EditableContext
+ Debug bool
+ HTTPErrorHandler HTTPErrorHandler
+ Binder Binder
+ JSONSerializer JSONSerializer
+ Validator Validator
+ Renderer Renderer
+ Logger Logger
+ IPExtractor IPExtractor
+ // Filesystem is file system used by Static and File handler to access files.
+ // Defaults to os.DirFS(".")
+ Filesystem fs.FS
+}
- // HandlerFunc defines a function to serve HTTP requests.
- HandlerFunc func(Context) error
+// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces.
+type JSONSerializer interface {
+ Serialize(c Context, i interface{}, indent string) error
+ Deserialize(c Context, i interface{}) error
+}
- // HTTPErrorHandler is a centralized HTTP error handler.
- HTTPErrorHandler func(error, Context)
+// HTTPErrorHandler is a centralized HTTP error handler.
+type HTTPErrorHandler func(c Context, err error)
- // Validator is the interface that wraps the Validate function.
- Validator interface {
- Validate(i interface{}) error
- }
+// HandlerFunc defines a function to serve HTTP requests.
+type HandlerFunc func(c Context) error
- // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces.
- JSONSerializer interface {
- Serialize(c Context, i interface{}, indent string) error
- Deserialize(c Context, i interface{}) error
- }
+// MiddlewareFunc defines a function to process middleware.
+type MiddlewareFunc func(next HandlerFunc) HandlerFunc
- // Renderer is the interface that wraps the Render function.
- Renderer interface {
- Render(io.Writer, string, interface{}, Context) error
- }
+// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking.
+type MiddlewareConfigurator interface {
+ ToMiddleware() (MiddlewareFunc, error)
+}
- // Map defines a generic map of type `map[string]interface{}`.
- Map map[string]interface{}
+// Validator is the interface that wraps the Validate function.
+type Validator interface {
+ Validate(i interface{}) error
+}
- // Common struct for Echo & Group.
- common struct{}
-)
+// Renderer is the interface that wraps the Render function.
+type Renderer interface {
+ Render(io.Writer, string, interface{}, Context) error
+}
-// HTTP methods
-// NOTE: Deprecated, please use the stdlib constants directly instead.
-const (
- CONNECT = http.MethodConnect
- DELETE = http.MethodDelete
- GET = http.MethodGet
- HEAD = http.MethodHead
- OPTIONS = http.MethodOptions
- PATCH = http.MethodPatch
- POST = http.MethodPost
- // PROPFIND = "PROPFIND"
- PUT = http.MethodPut
- TRACE = http.MethodTrace
-)
+// Map defines a generic map of type `map[string]interface{}`.
+type Map map[string]interface{}
// MIME types
const (
@@ -241,274 +197,278 @@ const (
const (
// Version of Echo
- Version = "4.6.1"
- website = "https://echo.labstack.com"
- // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo
- banner = `
- ____ __
- / __/___/ / ___
- / _// __/ _ \/ _ \
-/___/\__/_//_/\___/ %s
-High performance, minimalist Go web framework
-%s
-____________________________________O/_______
- O\
-`
+ Version = "5.0.X"
)
-var (
- methods = [...]string{
- http.MethodConnect,
- http.MethodDelete,
- http.MethodGet,
- http.MethodHead,
- http.MethodOptions,
- http.MethodPatch,
- http.MethodPost,
- PROPFIND,
- http.MethodPut,
- http.MethodTrace,
- REPORT,
- }
-)
-
-// Errors
-var (
- ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType)
- ErrNotFound = NewHTTPError(http.StatusNotFound)
- ErrUnauthorized = NewHTTPError(http.StatusUnauthorized)
- ErrForbidden = NewHTTPError(http.StatusForbidden)
- ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
- ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
- ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests)
- ErrBadRequest = NewHTTPError(http.StatusBadRequest)
- ErrBadGateway = NewHTTPError(http.StatusBadGateway)
- ErrInternalServerError = NewHTTPError(http.StatusInternalServerError)
- ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout)
- ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable)
- ErrValidatorNotRegistered = errors.New("validator not registered")
- ErrRendererNotRegistered = errors.New("renderer not registered")
- ErrInvalidRedirectCode = errors.New("invalid redirect status code")
- ErrCookieNotFound = errors.New("cookie not found")
- ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
- ErrInvalidListenerNetwork = errors.New("invalid listener network")
-)
-
-// Error handlers
-var (
- NotFoundHandler = func(c Context) error {
- return ErrNotFound
- }
-
- MethodNotAllowedHandler = func(c Context) error {
- return ErrMethodNotAllowed
- }
-)
+var methods = [...]string{
+ http.MethodConnect,
+ http.MethodDelete,
+ http.MethodGet,
+ http.MethodHead,
+ http.MethodOptions,
+ http.MethodPatch,
+ http.MethodPost,
+ PROPFIND,
+ http.MethodPut,
+ http.MethodTrace,
+ REPORT,
+}
// New creates an instance of Echo.
-func New() (e *Echo) {
- e = &Echo{
- Server: new(http.Server),
- TLSServer: new(http.Server),
- AutoTLSManager: autocert.Manager{
- Prompt: autocert.AcceptTOS,
+func New() *Echo {
+ logger := newJSONLogger(os.Stdout)
+ e := &Echo{
+ Logger: logger,
+ Filesystem: os.DirFS("."),
+ Binder: &DefaultBinder{},
+ JSONSerializer: &DefaultJSONSerializer{},
+
+ routers: make(map[string]Router),
+ routerCreator: func(ec *Echo) Router {
+ return NewRouter(ec, RouterConfig{})
},
- Logger: log.New("echo"),
- colorer: color.New(),
- maxParam: new(int),
- ListenerNetwork: "tcp",
}
- e.Server.Handler = e
- e.TLSServer.Handler = e
- e.HTTPErrorHandler = e.DefaultHTTPErrorHandler
- e.Binder = &DefaultBinder{}
- e.JSONSerializer = &DefaultJSONSerializer{}
- e.Logger.SetLevel(log.ERROR)
- e.StdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0)
- e.pool.New = func() interface{} {
+
+ e.router = NewRouter(e, RouterConfig{})
+ e.HTTPErrorHandler = DefaultHTTPErrorHandler(false)
+ e.contextPool.New = func() interface{} {
+ if e.NewContextFunc != nil {
+ return e.NewContextFunc(e.contextPathParamAllocSize)
+ }
return e.NewContext(nil, nil)
}
- e.router = NewRouter(e)
- e.routers = map[string]*Router{}
- return
+ return e
}
// NewContext returns a Context instance.
func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context {
+ p := make(PathParams, e.contextPathParamAllocSize)
return &context{
- request: r,
- response: NewResponse(w, e),
- store: make(Map),
- echo: e,
- pvalues: make([]string, *e.maxParam),
- handler: NotFoundHandler,
+ request: r,
+ response: NewResponse(w, e),
+ store: make(Map),
+ echo: e,
+ pathParams: &p,
+ matchType: RouteMatchUnknown,
+ route: nil,
+ path: "",
}
}
// Router returns the default router.
-func (e *Echo) Router() *Router {
+func (e *Echo) Router() Router {
return e.router
}
// Routers returns the map of host => router.
-func (e *Echo) Routers() map[string]*Router {
+func (e *Echo) Routers() map[string]Router {
return e.routers
}
-// DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response
-// with status code.
+// RouterFor returns Router for given host.
+func (e *Echo) RouterFor(host string) Router {
+ return e.routers[host]
+}
+
+// ResetRouterCreator resets callback for creating new router instances.
+// Note: current (default) router is immediately replaced with router created with creator func and vhost routers are cleared.
+func (e *Echo) ResetRouterCreator(creator func(e *Echo) Router) {
+ e.routerCreator = creator
+ e.router = creator(e)
+ e.routers = make(map[string]Router)
+}
+
+// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response
+// with status code. `exposeError` parameter decides if returned message will contain also error message or not
//
-// NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error).
+// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately)
+// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error).
// When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from
// handler. Then the error that global error handler received will be ignored because we have already "commited" the
// response and status code header has been sent to the client.
-func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {
-
- if c.Response().Committed {
- return
- }
-
- he, ok := err.(*HTTPError)
- if ok {
- if he.Internal != nil {
- if herr, ok := he.Internal.(*HTTPError); ok {
- he = herr
- }
+func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler {
+ return func(c Context, err error) {
+ if c.Response().Committed {
+ return
}
- } else {
- he = &HTTPError{
+
+ he := &HTTPError{
Code: http.StatusInternalServerError,
Message: http.StatusText(http.StatusInternalServerError),
}
- }
-
- // Issue #1426
- code := he.Code
- message := he.Message
- if m, ok := he.Message.(string); ok {
- if e.Debug {
- message = Map{"message": m, "error": err.Error()}
- } else {
- message = Map{"message": m}
+ if errors.As(err, &he) {
+ if he.Internal != nil { // max 2 levels of checks even if internal could have also internal
+ errors.As(he.Internal, &he)
+ }
}
- }
- // Send response
- if c.Request().Method == http.MethodHead { // Issue #608
- err = c.NoContent(he.Code)
- } else {
- err = c.JSON(code, message)
- }
- if err != nil {
- e.Logger.Error(err)
+ // Issue #1426
+ code := he.Code
+ message := he.Message
+ if m, ok := he.Message.(string); ok {
+ if exposeError {
+ message = Map{"message": m, "error": err.Error()}
+ } else {
+ message = Map{"message": m}
+ }
+ }
+
+ // Send response
+ var cErr error
+ if c.Request().Method == http.MethodHead { // Issue #608
+ cErr = c.NoContent(he.Code)
+ } else {
+ cErr = c.JSON(code, message)
+ }
+ if cErr != nil {
+ c.Echo().Logger.Error(err) // truly rare case. ala client already disconnected
+ }
}
}
-// Pre adds middleware to the chain which is run before router.
+// Pre adds middleware to the chain which is run before router tries to find matching route.
+// Meaning middleware is executed even for 404 (not found) cases.
func (e *Echo) Pre(middleware ...MiddlewareFunc) {
e.premiddleware = append(e.premiddleware, middleware...)
}
-// Use adds middleware to the chain which is run after router.
+// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed.
func (e *Echo) Use(middleware ...MiddlewareFunc) {
e.middleware = append(e.middleware, middleware...)
}
// CONNECT registers a new CONNECT route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodConnect, path, h, m...)
}
// DELETE registers a new DELETE route for a path with matching handler in the router
-// with optional route-level middleware.
-func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// with optional route-level middleware. Panics on error.
+func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodDelete, path, h, m...)
}
// GET registers a new GET route for a path with matching handler in the router
-// with optional route-level middleware.
-func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// with optional route-level middleware. Panics on error.
+func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodGet, path, h, m...)
}
// HEAD registers a new HEAD route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodHead, path, h, m...)
}
// OPTIONS registers a new OPTIONS route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodOptions, path, h, m...)
}
// PATCH registers a new PATCH route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodPatch, path, h, m...)
}
// POST registers a new POST route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodPost, path, h, m...)
}
// PUT registers a new PUT route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodPut, path, h, m...)
}
// TRACE registers a new TRACE route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodTrace, path, h, m...)
}
-// Any registers a new route for all HTTP methods and path with matching handler
-// in the router with optional route-level middleware.
-func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
- routes := make([]*Route, len(methods))
- for i, m := range methods {
- routes[i] = e.Add(m, path, handler, middleware...)
+// Any registers a new route for all supported HTTP methods and path with matching handler
+// in the router with optional route-level middleware. Panics on error.
+func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
+ errs := make([]error, 0)
+ ris := make(Routes, 0)
+ for _, m := range methods {
+ ri, err := e.AddRoute(Route{
+ Method: m,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ })
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ ris = append(ris, ri)
}
- return routes
+ if len(errs) > 0 {
+ panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ris
}
// Match registers a new route for multiple HTTP methods and path with matching
-// handler in the router with optional route-level middleware.
-func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
- routes := make([]*Route, len(methods))
- for i, m := range methods {
- routes[i] = e.Add(m, path, handler, middleware...)
+// handler in the router with optional route-level middleware. Panics on error.
+func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
+ errs := make([]error, 0)
+ ris := make(Routes, 0)
+ for _, m := range methods {
+ ri, err := e.AddRoute(Route{
+ Method: m,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ })
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ ris = append(ris, ri)
}
- return routes
+ if len(errs) > 0 {
+ panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ris
}
-// Static registers a new route with path prefix to serve static files from the
-// provided root directory.
-func (e *Echo) Static(prefix, root string) *Route {
+// Static registers a new route with path prefix to serve static files from the provided root directory. Panics on error.
+func (e *Echo) Static(prefix, root string, middleware ...MiddlewareFunc) RouteInfo {
+ return e.Add(
+ http.MethodGet,
+ prefix+"*",
+ StaticDirectoryHandler(root, false),
+ middleware...,
+ )
+}
+
+// StaticDirectoryHandler creates handler function to serve files from given root path
+func StaticDirectoryHandler(root string, disablePathUnescaping bool) HandlerFunc {
if root == "" {
root = "." // For security we want to restrict to CWD.
}
- return e.static(prefix, root, e.GET)
-}
-
-func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route {
- h := func(c Context) error {
- p, err := url.PathUnescape(c.Param("*"))
- if err != nil {
- return err
+ return func(c Context) error {
+ p := c.PathParam("*")
+ if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice
+ tmpPath, err := url.PathUnescape(p)
+ if err != nil {
+ return fmt.Errorf("failed to unescape path variable: %w", err)
+ }
+ p = tmpPath
}
name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security
- fi, err := os.Stat(name)
+ fi, err := fs.Stat(c.Echo().Filesystem, name)
if err != nil {
// The access path does not exist
- return NotFoundHandler(c)
+ return ErrNotFound
}
// If the request is for a directory and does not end with "/"
@@ -519,56 +479,57 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl
}
return c.File(name)
}
- // Handle added routes based on trailing slash:
- // /prefix => exact route "/prefix" + any route "/prefix/*"
- // /prefix/ => only any route "/prefix/*"
- if prefix != "" {
- if prefix[len(prefix)-1] == '/' {
- // Only add any route for intentional trailing slash
- return get(prefix+"*", h)
- }
- get(prefix, h)
- }
- return get(prefix+"/*", h)
}
-func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route,
- m ...MiddlewareFunc) *Route {
- return get(path, func(c Context) error {
+// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error.
+func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
+ handler := func(c Context) error {
return c.File(file)
- }, m...)
-}
-
-// File registers a new route with path to serve a static file with optional route-level middleware.
-func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route {
- return e.file(path, file, e.GET, m...)
-}
-
-func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
- name := handlerName(handler)
- router := e.findRouter(host)
- router.Add(method, path, func(c Context) error {
- h := applyMiddleware(handler, middleware...)
- return h(c)
- })
- r := &Route{
- Method: method,
- Path: path,
- Name: name,
}
- e.router.routes[method+path] = r
- return r
+ return e.Add(http.MethodGet, path, handler, middleware...)
+}
+
+// AddRoute registers a new Route with default host Router
+func (e *Echo) AddRoute(route Routable) (RouteInfo, error) {
+ return e.add("", route)
+}
+
+func (e *Echo) add(host string, route Routable) (RouteInfo, error) {
+ router := e.findRouter(host)
+ ri, err := router.Add(route)
+ if err != nil {
+ return nil, err
+ }
+
+ paramsCount := len(ri.Params())
+ if paramsCount > e.contextPathParamAllocSize {
+ e.contextPathParamAllocSize = paramsCount
+ }
+ return ri, nil
}
// Add registers a new route for an HTTP method and path with matching handler
// in the router with optional route-level middleware.
-func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
- return e.add("", method, path, handler, middleware...)
+func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
+ ri, err := e.add(
+ "",
+ Route{
+ Method: method,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ Name: "",
+ },
+ )
+ if err != nil {
+ panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ri
}
// Host creates a new router group for the provided host and optional host-level middleware.
func (e *Echo) Host(name string, m ...MiddlewareFunc) (g *Group) {
- e.routers[name] = NewRouter(e)
+ e.routers[name] = e.routerCreator(e)
g = &Group{host: name, echo: e}
g.Use(m...)
return
@@ -581,326 +542,95 @@ func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) {
return
}
-// URI generates a URI from handler.
-func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string {
- name := handlerName(handler)
- return e.Reverse(name, params...)
-}
-
-// URL is an alias for `URI` function.
-func (e *Echo) URL(h HandlerFunc, params ...interface{}) string {
- return e.URI(h, params...)
-}
-
-// Reverse generates an URL from route name and provided parameters.
-func (e *Echo) Reverse(name string, params ...interface{}) string {
- uri := new(bytes.Buffer)
- ln := len(params)
- n := 0
- for _, r := range e.router.routes {
- if r.Name == name {
- for i, l := 0, len(r.Path); i < l; i++ {
- if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln {
- for ; i < l && r.Path[i] != '/'; i++ {
- }
- uri.WriteString(fmt.Sprintf("%v", params[n]))
- n++
- }
- if i < l {
- uri.WriteByte(r.Path[i])
- }
- }
- break
- }
- }
- return uri.String()
-}
-
-// Routes returns the registered routes.
-func (e *Echo) Routes() []*Route {
- routes := make([]*Route, 0, len(e.router.routes))
- for _, v := range e.router.routes {
- routes = append(routes, v)
- }
- return routes
-}
-
// AcquireContext returns an empty `Context` instance from the pool.
// You must return the context by calling `ReleaseContext()`.
func (e *Echo) AcquireContext() Context {
- return e.pool.Get().(Context)
+ return e.contextPool.Get().(Context)
}
// ReleaseContext returns the `Context` instance back to the pool.
// You must call it after `AcquireContext()`.
func (e *Echo) ReleaseContext(c Context) {
- e.pool.Put(c)
+ e.contextPool.Put(c)
}
// ServeHTTP implements `http.Handler` interface, which serves HTTP requests.
func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- // Acquire context
- c := e.pool.Get().(*context)
+ var c EditableContext
+ if e.NewContextFunc != nil {
+ // NOTE: we are not casting always context to EditableContext because casting to interface vs pointer to struct is
+ // "significantly" slower. Echo Context interface has way to many methods so these checks take time.
+ // These are benchmarks with 1.16:
+ // * interface extending another interface = +24% slower (3233 ns/op vs 2605 ns/op)
+ // * interface (not extending any, just methods)= +14% slower
+ //
+ // Quote from https://stackoverflow.com/a/31584377
+ // "it's even worse with interface-to-interface assertion, because you also need to ensure that the type implements the interface."
+ //
+ // So most of the time we do not need custom context type and simple IF + cast to pointer to struct is fast enough.
+ c = e.contextPool.Get().(EditableContext)
+ } else {
+ c = e.contextPool.Get().(*context)
+ }
c.Reset(r, w)
- h := NotFoundHandler
+ var h func(c Context) error
if e.premiddleware == nil {
- e.findRouter(r.Host).Find(r.Method, GetPath(r), c)
- h = c.Handler()
- h = applyMiddleware(h, e.middleware...)
+ params := c.RawPathParams()
+ match := e.findRouter(r.Host).Match(r, params)
+
+ c.SetRawPathParams(params)
+ c.SetPath(match.RoutePath)
+ c.SetRouteInfo(match.RouteInfo)
+ c.SetRouteMatchType(match.Type)
+ h = applyMiddleware(match.Handler, e.middleware...)
} else {
- h = func(c Context) error {
- e.findRouter(r.Host).Find(r.Method, GetPath(r), c)
- h := c.Handler()
- h = applyMiddleware(h, e.middleware...)
- return h(c)
+ h = func(cc Context) error {
+ params := c.RawPathParams()
+ match := e.findRouter(r.Host).Match(r, params)
+ // NOTE: router will be executed after pre middlewares have been run. We assume here that context we receive after pre middlewares
+ // is the same we began with. If not - this is use-case we do not support and is probably abuse from developer.
+ c.SetRawPathParams(params)
+ c.SetPath(match.RoutePath)
+ c.SetRouteInfo(match.RouteInfo)
+ c.SetRouteMatchType(match.Type)
+ h1 := applyMiddleware(match.Handler, e.middleware...)
+ return h1(cc)
}
h = applyMiddleware(h, e.premiddleware...)
}
// Execute chain
if err := h(c); err != nil {
- e.HTTPErrorHandler(err, c)
+ e.HTTPErrorHandler(c, err)
}
- // Release context
- e.pool.Put(c)
+ e.contextPool.Put(c)
}
-// Start starts an HTTP server.
+// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by
+// sending os.Interrupt signal with `ctrl+c`.
+//
+// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration
+// options.
+//
+// In need of customization use:
+// sc := echo.StartConfig{Address: ":8080"}
+// if err := sc.Start(e); err != http.ErrServerClosed {
+// log.Fatal(err)
+// }
+// // or standard library `http.Server`
+// s := http.Server{Addr: ":8080", Handler: e}
+// if err := s.ListenAndServe(); err != http.ErrServerClosed {
+// log.Fatal(err)
+// }
func (e *Echo) Start(address string) error {
- e.startupMutex.Lock()
- e.Server.Addr = address
- if err := e.configureServer(e.Server); err != nil {
- e.startupMutex.Unlock()
- return err
- }
- e.startupMutex.Unlock()
- return e.Server.Serve(e.Listener)
-}
+ sc := StartConfig{Address: address}
+ ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt) // start shutdown process on ctrl+c
+ defer cancel()
+ sc.GracefulContext = ctx
-// StartTLS starts an HTTPS server.
-// If `certFile` or `keyFile` is `string` the values are treated as file paths.
-// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is.
-func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) {
- e.startupMutex.Lock()
- var cert []byte
- if cert, err = filepathOrContent(certFile); err != nil {
- e.startupMutex.Unlock()
- return
- }
-
- var key []byte
- if key, err = filepathOrContent(keyFile); err != nil {
- e.startupMutex.Unlock()
- return
- }
-
- s := e.TLSServer
- s.TLSConfig = new(tls.Config)
- s.TLSConfig.Certificates = make([]tls.Certificate, 1)
- if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil {
- e.startupMutex.Unlock()
- return
- }
-
- e.configureTLS(address)
- if err := e.configureServer(s); err != nil {
- e.startupMutex.Unlock()
- return err
- }
- e.startupMutex.Unlock()
- return s.Serve(e.TLSListener)
-}
-
-func filepathOrContent(fileOrContent interface{}) (content []byte, err error) {
- switch v := fileOrContent.(type) {
- case string:
- return ioutil.ReadFile(v)
- case []byte:
- return v, nil
- default:
- return nil, ErrInvalidCertOrKeyType
- }
-}
-
-// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org.
-func (e *Echo) StartAutoTLS(address string) error {
- e.startupMutex.Lock()
- s := e.TLSServer
- s.TLSConfig = new(tls.Config)
- s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate
- s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto)
-
- e.configureTLS(address)
- if err := e.configureServer(s); err != nil {
- e.startupMutex.Unlock()
- return err
- }
- e.startupMutex.Unlock()
- return s.Serve(e.TLSListener)
-}
-
-func (e *Echo) configureTLS(address string) {
- s := e.TLSServer
- s.Addr = address
- if !e.DisableHTTP2 {
- s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2")
- }
-}
-
-// StartServer starts a custom http server.
-func (e *Echo) StartServer(s *http.Server) (err error) {
- e.startupMutex.Lock()
- if err := e.configureServer(s); err != nil {
- e.startupMutex.Unlock()
- return err
- }
- if s.TLSConfig != nil {
- e.startupMutex.Unlock()
- return s.Serve(e.TLSListener)
- }
- e.startupMutex.Unlock()
- return s.Serve(e.Listener)
-}
-
-func (e *Echo) configureServer(s *http.Server) (err error) {
- // Setup
- e.colorer.SetOutput(e.Logger.Output())
- s.ErrorLog = e.StdLogger
- s.Handler = e
- if e.Debug {
- e.Logger.SetLevel(log.DEBUG)
- }
-
- if !e.HideBanner {
- e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website))
- }
-
- if s.TLSConfig == nil {
- if e.Listener == nil {
- e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
- if err != nil {
- return err
- }
- }
- if !e.HidePort {
- e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr()))
- }
- return nil
- }
- if e.TLSListener == nil {
- l, err := newListener(s.Addr, e.ListenerNetwork)
- if err != nil {
- return err
- }
- e.TLSListener = tls.NewListener(l, s.TLSConfig)
- }
- if !e.HidePort {
- e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr()))
- }
- return nil
-}
-
-// ListenerAddr returns net.Addr for Listener
-func (e *Echo) ListenerAddr() net.Addr {
- e.startupMutex.RLock()
- defer e.startupMutex.RUnlock()
- if e.Listener == nil {
- return nil
- }
- return e.Listener.Addr()
-}
-
-// TLSListenerAddr returns net.Addr for TLSListener
-func (e *Echo) TLSListenerAddr() net.Addr {
- e.startupMutex.RLock()
- defer e.startupMutex.RUnlock()
- if e.TLSListener == nil {
- return nil
- }
- return e.TLSListener.Addr()
-}
-
-// StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext).
-func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) {
- e.startupMutex.Lock()
- // Setup
- s := e.Server
- s.Addr = address
- e.colorer.SetOutput(e.Logger.Output())
- s.ErrorLog = e.StdLogger
- s.Handler = h2c.NewHandler(e, h2s)
- if e.Debug {
- e.Logger.SetLevel(log.DEBUG)
- }
-
- if !e.HideBanner {
- e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website))
- }
-
- if e.Listener == nil {
- e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
- if err != nil {
- e.startupMutex.Unlock()
- return err
- }
- }
- if !e.HidePort {
- e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr()))
- }
- e.startupMutex.Unlock()
- return s.Serve(e.Listener)
-}
-
-// Close immediately stops the server.
-// It internally calls `http.Server#Close()`.
-func (e *Echo) Close() error {
- e.startupMutex.Lock()
- defer e.startupMutex.Unlock()
- if err := e.TLSServer.Close(); err != nil {
- return err
- }
- return e.Server.Close()
-}
-
-// Shutdown stops the server gracefully.
-// It internally calls `http.Server#Shutdown()`.
-func (e *Echo) Shutdown(ctx stdContext.Context) error {
- e.startupMutex.Lock()
- defer e.startupMutex.Unlock()
- if err := e.TLSServer.Shutdown(ctx); err != nil {
- return err
- }
- return e.Server.Shutdown(ctx)
-}
-
-// NewHTTPError creates a new HTTPError instance.
-func NewHTTPError(code int, message ...interface{}) *HTTPError {
- he := &HTTPError{Code: code, Message: http.StatusText(code)}
- if len(message) > 0 {
- he.Message = message[0]
- }
- return he
-}
-
-// Error makes it compatible with `error` interface.
-func (he *HTTPError) Error() string {
- if he.Internal == nil {
- return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message)
- }
- return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal)
-}
-
-// SetInternal sets error to HTTPError.Internal
-func (he *HTTPError) SetInternal(err error) *HTTPError {
- he.Internal = err
- return he
-}
-
-// Unwrap satisfies the Go 1.13 error wrapper interface.
-func (he *HTTPError) Unwrap() error {
- return he.Internal
+ return sc.Start(e)
}
// WrapHandler wraps `http.Handler` into `echo.HandlerFunc`.
@@ -925,19 +655,7 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc {
}
}
-// GetPath returns RawPath, if it's empty returns Path from URL
-// Difference between RawPath and Path is:
-// * Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/.
-// * RawPath is an optional field which only gets set if the default encoding is different from Path.
-func GetPath(r *http.Request) string {
- path := r.URL.RawPath
- if path == "" {
- path = r.URL.Path
- }
- return path
-}
-
-func (e *Echo) findRouter(host string) *Router {
+func (e *Echo) findRouter(host string) Router {
if len(e.routers) > 0 {
if r, ok := e.routers[host]; ok {
return r
@@ -946,50 +664,6 @@ func (e *Echo) findRouter(host string) *Router {
return e.router
}
-func handlerName(h HandlerFunc) string {
- t := reflect.ValueOf(h).Type()
- if t.Kind() == reflect.Func {
- return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
- }
- return t.String()
-}
-
-// // PathUnescape is wraps `url.PathUnescape`
-// func PathUnescape(s string) (string, error) {
-// return url.PathUnescape(s)
-// }
-
-// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
-// connections. It's used by ListenAndServe and ListenAndServeTLS so
-// dead TCP connections (e.g. closing laptop mid-download) eventually
-// go away.
-type tcpKeepAliveListener struct {
- *net.TCPListener
-}
-
-func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
- if c, err = ln.AcceptTCP(); err != nil {
- return
- } else if err = c.(*net.TCPConn).SetKeepAlive(true); err != nil {
- return
- }
- // Ignore error from setting the KeepAlivePeriod as some systems, such as
- // OpenBSD, do not support setting TCP_USER_TIMEOUT on IPPROTO_TCP
- _ = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute)
- return
-}
-
-func newListener(address, network string) (*tcpKeepAliveListener, error) {
- if network != "tcp" && network != "tcp4" && network != "tcp6" {
- return nil, ErrInvalidListenerNetwork
- }
- l, err := net.Listen(network, address)
- if err != nil {
- return nil, err
- }
- return &tcpKeepAliveListener{l.(*net.TCPListener)}, nil
-}
-
func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc {
for i := len(middleware) - 1; i >= 0; i-- {
h = middleware[i](h)
diff --git a/echo_test.go b/echo_test.go
index f2891586..0fbd1b55 100644
--- a/echo_test.go
+++ b/echo_test.go
@@ -3,31 +3,23 @@ package echo
import (
"bytes"
stdContext "context"
- "crypto/tls"
"errors"
"fmt"
- "io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
- "os"
- "reflect"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "golang.org/x/net/http2"
)
-type (
- user struct {
- ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"`
- Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"`
- }
-)
+type user struct {
+ ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"`
+ Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"`
+}
const (
userJSON = `{"id":1,"name":"Jon Snow"}`
@@ -61,8 +53,8 @@ func TestEcho(t *testing.T) {
// Router
assert.NotNil(t, e.Router())
- // DefaultHTTPErrorHandler
- e.DefaultHTTPErrorHandler(errors.New("error"), c)
+ e.HTTPErrorHandler(c, errors.New("error"))
+
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
@@ -84,6 +76,22 @@ func TestEchoStatic(t *testing.T) {
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
+ {
+ name: "without prefix",
+ givenPrefix: "",
+ givenRoot: "_fixture/images",
+ whenURL: "/walle.png", // `` + `*` creates route `/test*` witch matches `walle.png`
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "without prefix does not serve dir index or redirect",
+ givenPrefix: "",
+ givenRoot: "_fixture/images",
+ whenURL: "/", // `/` + `*` creates route `/*`
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
{
name: "No file",
givenPrefix: "/images",
@@ -104,7 +112,7 @@ func TestEchoStatic(t *testing.T) {
name: "Directory Redirect",
givenPrefix: "/",
givenRoot: "_fixture",
- whenURL: "/folder",
+ whenURL: "/folder", // `/folder` is subdirectory inside `/_fixture`
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/folder/",
expectBodyStartsWith: "",
@@ -211,42 +219,36 @@ func TestEchoStatic(t *testing.T) {
}
func TestEchoStaticRedirectIndex(t *testing.T) {
- assert := assert.New(t)
e := New()
// HandlerFunc
- e.Static("/static", "_fixture")
+ ri := e.Static("/static", "_fixture")
+ assert.Equal(t, http.MethodGet, ri.Method())
+ assert.Equal(t, "/static*", ri.Path())
+ assert.Equal(t, "GET:/static*", ri.Name())
+ assert.Equal(t, []string{"*"}, ri.Params())
- errCh := make(chan error)
-
- go func() {
- errCh <- e.Start("127.0.0.1:1323")
- }()
-
- time.Sleep(200 * time.Millisecond)
-
- if resp, err := http.Get("http://127.0.0.1:1323/static"); err == nil {
- defer resp.Body.Close()
- assert.Equal(http.StatusOK, resp.StatusCode)
-
- if body, err := ioutil.ReadAll(resp.Body); err == nil {
- assert.Equal(true, strings.HasPrefix(string(body), ""))
- } else {
- assert.Fail(err.Error())
- }
-
- } else {
- assert.Fail(err.Error())
+ ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
+ defer cancel()
+ addr, err := startOnRandomPort(ctx, e)
+ if err != nil {
+ assert.Fail(t, err.Error())
}
- if err := e.Close(); err != nil {
- t.Fatal(err)
- }
+ code, body, err := doGet(fmt.Sprintf("http://%v/static", addr))
+ assert.NoError(t, err)
+ assert.True(t, strings.HasPrefix(body, ""))
+ assert.Equal(t, http.StatusOK, code)
}
func TestEchoFile(t *testing.T) {
e := New()
- e.File("/walle", "_fixture/images/walle.png")
+ ri := e.File("/walle", "_fixture/images/walle.png")
+ assert.Equal(t, http.MethodGet, ri.Method())
+ assert.Equal(t, "/walle", ri.Path())
+ assert.Equal(t, "GET:/walle", ri.Name())
+ assert.Nil(t, ri.Params())
+
c, b := request(http.MethodGet, "/walle", e)
assert.Equal(t, http.StatusOK, c)
assert.NotEmpty(t, b)
@@ -258,7 +260,8 @@ func TestEchoMiddleware(t *testing.T) {
e.Pre(func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
- assert.Empty(t, c.Path())
+ // before route match is found RouteInfo does not exist
+ assert.Equal(t, nil, c.RouteInfo())
buf.WriteString("-1")
return next(c)
}
@@ -303,7 +306,7 @@ func TestEchoMiddlewareError(t *testing.T) {
return errors.New("error")
}
})
- e.GET("/", NotFoundHandler)
+ e.GET("/", notFoundHandler)
c, _ := request(http.MethodGet, "/", e)
assert.Equal(t, http.StatusInternalServerError, c)
}
@@ -358,128 +361,202 @@ func TestEchoWrapMiddleware(t *testing.T) {
}
}
+func TestEchoGet_routeInfoIsImmutable(t *testing.T) {
+ e := New()
+ ri := e.GET("/test", handlerFunc)
+ assert.Equal(t, "GET:/test", ri.Name())
+
+ riFromRouter, err := e.Router().Routes().FindByMethodPath(http.MethodGet, "/test")
+ assert.NoError(t, err)
+ assert.Equal(t, "GET:/test", riFromRouter.Name())
+
+ rInfo := ri.(routeInfo)
+ rInfo.name = "changed" // this change should not change other returned values
+
+ assert.Equal(t, "GET:/test", ri.Name())
+
+ riFromRouter, err = e.Router().Routes().FindByMethodPath(http.MethodGet, "/test")
+ assert.NoError(t, err)
+ assert.Equal(t, "GET:/test", riFromRouter.Name())
+}
+
func TestEchoConnect(t *testing.T) {
e := New()
- testMethod(t, http.MethodConnect, "/", e)
+
+ ri := e.CONNECT("/", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodConnect, ri.Method())
+ assert.Equal(t, "/", ri.Path())
+ assert.Equal(t, http.MethodConnect+":/", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodConnect, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoDelete(t *testing.T) {
e := New()
- testMethod(t, http.MethodDelete, "/", e)
+
+ ri := e.DELETE("/", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodDelete, ri.Method())
+ assert.Equal(t, "/", ri.Path())
+ assert.Equal(t, http.MethodDelete+":/", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodDelete, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoGet(t *testing.T) {
e := New()
- testMethod(t, http.MethodGet, "/", e)
+
+ ri := e.GET("/", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodGet, ri.Method())
+ assert.Equal(t, "/", ri.Path())
+ assert.Equal(t, http.MethodGet+":/", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodGet, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoHead(t *testing.T) {
e := New()
- testMethod(t, http.MethodHead, "/", e)
+
+ ri := e.HEAD("/", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodHead, ri.Method())
+ assert.Equal(t, "/", ri.Path())
+ assert.Equal(t, http.MethodHead+":/", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodHead, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoOptions(t *testing.T) {
e := New()
- testMethod(t, http.MethodOptions, "/", e)
+
+ ri := e.OPTIONS("/", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodOptions, ri.Method())
+ assert.Equal(t, "/", ri.Path())
+ assert.Equal(t, http.MethodOptions+":/", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodOptions, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoPatch(t *testing.T) {
e := New()
- testMethod(t, http.MethodPatch, "/", e)
+
+ ri := e.PATCH("/", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPatch, ri.Method())
+ assert.Equal(t, "/", ri.Path())
+ assert.Equal(t, http.MethodPatch+":/", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodPatch, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoPost(t *testing.T) {
e := New()
- testMethod(t, http.MethodPost, "/", e)
+
+ ri := e.POST("/", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPost, ri.Method())
+ assert.Equal(t, "/", ri.Path())
+ assert.Equal(t, http.MethodPost+":/", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodPost, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoPut(t *testing.T) {
e := New()
- testMethod(t, http.MethodPut, "/", e)
+
+ ri := e.PUT("/", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPut, ri.Method())
+ assert.Equal(t, "/", ri.Path())
+ assert.Equal(t, http.MethodPut+":/", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodPut, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoTrace(t *testing.T) {
e := New()
- testMethod(t, http.MethodTrace, "/", e)
+
+ ri := e.TRACE("/", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodTrace, ri.Method())
+ assert.Equal(t, "/", ri.Path())
+ assert.Equal(t, http.MethodTrace+":/", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodTrace, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoAny(t *testing.T) { // JFC
e := New()
- e.Any("/", func(c Context) error {
+ ris := e.Any("/", func(c Context) error {
return c.String(http.StatusOK, "Any")
})
+ assert.Len(t, ris, 11)
}
func TestEchoMatch(t *testing.T) { // JFC
e := New()
- e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error {
+ ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error {
return c.String(http.StatusOK, "Match")
})
+ assert.Len(t, ris, 2)
}
-func TestEchoURL(t *testing.T) {
- e := New()
- static := func(Context) error { return nil }
- getUser := func(Context) error { return nil }
- getAny := func(Context) error { return nil }
- getFile := func(Context) error { return nil }
-
- e.GET("/static/file", static)
- e.GET("/users/:id", getUser)
- e.GET("/documents/*", getAny)
- g := e.Group("/group")
- g.GET("/users/:uid/files/:fid", getFile)
-
- assert := assert.New(t)
-
- assert.Equal("/static/file", e.URL(static))
- assert.Equal("/users/:id", e.URL(getUser))
- assert.Equal("/users/1", e.URL(getUser, "1"))
- assert.Equal("/users/1", e.URL(getUser, "1"))
- assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt"))
- assert.Equal("/documents/*", e.URL(getAny))
- assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
- assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
-}
-
-func TestEchoRoutes(t *testing.T) {
- e := New()
- routes := []*Route{
- {http.MethodGet, "/users/:user/events", ""},
- {http.MethodGet, "/users/:user/events/public", ""},
- {http.MethodPost, "/repos/:owner/:repo/git/refs", ""},
- {http.MethodPost, "/repos/:owner/:repo/git/tags", ""},
- }
- for _, r := range routes {
- e.Add(r.Method, r.Path, func(c Context) error {
- return c.String(http.StatusOK, "OK")
- })
- }
-
- if assert.Equal(t, len(routes), len(e.Routes())) {
- for _, r := range e.Routes() {
- found := false
- for _, rr := range routes {
- if r.Method == rr.Method && r.Path == rr.Path {
- found = true
- break
- }
- }
- if !found {
- t.Errorf("Route %s %s not found", r.Method, r.Path)
- }
- }
- }
-}
-
-func TestEchoRoutesHandleHostsProperly(t *testing.T) {
+func TestEcho_Routers_HandleHostsProperly(t *testing.T) {
e := New()
h := e.Host("route.com")
routes := []*Route{
- {http.MethodGet, "/users/:user/events", ""},
- {http.MethodGet, "/users/:user/events/public", ""},
- {http.MethodPost, "/repos/:owner/:repo/git/refs", ""},
- {http.MethodPost, "/repos/:owner/:repo/git/tags", ""},
+ {Method: http.MethodGet, Path: "/users/:user/events"},
+ {Method: http.MethodGet, Path: "/users/:user/events/public"},
+ {Method: http.MethodPost, Path: "/repos/:owner/:repo/git/refs"},
+ {Method: http.MethodPost, Path: "/repos/:owner/:repo/git/tags"},
}
for _, r := range routes {
h.Add(r.Method, r.Path, func(c Context) error {
@@ -487,17 +564,22 @@ func TestEchoRoutesHandleHostsProperly(t *testing.T) {
})
}
- if assert.Equal(t, len(routes), len(e.Routes())) {
- for _, r := range e.Routes() {
+ routers := e.Routers()
+
+ routeCom, ok := routers["route.com"]
+ assert.True(t, ok)
+
+ if assert.Equal(t, len(routes), len(routeCom.Routes())) {
+ for _, r := range routeCom.Routes() {
found := false
for _, rr := range routes {
- if r.Method == rr.Method && r.Path == rr.Path {
+ if r.Method() == rr.Method && r.Path() == rr.Path {
found = true
break
}
}
if !found {
- t.Errorf("Route %s %s not found", r.Method, r.Path)
+ t.Errorf("Route %s %s not found", r.Method(), r.Path())
}
}
}
@@ -509,7 +591,7 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) {
return c.String(http.StatusOK, "/with/slash")
})
e.GET("/:id", func(c Context) error {
- return c.String(http.StatusOK, c.Param("id"))
+ return c.String(http.StatusOK, c.PathParam("id"))
})
var testCases = []struct {
@@ -546,8 +628,6 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) {
}
func TestEchoHost(t *testing.T) {
- assert := assert.New(t)
-
okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) }
teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) }
acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) }
@@ -642,8 +722,8 @@ func TestEchoHost(t *testing.T) {
e.ServeHTTP(rec, req)
- assert.Equal(tc.expectStatus, rec.Code)
- assert.Equal(tc.expectBody, rec.Body.String())
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ assert.Equal(t, tc.expectBody, rec.Body.String())
})
}
}
@@ -732,339 +812,27 @@ func TestEchoContext(t *testing.T) {
e.ReleaseContext(c)
}
-func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error {
- ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
- defer cancel()
-
- ticker := time.NewTicker(5 * time.Millisecond)
- defer ticker.Stop()
-
- for {
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-ticker.C:
- var addr net.Addr
- if isTLS {
- addr = e.TLSListenerAddr()
- } else {
- addr = e.ListenerAddr()
- }
- if addr != nil && strings.Contains(addr.String(), ":") {
- return nil // was started
- }
- case err := <-errChan:
- if err == http.ErrServerClosed {
- return nil
- }
- return err
- }
- }
-}
-
-func TestEchoStart(t *testing.T) {
- e := New()
- errChan := make(chan error)
-
- go func() {
- err := e.Start(":0")
- if err != nil {
- errChan <- err
- }
- }()
-
- err := waitForServerStart(e, errChan, false)
- assert.NoError(t, err)
-
- assert.NoError(t, e.Close())
-}
-
-func TestEcho_StartTLS(t *testing.T) {
- var testCases = []struct {
- name string
- addr string
- certFile string
- keyFile string
- expectError string
- }{
- {
- name: "ok",
- addr: ":0",
- },
- {
- name: "nok, invalid certFile",
- addr: ":0",
- certFile: "not existing",
- expectError: "open not existing: no such file or directory",
- },
- {
- name: "nok, invalid keyFile",
- addr: ":0",
- keyFile: "not existing",
- expectError: "open not existing: no such file or directory",
- },
- {
- name: "nok, failed to create cert out of certFile and keyFile",
- addr: ":0",
- keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key
- expectError: "tls: found a certificate rather than a key in the PEM for the private key",
- },
- {
- name: "nok, invalid tls address",
- addr: "nope",
- expectError: "listen tcp: address nope: missing port in address",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- errChan := make(chan error)
-
- go func() {
- certFile := "_fixture/certs/cert.pem"
- if tc.certFile != "" {
- certFile = tc.certFile
- }
- keyFile := "_fixture/certs/key.pem"
- if tc.keyFile != "" {
- keyFile = tc.keyFile
- }
-
- err := e.StartTLS(tc.addr, certFile, keyFile)
- if err != nil {
- errChan <- err
- }
- }()
-
- err := waitForServerStart(e, errChan, true)
- if tc.expectError != "" {
- if _, ok := err.(*os.PathError); ok {
- assert.Error(t, err) // error messages for unix and windows are different. so test only error type here
- } else {
- assert.EqualError(t, err, tc.expectError)
- }
- } else {
- assert.NoError(t, err)
- }
-
- assert.NoError(t, e.Close())
- })
- }
-}
-
-func TestEchoStartTLSAndStart(t *testing.T) {
- // We test if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server
+func TestEcho_Start(t *testing.T) {
e := New()
e.GET("/", func(c Context) error {
- return c.String(http.StatusOK, "OK")
+ return c.String(http.StatusTeapot, "OK")
})
-
- errTLSChan := make(chan error)
+ rndPort, err := net.Listen("tcp", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rndPort.Close()
+ errChan := make(chan error, 1)
go func() {
- certFile := "_fixture/certs/cert.pem"
- keyFile := "_fixture/certs/key.pem"
- err := e.StartTLS("localhost:", certFile, keyFile)
- if err != nil {
- errTLSChan <- err
- }
+ errChan <- e.Start(rndPort.Addr().String())
}()
- err := waitForServerStart(e, errTLSChan, true)
- assert.NoError(t, err)
- defer func() {
- if err := e.Shutdown(stdContext.Background()); err != nil {
- t.Error(err)
- }
- }()
-
- // check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true)
- client := &http.Client{Transport: &http.Transport{
- TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
- }}
- res, err := client.Get("https://" + e.TLSListenerAddr().String())
- assert.NoError(t, err)
- assert.Equal(t, http.StatusOK, res.StatusCode)
-
- errChan := make(chan error)
- go func() {
- err := e.Start("localhost:")
- if err != nil {
- errChan <- err
- }
- }()
- err = waitForServerStart(e, errChan, false)
- assert.NoError(t, err)
-
- // now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS
- res, err = http.Get("http://" + e.ListenerAddr().String())
- assert.NoError(t, err)
- assert.Equal(t, http.StatusOK, res.StatusCode)
-
- // see if HTTPS works after HTTP listener is also added
- res, err = client.Get("https://" + e.TLSListenerAddr().String())
- assert.NoError(t, err)
- assert.Equal(t, http.StatusOK, res.StatusCode)
-}
-
-func TestEchoStartTLSByteString(t *testing.T) {
- cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
- require.NoError(t, err)
- key, err := ioutil.ReadFile("_fixture/certs/key.pem")
- require.NoError(t, err)
-
- testCases := []struct {
- cert interface{}
- key interface{}
- expectedErr error
- name string
- }{
- {
- cert: "_fixture/certs/cert.pem",
- key: "_fixture/certs/key.pem",
- expectedErr: nil,
- name: `ValidCertAndKeyFilePath`,
- },
- {
- cert: cert,
- key: key,
- expectedErr: nil,
- name: `ValidCertAndKeyByteString`,
- },
- {
- cert: cert,
- key: 1,
- expectedErr: ErrInvalidCertOrKeyType,
- name: `InvalidKeyType`,
- },
- {
- cert: 0,
- key: key,
- expectedErr: ErrInvalidCertOrKeyType,
- name: `InvalidCertType`,
- },
- {
- cert: 0,
- key: 1,
- expectedErr: ErrInvalidCertOrKeyType,
- name: `InvalidCertAndKeyTypes`,
- },
+ select {
+ case <-time.After(250 * time.Millisecond):
+ t.Fatal("start did not error out")
+ case err := <-errChan:
+ assert.Contains(t, err.Error(), "bind: address already in use")
}
-
- for _, test := range testCases {
- test := test
- t.Run(test.name, func(t *testing.T) {
- e := New()
- e.HideBanner = true
-
- errChan := make(chan error, 0)
-
- go func() {
- errChan <- e.StartTLS(":0", test.cert, test.key)
- }()
-
- err := waitForServerStart(e, errChan, true)
- if test.expectedErr != nil {
- assert.EqualError(t, err, test.expectedErr.Error())
- } else {
- assert.NoError(t, err)
- }
-
- assert.NoError(t, e.Close())
- })
- }
-}
-
-func TestEcho_StartAutoTLS(t *testing.T) {
- var testCases = []struct {
- name string
- addr string
- expectError string
- }{
- {
- name: "ok",
- addr: ":0",
- },
- {
- name: "nok, invalid address",
- addr: "nope",
- expectError: "listen tcp: address nope: missing port in address",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- errChan := make(chan error, 0)
-
- go func() {
- errChan <- e.StartAutoTLS(tc.addr)
- }()
-
- err := waitForServerStart(e, errChan, true)
- if tc.expectError != "" {
- assert.EqualError(t, err, tc.expectError)
- } else {
- assert.NoError(t, err)
- }
-
- assert.NoError(t, e.Close())
- })
- }
-}
-
-func TestEcho_StartH2CServer(t *testing.T) {
- var testCases = []struct {
- name string
- addr string
- expectError string
- }{
- {
- name: "ok",
- addr: ":0",
- },
- {
- name: "nok, invalid address",
- addr: "nope",
- expectError: "listen tcp: address nope: missing port in address",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- e := New()
- e.Debug = true
- h2s := &http2.Server{}
-
- errChan := make(chan error)
- go func() {
- err := e.StartH2CServer(tc.addr, h2s)
- if err != nil {
- errChan <- err
- }
- }()
-
- err := waitForServerStart(e, errChan, false)
- if tc.expectError != "" {
- assert.EqualError(t, err, tc.expectError)
- } else {
- assert.NoError(t, err)
- }
-
- assert.NoError(t, e.Close())
- })
- }
-}
-
-func testMethod(t *testing.T, method, path string, e *Echo) {
- p := reflect.ValueOf(path)
- h := reflect.ValueOf(func(c Context) error {
- return c.String(http.StatusOK, method)
- })
- i := interface{}(e)
- reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h})
- _, body := request(method, path, e)
- assert.Equal(t, method, body)
}
func request(method, path string, e *Echo) (int, string) {
@@ -1074,364 +842,131 @@ func request(method, path string, e *Echo) (int, string) {
return rec.Code, rec.Body.String()
}
-func TestHTTPError(t *testing.T) {
- t.Run("non-internal", func(t *testing.T) {
- err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
- "code": 12,
- })
-
- assert.Equal(t, "code=400, message=map[code:12]", err.Error())
- })
- t.Run("internal", func(t *testing.T) {
- err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
- "code": 12,
- })
- err.SetInternal(errors.New("internal error"))
- assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error())
- })
-}
-
-func TestHTTPError_Unwrap(t *testing.T) {
- t.Run("non-internal", func(t *testing.T) {
- err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
- "code": 12,
- })
-
- assert.Nil(t, errors.Unwrap(err))
- })
- t.Run("internal", func(t *testing.T) {
- err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
- "code": 12,
- })
- err.SetInternal(errors.New("internal error"))
- assert.Equal(t, "internal error", errors.Unwrap(err).Error())
- })
-}
-
func TestDefaultHTTPErrorHandler(t *testing.T) {
- e := New()
- e.Debug = true
- e.Any("/plain", func(c Context) error {
- return errors.New("An error occurred")
- })
- e.Any("/badrequest", func(c Context) error {
- return NewHTTPError(http.StatusBadRequest, "Invalid request")
- })
- e.Any("/servererror", func(c Context) error {
- return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{
- "code": 33,
- "message": "Something bad happened",
- "error": "stackinfo",
- })
- })
- e.Any("/early-return", func(c Context) error {
- c.String(http.StatusOK, "OK")
- return errors.New("ERROR")
- })
- e.GET("/internal-error", func(c Context) error {
- err := errors.New("internal error message body")
- return NewHTTPError(http.StatusBadRequest).SetInternal(err)
- })
-
- // With Debug=true plain response contains error message
- c, b := request(http.MethodGet, "/plain", e)
- assert.Equal(t, http.StatusInternalServerError, c)
- assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b)
- // and special handling for HTTPError
- c, b = request(http.MethodGet, "/badrequest", e)
- assert.Equal(t, http.StatusBadRequest, c)
- assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b)
- // complex errors are serialized to pretty JSON
- c, b = request(http.MethodGet, "/servererror", e)
- assert.Equal(t, http.StatusInternalServerError, c)
- assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b)
- // if the body is already set HTTPErrorHandler should not add anything to response body
- c, b = request(http.MethodGet, "/early-return", e)
- assert.Equal(t, http.StatusOK, c)
- assert.Equal(t, "OK", b)
- // internal error should be reflected in the message
- c, b = request(http.MethodGet, "/internal-error", e)
- assert.Equal(t, http.StatusBadRequest, c)
- assert.Equal(t, "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n", b)
-
- e.Debug = false
- // With Debug=false the error response is shortened
- c, b = request(http.MethodGet, "/plain", e)
- assert.Equal(t, http.StatusInternalServerError, c)
- assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b)
- c, b = request(http.MethodGet, "/badrequest", e)
- assert.Equal(t, http.StatusBadRequest, c)
- assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b)
- // No difference for error response with non plain string errors
- c, b = request(http.MethodGet, "/servererror", e)
- assert.Equal(t, http.StatusInternalServerError, c)
- assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b)
-}
-
-func TestEchoClose(t *testing.T) {
- e := New()
- errCh := make(chan error)
-
- go func() {
- errCh <- e.Start(":0")
- }()
-
- err := waitForServerStart(e, errCh, false)
- assert.NoError(t, err)
-
- if err := e.Close(); err != nil {
- t.Fatal(err)
- }
-
- assert.NoError(t, e.Close())
-
- err = <-errCh
- assert.Equal(t, err.Error(), "http: Server closed")
-}
-
-func TestEchoShutdown(t *testing.T) {
- e := New()
- errCh := make(chan error)
-
- go func() {
- errCh <- e.Start(":0")
- }()
-
- err := waitForServerStart(e, errCh, false)
- assert.NoError(t, err)
-
- if err := e.Close(); err != nil {
- t.Fatal(err)
- }
-
- ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 10*time.Second)
- defer cancel()
- assert.NoError(t, e.Shutdown(ctx))
-
- err = <-errCh
- assert.Equal(t, err.Error(), "http: Server closed")
-}
-
-var listenerNetworkTests = []struct {
- test string
- network string
- address string
-}{
- {"tcp ipv4 address", "tcp", "127.0.0.1:1323"},
- {"tcp ipv6 address", "tcp", "[::1]:1323"},
- {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"},
- {"tcp6 ipv6 address", "tcp6", "[::1]:1323"},
-}
-
-func supportsIPv6() bool {
- addrs, _ := net.InterfaceAddrs()
- for _, addr := range addrs {
- // Check if any interface has local IPv6 assigned
- if strings.Contains(addr.String(), "::1") {
- return true
- }
- }
- return false
-}
-
-func TestEchoListenerNetwork(t *testing.T) {
- hasIPv6 := supportsIPv6()
- for _, tt := range listenerNetworkTests {
- if !hasIPv6 && strings.Contains(tt.address, "::") {
- t.Skip("Skipping testing IPv6 for " + tt.address + ", not available")
- continue
- }
- t.Run(tt.test, func(t *testing.T) {
- e := New()
- e.ListenerNetwork = tt.network
-
- // HandlerFunc
- e.GET("/ok", func(c Context) error {
- return c.String(http.StatusOK, "OK")
- })
-
- errCh := make(chan error)
-
- go func() {
- errCh <- e.Start(tt.address)
- }()
-
- err := waitForServerStart(e, errCh, false)
- assert.NoError(t, err)
-
- if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil {
- defer resp.Body.Close()
- assert.Equal(t, http.StatusOK, resp.StatusCode)
-
- if body, err := ioutil.ReadAll(resp.Body); err == nil {
- assert.Equal(t, "OK", string(body))
- } else {
- assert.Fail(t, err.Error())
- }
-
- } else {
- assert.Fail(t, err.Error())
- }
-
- if err := e.Close(); err != nil {
- t.Fatal(err)
- }
- })
- }
-}
-
-func TestEchoListenerNetworkInvalid(t *testing.T) {
- e := New()
- e.ListenerNetwork = "unix"
-
- // HandlerFunc
- e.GET("/ok", func(c Context) error {
- return c.String(http.StatusOK, "OK")
- })
-
- assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323"))
-}
-
-func TestEchoReverse(t *testing.T) {
- assert := assert.New(t)
-
- e := New()
- dummyHandler := func(Context) error { return nil }
-
- e.GET("/static", dummyHandler).Name = "/static"
- e.GET("/static/*", dummyHandler).Name = "/static/*"
- e.GET("/params/:foo", dummyHandler).Name = "/params/:foo"
- e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
- e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"
-
- assert.Equal("/static", e.Reverse("/static"))
- assert.Equal("/static", e.Reverse("/static", "missing param"))
- assert.Equal("/static/*", e.Reverse("/static/*"))
- assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
-
- assert.Equal("/params/:foo", e.Reverse("/params/:foo"))
- assert.Equal("/params/one", e.Reverse("/params/:foo", "one"))
- assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
- assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
- assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
- assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
-}
-
-func TestEchoReverseHandleHostProperly(t *testing.T) {
- assert := assert.New(t)
-
- dummyHandler := func(Context) error { return nil }
-
- e := New()
- h := e.Host("the_host")
- h.GET("/static", dummyHandler).Name = "/static"
- h.GET("/static/*", dummyHandler).Name = "/static/*"
-
- assert.Equal("/static", e.Reverse("/static"))
- assert.Equal("/static", e.Reverse("/static", "missing param"))
- assert.Equal("/static/*", e.Reverse("/static/*"))
- assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
-}
-
-func TestEcho_ListenerAddr(t *testing.T) {
- e := New()
-
- addr := e.ListenerAddr()
- assert.Nil(t, addr)
-
- errCh := make(chan error)
- go func() {
- errCh <- e.Start(":0")
- }()
-
- err := waitForServerStart(e, errCh, false)
- assert.NoError(t, err)
-}
-
-func TestEcho_TLSListenerAddr(t *testing.T) {
- cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
- require.NoError(t, err)
- key, err := ioutil.ReadFile("_fixture/certs/key.pem")
- require.NoError(t, err)
-
- e := New()
-
- addr := e.TLSListenerAddr()
- assert.Nil(t, addr)
-
- errCh := make(chan error)
- go func() {
- errCh <- e.StartTLS(":0", cert, key)
- }()
-
- err = waitForServerStart(e, errCh, true)
- assert.NoError(t, err)
-}
-
-func TestEcho_StartServer(t *testing.T) {
- cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
- require.NoError(t, err)
- key, err := ioutil.ReadFile("_fixture/certs/key.pem")
- require.NoError(t, err)
- certs, err := tls.X509KeyPair(cert, key)
- require.NoError(t, err)
-
var testCases = []struct {
- name string
- addr string
- TLSConfig *tls.Config
- expectError string
+ name string
+ givenExposeError bool
+ givenLoggerFunc bool
+ whenMethod string
+ whenError error
+ expectBody string
+ expectStatus int
+ expectLogged string
}{
{
- name: "ok",
- addr: ":0",
+ name: "ok, expose error = true, HTTPError",
+ givenExposeError: true,
+ whenError: NewHTTPError(http.StatusTeapot, "my_error"),
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"error":"code=418, message=my_error","message":"my_error"}` + "\n",
},
{
- name: "ok, start with TLS",
- addr: ":0",
- TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}},
+ name: "ok, expose error = true, HTTPError + internal error",
+ givenExposeError: true,
+ whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(errors.New("internal_error")),
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"error":"code=418, message=my_error, internal=internal_error","message":"my_error"}` + "\n",
},
{
- name: "nok, invalid address",
- addr: "nope",
- expectError: "listen tcp: address nope: missing port in address",
+ name: "ok, expose error = true, HTTPError + internal HTTPError",
+ givenExposeError: true,
+ whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(NewHTTPError(http.StatusTooEarly, "early_error")),
+ expectStatus: http.StatusTooEarly,
+ expectBody: `{"error":"code=418, message=my_error, internal=code=425, message=early_error","message":"early_error"}` + "\n",
},
{
- name: "nok, invalid tls address",
- addr: "nope",
- TLSConfig: &tls.Config{InsecureSkipVerify: true},
- expectError: "listen tcp: address nope: missing port in address",
+ name: "ok, expose error = false, HTTPError",
+ whenError: NewHTTPError(http.StatusTeapot, "my_error"),
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"message":"my_error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = false, HTTPError + internal HTTPError",
+ whenError: NewHTTPError(http.StatusTeapot, "my_error").WithInternal(NewHTTPError(http.StatusTooEarly, "early_error")),
+ expectStatus: http.StatusTooEarly,
+ expectBody: `{"message":"early_error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = true, Error",
+ givenExposeError: true,
+ whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
+ expectStatus: http.StatusInternalServerError,
+ expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = false, Error",
+ whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
+ expectStatus: http.StatusInternalServerError,
+ expectBody: `{"message":"Internal Server Error"}` + "\n",
+ },
+ {
+ name: "ok, http.HEAD, expose error = true, Error",
+ givenExposeError: true,
+ whenMethod: http.MethodHead,
+ whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
+ expectStatus: http.StatusInternalServerError,
+ expectBody: ``,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
+ buf := new(bytes.Buffer)
e := New()
- e.Debug = true
+ e.Logger = &jsonLogger{writer: buf}
+ e.Any("/path", func(c Context) error {
+ return tc.whenError
+ })
- server := new(http.Server)
- server.Addr = tc.addr
- if tc.TLSConfig != nil {
- server.TLSConfig = tc.TLSConfig
+ e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError)
+
+ method := http.MethodGet
+ if tc.whenMethod != "" {
+ method = tc.whenMethod
}
+ c, b := request(method, "/path", e)
- errCh := make(chan error)
- go func() {
- errCh <- e.StartServer(server)
- }()
-
- err := waitForServerStart(e, errCh, tc.TLSConfig != nil)
- if tc.expectError != "" {
- assert.EqualError(t, err, tc.expectError)
- } else {
- assert.NoError(t, err)
- }
- assert.NoError(t, e.Close())
+ assert.Equal(t, tc.expectStatus, c)
+ assert.Equal(t, tc.expectBody, b)
+ assert.Equal(t, tc.expectLogged, buf.String())
})
}
}
-func benchmarkEchoRoutes(b *testing.B, routes []*Route) {
+type myCustomContext struct {
+ context
+}
+
+func (c *myCustomContext) QueryParam(name string) string {
+ return "prefix_" + name
+}
+
+func TestEcho_customContext(t *testing.T) {
+ e := New()
+ e.NewContextFunc = func(pathParamAllocSize int) EditableContext {
+ p := make(PathParams, pathParamAllocSize)
+ return &myCustomContext{
+ context{
+ request: nil,
+ response: NewResponse(nil, e),
+ store: make(Map),
+ echo: e,
+ pathParams: &p,
+ route: nil,
+ },
+ }
+ }
+
+ e.GET("/info/:id/:file", func(c Context) error {
+ return c.String(http.StatusTeapot, c.QueryParam("param"))
+ })
+
+ status, body := request(http.MethodGet, "/info/1/a.csv", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "prefix_param", body)
+}
+
+func benchmarkEchoRoutes(b *testing.B, routes []testRoute) {
e := New()
req := httptest.NewRequest("GET", "/", nil)
u := req.URL
diff --git a/go.mod b/go.mod
index 60a64317..115d56a4 100644
--- a/go.mod
+++ b/go.mod
@@ -1,17 +1,13 @@
module github.com/labstack/echo/v4
-go 1.15
+go 1.16
require (
- github.com/golang-jwt/jwt v3.2.2+incompatible
- github.com/labstack/gommon v0.3.0
- github.com/mattn/go-colorable v0.1.8 // indirect
- github.com/mattn/go-isatty v0.0.14 // indirect
- github.com/stretchr/testify v1.4.0
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/golang-jwt/jwt/v4 v4.0.0
+ github.com/stretchr/testify v1.7.0
github.com/valyala/fasttemplate v1.2.1
- golang.org/x/crypto v0.0.0-20210817164053-32db794688a5
- golang.org/x/net v0.0.0-20210913180222-943fd674d43e
- golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect
- golang.org/x/text v0.3.7 // indirect
+ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324
+ gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)
diff --git a/go.sum b/go.sum
index 9dcac7c5..54f4aa9b 100644
--- a/go.sum
+++ b/go.sum
@@ -1,51 +1,29 @@
-github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
-github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
-github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0=
-github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
-github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
-github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8=
-github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
-github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
-github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
-github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
-github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
-github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o=
+github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
-github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
-github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4=
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
-golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ=
-golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
-golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20210913180222-943fd674d43e h1:+b/22bPvDYt4NPDcy4xAGCmON713ONAWFeY3Z7I3tR8=
-golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
-golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
+golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 h1:xrCZDmdtoloIiooiA9q0OQb9r8HejIHYoHGhGCe1pGg=
-golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
-golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
-gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
+gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/group.go b/group.go
index 426bef9e..d4416497 100644
--- a/group.go
+++ b/group.go
@@ -4,95 +4,117 @@ import (
"net/http"
)
-type (
- // Group is a set of sub-routes for a specified route. It can be used for inner
- // routes that share a common middleware or functionality that should be separate
- // from the parent echo instance while still inheriting from it.
- Group struct {
- common
- host string
- prefix string
- middleware []MiddlewareFunc
- echo *Echo
- }
-)
-
-// Use implements `Echo#Use()` for sub-routes within the Group.
-func (g *Group) Use(middleware ...MiddlewareFunc) {
- g.middleware = append(g.middleware, middleware...)
- if len(g.middleware) == 0 {
- return
- }
- // Allow all requests to reach the group as they might get dropped if router
- // doesn't find a match, making none of the group middleware process.
- g.Any("", NotFoundHandler)
- g.Any("/*", NotFoundHandler)
+// Group is a set of sub-routes for a specified route. It can be used for inner
+// routes that share a common middleware or functionality that should be separate
+// from the parent echo instance while still inheriting from it.
+type Group struct {
+ host string
+ prefix string
+ middleware []MiddlewareFunc
+ echo *Echo
}
-// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group.
-func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// Use implements `Echo#Use()` for sub-routes within the Group.
+// Group middlewares are not executed on request when there is no matching route found.
+func (g *Group) Use(middleware ...MiddlewareFunc) {
+ g.middleware = append(g.middleware, middleware...)
+}
+
+// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error.
+func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodConnect, path, h, m...)
}
-// DELETE implements `Echo#DELETE()` for sub-routes within the Group.
-func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error.
+func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodDelete, path, h, m...)
}
-// GET implements `Echo#GET()` for sub-routes within the Group.
-func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error.
+func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodGet, path, h, m...)
}
-// HEAD implements `Echo#HEAD()` for sub-routes within the Group.
-func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error.
+func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodHead, path, h, m...)
}
-// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group.
-func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error.
+func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodOptions, path, h, m...)
}
-// PATCH implements `Echo#PATCH()` for sub-routes within the Group.
-func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error.
+func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPatch, path, h, m...)
}
-// POST implements `Echo#POST()` for sub-routes within the Group.
-func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error.
+func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPost, path, h, m...)
}
-// PUT implements `Echo#PUT()` for sub-routes within the Group.
-func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error.
+func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPut, path, h, m...)
}
-// TRACE implements `Echo#TRACE()` for sub-routes within the Group.
-func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error.
+func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodTrace, path, h, m...)
}
-// Any implements `Echo#Any()` for sub-routes within the Group.
-func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
- routes := make([]*Route, len(methods))
- for i, m := range methods {
- routes[i] = g.Add(m, path, handler, middleware...)
+// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error.
+func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
+ errs := make([]error, 0)
+ ris := make(Routes, 0)
+ for _, m := range methods {
+ ri, err := g.AddRoute(Route{
+ Method: m,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ })
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ ris = append(ris, ri)
}
- return routes
+ if len(errs) > 0 {
+ panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ris
}
-// Match implements `Echo#Match()` for sub-routes within the Group.
-func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
- routes := make([]*Route, len(methods))
- for i, m := range methods {
- routes[i] = g.Add(m, path, handler, middleware...)
+// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error.
+func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
+ errs := make([]error, 0)
+ ris := make(Routes, 0)
+ for _, m := range methods {
+ ri, err := g.AddRoute(Route{
+ Method: m,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ })
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ ris = append(ris, ri)
}
- return routes
+ if len(errs) > 0 {
+ panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ris
}
// Group creates a new sub-group with prefix and optional sub-group-level middleware.
+// Important! Group middlewares are only executed in case there was exact route match and not
+// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add
+// a catch-all route `/*` for the group which handler returns always 404
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
@@ -102,23 +124,43 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
return
}
-// Static implements `Echo#Static()` for sub-routes within the Group.
-func (g *Group) Static(prefix, root string) {
- g.static(prefix, root, g.GET)
+// Static implements `Echo#Static()` for sub-routes within the Group. Panics on error.
+func (g *Group) Static(prefix, root string, middleware ...MiddlewareFunc) RouteInfo {
+ return g.Add(
+ http.MethodGet,
+ prefix+"*",
+ StaticDirectoryHandler(root, false),
+ middleware...,
+ )
}
-// File implements `Echo#File()` for sub-routes within the Group.
-func (g *Group) File(path, file string) {
- g.file(path, file, g.GET)
+// File implements `Echo#File()` for sub-routes within the Group. Panics on error.
+func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
+ handler := func(c Context) error {
+ return c.File(file)
+ }
+ return g.Add(http.MethodGet, path, handler, middleware...)
}
-// Add implements `Echo#Add()` for sub-routes within the Group.
-func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
- // Combine into a new slice to avoid accidentally passing the same slice for
+// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
+func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
+ ri, err := g.AddRoute(Route{
+ Method: method,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ })
+ if err != nil {
+ panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ri
+}
+
+// AddRoute registers a new Routable with Router
+func (g *Group) AddRoute(route Routable) (RouteInfo, error) {
+ // Combine middleware into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls.
- m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
- m = append(m, g.middleware...)
- m = append(m, middleware...)
- return g.echo.add(g.host, method, g.prefix+path, handler, m...)
+ groupRoute := route.ForGroup(g.prefix, append([]MiddlewareFunc{}, g.middleware...))
+ return g.echo.add(g.host, groupRoute)
}
diff --git a/group_test.go b/group_test.go
index c51fd91e..606105fd 100644
--- a/group_test.go
+++ b/group_test.go
@@ -1,31 +1,68 @@
package echo
import (
+ "github.com/stretchr/testify/assert"
"io/ioutil"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
-
- "github.com/stretchr/testify/assert"
)
-// TODO: Fix me
-func TestGroup(t *testing.T) {
- g := New().Group("/group")
- h := func(Context) error { return nil }
- g.CONNECT("/", h)
- g.DELETE("/", h)
- g.GET("/", h)
- g.HEAD("/", h)
- g.OPTIONS("/", h)
- g.PATCH("/", h)
- g.POST("/", h)
- g.PUT("/", h)
- g.TRACE("/", h)
- g.Any("/", h)
- g.Match([]string{http.MethodGet, http.MethodPost}, "/", h)
- g.Static("/static", "/tmp")
- g.File("/walle", "_fixture/images//walle.png")
+func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) {
+ e := New()
+
+ called := false
+ mw := func(next HandlerFunc) HandlerFunc {
+ return func(c Context) error {
+ called = true
+ return c.NoContent(http.StatusTeapot)
+ }
+ }
+ // even though group has middleware it will not be executed when there are no routes under that group
+ _ = e.Group("/group", mw)
+
+ status, body := request(http.MethodGet, "/group/nope", e)
+ assert.Equal(t, http.StatusNotFound, status)
+ assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
+
+ assert.False(t, called)
+}
+
+func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) {
+ e := New()
+
+ called := false
+ mw := func(next HandlerFunc) HandlerFunc {
+ return func(c Context) error {
+ called = true
+ return c.NoContent(http.StatusTeapot)
+ }
+ }
+ // even though group has middleware and routes when we have no match on some route the middlewares for that
+ // group will not be executed
+ g := e.Group("/group", mw)
+ g.GET("/yes", handlerFunc)
+
+ status, body := request(http.MethodGet, "/group/nope", e)
+ assert.Equal(t, http.StatusNotFound, status)
+ assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
+
+ assert.False(t, called)
+}
+
+func TestGroup_multiLevelGroup(t *testing.T) {
+ e := New()
+
+ api := e.Group("/api")
+ users := api.Group("/users")
+ users.GET("/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ status, body := request(http.MethodGet, "/api/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
}
func TestGroupFile(t *testing.T) {
@@ -92,11 +129,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
}
m2 := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
- return c.String(http.StatusOK, c.Path())
+ return c.String(http.StatusOK, c.RouteInfo().Path())
}
}
h := func(c Context) error {
- return c.String(http.StatusOK, c.Path())
+ return c.String(http.StatusOK, c.RouteInfo().Path())
}
g.Use(m1)
g.GET("/help", h, m2)
@@ -119,3 +156,442 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
assert.Equal(t, "/*", m)
}
+
+func TestGroup_CONNECT(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.CONNECT("/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodConnect, ri.Method())
+ assert.Equal(t, "/users/activate", ri.Path())
+ assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodConnect, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_DELETE(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.DELETE("/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodDelete, ri.Method())
+ assert.Equal(t, "/users/activate", ri.Path())
+ assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodDelete, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_HEAD(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.HEAD("/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodHead, ri.Method())
+ assert.Equal(t, "/users/activate", ri.Path())
+ assert.Equal(t, http.MethodHead+":/users/activate", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodHead, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_OPTIONS(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.OPTIONS("/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodOptions, ri.Method())
+ assert.Equal(t, "/users/activate", ri.Path())
+ assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodOptions, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_PATCH(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.PATCH("/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPatch, ri.Method())
+ assert.Equal(t, "/users/activate", ri.Path())
+ assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodPatch, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_POST(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.POST("/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPost, ri.Method())
+ assert.Equal(t, "/users/activate", ri.Path())
+ assert.Equal(t, http.MethodPost+":/users/activate", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodPost, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_PUT(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.PUT("/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPut, ri.Method())
+ assert.Equal(t, "/users/activate", ri.Path())
+ assert.Equal(t, http.MethodPut+":/users/activate", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodPut, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_TRACE(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.TRACE("/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodTrace, ri.Method())
+ assert.Equal(t, "/users/activate", ri.Path())
+ assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name())
+ assert.Nil(t, ri.Params())
+
+ status, body := request(http.MethodTrace, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_Any(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ris := users.Any("/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+ assert.Len(t, ris, 11)
+
+ for _, m := range methods {
+ status, body := request(m, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+ }
+}
+
+func TestGroup_AnyWithErrors(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ users.GET("/activate", func(c Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+
+ errs := func() (errs []error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if tmpErr, ok := r.([]error); ok {
+ errs = tmpErr
+ return
+ }
+ panic(r)
+ }
+ }()
+
+ users.Any("/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+ return nil
+ }()
+ assert.Len(t, errs, 1)
+ assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
+
+ for _, m := range methods {
+ status, body := request(m, "/users/activate", e)
+
+ expect := http.StatusTeapot
+ if m == http.MethodGet {
+ expect = http.StatusOK
+ }
+ assert.Equal(t, expect, status)
+ assert.Equal(t, `OK`, body)
+ }
+}
+
+func TestGroup_Match(t *testing.T) {
+ e := New()
+
+ myMethods := []string{http.MethodGet, http.MethodPost}
+ users := e.Group("/users")
+ ris := users.Match(myMethods, "/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+ assert.Len(t, ris, 2)
+
+ for _, m := range myMethods {
+ status, body := request(m, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+ }
+}
+
+func TestGroup_MatchWithErrors(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ users.GET("/activate", func(c Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+ myMethods := []string{http.MethodGet, http.MethodPost}
+
+ errs := func() (errs []error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if tmpErr, ok := r.([]error); ok {
+ errs = tmpErr
+ return
+ }
+ panic(r)
+ }
+ }()
+
+ users.Match(myMethods, "/activate", func(c Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+ return nil
+ }()
+ assert.Len(t, errs, 1)
+ assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
+
+ for _, m := range myMethods {
+ status, body := request(m, "/users/activate", e)
+
+ expect := http.StatusTeapot
+ if m == http.MethodGet {
+ expect = http.StatusOK
+ }
+ assert.Equal(t, expect, status)
+ assert.Equal(t, `OK`, body)
+ }
+}
+
+func TestGroup_Static(t *testing.T) {
+ e := New()
+
+ g := e.Group("/books")
+ ri := g.Static("/download", "_fixture")
+ assert.Equal(t, http.MethodGet, ri.Method())
+ assert.Equal(t, "/books/download*", ri.Path())
+ assert.Equal(t, "GET:/books/download*", ri.Name())
+ assert.Equal(t, []string{"*"}, ri.Params())
+
+ req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ body := rec.Body.String()
+ assert.True(t, strings.HasPrefix(body, ""))
+}
+
+func TestGroup_StaticMultiTest(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenPrefix string
+ givenRoot string
+ whenURL string
+ expectStatus int
+ expectHeaderLocation string
+ expectBodyStartsWith string
+ }{
+ {
+ name: "ok",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/images",
+ whenURL: "/test/images/walle.png",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "ok, without prefix",
+ givenPrefix: "",
+ givenRoot: "_fixture/images",
+ whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png`
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "nok, without prefix does not serve dir index",
+ givenPrefix: "",
+ givenRoot: "_fixture/images",
+ whenURL: "/test/", // `/test` + `*` creates route `/test*`
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "No file",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/scripts",
+ whenURL: "/test/images/bolt.png",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/images",
+ whenURL: "/test/images/",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory Redirect",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/test/folder",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/test/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory Redirect with non-root path",
+ givenPrefix: "/static",
+ givenRoot: "_fixture",
+ whenURL: "/test/static",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/test/static/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory 404 (request URL without slash)",
+ givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
+ givenRoot: "_fixture",
+ whenURL: "/test/folder", // no trailing slash
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Prefixed directory redirect (without slash redirect to slash)",
+ givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
+ givenRoot: "_fixture",
+ whenURL: "/test/folder", // no trailing slash
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/test/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory with index.html",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/test/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending with slash)",
+ givenPrefix: "/assets/",
+ givenRoot: "_fixture",
+ whenURL: "/test/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending without slash)",
+ givenPrefix: "/assets",
+ givenRoot: "_fixture",
+ whenURL: "/test/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Sub-directory with index.html",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/test/folder/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "do not allow directory traversal (backslash - windows separator)",
+ givenPrefix: "/",
+ givenRoot: "_fixture/",
+ whenURL: `/test/..\\middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "do not allow directory traversal (slash - unix separator)",
+ givenPrefix: "/",
+ givenRoot: "_fixture/",
+ whenURL: `/test/../middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+
+ g := e.Group("/test")
+ g.Static(tc.givenPrefix, tc.givenRoot)
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ body := rec.Body.String()
+ if tc.expectBodyStartsWith != "" {
+ assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
+ } else {
+ assert.Equal(t, "", body)
+ }
+
+ if tc.expectHeaderLocation != "" {
+ assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
+ } else {
+ _, ok := rec.Result().Header["Location"]
+ assert.False(t, ok)
+ }
+ })
+ }
+}
diff --git a/httperror.go b/httperror.go
new file mode 100644
index 00000000..5c217dac
--- /dev/null
+++ b/httperror.go
@@ -0,0 +1,74 @@
+package echo
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+)
+
+// Errors
+var (
+ ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType)
+ ErrNotFound = NewHTTPError(http.StatusNotFound)
+ ErrUnauthorized = NewHTTPError(http.StatusUnauthorized)
+ ErrForbidden = NewHTTPError(http.StatusForbidden)
+ ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
+ ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
+ ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests)
+ ErrBadRequest = NewHTTPError(http.StatusBadRequest)
+ ErrBadGateway = NewHTTPError(http.StatusBadGateway)
+ ErrInternalServerError = NewHTTPError(http.StatusInternalServerError)
+ ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout)
+ ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable)
+ ErrValidatorNotRegistered = errors.New("validator not registered")
+ ErrRendererNotRegistered = errors.New("renderer not registered")
+ ErrInvalidRedirectCode = errors.New("invalid redirect status code")
+ ErrCookieNotFound = errors.New("cookie not found")
+ ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
+ ErrInvalidListenerNetwork = errors.New("invalid listener network")
+)
+
+// HTTPError represents an error that occurred while handling a request.
+type HTTPError struct {
+ Code int `json:"-"`
+ Message interface{} `json:"message"`
+ Internal error `json:"-"` // Stores the error returned by an external dependency
+}
+
+// NewHTTPError creates a new HTTPError instance.
+func NewHTTPError(code int, message ...interface{}) *HTTPError { // FIXME: this need cleanup - why vararg if [0] is only used?
+ he := &HTTPError{Code: code, Message: http.StatusText(code)}
+ if len(message) > 0 {
+ he.Message = message[0]
+ }
+ return he
+}
+
+// NewHTTPErrorWithInternal creates a new HTTPError instance with internal error set.
+func NewHTTPErrorWithInternal(code int, internalError error, message ...interface{}) *HTTPError {
+ he := NewHTTPError(code, message...)
+ he.Internal = internalError
+ return he
+}
+
+// Error makes it compatible with `error` interface.
+func (he *HTTPError) Error() string {
+ if he.Internal == nil {
+ return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message)
+ }
+ return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal)
+}
+
+// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field
+func (he *HTTPError) WithInternal(err error) *HTTPError {
+ return &HTTPError{
+ Code: he.Code,
+ Message: he.Message,
+ Internal: err,
+ }
+}
+
+// Unwrap satisfies the Go 1.13 error wrapper interface.
+func (he *HTTPError) Unwrap() error {
+ return he.Internal
+}
diff --git a/httperror_test.go b/httperror_test.go
new file mode 100644
index 00000000..f9d340f1
--- /dev/null
+++ b/httperror_test.go
@@ -0,0 +1,52 @@
+package echo
+
+import (
+ "errors"
+ "github.com/stretchr/testify/assert"
+ "net/http"
+ "testing"
+)
+
+func TestHTTPError(t *testing.T) {
+ t.Run("non-internal", func(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
+ "code": 12,
+ })
+
+ assert.Equal(t, "code=400, message=map[code:12]", err.Error())
+ })
+ t.Run("internal", func(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
+ "code": 12,
+ })
+ err = err.WithInternal(errors.New("internal error"))
+ assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error())
+ })
+}
+
+func TestNewHTTPErrorWithInternal(t *testing.T) {
+ he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"), "test message")
+ assert.Equal(t, "code=400, message=test message, internal=test", he.Error())
+}
+
+func TestNewHTTPErrorWithInternal_noCustomMessage(t *testing.T) {
+ he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"))
+ assert.Equal(t, "code=400, message=Bad Request, internal=test", he.Error())
+}
+
+func TestHTTPError_Unwrap(t *testing.T) {
+ t.Run("non-internal", func(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
+ "code": 12,
+ })
+
+ assert.Nil(t, errors.Unwrap(err))
+ })
+ t.Run("internal", func(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
+ "code": 12,
+ })
+ err = err.WithInternal(errors.New("internal error"))
+ assert.Equal(t, "internal error", errors.Unwrap(err).Error())
+ })
+}
diff --git a/json.go b/json.go
index 16b2d057..16074fa2 100644
--- a/json.go
+++ b/json.go
@@ -23,9 +23,16 @@ func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string
func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error {
err := json.NewDecoder(c.Request().Body).Decode(i)
if ute, ok := err.(*json.UnmarshalTypeError); ok {
- return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err)
+ return NewHTTPErrorWithInternal(
+ http.StatusBadRequest,
+ err,
+ fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset),
+ )
} else if se, ok := err.(*json.SyntaxError); ok {
- return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err)
+ return NewHTTPErrorWithInternal(http.StatusBadRequest,
+ err,
+ fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error()),
+ )
}
return err
}
diff --git a/log.go b/log.go
index 3f8de590..49442535 100644
--- a/log.go
+++ b/log.go
@@ -1,41 +1,141 @@
package echo
import (
+ "bytes"
"io"
-
- "github.com/labstack/gommon/log"
+ "strconv"
+ "sync"
+ "time"
)
-type (
- // Logger defines the logging interface.
- Logger interface {
- Output() io.Writer
- SetOutput(w io.Writer)
- Prefix() string
- SetPrefix(p string)
- Level() log.Lvl
- SetLevel(v log.Lvl)
- SetHeader(h string)
- Print(i ...interface{})
- Printf(format string, args ...interface{})
- Printj(j log.JSON)
- Debug(i ...interface{})
- Debugf(format string, args ...interface{})
- Debugj(j log.JSON)
- Info(i ...interface{})
- Infof(format string, args ...interface{})
- Infoj(j log.JSON)
- Warn(i ...interface{})
- Warnf(format string, args ...interface{})
- Warnj(j log.JSON)
- Error(i ...interface{})
- Errorf(format string, args ...interface{})
- Errorj(j log.JSON)
- Fatal(i ...interface{})
- Fatalj(j log.JSON)
- Fatalf(format string, args ...interface{})
- Panic(i ...interface{})
- Panicj(j log.JSON)
- Panicf(format string, args ...interface{})
+//-----------------------------------------------------------------------------
+// Example for Zap (https://github.com/uber-go/zap)
+//func main() {
+// e := echo.New()
+// logger, _ := zap.NewProduction()
+// e.Logger = &ZapLogger{logger: logger}
+//}
+//type ZapLogger struct {
+// logger *zap.Logger
+//}
+//
+//func (l *ZapLogger) Write(p []byte) (n int, err error) {
+// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
+// l.logger.Info(string(p), zap.String("subsystem", "echo")) // naively log everything as string message.
+// return len(p), nil
+//}
+//
+//func (l *ZapLogger) Error(err error) {
+// l.logger.Error(err.Error(), zap.Error(err), zap.String("subsystem", "echo"))
+//}
+
+//-----------------------------------------------------------------------------
+// Example for Zerolog (https://github.com/rs/zerolog)
+//func main() {
+// e := echo.New()
+// logger := zerolog.New(os.Stdout)
+// e.Logger = &ZeroLogger{logger: &logger}
+//}
+//
+//type ZeroLogger struct {
+// logger *zerolog.Logger
+//}
+//
+//func (l *ZeroLogger) Write(p []byte) (n int, err error) {
+// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
+// l.logger.Info().Str("subsystem", "echo").Msg(string(p)) // naively log everything as string message.
+// return len(p), nil
+//}
+//
+//func (l *ZeroLogger) Error(err error) {
+// l.logger.Error().Str("subsystem", "echo").Err(err).Msg(err.Error())
+//}
+
+//-----------------------------------------------------------------------------
+// Example for Logrus (https://github.com/sirupsen/logrus)
+//func main() {
+// e := echo.New()
+// e.Logger = &LogrusLogger{logger: logrus.New()}
+//}
+//
+//type LogrusLogger struct {
+// logger *logrus.Logger
+//}
+//
+//func (l *LogrusLogger) Write(p []byte) (n int, err error) {
+// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
+// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Info(string(p)) // naively log everything as string message.
+// return len(p), nil
+//}
+//
+//func (l *LogrusLogger) Error(err error) {
+// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Error(err)
+//}
+
+// Logger defines the logging interface that Echo uses internally in few places.
+// For logging in handlers use your own logger instance (dependency injected or package/public variable) from logging framework of your choice.
+type Logger interface {
+ // Write provides writer interface for http.Server `ErrorLog` and for logging startup messages.
+ // `http.Server.ErrorLog` logs errors from accepting connections, unexpected behavior from handlers,
+ // and underlying FileSystem errors.
+ // `logger` middleware will use this method to write its JSON payload.
+ Write(p []byte) (n int, err error)
+ // Error logs the error
+ Error(err error)
+}
+
+// jsonLogger is similar logger formatting implementation as `v4` had. It is not particularly fast or efficient. Only
+// goal it to exist is to have somewhat backwards compatibility with `v4` for Echo internals logging formatting.
+// It is not meant for logging in handlers/middlewares. Use some real logging library for those cases.
+type jsonLogger struct {
+ writer io.Writer
+ bufferPool sync.Pool
+
+ timeNow func() time.Time
+}
+
+func newJSONLogger(writer io.Writer) *jsonLogger {
+ return &jsonLogger{
+ writer: writer,
+ bufferPool: sync.Pool{
+ New: func() interface{} {
+ return bytes.NewBuffer(make([]byte, 256))
+ },
+ },
+ timeNow: time.Now,
}
-)
+}
+
+func (l *jsonLogger) Write(p []byte) (n int, err error) {
+ pLen := len(p)
+ if pLen >= 2 && // naively try to avoid JSON values to be wrapped into message
+ (p[0] == '{' && p[pLen-2] == '}' && p[pLen-1] == '\n') ||
+ (p[0] == '{' && p[pLen-1] == '}') {
+ return l.writer.Write(p)
+ }
+ // we log with WARN level as we have no idea what that message level should be. From Echo perspective this method is
+ // called when we pass Echo logger to http.Server.ErrorLog and there are problems inside http.Server - which probably
+ // deserves at least WARN level.
+ return l.printf("WARN", string(p))
+}
+
+func (l *jsonLogger) Error(err error) {
+ _, _ = l.printf("ERROR", err.Error())
+}
+
+func (l *jsonLogger) printf(level string, message string) (n int, err error) {
+ buf := l.bufferPool.Get().(*bytes.Buffer)
+ buf.Reset()
+ defer l.bufferPool.Put(buf)
+
+ buf.WriteString(`{"time":"`)
+ buf.WriteString(l.timeNow().Format(time.RFC3339Nano))
+ buf.WriteString(`","level":"`)
+ buf.WriteString(level)
+ buf.WriteString(`","prefix":"echo","message":`)
+
+ buf.WriteString(strconv.Quote(message))
+ buf.WriteString("}\n")
+
+ return l.writer.Write(buf.Bytes())
+}
diff --git a/log_test.go b/log_test.go
new file mode 100644
index 00000000..ed635290
--- /dev/null
+++ b/log_test.go
@@ -0,0 +1,77 @@
+package echo
+
+import (
+ "bytes"
+ "github.com/stretchr/testify/assert"
+ "testing"
+ "time"
+)
+
+func TestJsonLogger_Write(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when []byte
+ expect string
+ }{
+ {
+ name: "ok, write non JSONlike message",
+ when: []byte("version: %v, build: %v"),
+ expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: %v, build: %v"}` + "\n",
+ },
+ {
+ name: "ok, write quoted message",
+ when: []byte(`version: "%v"`),
+ expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: \"%v\""}` + "\n",
+ },
+ {
+ name: "ok, write JSON",
+ when: []byte(`{"version": 123}` + "\n"),
+ expect: `{"version": 123}` + "\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ buf := new(bytes.Buffer)
+ logger := newJSONLogger(buf)
+ logger.timeNow = func() time.Time {
+ return time.Unix(1631045377, 0)
+ }
+
+ _, err := logger.Write(tc.when)
+
+ result := buf.String()
+ assert.Equal(t, tc.expect, result)
+ assert.NoError(t, err)
+ })
+ }
+}
+
+func TestJsonLogger_Error(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenError error
+ expect string
+ }{
+ {
+ name: "ok",
+ whenError: ErrForbidden,
+ expect: `{"time":"2021-09-07T23:09:37+03:00","level":"ERROR","prefix":"echo","message":"code=403, message=Forbidden"}` + "\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ buf := new(bytes.Buffer)
+ logger := newJSONLogger(buf)
+ logger.timeNow = func() time.Time {
+ return time.Unix(1631045377, 0)
+ }
+
+ logger.Error(tc.whenError)
+
+ result := buf.String()
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
diff --git a/middleware/DEVELOPMENT.md b/middleware/DEVELOPMENT.md
new file mode 100644
index 00000000..68002cad
--- /dev/null
+++ b/middleware/DEVELOPMENT.md
@@ -0,0 +1,13 @@
+# Development Guidelines for middlewares
+
+// FIXME: add info about `MiddlewareConfigurator` interface
+
+## Best practices:
+
+* Do not use `panic` in middleware creator functions in case of invalid configuration.
+* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead
+ because previous middlewares up in call chain could have logic for dealing with returned errors.
+* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they
+ want to create middleware with panics or with returning errors on configuration errors.
+* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same.
+
diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go
index 8cf1ed9f..73caeaf9 100644
--- a/middleware/basic_auth.go
+++ b/middleware/basic_auth.go
@@ -1,64 +1,59 @@
package middleware
import (
+ "bytes"
"encoding/base64"
+ "errors"
+ "fmt"
"strconv"
"strings"
"github.com/labstack/echo/v4"
)
-type (
- // BasicAuthConfig defines the config for BasicAuth middleware.
- BasicAuthConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
+// BasicAuthConfig defines the config for BasicAuthWithConfig middleware.
+type BasicAuthConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
- // Validator is a function to validate BasicAuth credentials.
- // Required.
- Validator BasicAuthValidator
+ // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers
+ // this function would be called once for each header until first valid result is returned
+ // Required.
+ Validator BasicAuthValidator
- // Realm is a string to define realm attribute of BasicAuth.
- // Default value "Restricted".
- Realm string
- }
+ // Realm is a string to define realm attribute of BasicAuthWithConfig.
+ // Default value "Restricted".
+ Realm string
+}
- // BasicAuthValidator defines a function to validate BasicAuth credentials.
- BasicAuthValidator func(string, string, echo.Context) (bool, error)
-)
+// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials.
+type BasicAuthValidator func(c echo.Context, user string, password string) (bool, error)
const (
basic = "basic"
defaultRealm = "Restricted"
)
-var (
- // DefaultBasicAuthConfig is the default BasicAuth middleware config.
- DefaultBasicAuthConfig = BasicAuthConfig{
- Skipper: DefaultSkipper,
- Realm: defaultRealm,
- }
-)
-
// BasicAuth returns an BasicAuth middleware.
//
// For valid credentials it calls the next handler.
// For missing or invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
- c := DefaultBasicAuthConfig
- c.Validator = fn
- return BasicAuthWithConfig(c)
+ return BasicAuthWithConfig(BasicAuthConfig{Validator: fn})
}
-// BasicAuthWithConfig returns an BasicAuth middleware with config.
-// See `BasicAuth()`.
+// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
- // Defaults
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration
+func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Validator == nil {
- panic("echo: basic-auth middleware requires a validator function")
+ return nil, errors.New("echo basic-auth middleware requires a validator function")
}
if config.Skipper == nil {
- config.Skipper = DefaultBasicAuthConfig.Skipper
+ config.Skipper = DefaultSkipper
}
if config.Realm == "" {
config.Realm = defaultRealm
@@ -70,29 +65,33 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return next(c)
}
- auth := c.Request().Header.Get(echo.HeaderAuthorization)
+ var lastError error
l := len(basic)
-
- if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
- b, err := base64.StdEncoding.DecodeString(auth[l+1:])
- if err != nil {
- return err
+ for _, auth := range c.Request().Header[echo.HeaderAuthorization] {
+ if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
+ continue
}
- cred := string(b)
- for i := 0; i < len(cred); i++ {
- if cred[i] == ':' {
- // Verify credentials
- valid, err := config.Validator(cred[:i], cred[i+1:], c)
- if err != nil {
- return err
- } else if valid {
- return next(c)
- }
- break
+
+ b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
+ if errDecode != nil {
+ lastError = fmt.Errorf("invalid basic auth value: %w", errDecode)
+ continue
+ }
+ idx := bytes.IndexByte(b, ':')
+ if idx >= 0 {
+ valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:]))
+ if errValidate != nil {
+ lastError = errValidate
+ } else if valid {
+ return next(c)
}
}
}
+ if lastError != nil {
+ return lastError
+ }
+
realm := defaultRealm
if config.Realm != defaultRealm {
realm = strconv.Quote(config.Realm)
@@ -102,5 +101,5 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
return echo.ErrUnauthorized
}
- }
+ }, nil
}
diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go
index 76039db0..41532312 100644
--- a/middleware/basic_auth_test.go
+++ b/middleware/basic_auth_test.go
@@ -2,6 +2,7 @@ package middleware
import (
"encoding/base64"
+ "errors"
"net/http"
"net/http/httptest"
"strings"
@@ -12,60 +13,146 @@ import (
)
func TestBasicAuth(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- res := httptest.NewRecorder()
- c := e.NewContext(req, res)
- f := func(u, p string, c echo.Context) (bool, error) {
+ validatorFunc := func(c echo.Context, u, p string) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
}
+ if u == "error" {
+ return false, errors.New(p)
+ }
return false, nil
}
- h := BasicAuth(f)(func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
- })
+ defaultConfig := BasicAuthConfig{Validator: validatorFunc}
- assert := assert.New(t)
+ var testCases = []struct {
+ name string
+ givenConfig BasicAuthConfig
+ whenAuth []string
+ expectHeader string
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenConfig: defaultConfig,
+ whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
+ },
+ {
+ name: "ok, multiple",
+ givenConfig: defaultConfig,
+ whenAuth: []string{
+ "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
+ basic + " NOT_BASE64",
+ basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
+ },
+ },
+ {
+ name: "nok, invalid Authorization header",
+ givenConfig: defaultConfig,
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
+ expectHeader: basic + ` realm=Restricted`,
+ expectErr: "code=401, message=Unauthorized",
+ },
+ {
+ name: "nok, not base64 Authorization header",
+ givenConfig: defaultConfig,
+ whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
+ expectErr: "invalid basic auth value: illegal base64 data at input byte 3",
+ },
+ {
+ name: "nok, missing Authorization header",
+ givenConfig: defaultConfig,
+ expectHeader: basic + ` realm=Restricted`,
+ expectErr: "code=401, message=Unauthorized",
+ },
+ {
+ name: "ok, realm",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
+ whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
+ },
+ {
+ name: "ok, realm, case-insensitive header scheme",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
+ },
+ {
+ name: "nok, realm, invalid Authorization header",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
+ expectHeader: basic + ` realm="someRealm"`,
+ expectErr: "code=401, message=Unauthorized",
+ },
+ {
+ name: "nok, validator func returns an error",
+ givenConfig: defaultConfig,
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
+ expectErr: "my_error",
+ },
+ {
+ name: "ok, skipped",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool {
+ return true
+ }},
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
+ },
+ }
- // Valid credentials
- auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
- req.Header.Set(echo.HeaderAuthorization, auth)
- assert.NoError(h(c))
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ res := httptest.NewRecorder()
+ c := e.NewContext(req, res)
- h = BasicAuthWithConfig(BasicAuthConfig{
- Skipper: nil,
- Validator: f,
- Realm: "someRealm",
- })(func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
- })
+ config := tc.givenConfig
- // Valid credentials
- auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
- req.Header.Set(echo.HeaderAuthorization, auth)
- assert.NoError(h(c))
+ mw, err := config.ToMiddleware()
+ assert.NoError(t, err)
- // Case-insensitive header scheme
- auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
- req.Header.Set(echo.HeaderAuthorization, auth)
- assert.NoError(h(c))
+ h := mw(func(c echo.Context) error {
+ return c.String(http.StatusTeapot, "test")
+ })
- // Invalid credentials
- auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
- req.Header.Set(echo.HeaderAuthorization, auth)
- he := h(c).(*echo.HTTPError)
- assert.Equal(http.StatusUnauthorized, he.Code)
- assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
+ if len(tc.whenAuth) != 0 {
+ for _, a := range tc.whenAuth {
+ req.Header.Add(echo.HeaderAuthorization, a)
+ }
+ }
+ err = h(c)
- // Missing Authorization header
- req.Header.Del(echo.HeaderAuthorization)
- he = h(c).(*echo.HTTPError)
- assert.Equal(http.StatusUnauthorized, he.Code)
-
- // Invalid Authorization header
- auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
- req.Header.Set(echo.HeaderAuthorization, auth)
- he = h(c).(*echo.HTTPError)
- assert.Equal(http.StatusUnauthorized, he.Code)
+ if tc.expectErr != "" {
+ assert.Equal(t, http.StatusOK, res.Code)
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.Equal(t, http.StatusTeapot, res.Code)
+ assert.NoError(t, err)
+ }
+ if tc.expectHeader != "" {
+ assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
+ }
+ })
+ }
+}
+
+func TestBasicAuth_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ mw := BasicAuth(nil)
+ assert.NotNil(t, mw)
+ })
+
+ mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) {
+ return true, nil
+ })
+ assert.NotNil(t, mw)
+}
+
+func TestBasicAuthWithConfig_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
+ assert.NotNil(t, mw)
+ })
+
+ mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) {
+ return true, nil
+ }})
+ assert.NotNil(t, mw)
}
diff --git a/middleware/body_dump.go b/middleware/body_dump.go
index ebd0d0ab..d04822df 100644
--- a/middleware/body_dump.go
+++ b/middleware/body_dump.go
@@ -3,6 +3,7 @@ package middleware
import (
"bufio"
"bytes"
+ "errors"
"io"
"io/ioutil"
"net"
@@ -11,63 +12,56 @@ import (
"github.com/labstack/echo/v4"
)
-type (
- // BodyDumpConfig defines the config for BodyDump middleware.
- BodyDumpConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
+// BodyDumpConfig defines the config for BodyDump middleware.
+type BodyDumpConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
- // Handler receives request and response payload.
- // Required.
- Handler BodyDumpHandler
- }
+ // Handler receives request and response payload.
+ // Required.
+ Handler BodyDumpHandler
+}
- // BodyDumpHandler receives the request and response payload.
- BodyDumpHandler func(echo.Context, []byte, []byte)
+// BodyDumpHandler receives the request and response payload.
+type BodyDumpHandler func(c echo.Context, reqBody []byte, resBody []byte)
- bodyDumpResponseWriter struct {
- io.Writer
- http.ResponseWriter
- }
-)
-
-var (
- // DefaultBodyDumpConfig is the default BodyDump middleware config.
- DefaultBodyDumpConfig = BodyDumpConfig{
- Skipper: DefaultSkipper,
- }
-)
+type bodyDumpResponseWriter struct {
+ io.Writer
+ http.ResponseWriter
+}
// BodyDump returns a BodyDump middleware.
//
// BodyDump middleware captures the request and response payload and calls the
// registered handler.
func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
- c := DefaultBodyDumpConfig
- c.Handler = handler
- return BodyDumpWithConfig(c)
+ return BodyDumpWithConfig(BodyDumpConfig{Handler: handler})
}
// BodyDumpWithConfig returns a BodyDump middleware with config.
// See: `BodyDump()`.
func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
- // Defaults
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration
+func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Handler == nil {
- panic("echo: body-dump middleware requires a handler function")
+ return nil, errors.New("echo body-dump middleware requires a handler function")
}
if config.Skipper == nil {
- config.Skipper = DefaultBodyDumpConfig.Skipper
+ config.Skipper = DefaultSkipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) (err error) {
+ return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
// Request
reqBody := []byte{}
- if c.Request().Body != nil { // Read
+ if c.Request().Body != nil {
reqBody, _ = ioutil.ReadAll(c.Request().Body)
}
c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset
@@ -78,16 +72,14 @@ func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer}
c.Response().Writer = writer
- if err = next(c); err != nil {
- c.Error(err)
- }
+ err := next(c)
// Callback
config.Handler(c, reqBody, resBody.Bytes())
- return
+ return err
}
- }
+ }, nil
}
func (w *bodyDumpResponseWriter) WriteHeader(code int) {
diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go
index e6e00f72..21ca4cac 100644
--- a/middleware/body_dump_test.go
+++ b/middleware/body_dump_test.go
@@ -28,31 +28,48 @@ func TestBodyDump(t *testing.T) {
requestBody := ""
responseBody := ""
- mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
+ mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
- })
+ }}.ToMiddleware()
+ assert.NoError(t, err)
- assert := assert.New(t)
-
- if assert.NoError(mw(h)(c)) {
- assert.Equal(requestBody, hw)
- assert.Equal(responseBody, hw)
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(hw, rec.Body.String())
+ if assert.NoError(t, mw(h)(c)) {
+ assert.Equal(t, requestBody, hw)
+ assert.Equal(t, responseBody, hw)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.String())
}
- // Must set default skipper
- BodyDumpWithConfig(BodyDumpConfig{
- Skipper: nil,
- Handler: func(c echo.Context, reqBody, resBody []byte) {
- requestBody = string(reqBody)
- responseBody = string(resBody)
- },
- })
}
-func TestBodyDumpFails(t *testing.T) {
+func TestBodyDump_skipper(t *testing.T) {
+ e := echo.New()
+
+ isCalled := false
+ mw, err := BodyDumpConfig{
+ Skipper: func(c echo.Context) bool {
+ return true
+ },
+ Handler: func(c echo.Context, reqBody, resBody []byte) {
+ isCalled = true
+ },
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}"))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c echo.Context) error {
+ return errors.New("some error")
+ }
+
+ err = mw(h)(c)
+ assert.EqualError(t, err, "some error")
+ assert.False(t, isCalled)
+}
+
+func TestBodyDump_fails(t *testing.T) {
e := echo.New()
hw := "Hello, World!"
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
@@ -62,30 +79,37 @@ func TestBodyDumpFails(t *testing.T) {
return errors.New("some error")
}
- mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
+ mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}.ToMiddleware()
+ assert.NoError(t, err)
- if !assert.Error(t, mw(h)(c)) {
- t.FailNow()
- }
+ err = mw(h)(c)
+ assert.EqualError(t, err, "some error")
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestBodyDumpWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
- mw = BodyDumpWithConfig(BodyDumpConfig{
+ mw := BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: nil,
})
+ assert.NotNil(t, mw)
})
assert.NotPanics(t, func() {
- mw = BodyDumpWithConfig(BodyDumpConfig{
- Skipper: func(c echo.Context) bool {
- return true
- },
- Handler: func(c echo.Context, reqBody, resBody []byte) {
- },
- })
-
- if !assert.Error(t, mw(h)(c)) {
- t.FailNow()
- }
+ mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}})
+ assert.NotNil(t, mw)
+ })
+}
+
+func TestBodyDump_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ mw := BodyDump(nil)
+ assert.NotNil(t, mw)
+ })
+
+ assert.NotPanics(t, func() {
+ BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
})
}
diff --git a/middleware/body_limit.go b/middleware/body_limit.go
index b436bd59..b00bbae6 100644
--- a/middleware/body_limit.go
+++ b/middleware/body_limit.go
@@ -1,98 +1,83 @@
package middleware
import (
- "fmt"
"io"
"sync"
"github.com/labstack/echo/v4"
- "github.com/labstack/gommon/bytes"
)
-type (
- // BodyLimitConfig defines the config for BodyLimit middleware.
- BodyLimitConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
+// BodyLimitConfig defines the config for BodyLimitWithConfig middleware.
+type BodyLimitConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
- // Maximum allowed size for a request body, it can be specified
- // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P.
- Limit string `yaml:"limit"`
- limit int64
- }
+ // LimitBytes is maximum allowed size in bytes for a request body
+ LimitBytes int64
+}
- limitedReader struct {
- BodyLimitConfig
- reader io.ReadCloser
- read int64
- context echo.Context
- }
-)
-
-var (
- // DefaultBodyLimitConfig is the default BodyLimit middleware config.
- DefaultBodyLimitConfig = BodyLimitConfig{
- Skipper: DefaultSkipper,
- }
-)
+type limitedReader struct {
+ BodyLimitConfig
+ reader io.ReadCloser
+ read int64
+ context echo.Context
+}
// BodyLimit returns a BodyLimit middleware.
//
-// BodyLimit middleware sets the maximum allowed size for a request body, if the
-// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
-// response. The BodyLimit is determined based on both `Content-Length` request
+// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it
+// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request
// header and actual content read, which makes it super secure.
-// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
-// G, T or P.
-func BodyLimit(limit string) echo.MiddlewareFunc {
- c := DefaultBodyLimitConfig
- c.Limit = limit
- return BodyLimitWithConfig(c)
+func BodyLimit(limitBytes int64) echo.MiddlewareFunc {
+ return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes})
}
-// BodyLimitWithConfig returns a BodyLimit middleware with config.
-// See: `BodyLimit()`.
+// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for
+// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response.
+// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which
+// makes it super secure.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
- // Defaults
- if config.Skipper == nil {
- config.Skipper = DefaultBodyLimitConfig.Skipper
- }
+ return toMiddlewareOrPanic(config)
+}
- limit, err := bytes.Parse(config.Limit)
- if err != nil {
- panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit))
+// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration
+func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ pool := sync.Pool{
+ New: func() interface{} {
+ return &limitedReader{BodyLimitConfig: config}
+ },
}
- config.limit = limit
- pool := limitedReaderPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
-
req := c.Request()
// Based on content length
- if req.ContentLength > config.limit {
+ if req.ContentLength > config.LimitBytes {
return echo.ErrStatusRequestEntityTooLarge
}
// Based on content read
r := pool.Get().(*limitedReader)
- r.Reset(req.Body, c)
+ r.Reset(c, req.Body)
defer pool.Put(r)
req.Body = r
return next(c)
}
- }
+ }, nil
}
func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b)
r.read += int64(n)
- if r.read > r.limit {
+ if r.read > r.LimitBytes {
return n, echo.ErrStatusRequestEntityTooLarge
}
return
@@ -102,16 +87,8 @@ func (r *limitedReader) Close() error {
return r.reader.Close()
}
-func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) {
+func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) {
r.reader = reader
r.context = context
r.read = 0
}
-
-func limitedReaderPool(c BodyLimitConfig) sync.Pool {
- return sync.Pool{
- New: func() interface{} {
- return &limitedReader{BodyLimitConfig: c}
- },
- }
-}
diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go
index 0e8642a0..c03c6fd8 100644
--- a/middleware/body_limit_test.go
+++ b/middleware/body_limit_test.go
@@ -11,6 +11,137 @@ import (
"github.com/stretchr/testify/assert"
)
+func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
+ e := echo.New()
+ hw := []byte("Hello, World!")
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c echo.Context) error {
+ body, err := ioutil.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+
+ // Based on content length (within limit)
+ mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
+ }
+
+ // Based on content read (overlimit)
+ mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
+ assert.NoError(t, err)
+ he := mw(h)(c).(*echo.HTTPError)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
+
+ // Based on content read (within limit)
+ req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+
+ mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "Hello, World!", rec.Body.String())
+
+ // Based on content read (overlimit)
+ req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
+ assert.NoError(t, err)
+ he = mw(h)(c).(*echo.HTTPError)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
+}
+
+func TestBodyLimitReader(t *testing.T) {
+ hw := []byte("Hello, World!")
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec := httptest.NewRecorder()
+
+ config := BodyLimitConfig{
+ Skipper: DefaultSkipper,
+ LimitBytes: 2,
+ }
+ reader := &limitedReader{
+ BodyLimitConfig: config,
+ reader: ioutil.NopCloser(bytes.NewReader(hw)),
+ context: e.NewContext(req, rec),
+ }
+
+ // read all should return ErrStatusRequestEntityTooLarge
+ _, err := ioutil.ReadAll(reader)
+ he := err.(*echo.HTTPError)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
+
+ // reset reader and read two bytes must succeed
+ bt := make([]byte, 2)
+ reader.Reset(e.NewContext(req, rec), ioutil.NopCloser(bytes.NewReader(hw)))
+ n, err := reader.Read(bt)
+ assert.Equal(t, 2, n)
+ assert.Equal(t, nil, err)
+}
+
+func TestBodyLimit_skipper(t *testing.T) {
+ e := echo.New()
+ h := func(c echo.Context) error {
+ body, err := ioutil.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+ mw, err := BodyLimitConfig{
+ Skipper: func(c echo.Context) bool {
+ return true
+ },
+ LimitBytes: 2,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ hw := []byte("Hello, World!")
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
+}
+
+func TestBodyLimitWithConfig(t *testing.T) {
+ e := echo.New()
+ hw := []byte("Hello, World!")
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c echo.Context) error {
+ body, err := ioutil.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+
+ mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB})
+
+ err := mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
+}
+
func TestBodyLimit(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
@@ -25,61 +156,10 @@ func TestBodyLimit(t *testing.T) {
return c.String(http.StatusOK, string(body))
}
- assert := assert.New(t)
+ mw := BodyLimit(2 * MB)
- // Based on content length (within limit)
- if assert.NoError(BodyLimit("2M")(h)(c)) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(hw, rec.Body.Bytes())
- }
-
- // Based on content read (overlimit)
- he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
- assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
-
- // Based on content read (within limit)
- req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- if assert.NoError(BodyLimit("2M")(h)(c)) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("Hello, World!", rec.Body.String())
- }
-
- // Based on content read (overlimit)
- req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
- assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
-}
-
-func TestBodyLimitReader(t *testing.T) {
- hw := []byte("Hello, World!")
- e := echo.New()
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
- rec := httptest.NewRecorder()
-
- config := BodyLimitConfig{
- Skipper: DefaultSkipper,
- Limit: "2B",
- limit: 2,
- }
- reader := &limitedReader{
- BodyLimitConfig: config,
- reader: ioutil.NopCloser(bytes.NewReader(hw)),
- context: e.NewContext(req, rec),
- }
-
- // read all should return ErrStatusRequestEntityTooLarge
- _, err := ioutil.ReadAll(reader)
- he := err.(*echo.HTTPError)
- assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
-
- // reset reader and read two bytes must succeed
- bt := make([]byte, 2)
- reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec))
- n, err := reader.Read(bt)
- assert.Equal(t, 2, n)
- assert.Equal(t, nil, err)
+ err := mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
}
diff --git a/middleware/compress.go b/middleware/compress.go
index 6ae19745..8cd66f8d 100644
--- a/middleware/compress.go
+++ b/middleware/compress.go
@@ -3,6 +3,7 @@ package middleware
import (
"bufio"
"compress/gzip"
+ "errors"
"io"
"io/ioutil"
"net"
@@ -13,50 +14,45 @@ import (
"github.com/labstack/echo/v4"
)
-type (
- // GzipConfig defines the config for Gzip middleware.
- GzipConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // Gzip compression level.
- // Optional. Default value -1.
- Level int `yaml:"level"`
- }
-
- gzipResponseWriter struct {
- io.Writer
- http.ResponseWriter
- }
-)
-
const (
gzipScheme = "gzip"
)
-var (
- // DefaultGzipConfig is the default Gzip middleware config.
- DefaultGzipConfig = GzipConfig{
- Skipper: DefaultSkipper,
- Level: -1,
- }
-)
+// GzipConfig defines the config for Gzip middleware.
+type GzipConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
-// Gzip returns a middleware which compresses HTTP response using gzip compression
-// scheme.
-func Gzip() echo.MiddlewareFunc {
- return GzipWithConfig(DefaultGzipConfig)
+ // Gzip compression level.
+ // Optional. Default value -1.
+ Level int
}
-// GzipWithConfig return Gzip middleware with config.
-// See: `Gzip()`.
+type gzipResponseWriter struct {
+ io.Writer
+ http.ResponseWriter
+}
+
+// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
+func Gzip() echo.MiddlewareFunc {
+ return GzipWithConfig(GzipConfig{})
+}
+
+// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
- // Defaults
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration
+func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
- config.Skipper = DefaultGzipConfig.Skipper
+ config.Skipper = DefaultSkipper
+ }
+ if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
+ return nil, errors.New("invalid gzip level")
}
if config.Level == 0 {
- config.Level = DefaultGzipConfig.Level
+ config.Level = -1
}
pool := gzipCompressPool(config)
@@ -97,7 +93,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
}
return next(c)
}
- }
+ }, nil
}
func (w *gzipResponseWriter) WriteHeader(code int) {
diff --git a/middleware/compress_test.go b/middleware/compress_test.go
index d16ffca4..1b4ebc8c 100644
--- a/middleware/compress_test.go
+++ b/middleware/compress_test.go
@@ -3,94 +3,128 @@ package middleware
import (
"bytes"
"compress/gzip"
- "io"
"io/ioutil"
"net/http"
"net/http/httptest"
+ "os"
"testing"
+ "time"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
-func TestGzip(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
+func TestGzip_NoAcceptEncodingHeader(t *testing.T) {
// Skip if no Accept-Encoding header
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
- h(c)
- assert := assert.New(t)
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
- assert.Equal("test", rec.Body.String())
+ err := h(c)
+ assert.NoError(t, err)
- // Gzip
- req = httptest.NewRequest(http.MethodGet, "/", nil)
+ assert.Equal(t, "test", rec.Body.String())
+}
+
+func TestMustGzipWithConfig_panics(t *testing.T) {
+ assert.Panics(t, func() {
+ GzipWithConfig(GzipConfig{Level: 999})
+ })
+}
+
+func TestGzip_AcceptEncodingHeader(t *testing.T) {
+ h := Gzip()(func(c echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- h(c)
- assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
- assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := h(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
+
r, err := gzip.NewReader(rec.Body)
- if assert.NoError(err) {
- buf := new(bytes.Buffer)
- defer r.Close()
- buf.ReadFrom(r)
- assert.Equal("test", buf.String())
- }
+ assert.NoError(t, err)
+ buf := new(bytes.Buffer)
+ defer r.Close()
+ buf.ReadFrom(r)
+ assert.Equal(t, "test", buf.String())
+}
- chunkBuf := make([]byte, 5)
-
- // Gzip chunked
- req = httptest.NewRequest(http.MethodGet, "/", nil)
+func TestGzip_chunked(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- rec = httptest.NewRecorder()
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
- c = e.NewContext(req, rec)
- Gzip()(func(c echo.Context) error {
+ chunkChan := make(chan struct{})
+ waitChan := make(chan struct{})
+ h := Gzip()(func(c echo.Context) error {
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Transfer-Encoding", "chunked")
// Write and flush the first part of the data
- c.Response().Write([]byte("test\n"))
+ c.Response().Write([]byte("first\n"))
c.Response().Flush()
- // Read the first part of the data
- assert.True(rec.Flushed)
- assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
- r.Reset(rec.Body)
-
- _, err = io.ReadFull(r, chunkBuf)
- assert.NoError(err)
- assert.Equal("test\n", string(chunkBuf))
+ chunkChan <- struct{}{}
+ <-waitChan
// Write and flush the second part of the data
- c.Response().Write([]byte("test\n"))
+ c.Response().Write([]byte("second\n"))
c.Response().Flush()
- _, err = io.ReadFull(r, chunkBuf)
- assert.NoError(err)
- assert.Equal("test\n", string(chunkBuf))
+ chunkChan <- struct{}{}
+ <-waitChan
// Write the final part of the data and return
- c.Response().Write([]byte("test"))
- return nil
- })(c)
+ c.Response().Write([]byte("third"))
+ chunkChan <- struct{}{}
+ return nil
+ })
+
+ go func() {
+ err := h(c)
+ chunkChan <- struct{}{}
+ assert.NoError(t, err)
+ }()
+
+ <-chunkChan // wait for first write
+ waitChan <- struct{}{}
+
+ <-chunkChan // wait for second write
+ waitChan <- struct{}{}
+
+ <-chunkChan // wait for final write in handler
+ <-chunkChan // wait for return from handler
+ time.Sleep(5 * time.Millisecond) // to have time for flushing
+
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+
+ r, err := gzip.NewReader(rec.Body)
+ assert.NoError(t, err)
buf := new(bytes.Buffer)
- defer r.Close()
buf.ReadFrom(r)
- assert.Equal("test", buf.String())
+ assert.Equal(t, "first\nsecond\nthird", buf.String())
}
-func TestGzipNoContent(t *testing.T) {
+func TestGzip_NoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
@@ -106,7 +140,7 @@ func TestGzipNoContent(t *testing.T) {
}
}
-func TestGzipErrorReturned(t *testing.T) {
+func TestGzip_ErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Gzip())
e.GET("/", func(c echo.Context) error {
@@ -120,31 +154,25 @@ func TestGzipErrorReturned(t *testing.T) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
-func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
- e := echo.New()
- // Invalid level
- e.Use(GzipWithConfig(GzipConfig{Level: 12}))
- e.GET("/", func(c echo.Context) error {
- c.Response().Write([]byte("test"))
- return nil
- })
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusInternalServerError, rec.Code)
- assert.Contains(t, rec.Body.String(), "gzip")
+func TestGzipWithConfig_invalidLevel(t *testing.T) {
+ mw, err := GzipConfig{Level: 12}.ToMiddleware()
+ assert.EqualError(t, err, "invalid gzip level")
+ assert.Nil(t, mw)
}
// Issue #806
func TestGzipWithStatic(t *testing.T) {
e := echo.New()
+ e.Filesystem = os.DirFS("../")
+
e.Use(Gzip())
- e.Static("/test", "../_fixture/images")
+ e.Static("/test", "_fixture/images")
req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
+
e.ServeHTTP(rec, req)
+
assert.Equal(t, http.StatusOK, rec.Code)
// Data is written out in chunks when Content-Length == "", so only
// validate the content length if it's not set.
diff --git a/middleware/cors.go b/middleware/cors.go
index d6ef8964..c3db8749 100644
--- a/middleware/cors.go
+++ b/middleware/cors.go
@@ -9,60 +9,56 @@ import (
"github.com/labstack/echo/v4"
)
-type (
- // CORSConfig defines the config for CORS middleware.
- CORSConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
+// CORSConfig defines the config for CORS middleware.
+type CORSConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
- // AllowOrigin defines a list of origins that may access the resource.
- // Optional. Default value []string{"*"}.
- AllowOrigins []string `yaml:"allow_origins"`
+ // AllowOrigin defines a list of origins that may access the resource.
+ // Optional. Default value []string{"*"}.
+ AllowOrigins []string
- // AllowOriginFunc is a custom function to validate the origin. It takes the
- // origin as an argument and returns true if allowed or false otherwise. If
- // an error is returned, it is returned by the handler. If this option is
- // set, AllowOrigins is ignored.
- // Optional.
- AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
+ // AllowOriginFunc is a custom function to validate the origin. It takes the
+ // origin as an argument and returns true if allowed or false otherwise. If
+ // an error is returned, it is returned by the handler. If this option is
+ // set, AllowOrigins is ignored.
+ // Optional.
+ AllowOriginFunc func(origin string) (bool, error)
- // AllowMethods defines a list methods allowed when accessing the resource.
- // This is used in response to a preflight request.
- // Optional. Default value DefaultCORSConfig.AllowMethods.
- AllowMethods []string `yaml:"allow_methods"`
+ // AllowMethods defines a list methods allowed when accessing the resource.
+ // This is used in response to a preflight request.
+ // Optional. Default value DefaultCORSConfig.AllowMethods.
+ AllowMethods []string
- // AllowHeaders defines a list of request headers that can be used when
- // making the actual request. This is in response to a preflight request.
- // Optional. Default value []string{}.
- AllowHeaders []string `yaml:"allow_headers"`
+ // AllowHeaders defines a list of request headers that can be used when
+ // making the actual request. This is in response to a preflight request.
+ // Optional. Default value []string{}.
+ AllowHeaders []string
- // AllowCredentials indicates whether or not the response to the request
- // can be exposed when the credentials flag is true. When used as part of
- // a response to a preflight request, this indicates whether or not the
- // actual request can be made using credentials.
- // Optional. Default value false.
- AllowCredentials bool `yaml:"allow_credentials"`
+ // AllowCredentials indicates whether or not the response to the request
+ // can be exposed when the credentials flag is true. When used as part of
+ // a response to a preflight request, this indicates whether or not the
+ // actual request can be made using credentials.
+ // Optional. Default value false.
+ AllowCredentials bool
- // ExposeHeaders defines a whitelist headers that clients are allowed to
- // access.
- // Optional. Default value []string{}.
- ExposeHeaders []string `yaml:"expose_headers"`
+ // ExposeHeaders defines a whitelist headers that clients are allowed to
+ // access.
+ // Optional. Default value []string{}.
+ ExposeHeaders []string
- // MaxAge indicates how long (in seconds) the results of a preflight request
- // can be cached.
- // Optional. Default value 0.
- MaxAge int `yaml:"max_age"`
- }
-)
+ // MaxAge indicates how long (in seconds) the results of a preflight request
+ // can be cached.
+ // Optional. Default value 0.
+ MaxAge int
+}
-var (
- // DefaultCORSConfig is the default CORS middleware config.
- DefaultCORSConfig = CORSConfig{
- Skipper: DefaultSkipper,
- AllowOrigins: []string{"*"},
- AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
- }
-)
+// DefaultCORSConfig is the default CORS middleware config.
+var DefaultCORSConfig = CORSConfig{
+ Skipper: DefaultSkipper,
+ AllowOrigins: []string{"*"},
+ AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
+}
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
@@ -70,9 +66,14 @@ func CORS() echo.MiddlewareFunc {
return CORSWithConfig(DefaultCORSConfig)
}
-// CORSWithConfig returns a CORS middleware with config.
+// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
// See: `CORS()`.
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration
+func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCORSConfig.Skipper
@@ -207,5 +208,5 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
}
return c.NoContent(http.StatusNoContent)
}
- }
+ }, nil
}
diff --git a/middleware/cors_test.go b/middleware/cors_test.go
index 717abe49..8a654223 100644
--- a/middleware/cors_test.go
+++ b/middleware/cors_test.go
@@ -17,7 +17,7 @@ func TestCORS(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := CORS()(echo.NotFoundHandler)
+ h := CORS()(func(c echo.Context) error { return echo.ErrNotFound })
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@@ -26,7 +26,7 @@ func TestCORS(t *testing.T) {
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
- h = CORS()(echo.NotFoundHandler)
+ h = CORS()(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
@@ -38,7 +38,7 @@ func TestCORS(t *testing.T) {
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
- })(echo.NotFoundHandler)
+ })(func(c echo.Context) error { return echo.ErrNotFound })
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@@ -55,7 +55,7 @@ func TestCORS(t *testing.T) {
AllowCredentials: true,
MaxAge: 3600,
})
- h = cors(echo.NotFoundHandler)
+ h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
@@ -73,7 +73,7 @@ func TestCORS(t *testing.T) {
AllowCredentials: true,
MaxAge: 3600,
})
- h = cors(echo.NotFoundHandler)
+ h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
@@ -90,7 +90,7 @@ func TestCORS(t *testing.T) {
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
})
- h = cors(echo.NotFoundHandler)
+ h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
@@ -104,7 +104,7 @@ func TestCORS(t *testing.T) {
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"http://*.example.com"},
})
- h = cors(echo.NotFoundHandler)
+ h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@@ -149,7 +149,7 @@ func Test_allowOriginScheme(t *testing.T) {
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern},
})
- h := cors(echo.NotFoundHandler)
+ h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
if tt.expected {
@@ -240,7 +240,7 @@ func Test_allowOriginSubdomain(t *testing.T) {
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern},
})
- h := cors(echo.NotFoundHandler)
+ h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
if tt.expected {
@@ -331,7 +331,7 @@ func TestCorsHeaders(t *testing.T) {
//AllowCredentials: true,
//MaxAge: 3600,
})
- h := cors(echo.NotFoundHandler)
+ h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
@@ -387,11 +387,11 @@ func Test_allowOriginFunc(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
- cors := CORSWithConfig(CORSConfig{
- AllowOriginFunc: allowOriginFunc,
- })
- h := cors(echo.NotFoundHandler)
- err := h(c)
+ cors, err := CORSConfig{AllowOriginFunc: allowOriginFunc}.ToMiddleware()
+ assert.NoError(t, err)
+
+ h := cors(func(c echo.Context) error { return echo.ErrNotFound })
+ err = h(c)
expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil {
diff --git a/middleware/csrf.go b/middleware/csrf.go
index 7804997d..5859bd0f 100644
--- a/middleware/csrf.go
+++ b/middleware/csrf.go
@@ -8,89 +8,90 @@ import (
"time"
"github.com/labstack/echo/v4"
- "github.com/labstack/gommon/random"
)
-type (
- // CSRFConfig defines the config for CSRF middleware.
- CSRFConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
+// CSRFConfig defines the config for CSRF middleware.
+type CSRFConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
- // TokenLength is the length of the generated token.
- TokenLength uint8 `yaml:"token_length"`
- // Optional. Default value 32.
+ // TokenLength is the length of the generated token.
+ TokenLength uint8
+ // Optional. Default value 32.
- // TokenLookup is a string in the form of "