1
0
mirror of https://github.com/labstack/echo.git synced 2025-06-10 23:57:28 +02:00

WIP: logger examples

WIP: make default logger implemented custom writer for jsonlike logs
WIP: improve examples
WIP: defaultErrorHandler use errors.As to unwrap errors. Update readme
WIP: default logger logs json, restore e.Start method
WIP: clean router.Match a bit
WIP: func types/fields have echo.Context has first element
WIP: remove yaml tags as functions etc can not be serialized anyway
WIP: change BindPathParams,BindQueryParams,BindHeaders from methods to functions and reverse arguments to be like DefaultBinder.Bind is
WIP: improved comments, logger now extracts status from error
WIP: go mod tidy
WIP: rebase with 4.5.0
WIP:
* removed todos.
* removed StartAutoTLS and StartH2CServer methods from `StartConfig`
* KeyAuth middleware errorhandler can swallow the error and resume next middleware
WIP: add RouterConfig.UseEscapedPathForMatching to use escaped path for matching request against routes
WIP: FIXMEs
WIP: upgrade golang-jwt/jwt to `v4`
WIP: refactor http methods to return RouteInfo
WIP: refactor static not creating multiple routes
WIP: refactor route and middleware adding functions not to return error directly
WIP: Use 401 for problematic/missing headers for key auth and JWT middleware (#1552, #1402).
> In summary, a 401 Unauthorized response should be used for missing or bad authentication
WIP: replace `HTTPError.SetInternal` with `HTTPError.WithInternal` so we could not mutate global error variables
WIP: add RouteInfo and RouteMatchType into Context what we could know from in middleware what route was matched and/or type of that match (200/404/405)
WIP: make notFoundHandler and methodNotAllowedHandler private. encourage that all errors be handled in Echo.HTTPErrorHandler
WIP: server cleanup ideas
WIP: routable.ForGroup
WIP: note about logger middleware
WIP: bind should not default values on second try. use crypto rand for better randomness
WIP: router add route as interface and returns info as interface
WIP: improve flaky test (remains still flaky)
WIP: add notes about bind default values
WIP: every route can have their own path params names
WIP: routerCreator and different tests
WIP: different things
WIP: remove route implementation
WIP: support custom method types
WIP: extractor tests
WIP: v5.0.x proposal
over v4.4.0
This commit is contained in:
toimtoimtoim 2021-07-15 23:34:01 +03:00
parent c6f0c667f1
commit 6ef5f77bf2
80 changed files with 9216 additions and 4922 deletions

View File

@ -27,7 +27,8 @@ jobs:
os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, macos-latest, windows-latest]
# Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
# Echo tests with last four major releases # Echo tests with last four major releases
go: [1.14, 1.15, 1.16, 1.17] # except v5 starts from 1.16 until there is last four major releases after that
go: [1.16, 1.17]
name: ${{ matrix.os }} @ Go ${{ matrix.go }} name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:

View File

@ -1,21 +0,0 @@
arch:
- amd64
- ppc64le
language: go
go:
- 1.14.x
- 1.15.x
- tip
env:
- GO111MODULE=on
install:
- go get -v golang.org/x/lint/golint
script:
- golint -set_exit_status ./...
- go test -race -coverprofile=coverage.txt -covermode=atomic ./...
after_success:
- bash <(curl -s https://codecov.io/bash)
matrix:
allow_failures:
- go: tip

View File

@ -24,11 +24,11 @@ race: ## Run tests with data race detector
@go test -race ${PKG_LIST} @go test -race ${PKG_LIST}
benchmark: ## Run benchmarks benchmark: ## Run benchmarks
@go test -run="-" -bench=".*" ${PKG_LIST} @go test -run="-" -benchmem -bench=".*" ${PKG_LIST}
help: ## Display this help screen help: ## Display this help screen
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
goversion ?= "1.15" goversion ?= "1.16"
test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.15 test_version: ## Run tests inside Docker with given version (defaults to 1.16 oldest supported). Example: make test_version goversion=1.16
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"

View File

@ -12,6 +12,8 @@
## Supported Go versions ## Supported Go versions
Echo supports last four major releases. `v5` starts from 1.16 until there is last four major releases after that.
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
Therefore a Go version capable of understanding /vN suffixed imports is required: Therefore a Go version capable of understanding /vN suffixed imports is required:
@ -67,8 +69,8 @@ package main
import ( import (
"net/http" "net/http"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v5"
"github.com/labstack/echo/v4/middleware" "github.com/labstack/echo/v5/middleware"
) )
func main() { func main() {
@ -83,7 +85,9 @@ func main() {
e.GET("/", hello) e.GET("/", hello)
// Start server // Start server
e.Logger.Fatal(e.Start(":1323")) if err := e.Start(":1323"); err != http.ErrServerClosed {
log.Fatal(err)
}
} }
// Handler // Handler

88
bind.go
View File

@ -11,42 +11,38 @@ import (
"strings" "strings"
) )
type ( // Binder is the interface that wraps the Bind method.
// Binder is the interface that wraps the Bind method. type Binder interface {
Binder interface { Bind(c Context, i interface{}) error
Bind(i interface{}, c Context) error }
}
// DefaultBinder is the default implementation of the Binder interface. // DefaultBinder is the default implementation of the Binder interface.
DefaultBinder struct{} type DefaultBinder struct{}
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. // BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
// Types that don't implement this, but do implement encoding.TextUnmarshaler // Types that don't implement this, but do implement encoding.TextUnmarshaler
// will use that interface instead. // will use that interface instead.
BindUnmarshaler interface { type BindUnmarshaler interface {
// UnmarshalParam decodes and assigns a value from an form or query param. // UnmarshalParam decodes and assigns a value from an form or query param.
UnmarshalParam(param string) error UnmarshalParam(param string) error
} }
)
// BindPathParams binds path params to bindable object // BindPathParams binds path params to bindable object
func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { func BindPathParams(c Context, i interface{}) error {
names := c.ParamNames()
values := c.ParamValues()
params := map[string][]string{} params := map[string][]string{}
for i, name := range names { for _, param := range c.PathParams() {
params[name] = []string{values[i]} params[param.Name] = []string{param.Value}
} }
if err := b.bindData(i, params, "param"); err != nil { if err := bindData(i, params, "param"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
return nil return nil
} }
// BindQueryParams binds query params to bindable object // BindQueryParams binds query params to bindable object
func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { func BindQueryParams(c Context, i interface{}) error {
if err := b.bindData(i, c.QueryParams(), "query"); err != nil { if err := bindData(i, c.QueryParams(), "query"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
return nil return nil
} }
@ -56,7 +52,7 @@ func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error {
// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm // See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm
// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm // See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm
func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { func BindBody(c Context, i interface{}) (err error) {
req := c.Request() req := c.Request()
if req.ContentLength == 0 { if req.ContentLength == 0 {
return return
@ -70,25 +66,25 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
case *HTTPError: case *HTTPError:
return err return err
default: default:
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
} }
case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML): case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML):
if err = xml.NewDecoder(req.Body).Decode(i); err != nil { if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
if ute, ok := err.(*xml.UnsupportedTypeError); ok { if ute, ok := err.(*xml.UnsupportedTypeError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error()))
} else if se, ok := err.(*xml.SyntaxError); ok { } else if se, ok := err.(*xml.SyntaxError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error()))
} }
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
params, err := c.FormParams() params, err := c.FormParams()
if err != nil { if err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
if err = b.bindData(i, params, "form"); err != nil { if err = bindData(i, params, "form"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
default: default:
return ErrUnsupportedMediaType return ErrUnsupportedMediaType
@ -98,17 +94,17 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
// BindHeaders binds HTTP headers to a bindable object // BindHeaders binds HTTP headers to a bindable object
func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error { func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error {
if err := b.bindData(i, c.Request().Header, "header"); err != nil { if err := bindData(i, c.Request().Header, "header"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
} }
return nil return nil
} }
// Bind implements the `Binder#Bind` function. // Bind implements the `Binder#Bind` function.
// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous // Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous
// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. // step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { func (b *DefaultBinder) Bind(c Context, i interface{}) (err error) {
if err := b.BindPathParams(c, i); err != nil { if err := BindPathParams(c, i); err != nil {
return err return err
} }
// Issue #1670 - Query params are binded only for GET/DELETE and NOT for usual request with body (POST/PUT/PATCH) // Issue #1670 - Query params are binded only for GET/DELETE and NOT for usual request with body (POST/PUT/PATCH)
@ -116,15 +112,15 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
// i.e. is `&id=1&lang=en` from URL same as `{"id":100,"lang":"de"}` request body and which one should have priority when binding. // i.e. is `&id=1&lang=en` from URL same as `{"id":100,"lang":"de"}` request body and which one should have priority when binding.
// This HTTP method check restores pre v4.1.11 behavior and avoids different problems when query is mixed with body // This HTTP method check restores pre v4.1.11 behavior and avoids different problems when query is mixed with body
if c.Request().Method == http.MethodGet || c.Request().Method == http.MethodDelete { if c.Request().Method == http.MethodGet || c.Request().Method == http.MethodDelete {
if err = b.BindQueryParams(c, i); err != nil { if err = BindQueryParams(c, i); err != nil {
return err return err
} }
} }
return b.BindBody(c, i) return BindBody(c, i)
} }
// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag
func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error { func bindData(destination interface{}, data map[string][]string, tag string) error {
if destination == nil || len(data) == 0 { if destination == nil || len(data) == 0 {
return nil return nil
} }
@ -170,7 +166,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
// If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags).
// structs that implement BindUnmarshaler are binded only when they have explicit tag // structs that implement BindUnmarshaler are binded only when they have explicit tag
if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { if err := bindData(structField.Addr().Interface(), data, tag); err != nil {
return err return err
} }
} }
@ -297,7 +293,7 @@ func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
func setIntField(value string, bitSize int, field reflect.Value) error { func setIntField(value string, bitSize int, field reflect.Value) error {
if value == "" { if value == "" {
value = "0" return nil
} }
intVal, err := strconv.ParseInt(value, 10, bitSize) intVal, err := strconv.ParseInt(value, 10, bitSize)
if err == nil { if err == nil {
@ -308,7 +304,7 @@ func setIntField(value string, bitSize int, field reflect.Value) error {
func setUintField(value string, bitSize int, field reflect.Value) error { func setUintField(value string, bitSize int, field reflect.Value) error {
if value == "" { if value == "" {
value = "0" return nil
} }
uintVal, err := strconv.ParseUint(value, 10, bitSize) uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err == nil { if err == nil {
@ -319,7 +315,7 @@ func setUintField(value string, bitSize int, field reflect.Value) error {
func setBoolField(value string, field reflect.Value) error { func setBoolField(value string, field reflect.Value) error {
if value == "" { if value == "" {
value = "false" return nil
} }
boolVal, err := strconv.ParseBool(value) boolVal, err := strconv.ParseBool(value)
if err == nil { if err == nil {
@ -330,7 +326,7 @@ func setBoolField(value string, field reflect.Value) error {
func setFloatField(value string, bitSize int, field reflect.Value) error { func setFloatField(value string, bitSize int, field reflect.Value) error {
if value == "" { if value == "" {
value = "0.0" return nil
} }
floatVal, err := strconv.ParseFloat(value, bitSize) floatVal, err := strconv.ParseFloat(value, bitSize)
if err == nil { if err == nil {

View File

@ -300,6 +300,52 @@ func TestBindHeaderParamBadType(t *testing.T) {
} }
} }
func TestBind_CombineQueryWithHeaderParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/products/999?length=50&page=10&language=et", nil)
req.Header.Set("language", "de")
req.Header.Set("length", "99")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.SetPathParams(PathParams{{
Name: "id",
Value: "999",
}})
type SearchOpts struct {
ID int `param:"id"`
Length int `query:"length"`
Page int `query:"page"`
Search string `query:"search"`
Language string `query:"language" header:"language"`
}
opts := SearchOpts{
Length: 100,
Page: 0,
Search: "default value",
Language: "en",
}
err := c.Bind(&opts)
assert.NoError(t, err)
assert.Equal(t, 50, opts.Length) // bind from query
assert.Equal(t, 10, opts.Page) // bind from query
assert.Equal(t, 999, opts.ID) // bind from path param
assert.Equal(t, "et", opts.Language) // bind from query
assert.Equal(t, "default value", opts.Search) // default value stays
// make sure another bind will not mess already set values unless there are new values
err = (&DefaultBinder{}).BindHeaders(c, &opts)
assert.NoError(t, err)
assert.Equal(t, 50, opts.Length) // does not have tag in struct although header exists
assert.Equal(t, 10, opts.Page)
assert.Equal(t, 999, opts.ID)
assert.Equal(t, "de", opts.Language) // header overwrites now this value
assert.Equal(t, "default value", opts.Search)
}
func TestBindUnmarshalParam(t *testing.T) { func TestBindUnmarshalParam(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
@ -330,7 +376,7 @@ func TestBindUnmarshalParam(t *testing.T) {
func TestBindUnmarshalText(t *testing.T) { func TestBindUnmarshalText(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
result := struct { result := struct {
@ -406,7 +452,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) {
func TestBindUnmarshalTextPtr(t *testing.T) { func TestBindUnmarshalTextPtr(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil) req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
result := struct { result := struct {
@ -439,8 +485,7 @@ func TestBindUnsupportedMediaType(t *testing.T) {
func TestBindbindData(t *testing.T) { func TestBindbindData(t *testing.T) {
a := assert.New(t) a := assert.New(t)
ts := new(bindTestStruct) ts := new(bindTestStruct)
b := new(DefaultBinder) err := bindData(ts, values, "form")
err := b.bindData(ts, values, "form")
a.NoError(err) a.NoError(err)
a.Equal(0, ts.I) a.Equal(0, ts.I)
@ -462,12 +507,15 @@ func TestBindbindData(t *testing.T) {
func TestBindParam(t *testing.T) { func TestBindParam(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(GET, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
c.SetPath("/users/:id/:name") cc := c.(EditableContext)
c.SetParamNames("id", "name") cc.SetRouteInfo(routeInfo{path: "/users/:id/:name"})
c.SetParamValues("1", "Jon Snow") cc.SetPathParams(PathParams{
{Name: "id", Value: "1"},
{Name: "name", Value: "Jon Snow"},
})
u := new(user) u := new(user)
err := c.Bind(u) err := c.Bind(u)
@ -478,9 +526,11 @@ func TestBindParam(t *testing.T) {
// Second test for the absence of a param // Second test for the absence of a param
c2 := e.NewContext(req, rec) c2 := e.NewContext(req, rec)
c2.SetPath("/users/:id") cc2 := c2.(EditableContext)
c2.SetParamNames("id") cc2.SetRouteInfo(routeInfo{path: "/users/:id"})
c2.SetParamValues("1") cc2.SetPathParams(PathParams{
{Name: "id", Value: "1"},
})
u = new(user) u = new(user)
err = c2.Bind(u) err = c2.Bind(u)
@ -492,15 +542,17 @@ func TestBindParam(t *testing.T) {
// Bind something with param and post data payload // Bind something with param and post data payload
body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) body := bytes.NewBufferString(`{ "name": "Jon Snow" }`)
e2 := New() e2 := New()
req2 := httptest.NewRequest(POST, "/", body) req2 := httptest.NewRequest(http.MethodPost, "/", body)
req2.Header.Set(HeaderContentType, MIMEApplicationJSON) req2.Header.Set(HeaderContentType, MIMEApplicationJSON)
rec2 := httptest.NewRecorder() rec2 := httptest.NewRecorder()
c3 := e2.NewContext(req2, rec2) c3 := e2.NewContext(req2, rec2)
c3.SetPath("/users/:id") cc3 := c3.(EditableContext)
c3.SetParamNames("id") cc3.SetRouteInfo(routeInfo{path: "/users/:id"})
c3.SetParamValues("1") cc3.SetPathParams(PathParams{
{Name: "id", Value: "1"},
})
u = new(user) u = new(user)
err = c3.Bind(u) err = c3.Bind(u)
@ -556,47 +608,115 @@ func TestBindSetWithProperType(t *testing.T) {
assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
} }
func TestBindSetFields(t *testing.T) { func TestSetIntField(t *testing.T) {
assert := assert.New(t) ts := new(bindTestStruct)
ts.I = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setIntField("", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 100, ts.I)
// second set with value sets the value
err = setIntField("5", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 5, ts.I)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setIntField("", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 5, ts.I)
}
func TestSetUintField(t *testing.T) {
ts := new(bindTestStruct)
ts.UI = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setUintField("", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(100), ts.UI)
// second set with value sets the value
err = setUintField("5", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(5), ts.UI)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setUintField("", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(5), ts.UI)
}
func TestSetFloatField(t *testing.T) {
ts := new(bindTestStruct)
ts.F32 = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setFloatField("", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(100), ts.F32)
// second set with value sets the value
err = setFloatField("15.5", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(15.5), ts.F32)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setFloatField("", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(15.5), ts.F32)
}
func TestSetBoolField(t *testing.T) {
ts := new(bindTestStruct)
ts.B = true
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setBoolField("", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// second set with value sets the value
err = setBoolField("true", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setBoolField("", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// fourth set to false
err = setBoolField("false", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, false, ts.B)
}
func TestUnmarshalFieldNonPtr(t *testing.T) {
ts := new(bindTestStruct) ts := new(bindTestStruct)
val := reflect.ValueOf(ts).Elem() val := reflect.ValueOf(ts).Elem()
// Int
if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) {
assert.Equal(5, ts.I)
}
if assert.NoError(setIntField("", 0, val.FieldByName("I"))) {
assert.Equal(0, ts.I)
}
// Uint
if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) {
assert.Equal(uint(10), ts.UI)
}
if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) {
assert.Equal(uint(0), ts.UI)
}
// Float
if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) {
assert.Equal(float32(15.5), ts.F32)
}
if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) {
assert.Equal(float32(0.0), ts.F32)
}
// Bool
if assert.NoError(setBoolField("true", val.FieldByName("B"))) {
assert.Equal(true, ts.B)
}
if assert.NoError(setBoolField("", val.FieldByName("B"))) {
assert.Equal(false, ts.B)
}
ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T")) ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T"))
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(ok, true) assert.True(t, ok)
assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
} }
} }
@ -604,11 +724,10 @@ func BenchmarkBindbindDataWithTags(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
assert := assert.New(b) assert := assert.New(b)
ts := new(bindTestStructWithTags) ts := new(bindTestStructWithTags)
binder := new(DefaultBinder)
var err error var err error
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
err = binder.bindData(ts, values, "form") err = bindData(ts, values, "form")
} }
assert.NoError(err) assert.NoError(err)
assertBindTestStruct(assert, (*bindTestStruct)(ts)) assertBindTestStruct(assert, (*bindTestStruct)(ts))
@ -840,8 +959,10 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
if !tc.whenNoPathParams { if !tc.whenNoPathParams {
c.SetParamNames("node") cc := c.(EditableContext)
c.SetParamValues("node_from_path") cc.SetPathParams(PathParams{
{Name: "node", Value: "node_from_path"},
})
} }
var bindTarget interface{} var bindTarget interface{}
@ -852,7 +973,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
} }
b := new(DefaultBinder) b := new(DefaultBinder)
err := b.Bind(bindTarget, c) err := b.Bind(c, bindTarget)
if tc.expectError != "" { if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError) assert.EqualError(t, err, tc.expectError)
} else { } else {
@ -1021,8 +1142,10 @@ func TestDefaultBinder_BindBody(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
if !tc.whenNoPathParams { if !tc.whenNoPathParams {
c.SetParamNames("node") cc := c.(EditableContext)
c.SetParamValues("real_node") cc.SetPathParams(PathParams{
{Name: "node", Value: "real_node"},
})
} }
var bindTarget interface{} var bindTarget interface{}
@ -1031,9 +1154,8 @@ func TestDefaultBinder_BindBody(t *testing.T) {
} else { } else {
bindTarget = &Node{} bindTarget = &Node{}
} }
b := new(DefaultBinder)
err := b.BindBody(c, bindTarget) err := BindBody(c, bindTarget)
if tc.expectError != "" { if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError) assert.EqualError(t, err, tc.expectError)
} else { } else {

View File

@ -118,10 +118,10 @@ func QueryParamsBinder(c Context) *ValueBinder {
func PathParamsBinder(c Context) *ValueBinder { func PathParamsBinder(c Context) *ValueBinder {
return &ValueBinder{ return &ValueBinder{
failFast: true, failFast: true,
ValueFunc: c.Param, ValueFunc: c.PathParam,
ValuesFunc: func(sourceParam string) []string { ValuesFunc: func(sourceParam string) []string {
// path parameter should not have multiple values so getting values does not make sense but lets not error out here // path parameter should not have multiple values so getting values does not make sense but lets not error out here
value := c.Param(sourceParam) value := c.PathParam(sourceParam)
if value == "" { if value == "" {
return nil return nil
} }

View File

@ -30,14 +30,15 @@ func createTestContext15(URL string, body io.Reader, pathParams map[string]strin
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
if len(pathParams) > 0 { if len(pathParams) > 0 {
names := make([]string, 0) params := make(PathParams, 0)
values := make([]string, 0)
for name, value := range pathParams { for name, value := range pathParams {
names = append(names, name) params = append(params, PathParam{
values = append(values, value) Name: name,
Value: value,
})
} }
c.SetParamNames(names...) cc := c.(EditableContext)
c.SetParamValues(values...) cc.SetPathParams(params)
} }
return c return c

View File

@ -25,14 +25,15 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string)
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
if len(pathParams) > 0 { if len(pathParams) > 0 {
names := make([]string, 0) params := make(PathParams, 0)
values := make([]string, 0)
for name, value := range pathParams { for name, value := range pathParams {
names = append(names, name) params = append(params, PathParam{
values = append(values, value) Name: name,
Value: value,
})
} }
c.SetParamNames(names...) cc := c.(EditableContext)
c.SetParamValues(values...) cc.SetPathParams(params)
} }
return c return c
@ -2643,7 +2644,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) {
binder := new(DefaultBinder) binder := new(DefaultBinder)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var dest Opts var dest Opts
_ = binder.Bind(&dest, c) _ = binder.Bind(c, &dest)
} }
} }
@ -2710,7 +2711,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
binder := new(DefaultBinder) binder := new(DefaultBinder)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var dest Opts var dest Opts
_ = binder.Bind(&dest, c) _ = binder.Bind(c, &dest)
if dest.Int64 != 1 { if dest.Int64 != 1 {
b.Fatalf("int64!=1") b.Fatalf("int64!=1")
} }

View File

@ -3,212 +3,233 @@ package echo
import ( import (
"bytes" "bytes"
"encoding/xml" "encoding/xml"
"errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"mime/multipart" "mime/multipart"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
) )
type ( // Context represents the context of the current HTTP request. It holds request and
// Context represents the context of the current HTTP request. It holds request and // response objects, path, path parameters, data and registered handler.
// response objects, path, path parameters, data and registered handler. type Context interface {
Context interface { // Request returns `*http.Request`.
// Request returns `*http.Request`. Request() *http.Request
Request() *http.Request
// SetRequest sets `*http.Request`. // SetRequest sets `*http.Request`.
SetRequest(r *http.Request) SetRequest(r *http.Request)
// SetResponse sets `*Response`. // SetResponse sets `*Response`.
SetResponse(r *Response) SetResponse(r *Response)
// Response returns `*Response`. // Response returns `*Response`.
Response() *Response Response() *Response
// IsTLS returns true if HTTP connection is TLS otherwise false. // IsTLS returns true if HTTP connection is TLS otherwise false.
IsTLS() bool IsTLS() bool
// IsWebSocket returns true if HTTP connection is WebSocket otherwise false. // IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
IsWebSocket() bool IsWebSocket() bool
// Scheme returns the HTTP protocol scheme, `http` or `https`. // Scheme returns the HTTP protocol scheme, `http` or `https`.
Scheme() string Scheme() string
// RealIP returns the client's network address based on `X-Forwarded-For` // RealIP returns the client's network address based on `X-Forwarded-For`
// or `X-Real-IP` request header. // or `X-Real-IP` request header.
// The behavior can be configured using `Echo#IPExtractor`. // The behavior can be configured using `Echo#IPExtractor`.
RealIP() string RealIP() string
// Path returns the registered path for the handler. // RouteMatchType returns router match type for current context. This helps middlewares to distinguish which type
Path() string // of match router found and how this request context handler chain could end:
// * route match - this path + method had matching route.
// * not found - this path did not match any routes enough to be considered match
// * method not allowed - path had routes registered but for other method types then current request is
// * unknown - initial state for fresh context before router tries to do routing
//
// Note: for pre-middleware (Echo.Pre) this method result is always RouteMatchUnknown as at point router has not tried
// to match request to route.
RouteMatchType() RouteMatchType
// SetPath sets the registered path for the handler. // RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route.
SetPath(p string) // In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases.
RouteInfo() RouteInfo
// Param returns path parameter by name. // Path returns the registered path for the handler.
Param(name string) string Path() string
// ParamNames returns path parameter names. // PathParam returns path parameter by name.
ParamNames() []string PathParam(name string) string
// SetParamNames sets path parameter names. // PathParams returns path parameter values.
SetParamNames(names ...string) PathParams() PathParams
// ParamValues returns path parameter values. // SetPathParams set path parameter for during current request lifecycle.
ParamValues() []string SetPathParams(params PathParams)
// SetParamValues sets path parameter values. // QueryParam returns the query param for the provided name.
SetParamValues(values ...string) QueryParam(name string) string
// QueryParam returns the query param for the provided name. // QueryParams returns the query parameters as `url.Values`.
QueryParam(name string) string QueryParams() url.Values
// QueryParams returns the query parameters as `url.Values`. // QueryString returns the URL query string.
QueryParams() url.Values QueryString() string
// QueryString returns the URL query string. // FormValue returns the form field value for the provided name.
QueryString() string FormValue(name string) string
// FormValue returns the form field value for the provided name. // FormParams returns the form parameters as `url.Values`.
FormValue(name string) string FormParams() (url.Values, error)
// FormParams returns the form parameters as `url.Values`. // FormFile returns the multipart form file for the provided name.
FormParams() (url.Values, error) FormFile(name string) (*multipart.FileHeader, error)
// FormFile returns the multipart form file for the provided name. // MultipartForm returns the multipart form.
FormFile(name string) (*multipart.FileHeader, error) MultipartForm() (*multipart.Form, error)
// MultipartForm returns the multipart form. // Cookie returns the named cookie provided in the request.
MultipartForm() (*multipart.Form, error) Cookie(name string) (*http.Cookie, error)
// Cookie returns the named cookie provided in the request. // SetCookie adds a `Set-Cookie` header in HTTP response.
Cookie(name string) (*http.Cookie, error) SetCookie(cookie *http.Cookie)
// SetCookie adds a `Set-Cookie` header in HTTP response. // Cookies returns the HTTP cookies sent with the request.
SetCookie(cookie *http.Cookie) Cookies() []*http.Cookie
// Cookies returns the HTTP cookies sent with the request. // Get retrieves data from the context.
Cookies() []*http.Cookie Get(key string) interface{}
// Get retrieves data from the context. // Set saves data in the context.
Get(key string) interface{} Set(key string, val interface{})
// Set saves data in the context. // Bind binds the request body into provided type `i`. The default binder
Set(key string, val interface{}) // does it based on Content-Type header.
Bind(i interface{}) error
// Bind binds the request body into provided type `i`. The default binder // Validate validates provided `i`. It is usually called after `Context#Bind()`.
// does it based on Content-Type header. // Validator must be registered using `Echo#Validator`.
Bind(i interface{}) error Validate(i interface{}) error
// Validate validates provided `i`. It is usually called after `Context#Bind()`. // Render renders a template with data and sends a text/html response with status
// Validator must be registered using `Echo#Validator`. // code. Renderer must be registered using `Echo.Renderer`.
Validate(i interface{}) error Render(code int, name string, data interface{}) error
// Render renders a template with data and sends a text/html response with status // HTML sends an HTTP response with status code.
// code. Renderer must be registered using `Echo.Renderer`. HTML(code int, html string) error
Render(code int, name string, data interface{}) error
// HTML sends an HTTP response with status code. // HTMLBlob sends an HTTP blob response with status code.
HTML(code int, html string) error HTMLBlob(code int, b []byte) error
// HTMLBlob sends an HTTP blob response with status code. // String sends a string response with status code.
HTMLBlob(code int, b []byte) error String(code int, s string) error
// String sends a string response with status code. // JSON sends a JSON response with status code.
String(code int, s string) error JSON(code int, i interface{}) error
// JSON sends a JSON response with status code. // JSONPretty sends a pretty-print JSON with status code.
JSON(code int, i interface{}) error JSONPretty(code int, i interface{}, indent string) error
// JSONPretty sends a pretty-print JSON with status code. // JSONBlob sends a JSON blob response with status code.
JSONPretty(code int, i interface{}, indent string) error JSONBlob(code int, b []byte) error
// JSONBlob sends a JSON blob response with status code. // JSONP sends a JSONP response with status code. It uses `callback` to construct
JSONBlob(code int, b []byte) error // the JSONP payload.
JSONP(code int, callback string, i interface{}) error
// JSONP sends a JSONP response with status code. It uses `callback` to construct // JSONPBlob sends a JSONP blob response with status code. It uses `callback`
// the JSONP payload. // to construct the JSONP payload.
JSONP(code int, callback string, i interface{}) error JSONPBlob(code int, callback string, b []byte) error
// JSONPBlob sends a JSONP blob response with status code. It uses `callback` // XML sends an XML response with status code.
// to construct the JSONP payload. XML(code int, i interface{}) error
JSONPBlob(code int, callback string, b []byte) error
// XML sends an XML response with status code. // XMLPretty sends a pretty-print XML with status code.
XML(code int, i interface{}) error XMLPretty(code int, i interface{}, indent string) error
// XMLPretty sends a pretty-print XML with status code. // XMLBlob sends an XML blob response with status code.
XMLPretty(code int, i interface{}, indent string) error XMLBlob(code int, b []byte) error
// XMLBlob sends an XML blob response with status code. // Blob sends a blob response with status code and content type.
XMLBlob(code int, b []byte) error Blob(code int, contentType string, b []byte) error
// Blob sends a blob response with status code and content type. // Stream sends a streaming response with status code and content type.
Blob(code int, contentType string, b []byte) error Stream(code int, contentType string, r io.Reader) error
// Stream sends a streaming response with status code and content type. // File sends a response with the content of the file.
Stream(code int, contentType string, r io.Reader) error File(file string) error
// File sends a response with the content of the file. // Attachment sends a response as attachment, prompting client to save the
File(file string) error // file.
Attachment(file string, name string) error
// Attachment sends a response as attachment, prompting client to save the // Inline sends a response as inline, opening the file in the browser.
// file. Inline(file string, name string) error
Attachment(file string, name string) error
// Inline sends a response as inline, opening the file in the browser. // NoContent sends a response with no body and a status code.
Inline(file string, name string) error NoContent(code int) error
// NoContent sends a response with no body and a status code. // Redirect redirects the request to a provided URL with status code.
NoContent(code int) error Redirect(code int, url string) error
// Redirect redirects the request to a provided URL with status code. // Error invokes the registered HTTP error handler.
Redirect(code int, url string) error // NB: Avoid using this method. It is better to return errors so middlewares up in chain could act on returned error.
Error(err error)
// Error invokes the registered HTTP error handler. Generally used by middleware. // Echo returns the `Echo` instance.
Error(err error) Echo() *Echo
}
// Handler returns the matched handler by router. // EditableContext is additional interface that structure implementing Context must implement. Methods inside this
Handler() HandlerFunc // interface are meant for Echo internal usage (for mainly routing) and should not be used in middlewares.
type EditableContext interface {
Context
// SetHandler sets the matched handler by router. // RawPathParams returns raw path pathParams value.
SetHandler(h HandlerFunc) RawPathParams() *PathParams
// Logger returns the `Logger` instance. // SetRawPathParams replaces any existing param values with new values for this context lifetime (request).
Logger() Logger SetRawPathParams(params *PathParams)
// Set the logger // SetPath sets the registered path for the handler.
SetLogger(l Logger) SetPath(p string)
// Echo returns the `Echo` instance. // SetRouteMatchType sets the RouteMatchType of router match for this request.
Echo() *Echo SetRouteMatchType(t RouteMatchType)
// Reset resets the context after request completes. It must be called along // SetRouteInfo sets the route info of this request to the context.
// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. SetRouteInfo(ri RouteInfo)
// See `Echo#ServeHTTP()`
Reset(r *http.Request, w http.ResponseWriter)
}
context struct { // Reset resets the context after request completes. It must be called along
request *http.Request // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
response *Response // See `Echo#ServeHTTP()`
path string Reset(r *http.Request, w http.ResponseWriter)
pnames []string }
pvalues []string
query url.Values type context struct {
handler HandlerFunc request *http.Request
store Map response *Response
echo *Echo
logger Logger matchType RouteMatchType
lock sync.RWMutex route RouteInfo
} path string
)
// pathParams holds path/uri parameters determined by Router. Lifecycle is handled by Echo to reduce allocations.
pathParams *PathParams
// currentParams hold path parameters set by non-Echo implementation (custom middlewares, handlers) during the lifetime of Request.
// Lifecycle is not handle by Echo and could have excess allocations per served Request
currentParams PathParams
query url.Values
store Map
echo *Echo
lock sync.RWMutex
}
const ( const (
defaultMemory = 32 << 20 // 32 MB defaultMemory = 32 << 20 // 32 MB
@ -296,52 +317,50 @@ func (c *context) SetPath(p string) {
c.path = p c.path = p
} }
func (c *context) Param(name string) string { func (c *context) RouteMatchType() RouteMatchType {
for i, n := range c.pnames { return c.matchType
if i < len(c.pvalues) {
if n == name {
return c.pvalues[i]
}
}
}
return ""
} }
func (c *context) ParamNames() []string { func (c *context) SetRouteMatchType(t RouteMatchType) {
return c.pnames c.matchType = t
} }
func (c *context) SetParamNames(names ...string) { func (c *context) RouteInfo() RouteInfo {
c.pnames = names return c.route
l := len(names)
if *c.echo.maxParam < l {
*c.echo.maxParam = l
}
if len(c.pvalues) < l {
// Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them,
// probably those values will be overriden in a Context#SetParamValues
newPvalues := make([]string, l)
copy(newPvalues, c.pvalues)
c.pvalues = newPvalues
}
} }
func (c *context) ParamValues() []string { func (c *context) SetRouteInfo(ri RouteInfo) {
return c.pvalues[:len(c.pnames)] c.route = ri
} }
func (c *context) SetParamValues(values ...string) { func (c *context) RawPathParams() *PathParams {
// NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times return c.pathParams
// It will brake the Router#Find code }
limit := len(values)
if limit > *c.echo.maxParam { func (c *context) SetRawPathParams(params *PathParams) {
limit = *c.echo.maxParam c.pathParams = params
}
func (c *context) PathParam(name string) string {
if c.currentParams != nil {
return c.currentParams.Get(name, "")
} }
for i := 0; i < limit; i++ {
c.pvalues[i] = values[i] return c.pathParams.Get(name, "")
}
func (c *context) PathParams() PathParams {
if c.currentParams != nil {
return c.currentParams
} }
result := make(PathParams, len(*c.pathParams))
copy(result, *c.pathParams)
return result
}
func (c *context) SetPathParams(params PathParams) {
c.currentParams = params
} }
func (c *context) QueryParam(name string) string { func (c *context) QueryParam(name string) string {
@ -422,7 +441,7 @@ func (c *context) Set(key string, val interface{}) {
} }
func (c *context) Bind(i interface{}) error { func (c *context) Bind(i interface{}) error {
return c.echo.Binder.Bind(i, c) return c.echo.Binder.Bind(c, i)
} }
func (c *context) Validate(i interface{}) error { func (c *context) Validate(i interface{}) error {
@ -562,27 +581,36 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error)
return return
} }
func (c *context) File(file string) (err error) { func (c *context) File(file string) error {
f, err := os.Open(file) return c.FsFile(file, c.echo.Filesystem)
}
func (c *context) FsFile(file string, filesystem fs.FS) error {
// FIXME: should we add this method into echo.Context interface?
f, err := filesystem.Open(file)
if err != nil { if err != nil {
return NotFoundHandler(c) return ErrNotFound
} }
defer f.Close() defer f.Close()
fi, _ := f.Stat() fi, _ := f.Stat()
if fi.IsDir() { if fi.IsDir() {
file = filepath.Join(file, indexPage) file = filepath.Join(file, indexPage)
f, err = os.Open(file) f, err = filesystem.Open(file)
if err != nil { if err != nil {
return NotFoundHandler(c) return ErrNotFound
} }
defer f.Close() defer f.Close()
if fi, err = f.Stat(); err != nil { if fi, err = f.Stat(); err != nil {
return return err
} }
} }
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) ff, ok := f.(io.ReadSeeker)
return if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
return nil
} }
func (c *context) Attachment(file, name string) error { func (c *context) Attachment(file, name string) error {
@ -613,44 +641,23 @@ func (c *context) Redirect(code int, url string) error {
} }
func (c *context) Error(err error) { func (c *context) Error(err error) {
c.echo.HTTPErrorHandler(err, c) c.echo.HTTPErrorHandler(c, err)
} }
func (c *context) Echo() *Echo { func (c *context) Echo() *Echo {
return c.echo return c.echo
} }
func (c *context) Handler() HandlerFunc {
return c.handler
}
func (c *context) SetHandler(h HandlerFunc) {
c.handler = h
}
func (c *context) Logger() Logger {
res := c.logger
if res != nil {
return res
}
return c.echo.Logger
}
func (c *context) SetLogger(l Logger) {
c.logger = l
}
func (c *context) Reset(r *http.Request, w http.ResponseWriter) { func (c *context) Reset(r *http.Request, w http.ResponseWriter) {
c.request = r c.request = r
c.response.reset(w) c.response.reset(w)
c.query = nil c.query = nil
c.handler = NotFoundHandler
c.store = nil c.store = nil
c.matchType = RouteMatchUnknown
c.route = nil
c.path = "" c.path = ""
c.pnames = nil // NOTE: Don't reset because it has to have length c.echo.contextPathParamAllocSize at all times
c.logger = nil *c.pathParams = (*c.pathParams)[:0]
// NOTE: Don't reset because it has to have length c.echo.maxParam at all times c.currentParams = nil
for i := 0; i < *c.echo.maxParam; i++ {
c.pvalues[i] = ""
}
} }

View File

@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math" "math"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
@ -18,21 +19,19 @@ import (
"text/template" "text/template"
"time" "time"
"github.com/labstack/gommon/log" "github.com/stretchr/testify/assert"
testify "github.com/stretchr/testify/assert"
) )
type ( type Template struct {
Template struct { templates *template.Template
templates *template.Template }
}
)
var testUser = user{1, "Jon Snow"} var testUser = user{1, "Jon Snow"}
func BenchmarkAllocJSONP(b *testing.B) { func BenchmarkAllocJSONP(b *testing.B) {
e := New() e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) e.Logger = &jsonLogger{writer: ioutil.Discard}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
@ -46,7 +45,8 @@ func BenchmarkAllocJSONP(b *testing.B) {
func BenchmarkAllocJSON(b *testing.B) { func BenchmarkAllocJSON(b *testing.B) {
e := New() e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) e.Logger = &jsonLogger{writer: ioutil.Discard}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
@ -60,7 +60,8 @@ func BenchmarkAllocJSON(b *testing.B) {
func BenchmarkAllocXML(b *testing.B) { func BenchmarkAllocXML(b *testing.B) {
e := New() e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) e.Logger = &jsonLogger{writer: ioutil.Discard}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
@ -106,16 +107,14 @@ func TestContext(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
assert := testify.New(t)
// Echo // Echo
assert.Equal(e, c.Echo()) assert.Equal(t, e, c.Echo())
// Request // Request
assert.NotNil(c.Request()) assert.NotNil(t, c.Request())
// Response // Response
assert.NotNil(c.Response()) assert.NotNil(t, c.Response())
//-------- //--------
// Render // Render
@ -126,23 +125,23 @@ func TestContext(t *testing.T) {
} }
c.echo.Renderer = tmpl c.echo.Renderer = tmpl
err := c.Render(http.StatusOK, "hello", "Jon Snow") err := c.Render(http.StatusOK, "hello", "Jon Snow")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal("Hello, Jon Snow!", rec.Body.String()) assert.Equal(t, "Hello, Jon Snow!", rec.Body.String())
} }
c.echo.Renderer = nil c.echo.Renderer = nil
err = c.Render(http.StatusOK, "hello", "Jon Snow") err = c.Render(http.StatusOK, "hello", "Jon Snow")
assert.Error(err) assert.Error(t, err)
// JSON // JSON
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSON+"\n", rec.Body.String()) assert.Equal(t, userJSON+"\n", rec.Body.String())
} }
// JSON with "?pretty" // JSON with "?pretty"
@ -150,10 +149,10 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSONPretty+"\n", rec.Body.String()) assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
} }
req = httptest.NewRequest(http.MethodGet, "/", nil) // reset req = httptest.NewRequest(http.MethodGet, "/", nil) // reset
@ -161,37 +160,37 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSONPretty+"\n", rec.Body.String()) assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
} }
// JSON (error) // JSON (error)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, make(chan bool)) err = c.JSON(http.StatusOK, make(chan bool))
assert.Error(err) assert.Error(t, err)
// JSONP // JSONP
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
callback := "callback" callback := "callback"
err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"})
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String()) assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String())
} }
// XML // XML
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, user{1, "Jon Snow"}) err = c.XML(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXML, rec.Body.String()) assert.Equal(t, xml.Header+userXML, rec.Body.String())
} }
// XML with "?pretty" // XML with "?pretty"
@ -199,10 +198,10 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, user{1, "Jon Snow"}) err = c.XML(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
} }
req = httptest.NewRequest(http.MethodGet, "/", nil) req = httptest.NewRequest(http.MethodGet, "/", nil)
@ -210,22 +209,22 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, make(chan bool)) err = c.XML(http.StatusOK, make(chan bool))
assert.Error(err) assert.Error(t, err)
// XML response write error // XML response write error
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
c.response.Writer = responseWriterErr{} c.response.Writer = responseWriterErr{}
err = c.XML(0, 0) err = c.XML(0, 0)
testify.Error(t, err) assert.Error(t, err)
// XMLPretty // XMLPretty
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
} }
t.Run("empty indent", func(t *testing.T) { t.Run("empty indent", func(t *testing.T) {
@ -237,7 +236,6 @@ func TestContext(t *testing.T) {
t.Run("json", func(t *testing.T) { t.Run("json", func(t *testing.T) {
buf.Reset() buf.Reset()
assert := testify.New(t)
// New JSONBlob with empty indent // New JSONBlob with empty indent
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
@ -246,16 +244,15 @@ func TestContext(t *testing.T) {
enc.SetIndent(emptyIndent, emptyIndent) enc.SetIndent(emptyIndent, emptyIndent)
err = enc.Encode(u) err = enc.Encode(u)
err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(buf.String(), rec.Body.String()) assert.Equal(t, buf.String(), rec.Body.String())
} }
}) })
t.Run("xml", func(t *testing.T) { t.Run("xml", func(t *testing.T) {
buf.Reset() buf.Reset()
assert := testify.New(t)
// New XMLBlob with empty indent // New XMLBlob with empty indent
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
@ -264,10 +261,10 @@ func TestContext(t *testing.T) {
enc.Indent(emptyIndent, emptyIndent) enc.Indent(emptyIndent, emptyIndent)
err = enc.Encode(u) err = enc.Encode(u)
err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+buf.String(), rec.Body.String()) assert.Equal(t, xml.Header+buf.String(), rec.Body.String())
} }
}) })
}) })
@ -276,12 +273,12 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
data, err := json.Marshal(user{1, "Jon Snow"}) data, err := json.Marshal(user{1, "Jon Snow"})
assert.NoError(err) assert.NoError(t, err)
err = c.JSONBlob(http.StatusOK, data) err = c.JSONBlob(http.StatusOK, data)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSON, rec.Body.String()) assert.Equal(t, userJSON, rec.Body.String())
} }
// Legacy JSONPBlob // Legacy JSONPBlob
@ -289,44 +286,44 @@ func TestContext(t *testing.T) {
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
callback = "callback" callback = "callback"
data, err = json.Marshal(user{1, "Jon Snow"}) data, err = json.Marshal(user{1, "Jon Snow"})
assert.NoError(err) assert.NoError(t, err)
err = c.JSONPBlob(http.StatusOK, callback, data) err = c.JSONPBlob(http.StatusOK, callback, data)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(callback+"("+userJSON+");", rec.Body.String()) assert.Equal(t, callback+"("+userJSON+");", rec.Body.String())
} }
// Legacy XMLBlob // Legacy XMLBlob
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
data, err = xml.Marshal(user{1, "Jon Snow"}) data, err = xml.Marshal(user{1, "Jon Snow"})
assert.NoError(err) assert.NoError(t, err)
err = c.XMLBlob(http.StatusOK, data) err = c.XMLBlob(http.StatusOK, data)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXML, rec.Body.String()) assert.Equal(t, xml.Header+userXML, rec.Body.String())
} }
// String // String
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.String(http.StatusOK, "Hello, World!") err = c.String(http.StatusOK, "Hello, World!")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal("Hello, World!", rec.Body.String()) assert.Equal(t, "Hello, World!", rec.Body.String())
} }
// HTML // HTML
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.HTML(http.StatusOK, "Hello, <strong>World!</strong>") err = c.HTML(http.StatusOK, "Hello, <strong>World!</strong>")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal("Hello, <strong>World!</strong>", rec.Body.String()) assert.Equal(t, "Hello, <strong>World!</strong>", rec.Body.String())
} }
// Stream // Stream
@ -334,55 +331,55 @@ func TestContext(t *testing.T) {
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
r := strings.NewReader("response from a stream") r := strings.NewReader("response from a stream")
err = c.Stream(http.StatusOK, "application/octet-stream", r) err = c.Stream(http.StatusOK, "application/octet-stream", r)
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType)) assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType))
assert.Equal("response from a stream", rec.Body.String()) assert.Equal(t, "response from a stream", rec.Body.String())
} }
// Attachment // Attachment
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.Attachment("_fixture/images/walle.png", "walle.png") err = c.Attachment("_fixture/images/walle.png", "walle.png")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(219885, rec.Body.Len()) assert.Equal(t, 219885, rec.Body.Len())
} }
// Inline // Inline
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
err = c.Inline("_fixture/images/walle.png", "walle.png") err = c.Inline("_fixture/images/walle.png", "walle.png")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(219885, rec.Body.Len()) assert.Equal(t, 219885, rec.Body.Len())
} }
// NoContent // NoContent
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
c.NoContent(http.StatusOK) c.NoContent(http.StatusOK)
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
// Error // Error
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context) c = e.NewContext(req, rec).(*context)
c.Error(errors.New("error")) c.Error(errors.New("error"))
assert.Equal(http.StatusInternalServerError, rec.Code) assert.Equal(t, http.StatusInternalServerError, rec.Code)
// Reset // Reset
c.SetParamNames("foo") c.pathParams = &PathParams{
c.SetParamValues("bar") {Name: "foo", Value: "bar"},
}
c.Set("foe", "ban") c.Set("foe", "ban")
c.query = url.Values(map[string][]string{"fon": {"baz"}}) c.query = url.Values(map[string][]string{"fon": {"baz"}})
c.Reset(req, httptest.NewRecorder()) c.Reset(req, httptest.NewRecorder())
assert.Equal(0, len(c.ParamValues())) assert.Equal(t, 0, len(c.PathParams()))
assert.Equal(0, len(c.ParamNames())) assert.Equal(t, 0, len(c.store))
assert.Equal(0, len(c.store)) assert.Equal(t, nil, c.RouteInfo())
assert.Equal("", c.Path()) assert.Equal(t, 0, len(c.QueryParams()))
assert.Equal(0, len(c.QueryParams()))
} }
func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
@ -392,11 +389,10 @@ func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
err := c.JSON(http.StatusCreated, user{1, "Jon Snow"}) err := c.JSON(http.StatusCreated, user{1, "Jon Snow"})
assert := testify.New(t) if assert.NoError(t, err) {
if assert.NoError(err) { assert.Equal(t, http.StatusCreated, rec.Code)
assert.Equal(http.StatusCreated, rec.Code) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSON+"\n", rec.Body.String())
assert.Equal(userJSON+"\n", rec.Body.String())
} }
} }
@ -407,9 +403,8 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) {
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()}) err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()})
assert := testify.New(t) if assert.Error(t, err) {
if assert.Error(err) { assert.False(t, c.response.Committed)
assert.False(c.response.Committed)
} }
} }
@ -423,22 +418,20 @@ func TestContextCookie(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context) c := e.NewContext(req, rec).(*context)
assert := testify.New(t)
// Read single // Read single
cookie, err := c.Cookie("theme") cookie, err := c.Cookie("theme")
if assert.NoError(err) { if assert.NoError(t, err) {
assert.Equal("theme", cookie.Name) assert.Equal(t, "theme", cookie.Name)
assert.Equal("light", cookie.Value) assert.Equal(t, "light", cookie.Value)
} }
// Read multiple // Read multiple
for _, cookie := range c.Cookies() { for _, cookie := range c.Cookies() {
switch cookie.Name { switch cookie.Name {
case "theme": case "theme":
assert.Equal("light", cookie.Value) assert.Equal(t, "light", cookie.Value)
case "user": case "user":
assert.Equal("Jon Snow", cookie.Value) assert.Equal(t, "Jon Snow", cookie.Value)
} }
} }
@ -453,104 +446,95 @@ func TestContextCookie(t *testing.T) {
HttpOnly: true, HttpOnly: true,
} }
c.SetCookie(cookie) c.SetCookie(cookie)
assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID")
assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com")
assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure")
assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
}
func TestContextPath(t *testing.T) {
e := New()
r := e.Router()
handler := func(c Context) error { return c.String(http.StatusOK, "OK") }
r.Add(http.MethodGet, "/users/:id", handler)
c := e.NewContext(nil, nil)
r.Find(http.MethodGet, "/users/1", c)
assert := testify.New(t)
assert.Equal("/users/:id", c.Path())
r.Add(http.MethodGet, "/users/:uid/files/:fid", handler)
c = e.NewContext(nil, nil)
r.Find(http.MethodGet, "/users/1/files/1", c)
assert.Equal("/users/:uid/files/:fid", c.Path())
} }
func TestContextPathParam(t *testing.T) { func TestContextPathParam(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, nil) c := e.NewContext(req, nil).(*context)
params := &PathParams{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
}
// ParamNames // ParamNames
c.SetParamNames("uid", "fid") c.pathParams = params
testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) assert.EqualValues(t, *params, c.PathParams())
// ParamValues
c.SetParamValues("101", "501")
testify.EqualValues(t, []string{"101", "501"}, c.ParamValues())
// Param // Param
testify.Equal(t, "501", c.Param("fid")) assert.Equal(t, "501", c.PathParam("fid"))
testify.Equal(t, "", c.Param("undefined")) assert.Equal(t, "", c.PathParam("undefined"))
} }
func TestContextGetAndSetParam(t *testing.T) { func TestContextGetAndSetParam(t *testing.T) {
e := New() e := New()
r := e.Router() r := e.Router()
r.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) _, err := r.Add(Route{
Method: http.MethodGet,
Path: "/:foo",
Name: "",
Handler: func(Context) error { return nil },
Middlewares: nil,
})
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/:foo", nil) req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
c.SetParamNames("foo")
params := &PathParams{{Name: "foo", Value: "101"}}
// ParamNames
c.(*context).pathParams = params
// round-trip param values with modification // round-trip param values with modification
paramVals := c.ParamValues() paramVals := c.PathParams()
testify.EqualValues(t, []string{""}, c.ParamValues()) assert.Equal(t, *params, c.PathParams())
paramVals[0] = "bar"
c.SetParamValues(paramVals...) paramVals[0] = PathParam{Name: "xxx", Value: "yyy"} // PathParams() returns copy and modifying it does nothing to context
testify.EqualValues(t, []string{"bar"}, c.ParamValues()) assert.Equal(t, PathParams{{Name: "foo", Value: "101"}}, c.PathParams())
pathParams := PathParams{
{Name: "aaa", Value: "bbb"},
{Name: "ccc", Value: "ddd"},
}
c.SetPathParams(pathParams)
assert.Equal(t, pathParams, c.PathParams())
// shouldn't explode during Reset() afterwards! // shouldn't explode during Reset() afterwards!
testify.NotPanics(t, func() { assert.NotPanics(t, func() {
c.Reset(nil, nil) c.(EditableContext).Reset(nil, nil)
}) })
assert.Equal(t, PathParams{}, c.PathParams())
assert.Len(t, *c.(*context).pathParams, 0)
assert.Equal(t, cap(*c.(*context).pathParams), 1)
} }
// Issue #1655 // Issue #1655
func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) { func TestContext_SetParamNamesShouldNotModifyPathParams(t *testing.T) {
assert := testify.New(t)
e := New() e := New()
assert.Equal(0, *e.maxParam) c := e.NewContext(nil, nil).(*context)
expectedOneParam := []string{"one"} assert.Equal(t, 0, e.contextPathParamAllocSize)
expectedTwoParams := []string{"one", "two"} expectedTwoParams := &PathParams{
expectedThreeParams := []string{"one", "two", ""} {Name: "1", Value: "one"},
expectedABCParams := []string{"A", "B", "C"} {Name: "2", Value: "two"},
}
c.SetRawPathParams(expectedTwoParams)
assert.Equal(t, 0, e.contextPathParamAllocSize)
assert.Equal(t, *expectedTwoParams, c.PathParams())
c := e.NewContext(nil, nil) expectedThreeParams := PathParams{
c.SetParamNames("1", "2") {Name: "1", Value: "one"},
c.SetParamValues(expectedTwoParams...) {Name: "2", Value: "two"},
assert.Equal(2, *e.maxParam) {Name: "3", Value: "three"},
assert.EqualValues(expectedTwoParams, c.ParamValues()) }
c.SetPathParams(expectedThreeParams)
c.SetParamNames("1") assert.Equal(t, 0, e.contextPathParamAllocSize)
assert.Equal(2, *e.maxParam) assert.Equal(t, expectedThreeParams, c.PathParams())
// Here for backward compatibility the ParamValues remains as they are
assert.EqualValues(expectedOneParam, c.ParamValues())
c.SetParamNames("1", "2", "3")
assert.Equal(3, *e.maxParam)
// Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam
assert.EqualValues(expectedThreeParams, c.ParamValues())
c.SetParamValues("A", "B", "C", "D")
assert.Equal(3, *e.maxParam)
// Here D shouldn't be returned
assert.EqualValues(expectedABCParams, c.ParamValues())
} }
func TestContextFormValue(t *testing.T) { func TestContextFormValue(t *testing.T) {
@ -564,13 +548,13 @@ func TestContextFormValue(t *testing.T) {
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
// FormValue // FormValue
testify.Equal(t, "Jon Snow", c.FormValue("name")) assert.Equal(t, "Jon Snow", c.FormValue("name"))
testify.Equal(t, "jon@labstack.com", c.FormValue("email")) assert.Equal(t, "jon@labstack.com", c.FormValue("email"))
// FormParams // FormParams
params, err := c.FormParams() params, err := c.FormParams()
if testify.NoError(t, err) { if assert.NoError(t, err) {
testify.Equal(t, url.Values{ assert.Equal(t, url.Values{
"name": []string{"Jon Snow"}, "name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"}, "email": []string{"jon@labstack.com"},
}, params) }, params)
@ -581,8 +565,8 @@ func TestContextFormValue(t *testing.T) {
req.Header.Add(HeaderContentType, MIMEMultipartForm) req.Header.Add(HeaderContentType, MIMEMultipartForm)
c = e.NewContext(req, nil) c = e.NewContext(req, nil)
params, err = c.FormParams() params, err = c.FormParams()
testify.Nil(t, params) assert.Nil(t, params)
testify.Error(t, err) assert.Error(t, err)
} }
func TestContextQueryParam(t *testing.T) { func TestContextQueryParam(t *testing.T) {
@ -594,11 +578,11 @@ func TestContextQueryParam(t *testing.T) {
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
// QueryParam // QueryParam
testify.Equal(t, "Jon Snow", c.QueryParam("name")) assert.Equal(t, "Jon Snow", c.QueryParam("name"))
testify.Equal(t, "jon@labstack.com", c.QueryParam("email")) assert.Equal(t, "jon@labstack.com", c.QueryParam("email"))
// QueryParams // QueryParams
testify.Equal(t, url.Values{ assert.Equal(t, url.Values{
"name": []string{"Jon Snow"}, "name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"}, "email": []string{"jon@labstack.com"},
}, c.QueryParams()) }, c.QueryParams())
@ -609,7 +593,7 @@ func TestContextFormFile(t *testing.T) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
mr := multipart.NewWriter(buf) mr := multipart.NewWriter(buf)
w, err := mr.CreateFormFile("file", "test") w, err := mr.CreateFormFile("file", "test")
if testify.NoError(t, err) { if assert.NoError(t, err) {
w.Write([]byte("test")) w.Write([]byte("test"))
} }
mr.Close() mr.Close()
@ -618,8 +602,8 @@ func TestContextFormFile(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
f, err := c.FormFile("file") f, err := c.FormFile("file")
if testify.NoError(t, err) { if assert.NoError(t, err) {
testify.Equal(t, "test", f.Filename) assert.Equal(t, "test", f.Filename)
} }
} }
@ -634,8 +618,8 @@ func TestContextMultipartForm(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
f, err := c.MultipartForm() f, err := c.MultipartForm()
if testify.NoError(t, err) { if assert.NoError(t, err) {
testify.NotNil(t, f) assert.NotNil(t, f)
} }
} }
@ -644,16 +628,16 @@ func TestContextRedirect(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
testify.Equal(t, http.StatusMovedPermanently, rec.Code) assert.Equal(t, http.StatusMovedPermanently, rec.Code)
testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
} }
func TestContextStore(t *testing.T) { func TestContextStore(t *testing.T) {
var c Context = new(context) var c Context = new(context)
c.Set("name", "Jon Snow") c.Set("name", "Jon Snow")
testify.Equal(t, "Jon Snow", c.Get("name")) assert.Equal(t, "Jon Snow", c.Get("name"))
} }
func BenchmarkContext_Store(b *testing.B) { func BenchmarkContext_Store(b *testing.B) {
@ -671,42 +655,6 @@ func BenchmarkContext_Store(b *testing.B) {
} }
} }
func TestContextHandler(t *testing.T) {
e := New()
r := e.Router()
b := new(bytes.Buffer)
r.Add(http.MethodGet, "/handler", func(Context) error {
_, err := b.Write([]byte("handler"))
return err
})
c := e.NewContext(nil, nil)
r.Find(http.MethodGet, "/handler", c)
err := c.Handler()(c)
testify.Equal(t, "handler", b.String())
testify.NoError(t, err)
}
func TestContext_SetHandler(t *testing.T) {
var c Context = new(context)
testify.Nil(t, c.Handler())
c.SetHandler(func(c Context) error {
return nil
})
testify.NotNil(t, c.Handler())
}
func TestContext_Path(t *testing.T) {
path := "/pa/th"
var c Context = new(context)
c.SetPath(path)
testify.Equal(t, path, c.Path())
}
type validator struct{} type validator struct{}
func (*validator) Validate(i interface{}) error { func (*validator) Validate(i interface{}) error {
@ -717,10 +665,10 @@ func TestContext_Validate(t *testing.T) {
e := New() e := New()
c := e.NewContext(nil, nil) c := e.NewContext(nil, nil)
testify.Error(t, c.Validate(struct{}{})) assert.Error(t, c.Validate(struct{}{}))
e.Validator = &validator{} e.Validator = &validator{}
testify.NoError(t, c.Validate(struct{}{})) assert.NoError(t, c.Validate(struct{}{}))
} }
func TestContext_QueryString(t *testing.T) { func TestContext_QueryString(t *testing.T) {
@ -728,21 +676,21 @@ func TestContext_QueryString(t *testing.T) {
queryString := "query=string&var=val" queryString := "query=string&var=val"
req := httptest.NewRequest(GET, "/?"+queryString, nil) req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil)
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
testify.Equal(t, queryString, c.QueryString()) assert.Equal(t, queryString, c.QueryString())
} }
func TestContext_Request(t *testing.T) { func TestContext_Request(t *testing.T) {
var c Context = new(context) var c Context = new(context)
testify.Nil(t, c.Request()) assert.Nil(t, c.Request())
req := httptest.NewRequest(GET, "/path", nil) req := httptest.NewRequest(http.MethodGet, "/path", nil)
c.SetRequest(req) c.SetRequest(req)
testify.Equal(t, req, c.Request()) assert.Equal(t, req, c.Request())
} }
func TestContext_Scheme(t *testing.T) { func TestContext_Scheme(t *testing.T) {
@ -799,14 +747,14 @@ func TestContext_Scheme(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
testify.Equal(t, tt.s, tt.c.Scheme()) assert.Equal(t, tt.s, tt.c.Scheme())
} }
} }
func TestContext_IsWebSocket(t *testing.T) { func TestContext_IsWebSocket(t *testing.T) {
tests := []struct { tests := []struct {
c Context c Context
ws testify.BoolAssertionFunc ws assert.BoolAssertionFunc
}{ }{
{ {
&context{ &context{
@ -814,7 +762,7 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"websocket"}}, Header: http.Header{HeaderUpgrade: []string{"websocket"}},
}, },
}, },
testify.True, assert.True,
}, },
{ {
&context{ &context{
@ -822,13 +770,13 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, Header: http.Header{HeaderUpgrade: []string{"Websocket"}},
}, },
}, },
testify.True, assert.True,
}, },
{ {
&context{ &context{
request: &http.Request{}, request: &http.Request{},
}, },
testify.False, assert.False,
}, },
{ {
&context{ &context{
@ -836,7 +784,7 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"other"}}, Header: http.Header{HeaderUpgrade: []string{"other"}},
}, },
}, },
testify.False, assert.False,
}, },
} }
@ -849,30 +797,14 @@ func TestContext_IsWebSocket(t *testing.T) {
func TestContext_Bind(t *testing.T) { func TestContext_Bind(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
u := new(user) u := new(user)
req.Header.Add(HeaderContentType, MIMEApplicationJSON) req.Header.Add(HeaderContentType, MIMEApplicationJSON)
err := c.Bind(u) err := c.Bind(u)
testify.NoError(t, err) assert.NoError(t, err)
testify.Equal(t, &user{1, "Jon Snow"}, u) assert.Equal(t, &user{1, "Jon Snow"}, u)
}
func TestContext_Logger(t *testing.T) {
e := New()
c := e.NewContext(nil, nil)
log1 := c.Logger()
testify.NotNil(t, log1)
log2 := log.New("echo2")
c.SetLogger(log2)
testify.Equal(t, log2, c.Logger())
// Resetting the context returns the initial logger
c.Reset(nil, nil)
testify.Equal(t, log1, c.Logger())
} }
func TestContext_RealIP(t *testing.T) { func TestContext_RealIP(t *testing.T) {
@ -925,6 +857,6 @@ func TestContext_RealIP(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
testify.Equal(t, tt.s, tt.c.RealIP()) assert.Equal(t, tt.s, tt.c.RealIP())
} }
} }

992
echo.go

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

16
go.mod
View File

@ -1,17 +1,13 @@
module github.com/labstack/echo/v4 module github.com/labstack/echo/v4
go 1.15 go 1.16
require ( require (
github.com/golang-jwt/jwt v3.2.2+incompatible github.com/davecgh/go-spew v1.1.1 // indirect
github.com/labstack/gommon v0.3.0 github.com/golang-jwt/jwt/v4 v4.0.0
github.com/mattn/go-colorable v0.1.8 // indirect github.com/stretchr/testify v1.7.0
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/stretchr/testify v1.4.0
github.com/valyala/fasttemplate v1.2.1 github.com/valyala/fasttemplate v1.2.1
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
golang.org/x/net v0.0.0-20210913180222-943fd674d43e
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 golang.org/x/time v0.0.0-20201208040808-7e3f01d25324
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
) )

48
go.sum
View File

@ -1,51 +1,29 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o=
github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4=
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210913180222-943fd674d43e h1:+b/22bPvDYt4NPDcy4xAGCmON713ONAWFeY3Z7I3tR8=
golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 h1:xrCZDmdtoloIiooiA9q0OQb9r8HejIHYoHGhGCe1pGg=
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

174
group.go
View File

@ -4,95 +4,117 @@ import (
"net/http" "net/http"
) )
type ( // Group is a set of sub-routes for a specified route. It can be used for inner
// Group is a set of sub-routes for a specified route. It can be used for inner // routes that share a common middleware or functionality that should be separate
// routes that share a common middleware or functionality that should be separate // from the parent echo instance while still inheriting from it.
// from the parent echo instance while still inheriting from it. type Group struct {
Group struct { host string
common prefix string
host string middleware []MiddlewareFunc
prefix string echo *Echo
middleware []MiddlewareFunc
echo *Echo
}
)
// Use implements `Echo#Use()` for sub-routes within the Group.
func (g *Group) Use(middleware ...MiddlewareFunc) {
g.middleware = append(g.middleware, middleware...)
if len(g.middleware) == 0 {
return
}
// Allow all requests to reach the group as they might get dropped if router
// doesn't find a match, making none of the group middleware process.
g.Any("", NotFoundHandler)
g.Any("/*", NotFoundHandler)
} }
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. // Use implements `Echo#Use()` for sub-routes within the Group.
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { // Group middlewares are not executed on request when there is no matching route found.
func (g *Group) Use(middleware ...MiddlewareFunc) {
g.middleware = append(g.middleware, middleware...)
}
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error.
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodConnect, path, h, m...) return g.Add(http.MethodConnect, path, h, m...)
} }
// DELETE implements `Echo#DELETE()` for sub-routes within the Group. // DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error.
func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodDelete, path, h, m...) return g.Add(http.MethodDelete, path, h, m...)
} }
// GET implements `Echo#GET()` for sub-routes within the Group. // GET implements `Echo#GET()` for sub-routes within the Group. Panics on error.
func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodGet, path, h, m...) return g.Add(http.MethodGet, path, h, m...)
} }
// HEAD implements `Echo#HEAD()` for sub-routes within the Group. // HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error.
func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodHead, path, h, m...) return g.Add(http.MethodHead, path, h, m...)
} }
// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. // OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error.
func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodOptions, path, h, m...) return g.Add(http.MethodOptions, path, h, m...)
} }
// PATCH implements `Echo#PATCH()` for sub-routes within the Group. // PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error.
func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPatch, path, h, m...) return g.Add(http.MethodPatch, path, h, m...)
} }
// POST implements `Echo#POST()` for sub-routes within the Group. // POST implements `Echo#POST()` for sub-routes within the Group. Panics on error.
func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPost, path, h, m...) return g.Add(http.MethodPost, path, h, m...)
} }
// PUT implements `Echo#PUT()` for sub-routes within the Group. // PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error.
func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPut, path, h, m...) return g.Add(http.MethodPut, path, h, m...)
} }
// TRACE implements `Echo#TRACE()` for sub-routes within the Group. // TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error.
func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodTrace, path, h, m...) return g.Add(http.MethodTrace, path, h, m...)
} }
// Any implements `Echo#Any()` for sub-routes within the Group. // Any implements `Echo#Any()` for sub-routes within the Group. Panics on error.
func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
routes := make([]*Route, len(methods)) errs := make([]error, 0)
for i, m := range methods { ris := make(Routes, 0)
routes[i] = g.Add(m, path, handler, middleware...) for _, m := range methods {
ri, err := g.AddRoute(Route{
Method: m,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
errs = append(errs, err)
continue
}
ris = append(ris, ri)
} }
return routes if len(errs) > 0 {
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ris
} }
// Match implements `Echo#Match()` for sub-routes within the Group. // Match implements `Echo#Match()` for sub-routes within the Group. Panics on error.
func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
routes := make([]*Route, len(methods)) errs := make([]error, 0)
for i, m := range methods { ris := make(Routes, 0)
routes[i] = g.Add(m, path, handler, middleware...) for _, m := range methods {
ri, err := g.AddRoute(Route{
Method: m,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
errs = append(errs, err)
continue
}
ris = append(ris, ri)
} }
return routes if len(errs) > 0 {
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ris
} }
// Group creates a new sub-group with prefix and optional sub-group-level middleware. // Group creates a new sub-group with prefix and optional sub-group-level middleware.
// Important! Group middlewares are only executed in case there was exact route match and not
// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add
// a catch-all route `/*` for the group which handler returns always 404
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...) m = append(m, g.middleware...)
@ -102,23 +124,43 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
return return
} }
// Static implements `Echo#Static()` for sub-routes within the Group. // Static implements `Echo#Static()` for sub-routes within the Group. Panics on error.
func (g *Group) Static(prefix, root string) { func (g *Group) Static(prefix, root string, middleware ...MiddlewareFunc) RouteInfo {
g.static(prefix, root, g.GET) return g.Add(
http.MethodGet,
prefix+"*",
StaticDirectoryHandler(root, false),
middleware...,
)
} }
// File implements `Echo#File()` for sub-routes within the Group. // File implements `Echo#File()` for sub-routes within the Group. Panics on error.
func (g *Group) File(path, file string) { func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
g.file(path, file, g.GET) handler := func(c Context) error {
return c.File(file)
}
return g.Add(http.MethodGet, path, handler, middleware...)
} }
// Add implements `Echo#Add()` for sub-routes within the Group. // Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
// Combine into a new slice to avoid accidentally passing the same slice for ri, err := g.AddRoute(Route{
Method: method,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ri
}
// AddRoute registers a new Routable with Router
func (g *Group) AddRoute(route Routable) (RouteInfo, error) {
// Combine middleware into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the // multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls. // middleware from earlier calls.
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) groupRoute := route.ForGroup(g.prefix, append([]MiddlewareFunc{}, g.middleware...))
m = append(m, g.middleware...) return g.echo.add(g.host, groupRoute)
m = append(m, middleware...)
return g.echo.add(g.host, method, g.prefix+path, handler, m...)
} }

View File

@ -1,31 +1,68 @@
package echo package echo
import ( import (
"github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
// TODO: Fix me func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) {
func TestGroup(t *testing.T) { e := New()
g := New().Group("/group")
h := func(Context) error { return nil } called := false
g.CONNECT("/", h) mw := func(next HandlerFunc) HandlerFunc {
g.DELETE("/", h) return func(c Context) error {
g.GET("/", h) called = true
g.HEAD("/", h) return c.NoContent(http.StatusTeapot)
g.OPTIONS("/", h) }
g.PATCH("/", h) }
g.POST("/", h) // even though group has middleware it will not be executed when there are no routes under that group
g.PUT("/", h) _ = e.Group("/group", mw)
g.TRACE("/", h)
g.Any("/", h) status, body := request(http.MethodGet, "/group/nope", e)
g.Match([]string{http.MethodGet, http.MethodPost}, "/", h) assert.Equal(t, http.StatusNotFound, status)
g.Static("/static", "/tmp") assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
g.File("/walle", "_fixture/images//walle.png")
assert.False(t, called)
}
func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) {
e := New()
called := false
mw := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
called = true
return c.NoContent(http.StatusTeapot)
}
}
// even though group has middleware and routes when we have no match on some route the middlewares for that
// group will not be executed
g := e.Group("/group", mw)
g.GET("/yes", handlerFunc)
status, body := request(http.MethodGet, "/group/nope", e)
assert.Equal(t, http.StatusNotFound, status)
assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
assert.False(t, called)
}
func TestGroup_multiLevelGroup(t *testing.T) {
e := New()
api := e.Group("/api")
users := api.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
status, body := request(http.MethodGet, "/api/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
} }
func TestGroupFile(t *testing.T) { func TestGroupFile(t *testing.T) {
@ -92,11 +129,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
} }
m2 := func(next HandlerFunc) HandlerFunc { m2 := func(next HandlerFunc) HandlerFunc {
return func(c Context) error { return func(c Context) error {
return c.String(http.StatusOK, c.Path()) return c.String(http.StatusOK, c.RouteInfo().Path())
} }
} }
h := func(c Context) error { h := func(c Context) error {
return c.String(http.StatusOK, c.Path()) return c.String(http.StatusOK, c.RouteInfo().Path())
} }
g.Use(m1) g.Use(m1)
g.GET("/help", h, m2) g.GET("/help", h, m2)
@ -119,3 +156,442 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
assert.Equal(t, "/*", m) assert.Equal(t, "/*", m)
} }
func TestGroup_CONNECT(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.CONNECT("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodConnect, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodConnect, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_DELETE(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.DELETE("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodDelete, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodDelete, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_HEAD(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.HEAD("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodHead, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodHead+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodHead, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_OPTIONS(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.OPTIONS("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodOptions, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodOptions, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_PATCH(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.PATCH("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPatch, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPatch, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_POST(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.POST("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPost, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPost+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPost, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_PUT(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.PUT("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPut, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPut+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPut, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_TRACE(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.TRACE("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodTrace, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodTrace, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_Any(t *testing.T) {
e := New()
users := e.Group("/users")
ris := users.Any("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Len(t, ris, 11)
for _, m := range methods {
status, body := request(m, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_AnyWithErrors(t *testing.T) {
e := New()
users := e.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
errs := func() (errs []error) {
defer func() {
if r := recover(); r != nil {
if tmpErr, ok := r.([]error); ok {
errs = tmpErr
return
}
panic(r)
}
}()
users.Any("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
return nil
}()
assert.Len(t, errs, 1)
assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
for _, m := range methods {
status, body := request(m, "/users/activate", e)
expect := http.StatusTeapot
if m == http.MethodGet {
expect = http.StatusOK
}
assert.Equal(t, expect, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_Match(t *testing.T) {
e := New()
myMethods := []string{http.MethodGet, http.MethodPost}
users := e.Group("/users")
ris := users.Match(myMethods, "/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Len(t, ris, 2)
for _, m := range myMethods {
status, body := request(m, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_MatchWithErrors(t *testing.T) {
e := New()
users := e.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
myMethods := []string{http.MethodGet, http.MethodPost}
errs := func() (errs []error) {
defer func() {
if r := recover(); r != nil {
if tmpErr, ok := r.([]error); ok {
errs = tmpErr
return
}
panic(r)
}
}()
users.Match(myMethods, "/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
return nil
}()
assert.Len(t, errs, 1)
assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
for _, m := range myMethods {
status, body := request(m, "/users/activate", e)
expect := http.StatusTeapot
if m == http.MethodGet {
expect = http.StatusOK
}
assert.Equal(t, expect, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_Static(t *testing.T) {
e := New()
g := e.Group("/books")
ri := g.Static("/download", "_fixture")
assert.Equal(t, http.MethodGet, ri.Method())
assert.Equal(t, "/books/download*", ri.Path())
assert.Equal(t, "GET:/books/download*", ri.Name())
assert.Equal(t, []string{"*"}, ri.Params())
req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
body := rec.Body.String()
assert.True(t, strings.HasPrefix(body, "<!doctype html>"))
}
func TestGroup_StaticMultiTest(t *testing.T) {
var testCases = []struct {
name string
givenPrefix string
givenRoot string
whenURL string
expectStatus int
expectHeaderLocation string
expectBodyStartsWith string
}{
{
name: "ok",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/test/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "ok, without prefix",
givenPrefix: "",
givenRoot: "_fixture/images",
whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png`
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "nok, without prefix does not serve dir index",
givenPrefix: "",
givenRoot: "_fixture/images",
whenURL: "/test/", // `/test` + `*` creates route `/test*`
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "No file",
givenPrefix: "/images",
givenRoot: "_fixture/scripts",
whenURL: "/test/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/test/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory Redirect",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory Redirect with non-root path",
givenPrefix: "/static",
givenRoot: "_fixture",
whenURL: "/test/static",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/static/",
expectBodyStartsWith: "",
},
{
name: "Prefixed directory 404 (request URL without slash)",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
givenRoot: "_fixture",
whenURL: "/test/folder", // no trailing slash
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Prefixed directory redirect (without slash redirect to slash)",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
givenRoot: "_fixture",
whenURL: "/test/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/",
givenRoot: "_fixture",
whenURL: "/test/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets",
givenRoot: "_fixture",
whenURL: "/test/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Sub-directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/test/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/test/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
g := e.Group("/test")
g.Static(tc.givenPrefix, tc.givenRoot)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String()
if tc.expectBodyStartsWith != "" {
assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
} else {
assert.Equal(t, "", body)
}
if tc.expectHeaderLocation != "" {
assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
} else {
_, ok := rec.Result().Header["Location"]
assert.False(t, ok)
}
})
}
}

74
httperror.go Normal file
View File

@ -0,0 +1,74 @@
package echo
import (
"errors"
"fmt"
"net/http"
)
// Errors
var (
ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType)
ErrNotFound = NewHTTPError(http.StatusNotFound)
ErrUnauthorized = NewHTTPError(http.StatusUnauthorized)
ErrForbidden = NewHTTPError(http.StatusForbidden)
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests)
ErrBadRequest = NewHTTPError(http.StatusBadRequest)
ErrBadGateway = NewHTTPError(http.StatusBadGateway)
ErrInternalServerError = NewHTTPError(http.StatusInternalServerError)
ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout)
ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable)
ErrValidatorNotRegistered = errors.New("validator not registered")
ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
ErrInvalidListenerNetwork = errors.New("invalid listener network")
)
// HTTPError represents an error that occurred while handling a request.
type HTTPError struct {
Code int `json:"-"`
Message interface{} `json:"message"`
Internal error `json:"-"` // Stores the error returned by an external dependency
}
// NewHTTPError creates a new HTTPError instance.
func NewHTTPError(code int, message ...interface{}) *HTTPError { // FIXME: this need cleanup - why vararg if [0] is only used?
he := &HTTPError{Code: code, Message: http.StatusText(code)}
if len(message) > 0 {
he.Message = message[0]
}
return he
}
// NewHTTPErrorWithInternal creates a new HTTPError instance with internal error set.
func NewHTTPErrorWithInternal(code int, internalError error, message ...interface{}) *HTTPError {
he := NewHTTPError(code, message...)
he.Internal = internalError
return he
}
// Error makes it compatible with `error` interface.
func (he *HTTPError) Error() string {
if he.Internal == nil {
return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message)
}
return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal)
}
// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field
func (he *HTTPError) WithInternal(err error) *HTTPError {
return &HTTPError{
Code: he.Code,
Message: he.Message,
Internal: err,
}
}
// Unwrap satisfies the Go 1.13 error wrapper interface.
func (he *HTTPError) Unwrap() error {
return he.Internal
}

52
httperror_test.go Normal file
View File

@ -0,0 +1,52 @@
package echo
import (
"errors"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)
func TestHTTPError(t *testing.T) {
t.Run("non-internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
assert.Equal(t, "code=400, message=map[code:12]", err.Error())
})
t.Run("internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
err = err.WithInternal(errors.New("internal error"))
assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error())
})
}
func TestNewHTTPErrorWithInternal(t *testing.T) {
he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"), "test message")
assert.Equal(t, "code=400, message=test message, internal=test", he.Error())
}
func TestNewHTTPErrorWithInternal_noCustomMessage(t *testing.T) {
he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"))
assert.Equal(t, "code=400, message=Bad Request, internal=test", he.Error())
}
func TestHTTPError_Unwrap(t *testing.T) {
t.Run("non-internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
assert.Nil(t, errors.Unwrap(err))
})
t.Run("internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
err = err.WithInternal(errors.New("internal error"))
assert.Equal(t, "internal error", errors.Unwrap(err).Error())
})
}

11
json.go
View File

@ -23,9 +23,16 @@ func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string
func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error { func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error {
err := json.NewDecoder(c.Request().Body).Decode(i) err := json.NewDecoder(c.Request().Body).Decode(i)
if ute, ok := err.(*json.UnmarshalTypeError); ok { if ute, ok := err.(*json.UnmarshalTypeError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) return NewHTTPErrorWithInternal(
http.StatusBadRequest,
err,
fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset),
)
} else if se, ok := err.(*json.SyntaxError); ok { } else if se, ok := err.(*json.SyntaxError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) return NewHTTPErrorWithInternal(http.StatusBadRequest,
err,
fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error()),
)
} }
return err return err
} }

168
log.go
View File

@ -1,41 +1,141 @@
package echo package echo
import ( import (
"bytes"
"io" "io"
"strconv"
"github.com/labstack/gommon/log" "sync"
"time"
) )
type ( //-----------------------------------------------------------------------------
// Logger defines the logging interface. // Example for Zap (https://github.com/uber-go/zap)
Logger interface { //func main() {
Output() io.Writer // e := echo.New()
SetOutput(w io.Writer) // logger, _ := zap.NewProduction()
Prefix() string // e.Logger = &ZapLogger{logger: logger}
SetPrefix(p string) //}
Level() log.Lvl //type ZapLogger struct {
SetLevel(v log.Lvl) // logger *zap.Logger
SetHeader(h string) //}
Print(i ...interface{}) //
Printf(format string, args ...interface{}) //func (l *ZapLogger) Write(p []byte) (n int, err error) {
Printj(j log.JSON) // // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
Debug(i ...interface{}) // l.logger.Info(string(p), zap.String("subsystem", "echo")) // naively log everything as string message.
Debugf(format string, args ...interface{}) // return len(p), nil
Debugj(j log.JSON) //}
Info(i ...interface{}) //
Infof(format string, args ...interface{}) //func (l *ZapLogger) Error(err error) {
Infoj(j log.JSON) // l.logger.Error(err.Error(), zap.Error(err), zap.String("subsystem", "echo"))
Warn(i ...interface{}) //}
Warnf(format string, args ...interface{})
Warnj(j log.JSON) //-----------------------------------------------------------------------------
Error(i ...interface{}) // Example for Zerolog (https://github.com/rs/zerolog)
Errorf(format string, args ...interface{}) //func main() {
Errorj(j log.JSON) // e := echo.New()
Fatal(i ...interface{}) // logger := zerolog.New(os.Stdout)
Fatalj(j log.JSON) // e.Logger = &ZeroLogger{logger: &logger}
Fatalf(format string, args ...interface{}) //}
Panic(i ...interface{}) //
Panicj(j log.JSON) //type ZeroLogger struct {
Panicf(format string, args ...interface{}) // logger *zerolog.Logger
//}
//
//func (l *ZeroLogger) Write(p []byte) (n int, err error) {
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
// l.logger.Info().Str("subsystem", "echo").Msg(string(p)) // naively log everything as string message.
// return len(p), nil
//}
//
//func (l *ZeroLogger) Error(err error) {
// l.logger.Error().Str("subsystem", "echo").Err(err).Msg(err.Error())
//}
//-----------------------------------------------------------------------------
// Example for Logrus (https://github.com/sirupsen/logrus)
//func main() {
// e := echo.New()
// e.Logger = &LogrusLogger{logger: logrus.New()}
//}
//
//type LogrusLogger struct {
// logger *logrus.Logger
//}
//
//func (l *LogrusLogger) Write(p []byte) (n int, err error) {
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Info(string(p)) // naively log everything as string message.
// return len(p), nil
//}
//
//func (l *LogrusLogger) Error(err error) {
// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Error(err)
//}
// Logger defines the logging interface that Echo uses internally in few places.
// For logging in handlers use your own logger instance (dependency injected or package/public variable) from logging framework of your choice.
type Logger interface {
// Write provides writer interface for http.Server `ErrorLog` and for logging startup messages.
// `http.Server.ErrorLog` logs errors from accepting connections, unexpected behavior from handlers,
// and underlying FileSystem errors.
// `logger` middleware will use this method to write its JSON payload.
Write(p []byte) (n int, err error)
// Error logs the error
Error(err error)
}
// jsonLogger is similar logger formatting implementation as `v4` had. It is not particularly fast or efficient. Only
// goal it to exist is to have somewhat backwards compatibility with `v4` for Echo internals logging formatting.
// It is not meant for logging in handlers/middlewares. Use some real logging library for those cases.
type jsonLogger struct {
writer io.Writer
bufferPool sync.Pool
timeNow func() time.Time
}
func newJSONLogger(writer io.Writer) *jsonLogger {
return &jsonLogger{
writer: writer,
bufferPool: sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 256))
},
},
timeNow: time.Now,
} }
) }
func (l *jsonLogger) Write(p []byte) (n int, err error) {
pLen := len(p)
if pLen >= 2 && // naively try to avoid JSON values to be wrapped into message
(p[0] == '{' && p[pLen-2] == '}' && p[pLen-1] == '\n') ||
(p[0] == '{' && p[pLen-1] == '}') {
return l.writer.Write(p)
}
// we log with WARN level as we have no idea what that message level should be. From Echo perspective this method is
// called when we pass Echo logger to http.Server.ErrorLog and there are problems inside http.Server - which probably
// deserves at least WARN level.
return l.printf("WARN", string(p))
}
func (l *jsonLogger) Error(err error) {
_, _ = l.printf("ERROR", err.Error())
}
func (l *jsonLogger) printf(level string, message string) (n int, err error) {
buf := l.bufferPool.Get().(*bytes.Buffer)
buf.Reset()
defer l.bufferPool.Put(buf)
buf.WriteString(`{"time":"`)
buf.WriteString(l.timeNow().Format(time.RFC3339Nano))
buf.WriteString(`","level":"`)
buf.WriteString(level)
buf.WriteString(`","prefix":"echo","message":`)
buf.WriteString(strconv.Quote(message))
buf.WriteString("}\n")
return l.writer.Write(buf.Bytes())
}

77
log_test.go Normal file
View File

@ -0,0 +1,77 @@
package echo
import (
"bytes"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestJsonLogger_Write(t *testing.T) {
var testCases = []struct {
name string
when []byte
expect string
}{
{
name: "ok, write non JSONlike message",
when: []byte("version: %v, build: %v"),
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: %v, build: %v"}` + "\n",
},
{
name: "ok, write quoted message",
when: []byte(`version: "%v"`),
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: \"%v\""}` + "\n",
},
{
name: "ok, write JSON",
when: []byte(`{"version": 123}` + "\n"),
expect: `{"version": 123}` + "\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
logger := newJSONLogger(buf)
logger.timeNow = func() time.Time {
return time.Unix(1631045377, 0)
}
_, err := logger.Write(tc.when)
result := buf.String()
assert.Equal(t, tc.expect, result)
assert.NoError(t, err)
})
}
}
func TestJsonLogger_Error(t *testing.T) {
var testCases = []struct {
name string
whenError error
expect string
}{
{
name: "ok",
whenError: ErrForbidden,
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"ERROR","prefix":"echo","message":"code=403, message=Forbidden"}` + "\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
logger := newJSONLogger(buf)
logger.timeNow = func() time.Time {
return time.Unix(1631045377, 0)
}
logger.Error(tc.whenError)
result := buf.String()
assert.Equal(t, tc.expect, result)
})
}
}

13
middleware/DEVELOPMENT.md Normal file
View File

@ -0,0 +1,13 @@
# Development Guidelines for middlewares
// FIXME: add info about `MiddlewareConfigurator` interface
## Best practices:
* Do not use `panic` in middleware creator functions in case of invalid configuration.
* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead
because previous middlewares up in call chain could have logic for dealing with returned errors.
* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they
want to create middleware with panics or with returning errors on configuration errors.
* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same.

View File

@ -1,64 +1,59 @@
package middleware package middleware
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // BasicAuthConfig defines the config for BasicAuthWithConfig middleware.
// BasicAuthConfig defines the config for BasicAuth middleware. type BasicAuthConfig struct {
BasicAuthConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Validator is a function to validate BasicAuth credentials. // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers
// Required. // this function would be called once for each header until first valid result is returned
Validator BasicAuthValidator // Required.
Validator BasicAuthValidator
// Realm is a string to define realm attribute of BasicAuth. // Realm is a string to define realm attribute of BasicAuthWithConfig.
// Default value "Restricted". // Default value "Restricted".
Realm string Realm string
} }
// BasicAuthValidator defines a function to validate BasicAuth credentials. // BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials.
BasicAuthValidator func(string, string, echo.Context) (bool, error) type BasicAuthValidator func(c echo.Context, user string, password string) (bool, error)
)
const ( const (
basic = "basic" basic = "basic"
defaultRealm = "Restricted" defaultRealm = "Restricted"
) )
var (
// DefaultBasicAuthConfig is the default BasicAuth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{
Skipper: DefaultSkipper,
Realm: defaultRealm,
}
)
// BasicAuth returns an BasicAuth middleware. // BasicAuth returns an BasicAuth middleware.
// //
// For valid credentials it calls the next handler. // For valid credentials it calls the next handler.
// For missing or invalid credentials, it sends "401 - Unauthorized" response. // For missing or invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig return BasicAuthWithConfig(BasicAuthConfig{Validator: fn})
c.Validator = fn
return BasicAuthWithConfig(c)
} }
// BasicAuthWithConfig returns an BasicAuth middleware with config. // BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config.
// See `BasicAuth()`.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration
func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Validator == nil { if config.Validator == nil {
panic("echo: basic-auth middleware requires a validator function") return nil, errors.New("echo basic-auth middleware requires a validator function")
} }
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultBasicAuthConfig.Skipper config.Skipper = DefaultSkipper
} }
if config.Realm == "" { if config.Realm == "" {
config.Realm = defaultRealm config.Realm = defaultRealm
@ -70,29 +65,33 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
auth := c.Request().Header.Get(echo.HeaderAuthorization) var lastError error
l := len(basic) l := len(basic)
for _, auth := range c.Request().Header[echo.HeaderAuthorization] {
if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
b, err := base64.StdEncoding.DecodeString(auth[l+1:]) continue
if err != nil {
return err
} }
cred := string(b)
for i := 0; i < len(cred); i++ { b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
if cred[i] == ':' { if errDecode != nil {
// Verify credentials lastError = fmt.Errorf("invalid basic auth value: %w", errDecode)
valid, err := config.Validator(cred[:i], cred[i+1:], c) continue
if err != nil { }
return err idx := bytes.IndexByte(b, ':')
} else if valid { if idx >= 0 {
return next(c) valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:]))
} if errValidate != nil {
break lastError = errValidate
} else if valid {
return next(c)
} }
} }
} }
if lastError != nil {
return lastError
}
realm := defaultRealm realm := defaultRealm
if config.Realm != defaultRealm { if config.Realm != defaultRealm {
realm = strconv.Quote(config.Realm) realm = strconv.Quote(config.Realm)
@ -102,5 +101,5 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
return echo.ErrUnauthorized return echo.ErrUnauthorized
} }
} }, nil
} }

View File

@ -2,6 +2,7 @@ package middleware
import ( import (
"encoding/base64" "encoding/base64"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -12,60 +13,146 @@ import (
) )
func TestBasicAuth(t *testing.T) { func TestBasicAuth(t *testing.T) {
e := echo.New() validatorFunc := func(c echo.Context, u, p string) (bool, error) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
f := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" { if u == "joe" && p == "secret" {
return true, nil return true, nil
} }
if u == "error" {
return false, errors.New(p)
}
return false, nil return false, nil
} }
h := BasicAuth(f)(func(c echo.Context) error { defaultConfig := BasicAuthConfig{Validator: validatorFunc}
return c.String(http.StatusOK, "test")
})
assert := assert.New(t) var testCases = []struct {
name string
givenConfig BasicAuthConfig
whenAuth []string
expectHeader string
expectErr string
}{
{
name: "ok",
givenConfig: defaultConfig,
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, multiple",
givenConfig: defaultConfig,
whenAuth: []string{
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
basic + " NOT_BASE64",
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
},
},
{
name: "nok, invalid Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, not base64 Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
expectErr: "invalid basic auth value: illegal base64 data at input byte 3",
},
{
name: "nok, missing Authorization header",
givenConfig: defaultConfig,
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "ok, realm",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, realm, case-insensitive header scheme",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "nok, realm, invalid Authorization header",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm="someRealm"`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, validator func returns an error",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
expectErr: "my_error",
},
{
name: "ok, skipped",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool {
return true
}},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
},
}
// Valid credentials for _, tc := range testCases {
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) t.Run(tc.name, func(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, auth) e := echo.New()
assert.NoError(h(c)) req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
h = BasicAuthWithConfig(BasicAuthConfig{ config := tc.givenConfig
Skipper: nil,
Validator: f,
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials mw, err := config.ToMiddleware()
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) assert.NoError(t, err)
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
// Case-insensitive header scheme h := mw(func(c echo.Context) error {
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) return c.String(http.StatusTeapot, "test")
req.Header.Set(echo.HeaderAuthorization, auth) })
assert.NoError(h(c))
// Invalid credentials if len(tc.whenAuth) != 0 {
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) for _, a := range tc.whenAuth {
req.Header.Set(echo.HeaderAuthorization, auth) req.Header.Add(echo.HeaderAuthorization, a)
he := h(c).(*echo.HTTPError) }
assert.Equal(http.StatusUnauthorized, he.Code) }
assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) err = h(c)
// Missing Authorization header if tc.expectErr != "" {
req.Header.Del(echo.HeaderAuthorization) assert.Equal(t, http.StatusOK, res.Code)
he = h(c).(*echo.HTTPError) assert.EqualError(t, err, tc.expectErr)
assert.Equal(http.StatusUnauthorized, he.Code) } else {
assert.Equal(t, http.StatusTeapot, res.Code)
// Invalid Authorization header assert.NoError(t, err)
auth = base64.StdEncoding.EncodeToString([]byte("invalid")) }
req.Header.Set(echo.HeaderAuthorization, auth) if tc.expectHeader != "" {
he = h(c).(*echo.HTTPError) assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
assert.Equal(http.StatusUnauthorized, he.Code) }
})
}
}
func TestBasicAuth_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BasicAuth(nil)
assert.NotNil(t, mw)
})
mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) {
return true, nil
})
assert.NotNil(t, mw)
}
func TestBasicAuthWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
assert.NotNil(t, mw)
})
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) {
return true, nil
}})
assert.NotNil(t, mw)
} }

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
@ -11,63 +12,56 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // BodyDumpConfig defines the config for BodyDump middleware.
// BodyDumpConfig defines the config for BodyDump middleware. type BodyDumpConfig struct {
BodyDumpConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Handler receives request and response payload. // Handler receives request and response payload.
// Required. // Required.
Handler BodyDumpHandler Handler BodyDumpHandler
} }
// BodyDumpHandler receives the request and response payload. // BodyDumpHandler receives the request and response payload.
BodyDumpHandler func(echo.Context, []byte, []byte) type BodyDumpHandler func(c echo.Context, reqBody []byte, resBody []byte)
bodyDumpResponseWriter struct { type bodyDumpResponseWriter struct {
io.Writer io.Writer
http.ResponseWriter http.ResponseWriter
} }
)
var (
// DefaultBodyDumpConfig is the default BodyDump middleware config.
DefaultBodyDumpConfig = BodyDumpConfig{
Skipper: DefaultSkipper,
}
)
// BodyDump returns a BodyDump middleware. // BodyDump returns a BodyDump middleware.
// //
// BodyDump middleware captures the request and response payload and calls the // BodyDump middleware captures the request and response payload and calls the
// registered handler. // registered handler.
func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc { func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
c := DefaultBodyDumpConfig return BodyDumpWithConfig(BodyDumpConfig{Handler: handler})
c.Handler = handler
return BodyDumpWithConfig(c)
} }
// BodyDumpWithConfig returns a BodyDump middleware with config. // BodyDumpWithConfig returns a BodyDump middleware with config.
// See: `BodyDump()`. // See: `BodyDump()`.
func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration
func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Handler == nil { if config.Handler == nil {
panic("echo: body-dump middleware requires a handler function") return nil, errors.New("echo body-dump middleware requires a handler function")
} }
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper config.Skipper = DefaultSkipper
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) error {
if config.Skipper(c) { if config.Skipper(c) {
return next(c) return next(c)
} }
// Request // Request
reqBody := []byte{} reqBody := []byte{}
if c.Request().Body != nil { // Read if c.Request().Body != nil {
reqBody, _ = ioutil.ReadAll(c.Request().Body) reqBody, _ = ioutil.ReadAll(c.Request().Body)
} }
c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset
@ -78,16 +72,14 @@ func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer} writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer}
c.Response().Writer = writer c.Response().Writer = writer
if err = next(c); err != nil { err := next(c)
c.Error(err)
}
// Callback // Callback
config.Handler(c, reqBody, resBody.Bytes()) config.Handler(c, reqBody, resBody.Bytes())
return return err
} }
} }, nil
} }
func (w *bodyDumpResponseWriter) WriteHeader(code int) { func (w *bodyDumpResponseWriter) WriteHeader(code int) {

View File

@ -28,31 +28,48 @@ func TestBodyDump(t *testing.T) {
requestBody := "" requestBody := ""
responseBody := "" responseBody := ""
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody) requestBody = string(reqBody)
responseBody = string(resBody) responseBody = string(resBody)
}) }}.ToMiddleware()
assert.NoError(t, err)
assert := assert.New(t) if assert.NoError(t, mw(h)(c)) {
assert.Equal(t, requestBody, hw)
if assert.NoError(mw(h)(c)) { assert.Equal(t, responseBody, hw)
assert.Equal(requestBody, hw) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(responseBody, hw) assert.Equal(t, hw, rec.Body.String())
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.String())
} }
// Must set default skipper
BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
},
})
} }
func TestBodyDumpFails(t *testing.T) { func TestBodyDump_skipper(t *testing.T) {
e := echo.New()
isCalled := false
mw, err := BodyDumpConfig{
Skipper: func(c echo.Context) bool {
return true
},
Handler: func(c echo.Context, reqBody, resBody []byte) {
isCalled = true
},
}.ToMiddleware()
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
return errors.New("some error")
}
err = mw(h)(c)
assert.EqualError(t, err, "some error")
assert.False(t, isCalled)
}
func TestBodyDump_fails(t *testing.T) {
e := echo.New() e := echo.New()
hw := "Hello, World!" hw := "Hello, World!"
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
@ -62,30 +79,37 @@ func TestBodyDumpFails(t *testing.T) {
return errors.New("some error") return errors.New("some error")
} }
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}.ToMiddleware()
assert.NoError(t, err)
if !assert.Error(t, mw(h)(c)) { err = mw(h)(c)
t.FailNow() assert.EqualError(t, err, "some error")
} assert.Equal(t, http.StatusOK, rec.Code)
}
func TestBodyDumpWithConfig_panic(t *testing.T) {
assert.Panics(t, func() { assert.Panics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{ mw := BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil, Skipper: nil,
Handler: nil, Handler: nil,
}) })
assert.NotNil(t, mw)
}) })
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{ mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}})
Skipper: func(c echo.Context) bool { assert.NotNil(t, mw)
return true })
}, }
Handler: func(c echo.Context, reqBody, resBody []byte) {
}, func TestBodyDump_panic(t *testing.T) {
}) assert.Panics(t, func() {
mw := BodyDump(nil)
if !assert.Error(t, mw(h)(c)) { assert.NotNil(t, mw)
t.FailNow() })
}
assert.NotPanics(t, func() {
BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
}) })
} }

View File

@ -1,98 +1,83 @@
package middleware package middleware
import ( import (
"fmt"
"io" "io"
"sync" "sync"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/bytes"
) )
type ( // BodyLimitConfig defines the config for BodyLimitWithConfig middleware.
// BodyLimitConfig defines the config for BodyLimit middleware. type BodyLimitConfig struct {
BodyLimitConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Maximum allowed size for a request body, it can be specified // LimitBytes is maximum allowed size in bytes for a request body
// as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. LimitBytes int64
Limit string `yaml:"limit"` }
limit int64
}
limitedReader struct { type limitedReader struct {
BodyLimitConfig BodyLimitConfig
reader io.ReadCloser reader io.ReadCloser
read int64 read int64
context echo.Context context echo.Context
} }
)
var (
// DefaultBodyLimitConfig is the default BodyLimit middleware config.
DefaultBodyLimitConfig = BodyLimitConfig{
Skipper: DefaultSkipper,
}
)
// BodyLimit returns a BodyLimit middleware. // BodyLimit returns a BodyLimit middleware.
// //
// BodyLimit middleware sets the maximum allowed size for a request body, if the // BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it
// size exceeds the configured limit, it sends "413 - Request Entity Too Large" // sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request
// response. The BodyLimit is determined based on both `Content-Length` request
// header and actual content read, which makes it super secure. // header and actual content read, which makes it super secure.
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, func BodyLimit(limitBytes int64) echo.MiddlewareFunc {
// G, T or P. return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes})
func BodyLimit(limit string) echo.MiddlewareFunc {
c := DefaultBodyLimitConfig
c.Limit = limit
return BodyLimitWithConfig(c)
} }
// BodyLimitWithConfig returns a BodyLimit middleware with config. // BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for
// See: `BodyLimit()`. // a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response.
// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which
// makes it super secure.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
if config.Skipper == nil { }
config.Skipper = DefaultBodyLimitConfig.Skipper
}
limit, err := bytes.Parse(config.Limit) // ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration
if err != nil { func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit)) if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
pool := sync.Pool{
New: func() interface{} {
return &limitedReader{BodyLimitConfig: config}
},
} }
config.limit = limit
pool := limitedReaderPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
if config.Skipper(c) { if config.Skipper(c) {
return next(c) return next(c)
} }
req := c.Request() req := c.Request()
// Based on content length // Based on content length
if req.ContentLength > config.limit { if req.ContentLength > config.LimitBytes {
return echo.ErrStatusRequestEntityTooLarge return echo.ErrStatusRequestEntityTooLarge
} }
// Based on content read // Based on content read
r := pool.Get().(*limitedReader) r := pool.Get().(*limitedReader)
r.Reset(req.Body, c) r.Reset(c, req.Body)
defer pool.Put(r) defer pool.Put(r)
req.Body = r req.Body = r
return next(c) return next(c)
} }
} }, nil
} }
func (r *limitedReader) Read(b []byte) (n int, err error) { func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b) n, err = r.reader.Read(b)
r.read += int64(n) r.read += int64(n)
if r.read > r.limit { if r.read > r.LimitBytes {
return n, echo.ErrStatusRequestEntityTooLarge return n, echo.ErrStatusRequestEntityTooLarge
} }
return return
@ -102,16 +87,8 @@ func (r *limitedReader) Close() error {
return r.reader.Close() return r.reader.Close()
} }
func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) { func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) {
r.reader = reader r.reader = reader
r.context = context r.context = context
r.read = 0 r.read = 0
} }
func limitedReaderPool(c BodyLimitConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
return &limitedReader{BodyLimitConfig: c}
},
}
}

View File

@ -11,6 +11,137 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
// Based on content length (within limit)
mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
// Based on content read (overlimit)
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
assert.NoError(t, err)
he := mw(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Hello, World!", rec.Body.String())
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
assert.NoError(t, err)
he = mw(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
}
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
config := BodyLimitConfig{
Skipper: DefaultSkipper,
LimitBytes: 2,
}
reader := &limitedReader{
BodyLimitConfig: config,
reader: ioutil.NopCloser(bytes.NewReader(hw)),
context: e.NewContext(req, rec),
}
// read all should return ErrStatusRequestEntityTooLarge
_, err := ioutil.ReadAll(reader)
he := err.(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
reader.Reset(e.NewContext(req, rec), ioutil.NopCloser(bytes.NewReader(hw)))
n, err := reader.Read(bt)
assert.Equal(t, 2, n)
assert.Equal(t, nil, err)
}
func TestBodyLimit_skipper(t *testing.T) {
e := echo.New()
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
mw, err := BodyLimitConfig{
Skipper: func(c echo.Context) bool {
return true
},
LimitBytes: 2,
}.ToMiddleware()
assert.NoError(t, err)
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
func TestBodyLimitWithConfig(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB})
err := mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
func TestBodyLimit(t *testing.T) { func TestBodyLimit(t *testing.T) {
e := echo.New() e := echo.New()
hw := []byte("Hello, World!") hw := []byte("Hello, World!")
@ -25,61 +156,10 @@ func TestBodyLimit(t *testing.T) {
return c.String(http.StatusOK, string(body)) return c.String(http.StatusOK, string(body))
} }
assert := assert.New(t) mw := BodyLimit(2 * MB)
// Based on content length (within limit) err := mw(h)(c)
if assert.NoError(BodyLimit("2M")(h)(c)) { assert.NoError(t, err)
assert.Equal(http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.Bytes()) assert.Equal(t, hw, rec.Body.Bytes())
}
// Based on content read (overlimit)
he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
if assert.NoError(BodyLimit("2M")(h)(c)) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal("Hello, World!", rec.Body.String())
}
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
}
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
config := BodyLimitConfig{
Skipper: DefaultSkipper,
Limit: "2B",
limit: 2,
}
reader := &limitedReader{
BodyLimitConfig: config,
reader: ioutil.NopCloser(bytes.NewReader(hw)),
context: e.NewContext(req, rec),
}
// read all should return ErrStatusRequestEntityTooLarge
_, err := ioutil.ReadAll(reader)
he := err.(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec))
n, err := reader.Read(bt)
assert.Equal(t, 2, n)
assert.Equal(t, nil, err)
} }

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"bufio" "bufio"
"compress/gzip" "compress/gzip"
"errors"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
@ -13,50 +14,45 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type (
// GzipConfig defines the config for Gzip middleware.
GzipConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Gzip compression level.
// Optional. Default value -1.
Level int `yaml:"level"`
}
gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
)
const ( const (
gzipScheme = "gzip" gzipScheme = "gzip"
) )
var ( // GzipConfig defines the config for Gzip middleware.
// DefaultGzipConfig is the default Gzip middleware config. type GzipConfig struct {
DefaultGzipConfig = GzipConfig{ // Skipper defines a function to skip middleware.
Skipper: DefaultSkipper, Skipper Skipper
Level: -1,
}
)
// Gzip returns a middleware which compresses HTTP response using gzip compression // Gzip compression level.
// scheme. // Optional. Default value -1.
func Gzip() echo.MiddlewareFunc { Level int
return GzipWithConfig(DefaultGzipConfig)
} }
// GzipWithConfig return Gzip middleware with config. type gzipResponseWriter struct {
// See: `Gzip()`. io.Writer
http.ResponseWriter
}
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
func Gzip() echo.MiddlewareFunc {
return GzipWithConfig(GzipConfig{})
}
// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration
func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper config.Skipper = DefaultSkipper
}
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
return nil, errors.New("invalid gzip level")
} }
if config.Level == 0 { if config.Level == 0 {
config.Level = DefaultGzipConfig.Level config.Level = -1
} }
pool := gzipCompressPool(config) pool := gzipCompressPool(config)
@ -97,7 +93,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
} }, nil
} }
func (w *gzipResponseWriter) WriteHeader(code int) { func (w *gzipResponseWriter) WriteHeader(code int) {

View File

@ -3,94 +3,128 @@ package middleware
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"testing" "testing"
"time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestGzip(t *testing.T) { func TestGzip_NoAcceptEncodingHeader(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Skip if no Accept-Encoding header // Skip if no Accept-Encoding header
h := Gzip()(func(c echo.Context) error { h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil return nil
}) })
h(c)
assert := assert.New(t) e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
assert.Equal("test", rec.Body.String()) err := h(c)
assert.NoError(t, err)
// Gzip assert.Equal(t, "test", rec.Body.String())
req = httptest.NewRequest(http.MethodGet, "/", nil) }
func TestMustGzipWithConfig_panics(t *testing.T) {
assert.Panics(t, func() {
GzipWithConfig(GzipConfig{Level: 999})
})
}
func TestGzip_AcceptEncodingHeader(t *testing.T) {
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec) rec := httptest.NewRecorder()
h(c) c := e.NewContext(req, rec)
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) err := h(c)
assert.NoError(t, err)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
r, err := gzip.NewReader(rec.Body) r, err := gzip.NewReader(rec.Body)
if assert.NoError(err) { assert.NoError(t, err)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
defer r.Close() defer r.Close()
buf.ReadFrom(r) buf.ReadFrom(r)
assert.Equal("test", buf.String()) assert.Equal(t, "test", buf.String())
} }
chunkBuf := make([]byte, 5) func TestGzip_chunked(t *testing.T) {
e := echo.New()
// Gzip chunked req := httptest.NewRequest(http.MethodGet, "/", nil)
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec = httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c = e.NewContext(req, rec) chunkChan := make(chan struct{})
Gzip()(func(c echo.Context) error { waitChan := make(chan struct{})
h := Gzip()(func(c echo.Context) error {
c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Transfer-Encoding", "chunked") c.Response().Header().Set("Transfer-Encoding", "chunked")
// Write and flush the first part of the data // Write and flush the first part of the data
c.Response().Write([]byte("test\n")) c.Response().Write([]byte("first\n"))
c.Response().Flush() c.Response().Flush()
// Read the first part of the data chunkChan <- struct{}{}
assert.True(rec.Flushed) <-waitChan
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r.Reset(rec.Body)
_, err = io.ReadFull(r, chunkBuf)
assert.NoError(err)
assert.Equal("test\n", string(chunkBuf))
// Write and flush the second part of the data // Write and flush the second part of the data
c.Response().Write([]byte("test\n")) c.Response().Write([]byte("second\n"))
c.Response().Flush() c.Response().Flush()
_, err = io.ReadFull(r, chunkBuf) chunkChan <- struct{}{}
assert.NoError(err) <-waitChan
assert.Equal("test\n", string(chunkBuf))
// Write the final part of the data and return // Write the final part of the data and return
c.Response().Write([]byte("test")) c.Response().Write([]byte("third"))
return nil
})(c)
chunkChan <- struct{}{}
return nil
})
go func() {
err := h(c)
chunkChan <- struct{}{}
assert.NoError(t, err)
}()
<-chunkChan // wait for first write
waitChan <- struct{}{}
<-chunkChan // wait for second write
waitChan <- struct{}{}
<-chunkChan // wait for final write in handler
<-chunkChan // wait for return from handler
time.Sleep(5 * time.Millisecond) // to have time for flushing
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r, err := gzip.NewReader(rec.Body)
assert.NoError(t, err)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r) buf.ReadFrom(r)
assert.Equal("test", buf.String()) assert.Equal(t, "first\nsecond\nthird", buf.String())
} }
func TestGzipNoContent(t *testing.T) { func TestGzip_NoContent(t *testing.T) {
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
@ -106,7 +140,7 @@ func TestGzipNoContent(t *testing.T) {
} }
} }
func TestGzipErrorReturned(t *testing.T) { func TestGzip_ErrorReturned(t *testing.T) {
e := echo.New() e := echo.New()
e.Use(Gzip()) e.Use(Gzip())
e.GET("/", func(c echo.Context) error { e.GET("/", func(c echo.Context) error {
@ -120,31 +154,25 @@ func TestGzipErrorReturned(t *testing.T) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
} }
func TestGzipErrorReturnedInvalidConfig(t *testing.T) { func TestGzipWithConfig_invalidLevel(t *testing.T) {
e := echo.New() mw, err := GzipConfig{Level: 12}.ToMiddleware()
// Invalid level assert.EqualError(t, err, "invalid gzip level")
e.Use(GzipWithConfig(GzipConfig{Level: 12})) assert.Nil(t, mw)
e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("test"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "gzip")
} }
// Issue #806 // Issue #806
func TestGzipWithStatic(t *testing.T) { func TestGzipWithStatic(t *testing.T) {
e := echo.New() e := echo.New()
e.Filesystem = os.DirFS("../")
e.Use(Gzip()) e.Use(Gzip())
e.Static("/test", "../_fixture/images") e.Static("/test", "_fixture/images")
req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
// Data is written out in chunks when Content-Length == "", so only // Data is written out in chunks when Content-Length == "", so only
// validate the content length if it's not set. // validate the content length if it's not set.

View File

@ -9,60 +9,56 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // CORSConfig defines the config for CORS middleware.
// CORSConfig defines the config for CORS middleware. type CORSConfig struct {
CORSConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// AllowOrigin defines a list of origins that may access the resource. // AllowOrigin defines a list of origins that may access the resource.
// Optional. Default value []string{"*"}. // Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"` AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It takes the // AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If // origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is // an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored. // set, AllowOrigins is ignored.
// Optional. // Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` AllowOriginFunc func(origin string) (bool, error)
// AllowMethods defines a list methods allowed when accessing the resource. // AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request. // This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods. // Optional. Default value DefaultCORSConfig.AllowMethods.
AllowMethods []string `yaml:"allow_methods"` AllowMethods []string
// AllowHeaders defines a list of request headers that can be used when // AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request. // making the actual request. This is in response to a preflight request.
// Optional. Default value []string{}. // Optional. Default value []string{}.
AllowHeaders []string `yaml:"allow_headers"` AllowHeaders []string
// AllowCredentials indicates whether or not the response to the request // AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of // can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the // a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials. // actual request can be made using credentials.
// Optional. Default value false. // Optional. Default value false.
AllowCredentials bool `yaml:"allow_credentials"` AllowCredentials bool
// ExposeHeaders defines a whitelist headers that clients are allowed to // ExposeHeaders defines a whitelist headers that clients are allowed to
// access. // access.
// Optional. Default value []string{}. // Optional. Default value []string{}.
ExposeHeaders []string `yaml:"expose_headers"` ExposeHeaders []string
// MaxAge indicates how long (in seconds) the results of a preflight request // MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached. // can be cached.
// Optional. Default value 0. // Optional. Default value 0.
MaxAge int `yaml:"max_age"` MaxAge int
} }
)
var ( // DefaultCORSConfig is the default CORS middleware config.
// DefaultCORSConfig is the default CORS middleware config. var DefaultCORSConfig = CORSConfig{
DefaultCORSConfig = CORSConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, AllowOrigins: []string{"*"},
AllowOrigins: []string{"*"}, AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, }
}
)
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS // See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
@ -70,9 +66,14 @@ func CORS() echo.MiddlewareFunc {
return CORSWithConfig(DefaultCORSConfig) return CORSWithConfig(DefaultCORSConfig)
} }
// CORSWithConfig returns a CORS middleware with config. // CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
// See: `CORS()`. // See: `CORS()`.
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration
func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultCORSConfig.Skipper config.Skipper = DefaultCORSConfig.Skipper
@ -207,5 +208,5 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
} }
return c.NoContent(http.StatusNoContent) return c.NoContent(http.StatusNoContent)
} }
} }, nil
} }

View File

@ -17,7 +17,7 @@ func TestCORS(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := CORS()(echo.NotFoundHandler) h := CORS()(func(c echo.Context) error { return echo.ErrNotFound })
req.Header.Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderOrigin, "localhost")
h(c) h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -26,7 +26,7 @@ func TestCORS(t *testing.T) {
req = httptest.NewRequest(http.MethodGet, "/", nil) req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h = CORS()(echo.NotFoundHandler) h = CORS()(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
@ -38,7 +38,7 @@ func TestCORS(t *testing.T) {
AllowOrigins: []string{"localhost"}, AllowOrigins: []string{"localhost"},
AllowCredentials: true, AllowCredentials: true,
MaxAge: 3600, MaxAge: 3600,
})(echo.NotFoundHandler) })(func(c echo.Context) error { return echo.ErrNotFound })
req.Header.Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderOrigin, "localhost")
h(c) h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -55,7 +55,7 @@ func TestCORS(t *testing.T) {
AllowCredentials: true, AllowCredentials: true,
MaxAge: 3600, MaxAge: 3600,
}) })
h = cors(echo.NotFoundHandler) h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
@ -73,7 +73,7 @@ func TestCORS(t *testing.T) {
AllowCredentials: true, AllowCredentials: true,
MaxAge: 3600, MaxAge: 3600,
}) })
h = cors(echo.NotFoundHandler) h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
@ -90,7 +90,7 @@ func TestCORS(t *testing.T) {
cors = CORSWithConfig(CORSConfig{ cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"}, AllowOrigins: []string{"*"},
}) })
h = cors(echo.NotFoundHandler) h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
@ -104,7 +104,7 @@ func TestCORS(t *testing.T) {
cors = CORSWithConfig(CORSConfig{ cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"http://*.example.com"}, AllowOrigins: []string{"http://*.example.com"},
}) })
h = cors(echo.NotFoundHandler) h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -149,7 +149,7 @@ func Test_allowOriginScheme(t *testing.T) {
cors := CORSWithConfig(CORSConfig{ cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern}, AllowOrigins: []string{tt.pattern},
}) })
h := cors(echo.NotFoundHandler) h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
if tt.expected { if tt.expected {
@ -240,7 +240,7 @@ func Test_allowOriginSubdomain(t *testing.T) {
cors := CORSWithConfig(CORSConfig{ cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern}, AllowOrigins: []string{tt.pattern},
}) })
h := cors(echo.NotFoundHandler) h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
if tt.expected { if tt.expected {
@ -331,7 +331,7 @@ func TestCorsHeaders(t *testing.T) {
//AllowCredentials: true, //AllowCredentials: true,
//MaxAge: 3600, //MaxAge: 3600,
}) })
h := cors(echo.NotFoundHandler) h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c) h(c)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
@ -387,11 +387,11 @@ func Test_allowOriginFunc(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin) req.Header.Set(echo.HeaderOrigin, origin)
cors := CORSWithConfig(CORSConfig{ cors, err := CORSConfig{AllowOriginFunc: allowOriginFunc}.ToMiddleware()
AllowOriginFunc: allowOriginFunc, assert.NoError(t, err)
})
h := cors(echo.NotFoundHandler) h := cors(func(c echo.Context) error { return echo.ErrNotFound })
err := h(c) err = h(c)
expected, expectedErr := allowOriginFunc(origin) expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil { if expectedErr != nil {

View File

@ -8,89 +8,90 @@ import (
"time" "time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
) )
type ( // CSRFConfig defines the config for CSRF middleware.
// CSRFConfig defines the config for CSRF middleware. type CSRFConfig struct {
CSRFConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// TokenLength is the length of the generated token. // TokenLength is the length of the generated token.
TokenLength uint8 `yaml:"token_length"` TokenLength uint8
// Optional. Default value 32. // Optional. Default value 32.
// TokenLookup is a string in the form of "<source>:<key>" that is used // Generator defines a function to generate token.
// to extract token from the request. // Optional. Defaults tp randomString(TokenLength).
// Optional. Default value "header:X-CSRF-Token". Generator func() string
// Possible values:
// - "header:<name>"
// - "form:<name>"
// - "query:<name>"
TokenLookup string `yaml:"token_lookup"`
// Context key to store generated CSRF token into context. // TokenLookup is a string in the form of "<source>:<key>" that is used
// Optional. Default value "csrf". // to extract token from the request.
ContextKey string `yaml:"context_key"` // Optional. Default value "header:X-CSRF-Token".
// Possible values:
// - "header:<name>"
// - "form:<name>"
// - "query:<name>"
TokenLookup string
// Name of the CSRF cookie. This cookie will store CSRF token. // Context key to store generated CSRF token into context.
// Optional. Default value "csrf". // Optional. Default value "csrf".
CookieName string `yaml:"cookie_name"` ContextKey string
// Domain of the CSRF cookie. // Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value none. // Optional. Default value "csrf".
CookieDomain string `yaml:"cookie_domain"` CookieName string
// Path of the CSRF cookie. // Domain of the CSRF cookie.
// Optional. Default value none. // Optional. Default value none.
CookiePath string `yaml:"cookie_path"` CookieDomain string
// Max age (in seconds) of the CSRF cookie. // Path of the CSRF cookie.
// Optional. Default value 86400 (24hr). // Optional. Default value none.
CookieMaxAge int `yaml:"cookie_max_age"` CookiePath string
// Indicates if CSRF cookie is secure. // Max age (in seconds) of the CSRF cookie.
// Optional. Default value false. // Optional. Default value 86400 (24hr).
CookieSecure bool `yaml:"cookie_secure"` CookieMaxAge int
// Indicates if CSRF cookie is HTTP only. // Indicates if CSRF cookie is secure.
// Optional. Default value false. // Optional. Default value false.
CookieHTTPOnly bool `yaml:"cookie_http_only"` CookieSecure bool
// Indicates SameSite mode of the CSRF cookie. // Indicates if CSRF cookie is HTTP only.
// Optional. Default value SameSiteDefaultMode. // Optional. Default value false.
CookieSameSite http.SameSite `yaml:"cookie_same_site"` CookieHTTPOnly bool
}
// csrfTokenExtractor defines a function that takes `echo.Context` and returns // Indicates SameSite mode of the CSRF cookie.
// either a token or an error. // Optional. Default value SameSiteDefaultMode.
csrfTokenExtractor func(echo.Context) (string, error) CookieSameSite http.SameSite
) }
var ( // csrfTokenExtractor defines a function that takes `echo.Context` and returns either a token or an error.
// DefaultCSRFConfig is the default CSRF middleware config. type csrfTokenExtractor func(echo.Context) (string, error)
DefaultCSRFConfig = CSRFConfig{
Skipper: DefaultSkipper, // DefaultCSRFConfig is the default CSRF middleware config.
TokenLength: 32, var DefaultCSRFConfig = CSRFConfig{
TokenLookup: "header:" + echo.HeaderXCSRFToken, Skipper: DefaultSkipper,
ContextKey: "csrf", TokenLength: 32,
CookieName: "_csrf", TokenLookup: "header:" + echo.HeaderXCSRFToken,
CookieMaxAge: 86400, ContextKey: "csrf",
CookieSameSite: http.SameSiteDefaultMode, CookieName: "_csrf",
} CookieMaxAge: 86400,
) CookieSameSite: http.SameSiteDefaultMode,
}
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
func CSRF() echo.MiddlewareFunc { func CSRF() echo.MiddlewareFunc {
c := DefaultCSRFConfig return CSRFWithConfig(DefaultCSRFConfig)
return CSRFWithConfig(c)
} }
// CSRFWithConfig returns a CSRF middleware with config. // CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration.
// See `CSRF()`.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration
func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultCSRFConfig.Skipper config.Skipper = DefaultCSRFConfig.Skipper
@ -98,6 +99,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.TokenLength == 0 { if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength config.TokenLength = DefaultCSRFConfig.TokenLength
} }
if config.Generator == nil {
config.Generator = createRandomStringGenerator(config.TokenLength)
}
if config.TokenLookup == "" { if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup config.TokenLookup = DefaultCSRFConfig.TokenLookup
} }
@ -136,7 +140,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
// Generate token // Generate token
if err != nil { if err != nil {
token = random.String(config.TokenLength) token = config.Generator()
} else { } else {
// Reuse token // Reuse token
token = k.Value token = k.Value
@ -181,7 +185,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
} }, nil
} }
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the

View File

@ -9,11 +9,26 @@ import (
"testing" "testing"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestCSRF(t *testing.T) { func TestCSRF(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRF()
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Generate CSRF token
h(c)
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
}
func TestMustCSRFWithConfig(t *testing.T) {
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@ -43,7 +58,7 @@ func TestCSRF(t *testing.T) {
assert.Error(t, h(c)) assert.Error(t, h(c))
// Valid CSRF token // Valid CSRF token
token := random.String(16) token := randomString(16)
req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderCookie, "_csrf="+token)
req.Header.Set(echo.HeaderXCSRFToken, token) req.Header.Set(echo.HeaderXCSRFToken, token)
if assert.NoError(t, h(c)) { if assert.NoError(t, h(c)) {
@ -145,9 +160,10 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{ csrf, err := CSRFConfig{
CookieSameSite: http.SameSiteNoneMode, CookieSameSite: http.SameSiteNoneMode,
}) }.ToMiddleware()
assert.NoError(t, err)
h := csrf(func(c echo.Context) error { h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")

View File

@ -11,18 +11,16 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // DecompressConfig defines the config for Decompress middleware.
// DecompressConfig defines the config for Decompress middleware. type DecompressConfig struct {
DecompressConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor GzipDecompressPool Decompressor
} }
)
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents. // GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip" const GZIPEncoding string = "gzip"
// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers // Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
@ -30,14 +28,6 @@ type Decompressor interface {
gzipDecompressPool() sync.Pool gzipDecompressPool() sync.Pool
} }
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
)
// DefaultGzipDecompressPool is the default implementation of Decompressor interface // DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct { type DefaultGzipDecompressPool struct {
} }
@ -65,19 +55,23 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
} }
} }
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config // Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc { func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig) return DecompressWithConfig(DecompressConfig{})
} }
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config // DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration.
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration
func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper config.Skipper = DefaultSkipper
} }
if config.GzipDecompressPool == nil { if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool config.GzipDecompressPool = &DefaultGzipDecompressPool{}
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -116,5 +110,5 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
} }, nil
} }

View File

@ -17,6 +17,31 @@ import (
func TestDecompress(t *testing.T) { func TestDecompress(t *testing.T) {
e := echo.New() e := echo.New()
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
// Decompress request body
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}
func TestDecompress_skippedIfNoHeader(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
@ -26,39 +51,42 @@ func TestDecompress(t *testing.T) {
c.Response().Write([]byte("test")) // For Content-Type sniffing c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil return nil
}) })
h(c)
assert := assert.New(t) err := h(c)
assert.Equal("test", rec.Body.String()) assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
} }
func TestDecompressDefaultConfig(t *testing.T) { func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) {
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { h, err := DecompressConfig{}.ToMiddleware()
assert.NoError(t, err)
err = h(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})(c)
assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
}
func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil return nil
}) })
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress // Decompress
body := `{"name": "echo"}` body := `{"name": "echo"}`
@ -67,11 +95,14 @@ func TestDecompressDefaultConfig(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) err := h(c)
assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body) b, err := ioutil.ReadAll(req.Body)
assert.NoError(err) assert.NoError(t, err)
assert.Equal(body, string(b)) assert.Equal(t, body, string(b))
} }
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
@ -82,7 +113,9 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.NewContext(req, rec) e.NewContext(req, rec)
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body) b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err) assert.NoError(t, err)
@ -99,7 +132,10 @@ func TestDecompressNoContent(t *testing.T) {
h := Decompress()(func(c echo.Context) error { h := Decompress()(func(c echo.Context) error {
return c.NoContent(http.StatusNoContent) return c.NoContent(http.StatusNoContent)
}) })
if assert.NoError(t, h(c)) {
err := h(c)
if assert.NoError(t, err) {
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes())) assert.Equal(t, 0, len(rec.Body.Bytes()))
@ -115,7 +151,9 @@ func TestDecompressErrorReturned(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code) assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
} }
@ -132,7 +170,9 @@ func TestDecompressSkipper(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
reqBody, err := ioutil.ReadAll(c.Request().Body) reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err) assert.NoError(t, err)
@ -161,7 +201,9 @@ func TestDecompressPoolError(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := ioutil.ReadAll(c.Request().Body) reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err) assert.NoError(t, err)

148
middleware/extractor.go Normal file
View File

@ -0,0 +1,148 @@
package middleware
import (
"fmt"
"github.com/labstack/echo/v4"
"net/http"
"net/textproto"
"strings"
)
// ErrExtractionValueMissing denotes an error raised when value could not be extracted from request
var ErrExtractionValueMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed value")
// ExtractorType is enum type for where extractor will take its data
type ExtractorType string
const (
// HeaderExtractor tells extractor to take values from request header
HeaderExtractor ExtractorType = "header"
// QueryExtractor tells extractor to take values from request query parameters
QueryExtractor ExtractorType = "query"
// ParamExtractor tells extractor to take values from request route parameters
ParamExtractor ExtractorType = "param"
// CookieExtractor tells extractor to take values from request cookie
CookieExtractor ExtractorType = "cookie"
// FormExtractor tells extractor to take values from request form fields
FormExtractor ExtractorType = "form"
)
func createExtractors(lookups string) ([]valuesExtractor, error) {
sources := strings.Split(lookups, ",")
var extractors []valuesExtractor
for _, source := range sources {
parts := strings.Split(source, ":")
if len(parts) < 2 {
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
}
switch ExtractorType(parts[0]) {
case QueryExtractor:
extractors = append(extractors, valuesFromQuery(parts[1]))
case ParamExtractor:
extractors = append(extractors, valuesFromParam(parts[1]))
case CookieExtractor:
extractors = append(extractors, valuesFromCookie(parts[1]))
case FormExtractor:
extractors = append(extractors, valuesFromForm(parts[1]))
case HeaderExtractor:
prefix := ""
if len(parts) > 2 {
prefix = parts[2]
}
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
}
}
return extractors, nil
}
// valuesFromHeader returns a functions that extracts values from the request header.
func valuesFromHeader(header string, valuePrefix string) valuesExtractor {
prefixLen := len(valuePrefix)
return func(c echo.Context) ([]string, ExtractorType, error) {
values := textproto.MIMEHeader(c.Request().Header).Values(header)
if len(values) == 0 {
return nil, HeaderExtractor, ErrExtractionValueMissing
}
result := make([]string, 0)
for _, value := range values {
if prefixLen == 0 {
result = append(result, value)
continue
}
if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
result = append(result, value[prefixLen:])
}
}
if len(result) == 0 {
return nil, HeaderExtractor, ErrExtractionValueMissing
}
return result, HeaderExtractor, nil
}
}
// valuesFromQuery returns a function that extracts values from the query string.
func valuesFromQuery(param string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
result := c.QueryParams()[param]
if len(result) == 0 {
return nil, QueryExtractor, ErrExtractionValueMissing
}
return result, QueryExtractor, nil
}
}
// valuesFromParam returns a function that extracts values from the url param string.
func valuesFromParam(param string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
result := make([]string, 0)
for _, p := range c.PathParams() {
if param == p.Name {
result = append(result, p.Value)
}
}
if len(result) == 0 {
return nil, ParamExtractor, ErrExtractionValueMissing
}
return result, ParamExtractor, nil
}
}
// valuesFromCookie returns a function that extracts values from the named cookie.
func valuesFromCookie(name string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
cookies := c.Cookies()
if len(cookies) == 0 {
return nil, CookieExtractor, ErrExtractionValueMissing
}
result := make([]string, 0)
for _, cookie := range cookies {
if name == cookie.Name {
result = append(result, cookie.Value)
}
}
if len(result) == 0 {
return nil, CookieExtractor, ErrExtractionValueMissing
}
return result, CookieExtractor, nil
}
}
// valuesFromForm returns a function that extracts values from the form field.
func valuesFromForm(name string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
if err := c.Request().ParseForm(); err != nil {
return nil, FormExtractor, fmt.Errorf("valuesFromForm parse form failed: %w", err)
}
values := c.Request().Form[name]
if len(values) == 0 {
return nil, FormExtractor, ErrExtractionValueMissing
}
result := append([]string{}, values...)
return result, FormExtractor, nil
}
}

View File

@ -0,0 +1,498 @@
package middleware
import (
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestCreateExtractors(t *testing.T) {
var testCases = []struct {
name string
givenRequest func() *http.Request
givenPathParams echo.PathParams
whenLoopups string
expectValues []string
expectExtractorType ExtractorType
expectCreateError string
expectError string
}{
{
name: "ok, header",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer token")
return req
},
whenLoopups: "header:Authorization:Bearer ",
expectValues: []string{"token"},
expectExtractorType: HeaderExtractor,
},
{
name: "ok, form",
givenRequest: func() *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenLoopups: "form:name",
expectValues: []string{"Jon Snow"},
expectExtractorType: FormExtractor,
},
{
name: "ok, cookie",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderCookie, "_csrf=token")
return req
},
whenLoopups: "cookie:_csrf",
expectValues: []string{"token"},
expectExtractorType: CookieExtractor,
},
{
name: "ok, param",
givenPathParams: echo.PathParams{
{Name: "id", Value: "123"},
},
whenLoopups: "param:id",
expectValues: []string{"123"},
expectExtractorType: ParamExtractor,
},
{
name: "ok, query",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
return req
},
whenLoopups: "query:id",
expectValues: []string{"999"},
expectExtractorType: QueryExtractor,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
req = tc.givenRequest()
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(echo.EditableContext)
if tc.givenPathParams != nil {
c.SetRawPathParams(&tc.givenPathParams)
}
extractors, err := createExtractors(tc.whenLoopups)
if tc.expectCreateError != "" {
assert.EqualError(t, err, tc.expectCreateError)
return
}
assert.NoError(t, err)
for _, e := range extractors {
values, eType, eErr := e(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, tc.expectExtractorType, eType)
if tc.expectError != "" {
assert.EqualError(t, eErr, tc.expectError)
return
}
assert.NoError(t, eErr)
}
})
}
}
func TestValuesFromHeader(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
whenValuePrefix string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, single value, case insensitive",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
},
{
name: "ok, empty prefix",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "",
expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="},
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Bearer ",
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderWWWAuthenticate,
whenValuePrefix: "",
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no headers",
givenRequest: nil,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, HeaderExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromQuery(t *testing.T) {
var testCases = []struct {
name string
givenQueryPart string
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenQueryPart: "?id=123&name=test",
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenQueryPart: "?id=123&id=456&name=test",
whenName: "id",
expectValues: []string{"123", "456"},
},
{
name: "nok, missing value",
givenQueryPart: "?id=123&name=test",
whenName: "nope",
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromQuery(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, QueryExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromParam(t *testing.T) {
examplePathParams := echo.PathParams{
{Name: "id", Value: "123"},
{Name: "gid", Value: "456"},
{Name: "gid", Value: "789"},
}
var testCases = []struct {
name string
givenPathParams echo.PathParams
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenPathParams: examplePathParams,
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenPathParams: examplePathParams,
whenName: "gid",
expectValues: []string{"456", "789"},
},
{
name: "nok, no values",
givenPathParams: nil,
whenName: "nope",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no matching value",
givenPathParams: examplePathParams,
whenName: "nope",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(echo.EditableContext)
if tc.givenPathParams != nil {
c.SetRawPathParams(&tc.givenPathParams)
}
extractor := valuesFromParam(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ParamExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromCookie(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderCookie, "_csrf=token")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: "_csrf",
expectValues: []string{"token"},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Add(echo.HeaderCookie, "_csrf=token")
req.Header.Add(echo.HeaderCookie, "_csrf=token2")
},
whenName: "_csrf",
expectValues: []string{"token", "token2"},
},
{
name: "nok, no matching cookie",
givenRequest: exampleRequest,
whenName: "xxx",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no cookies at all",
givenRequest: nil,
whenName: "xxx",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromCookie(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, CookieExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromForm(t *testing.T) {
examplePostFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
}
exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil)
return req
}
var testCases = []struct {
name string
givenRequest *http.Request
whenName string
expectValues []string
expectError string
}{
{
name: "ok, POST form, single value",
givenRequest: examplePostFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, POST form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "ok, GET form, single value",
givenRequest: exampleGetFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, GET form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "nok, POST form, value missing",
givenRequest: examplePostFormRequest(nil),
whenName: "nope",
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, POST form, form parsing error",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Body = nil
return req
}(),
whenName: "name",
expectError: "valuesFromForm parse form failed: missing form body",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := tc.givenRequest
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromForm(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, FormExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@ -1,134 +1,84 @@
// +build go1.15
package middleware package middleware
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/http"
"reflect"
"strings"
"github.com/golang-jwt/jwt"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"net/http"
) )
type ( // JWTConfig defines the config for JWT middleware.
// JWTConfig defines the config for JWT middleware. type JWTConfig struct {
JWTConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// BeforeFunc defines a function which is executed just before the middleware. // BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc BeforeFunc BeforeFunc BeforeFunc
// SuccessHandler defines a function which is executed for a valid token. // SuccessHandler defines a function which is executed for a valid token.
SuccessHandler JWTSuccessHandler SuccessHandler JWTSuccessHandler
// ErrorHandler defines a function which is executed for an invalid token. // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
// It may be used to define a custom JWT error. // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
ErrorHandler JWTErrorHandler // It may be used to define a custom JWT error.
//
// Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public JWT token value to request and continue with handler chain.
ErrorHandler JWTErrorHandlerWithContext
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. // Context key to store user information from the token into context.
ErrorHandlerWithContext JWTErrorHandlerWithContext // Optional. Default value "user".
ContextKey string
// Signing key to validate token. // TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// This is one of the three options to provide a token validation key. // to extract token(s) from the request.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. // Optional. Default value "header:Authorization:Bearer ".
// Required if neither user-defined KeyFunc nor SigningKeys is provided. // Possible values:
SigningKey interface{} // - "header:<name>"
// - "query:<name>"
// - "param:<name>"
// - "cookie:<name>"
// - "form:<name>"
// Multiple sources example:
// - "header:Authorization,cookie:myowncookie"
TokenLookup string
// Map of signing keys to validate token with kid field usage. // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
// This is one of the three options to provide a token validation key. // parsing fails or parsed token is invalid.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. // NB: could be called multiple times per request when token lookup is able to extract multiple token values (i.e. multiple Authorization headers)
// Required if neither user-defined KeyFunc nor SigningKey is provided. // See `jwt_external_test.go` for example implementation using `github.com/golang-jwt/jwt` as JWT implementation library
SigningKeys map[string]interface{} ParseTokenFunc func(c echo.Context, auth string) (interface{}, error)
}
// Signing method used to check the token's signing algorithm. // JWTSuccessHandler defines a function which is executed for a valid token.
// Optional. Default value HS256. type JWTSuccessHandler func(c echo.Context)
SigningMethod string
// Context key to store user information from the token into context. // JWTErrorHandler defines a function which is executed for an invalid token.
// Optional. Default value "user". type JWTErrorHandler func(err error) error
ContextKey string
// Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
// Not used if custom ParseTokenFunc is set. type JWTErrorHandlerWithContext func(c echo.Context, err error) error
// Optional. Default value jwt.MapClaims
Claims jwt.Claims
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used type valuesExtractor func(c echo.Context) ([]string, ExtractorType, error)
// to extract token from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "param:<name>"
// - "cookie:<name>"
// - "form:<name>"
// Multiply sources example:
// - "header: Authorization,cookie: myowncookie"
TokenLookup string
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// KeyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
// Used by default ParseTokenFunc implementation.
//
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither SigningKeys nor SigningKey is provided.
// Not used if custom ParseTokenFunc is set.
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
KeyFunc jwt.Keyfunc
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
// parsing fails or parsed token is invalid.
// Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library
ParseTokenFunc func(auth string, c echo.Context) (interface{}, error)
}
// JWTSuccessHandler defines a function which is executed for a valid token.
JWTSuccessHandler func(echo.Context)
// JWTErrorHandler defines a function which is executed for an invalid token.
JWTErrorHandler func(error) error
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
JWTErrorHandlerWithContext func(error, echo.Context) error
jwtExtractor func(echo.Context) (string, error)
)
// Algorithms
const ( const (
// AlgorithmHS256 is token signing algorithm
AlgorithmHS256 = "HS256" AlgorithmHS256 = "HS256"
) )
// Errors // ErrJWTMissing denotes an error raised when JWT token value could not be extracted from request
var ( var ErrJWTMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing or malformed jwt")
ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
)
var ( // ErrJWTInvalid denotes an error raised when JWT token value is invalid or expired
// DefaultJWTConfig is the default JWT auth middleware config. var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
DefaultJWTConfig = JWTConfig{
Skipper: DefaultSkipper, // DefaultJWTConfig is the default JWT auth middleware config.
SigningMethod: AlgorithmHS256, var DefaultJWTConfig = JWTConfig{
ContextKey: "user", Skipper: DefaultSkipper,
TokenLookup: "header:" + echo.HeaderAuthorization, ContextKey: "user",
AuthScheme: "Bearer", TokenLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
Claims: jwt.MapClaims{}, }
KeyFunc: nil,
}
)
// JWT returns a JSON Web Token (JWT) auth middleware. // JWT returns a JSON Web Token (JWT) auth middleware.
// //
@ -137,64 +87,43 @@ var (
// For missing token, it returns "400 - Bad Request" error. // For missing token, it returns "400 - Bad Request" error.
// //
// See: https://jwt.io/introduction // See: https://jwt.io/introduction
// See `JWTConfig.TokenLookup` func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc {
func JWT(key interface{}) echo.MiddlewareFunc {
c := DefaultJWTConfig c := DefaultJWTConfig
c.SigningKey = key c.ParseTokenFunc = parseTokenFunc
return JWTWithConfig(c) return JWTWithConfig(c)
} }
// JWTWithConfig returns a JWT auth middleware with config. // JWTWithConfig returns a JSON Web Token (JWT) auth middleware or panics if configuration is invalid.
// See: `JWT()`. //
// For valid token, it sets the user in context and calls next handler.
// For invalid token, it returns "401 - Unauthorized" error.
// For missing token, it returns "400 - Bad Request" error.
//
// See: https://jwt.io/introduction
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts JWTConfig to middleware or returns an error for invalid configuration
func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultJWTConfig.Skipper config.Skipper = DefaultJWTConfig.Skipper
} }
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil { if config.ParseTokenFunc == nil {
panic("echo: jwt middleware requires signing key") return nil, errors.New("echo jwt middleware requires parse token function")
}
if config.SigningMethod == "" {
config.SigningMethod = DefaultJWTConfig.SigningMethod
} }
if config.ContextKey == "" { if config.ContextKey == "" {
config.ContextKey = DefaultJWTConfig.ContextKey config.ContextKey = DefaultJWTConfig.ContextKey
} }
if config.Claims == nil {
config.Claims = DefaultJWTConfig.Claims
}
if config.TokenLookup == "" { if config.TokenLookup == "" {
config.TokenLookup = DefaultJWTConfig.TokenLookup config.TokenLookup = DefaultJWTConfig.TokenLookup
} }
if config.AuthScheme == "" { extractors, err := createExtractors(config.TokenLookup)
config.AuthScheme = DefaultJWTConfig.AuthScheme if err != nil {
return nil, fmt.Errorf("echo jwt middleware could not create token extractor: %w", err)
} }
if config.KeyFunc == nil { if len(extractors) == 0 {
config.KeyFunc = config.defaultKeyFunc return nil, errors.New("echo jwt middleware could not create extractors from TokenLookup string")
}
if config.ParseTokenFunc == nil {
config.ParseTokenFunc = config.defaultParseToken
}
// Initialize
// Split sources
sources := strings.Split(config.TokenLookup, ",")
var extractors []jwtExtractor
for _, source := range sources {
parts := strings.Split(source, ":")
switch parts[0] {
case "query":
extractors = append(extractors, jwtFromQuery(parts[1]))
case "param":
extractors = append(extractors, jwtFromParam(parts[1]))
case "cookie":
extractors = append(extractors, jwtFromCookie(parts[1]))
case "form":
extractors = append(extractors, jwtFromForm(parts[1]))
case "header":
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
}
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -206,142 +135,55 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.BeforeFunc != nil { if config.BeforeFunc != nil {
config.BeforeFunc(c) config.BeforeFunc(c)
} }
var auth string var lastExtractorErr error
var err error var lastTokenErr error
for _, extractor := range extractors { for _, extractor := range extractors {
// Extract token from extractor, if it's not fail break the loop and auths, _, extrErr := extractor(c)
// set auth if extrErr != nil {
auth, err = extractor(c) lastExtractorErr = extrErr
if err == nil { continue
break }
for _, auth := range auths {
token, err := config.ParseTokenFunc(c, auth)
if err != nil {
lastTokenErr = err
continue
}
// Store user information from token into context.
c.Set(config.ContextKey, token)
if config.SuccessHandler != nil {
config.SuccessHandler(c)
}
return next(c)
} }
} }
// If none of extractor has a token, handle error
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err)
}
if config.ErrorHandlerWithContext != nil { // prioritize token errors over extracting errors
return config.ErrorHandlerWithContext(err, c) err := lastTokenErr
}
return err
}
token, err := config.ParseTokenFunc(auth, c)
if err == nil { if err == nil {
// Store user information from token into context. err = lastExtractorErr
c.Set(config.ContextKey, token) }
if config.SuccessHandler != nil { if config.ErrorHandler != nil {
config.SuccessHandler(c) if err == ErrExtractionValueMissing {
err = ErrJWTMissing
}
// Allow error handler to swallow the error and continue handler chain execution
// Useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public token to request and continue with handler chain
if handledErr := config.ErrorHandler(c, err); handledErr != nil {
return handledErr
} }
return next(c) return next(c)
} }
if config.ErrorHandler != nil { if err == ErrExtractionValueMissing {
return config.ErrorHandler(err) return ErrJWTMissing
}
if config.ErrorHandlerWithContext != nil {
return config.ErrorHandlerWithContext(err, c)
} }
// everything else goes under http.StatusUnauthorized to avoid exposing JWT internals with generic error
return &echo.HTTPError{ return &echo.HTTPError{
Code: ErrJWTInvalid.Code, Code: ErrJWTInvalid.Code,
Message: ErrJWTInvalid.Message, Message: ErrJWTInvalid.Message,
Internal: err, Internal: err,
} }
} }
} }, nil
}
func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) {
token := new(jwt.Token)
var err error
// Issue #647, #656
if _, ok := config.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, config.KeyFunc)
} else {
t := reflect.ValueOf(config.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
}
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
// defaultKeyFunc returns a signing key of the given token.
func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(config.SigningKeys) > 0 {
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := config.SigningKeys[kid]; ok {
return key, nil
}
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}
return config.SigningKey, nil
}
// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) jwtExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
l := len(authScheme)
if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) {
return auth[l+1:], nil
}
return "", ErrJWTMissing
}
}
// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string.
func jwtFromQuery(param string) jwtExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string.
func jwtFromParam(param string) jwtExtractor {
return func(c echo.Context) (string, error) {
token := c.Param(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie.
func jwtFromCookie(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
cookie, err := c.Cookie(name)
if err != nil {
return "", ErrJWTMissing
}
return cookie.Value, nil
}
}
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
} }

View File

@ -0,0 +1,76 @@
package middleware_test
import (
"errors"
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"net/http"
"net/http/httptest"
)
// CreateJWTGoParseTokenFunc creates JWTGo implementation for ParseTokenFunc
//
// signingKey is signing key to validate token.
// This is one of the options to provide a token validation key.
// The order of precedence is a user-defined SigningKeys and SigningKey.
// Required if signingKeys is not provided.
//
// signingKeys is Map of signing keys to validate token with kid field usage.
// This is one of the options to provide a token validation key.
// The order of precedence is a user-defined SigningKeys and SigningKey.
// Required if signingKey is not provided
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) {
// keyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != middleware.AlgorithmHS256 {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(signingKeys) == 0 {
return signingKey, nil
}
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := signingKeys[kid]; ok {
return key, nil
}
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}
return func(c echo.Context, auth string) (interface{}, error) {
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
}
func ExampleJWTConfig_withJWTGoAsTokenParser() {
mw := middleware.JWTWithConfig(middleware.JWTConfig{
ParseTokenFunc: CreateJWTGoParseTokenFunc([]byte("secret"), nil),
})
e := echo.New()
e.Use(mw)
e.GET("/", func(c echo.Context) error {
user := c.Get("user").(*jwt.Token)
return c.JSON(http.StatusTeapot, user.Claims)
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
fmt.Printf("status: %v, body: %v", res.Code, res.Body.String())
// Output: status: 418, body: {"admin":true,"name":"John Doe","sub":"1234567890"}
}

View File

@ -1,5 +1,3 @@
// +build go1.15
package middleware package middleware
import ( import (
@ -11,11 +9,32 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) {
// This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != signingMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
return signingKey, nil
}
return func(c echo.Context, auth string) (interface{}, error) {
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc)
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
}
// jwtCustomInfo defines some custom types we're going to use within our tokens. // jwtCustomInfo defines some custom types we're going to use within our tokens.
type jwtCustomInfo struct { type jwtCustomInfo struct {
Name string `json:"name"` Name string `json:"name"`
@ -28,43 +47,7 @@ type jwtCustomClaims struct {
jwtCustomInfo jwtCustomInfo
} }
func TestJWTRace(t *testing.T) { func TestJWT_combinations(t *testing.T) {
e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss"
validKey := []byte("secret")
h := JWTWithConfig(JWTConfig{
Claims: &jwtCustomClaims{},
SigningKey: validKey,
})(handler)
makeReq := func(token string) echo.Context {
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token)
c := e.NewContext(req, res)
assert.NoError(t, h(c))
return c
}
c := makeReq(initialToken)
user := c.Get("user").(*jwt.Token)
claims := user.Claims.(*jwtCustomClaims)
assert.Equal(t, claims.Name, "John Doe")
makeReq(raceToken)
user = c.Get("user").(*jwt.Token)
claims = user.Claims.(*jwtCustomClaims)
// Initial context should still be "John Doe", not "Race Condition"
assert.Equal(t, claims.Name, "John Doe")
assert.Equal(t, claims.Admin, true)
}
func TestJWT(t *testing.T) {
e := echo.New() e := echo.New()
handler := func(c echo.Context) error { handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
@ -72,344 +55,236 @@ func TestJWT(t *testing.T) {
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
validKey := []byte("secret") validKey := []byte("secret")
invalidKey := []byte("invalid-key") invalidKey := []byte("invalid-key")
validAuth := DefaultJWTConfig.AuthScheme + " " + token validAuth := "Bearer " + token
for _, tc := range []struct { var testCases = []struct {
expPanic bool name string
expErrCode int // 0 for Success
config JWTConfig config JWTConfig
reqURL string // "/" if empty reqURL string // "/" if empty
hdrAuth string hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string formValues map[string]string
info string expPanic bool
expErrCode int // 0 for Success
}{ }{
{ {
expPanic: true, expPanic: true,
info: "No signing key provided", name: "No signing key provided",
},
{
expErrCode: http.StatusBadRequest,
config: JWTConfig{
SigningKey: validKey,
SigningMethod: "RS256",
},
info: "Unexpected signing method",
}, },
{ {
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{SigningKey: invalidKey}, config: JWTConfig{
info: "Invalid key", ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey),
},
name: "Unexpected signing method",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: validAuth,
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, invalidKey),
},
name: "Invalid key",
}, },
{ {
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{SigningKey: validKey}, config: JWTConfig{
info: "Valid JWT", ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
name: "Valid JWT",
}, },
{ {
hdrAuth: "Token" + " " + token, hdrAuth: "Token" + " " + token,
config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, config: JWTConfig{
info: "Valid JWT with custom AuthScheme", TokenLookup: "header:" + echo.HeaderAuthorization + ":Token ",
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
name: "Valid JWT with custom AuthScheme",
}, },
{ {
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{ config: JWTConfig{
Claims: &jwtCustomClaims{}, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
SigningKey: []byte("secret"),
}, },
info: "Valid JWT with custom claims", name: "Valid JWT with custom claims",
}, },
{ {
hdrAuth: "invalid-auth", hdrAuth: "invalid-auth",
expErrCode: http.StatusBadRequest, expErrCode: http.StatusUnauthorized,
config: JWTConfig{SigningKey: validKey}, config: JWTConfig{
info: "Invalid Authorization header", ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
}, },
{ name: "Invalid Authorization header",
config: JWTConfig{SigningKey: validKey},
expErrCode: http.StatusBadRequest,
info: "Empty header auth field",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt", },
expErrCode: http.StatusUnauthorized,
name: "Empty header auth field",
},
{
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b&jwt=" + token, reqURL: "/?a=b&jwt=" + token,
info: "Valid query method", name: "Valid query method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt", TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b&jwtxyz=" + token, reqURL: "/?a=b&jwtxyz=" + token,
expErrCode: http.StatusBadRequest, expErrCode: http.StatusUnauthorized,
info: "Invalid query param name", name: "Invalid query param name",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt", TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b&jwt=invalid-token", reqURL: "/?a=b&jwt=invalid-token",
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
info: "Invalid query param value", name: "Invalid query param value",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt", TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b", reqURL: "/?a=b",
expErrCode: http.StatusBadRequest, expErrCode: http.StatusUnauthorized,
info: "Empty query", name: "Empty query",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "param:jwt", TokenLookup: "param:jwt",
}, },
reqURL: "/" + token, reqURL: "/" + token,
info: "Valid param method", name: "Valid param method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt", TokenLookup: "cookie:jwt",
}, },
hdrCookie: "jwt=" + token, hdrCookie: "jwt=" + token,
info: "Valid cookie method", name: "Valid cookie method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt,cookie:jwt", TokenLookup: "query:jwt,cookie:jwt",
}, },
hdrCookie: "jwt=" + token, hdrCookie: "jwt=" + token,
info: "Multiple jwt lookuop", name: "Multiple jwt lookuop",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt", TokenLookup: "cookie:jwt",
}, },
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
hdrCookie: "jwt=invalid", hdrCookie: "jwt=invalid",
info: "Invalid token with cookie method", name: "Invalid token with cookie method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt", TokenLookup: "cookie:jwt",
}, },
expErrCode: http.StatusBadRequest, expErrCode: http.StatusUnauthorized,
info: "Empty cookie", name: "Empty cookie",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt", TokenLookup: "form:jwt",
}, },
formValues: map[string]string{"jwt": token}, formValues: map[string]string{"jwt": token},
info: "Valid form method", name: "Valid form method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt", TokenLookup: "form:jwt",
}, },
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"}, formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method", name: "Invalid token with form method",
}, },
{ {
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt", TokenLookup: "form:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
{
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return validKey, nil
},
},
info: "Valid JWT with a valid key using a user-defined KeyFunc",
},
{
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return invalidKey, nil
},
}, },
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
info: "Valid JWT with an invalid key using a user-defined KeyFunc", name: "Empty form field",
}, },
{
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return nil, errors.New("faulty KeyFunc")
},
},
expErrCode: http.StatusUnauthorized,
info: "Token verification does not pass using a user-defined KeyFunc",
},
{
hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token,
config: JWTConfig{SigningKey: validKey},
info: "Valid JWT with lower case AuthScheme",
},
} {
if tc.reqURL == "" {
tc.reqURL = "/"
}
var req *http.Request
if len(tc.formValues) > 0 {
form := url.Values{}
for k, v := range tc.formValues {
form.Set(k, v)
}
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
req.ParseForm()
} else {
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
}
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
req.Header.Set(echo.HeaderCookie, tc.hdrCookie)
c := e.NewContext(req, res)
if tc.reqURL == "/"+token {
c.SetParamNames("jwt")
c.SetParamValues(token)
}
if tc.expPanic {
assert.Panics(t, func() {
JWTWithConfig(tc.config)
}, tc.info)
continue
}
if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
assert.Equal(t, tc.expErrCode, he.Code, tc.info)
continue
}
h := JWTWithConfig(tc.config)(handler)
if assert.NoError(t, h(c), tc.info) {
user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) {
case jwt.MapClaims:
assert.Equal(t, claims["name"], "John Doe", tc.info)
case *jwtCustomClaims:
assert.Equal(t, claims.Name, "John Doe", tc.info)
assert.Equal(t, claims.Admin, true, tc.info)
default:
panic("unexpected type of claims")
}
}
} }
}
func TestJWTwithKID(t *testing.T) { for _, tc := range testCases {
test := assert.New(t) t.Run(tc.name, func(t *testing.T) {
if tc.reqURL == "" {
e := echo.New() tc.reqURL = "/"
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk"
secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM"
wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90"
staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o"
validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")}
invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")}
staticSecret := []byte("static_secret")
invalidStaticSecret := []byte("invalid_secret")
for _, tc := range []struct {
expErrCode int // 0 for Success
config JWTConfig
hdrAuth string
info string
}{
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
config: JWTConfig{SigningKeys: validKeys},
info: "First token valid",
},
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
config: JWTConfig{SigningKeys: validKeys},
info: "Second token valid",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken,
config: JWTConfig{SigningKeys: validKeys},
info: "Wrong key id token",
},
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
config: JWTConfig{SigningKey: staticSecret},
info: "Valid static secret token",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
config: JWTConfig{SigningKey: invalidStaticSecret},
info: "Invalid static secret",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
config: JWTConfig{SigningKeys: invalidKeys},
info: "Invalid keys first token",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
config: JWTConfig{SigningKeys: invalidKeys},
info: "Invalid keys second token",
},
} {
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
c := e.NewContext(req, res)
if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
test.Equal(tc.expErrCode, he.Code, tc.info)
continue
}
h := JWTWithConfig(tc.config)(handler)
if test.NoError(h(c), tc.info) {
user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) {
case jwt.MapClaims:
test.Equal(claims["name"], "John Doe", tc.info)
case *jwtCustomClaims:
test.Equal(claims.Name, "John Doe", tc.info)
test.Equal(claims.Admin, true, tc.info)
default:
panic("unexpected type of claims")
} }
}
var req *http.Request
if len(tc.formValues) > 0 {
form := url.Values{}
for k, v := range tc.formValues {
form.Set(k, v)
}
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
req.ParseForm()
} else {
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
}
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
req.Header.Set(echo.HeaderCookie, tc.hdrCookie)
c := e.NewContext(req, res)
if tc.reqURL == "/"+token {
cc := c.(echo.EditableContext)
cc.SetPathParams(echo.PathParams{
{Name: "jwt", Value: token},
})
}
if tc.expPanic {
assert.Panics(t, func() {
JWTWithConfig(tc.config)
}, tc.name)
return
}
if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
assert.Equal(t, tc.expErrCode, he.Code)
return
}
h := JWTWithConfig(tc.config)(handler)
if assert.NoError(t, h(c), tc.name) {
user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) {
case jwt.MapClaims:
assert.Equal(t, claims["name"], "John Doe")
case *jwtCustomClaims:
assert.Equal(t, claims.Name, "John Doe")
assert.Equal(t, claims.Admin, true)
default:
panic("unexpected type of claims")
}
}
})
} }
} }
@ -420,7 +295,7 @@ func TestJWTConfig_skipper(t *testing.T) {
Skipper: func(context echo.Context) bool { Skipper: func(context echo.Context) bool {
return true // skip everything return true // skip everything
}, },
SigningKey: []byte("secret"), ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
})) }))
isCalled := false isCalled := false
@ -448,11 +323,11 @@ func TestJWTConfig_BeforeFunc(t *testing.T) {
BeforeFunc: func(context echo.Context) { BeforeFunc: func(context echo.Context) {
isCalled = true isCalled = true
}, },
SigningKey: []byte("secret"), ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
})) }))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder() res := httptest.NewRecorder()
e.ServeHTTP(res, req) e.ServeHTTP(res, req)
@ -469,18 +344,8 @@ func TestJWTConfig_extractorErrorHandling(t *testing.T) {
{ {
name: "ok, ErrorHandler is executed", name: "ok, ErrorHandler is executed",
given: JWTConfig{ given: JWTConfig{
SigningKey: []byte("secret"), ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
ErrorHandler: func(err error) error { ErrorHandler: func(c echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
},
},
expectStatusCode: http.StatusTeapot,
},
{
name: "ok, ErrorHandlerWithContext is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, context echo.Context) error {
return echo.NewHTTPError(http.StatusTeapot, "custom_error") return echo.NewHTTPError(http.StatusTeapot, "custom_error")
}, },
}, },
@ -515,23 +380,13 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
{ {
name: "ok, ErrorHandler is executed", name: "ok, ErrorHandler is executed",
given: JWTConfig{ given: JWTConfig{
SigningKey: []byte("secret"), ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
ErrorHandler: func(err error) error { ErrorHandler: func(c echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error())
}, },
}, },
expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n", expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n",
}, },
{
name: "ok, ErrorHandlerWithContext is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, context echo.Context) error {
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error())
},
},
expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n",
},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -544,14 +399,14 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
config := tc.given config := tc.given
parseTokenCalled := false parseTokenCalled := false
config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) { config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) {
parseTokenCalled = true parseTokenCalled = true
return nil, errors.New("parsing failed") return nil, errors.New("parsing failed")
} }
e.Use(JWTWithConfig(config)) e.Use(JWTWithConfig(config))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder() res := httptest.NewRecorder()
e.ServeHTTP(res, req) e.ServeHTTP(res, req)
@ -574,7 +429,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
signingKey := []byte("secret") signingKey := []byte("secret")
config := JWTConfig{ config := JWTConfig{
ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) { ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
keyFunc := func(t *jwt.Token) (interface{}, error) { keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != "HS256" { if t.Method.Alg() != "HS256" {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
@ -597,9 +452,161 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
e.Use(JWTWithConfig(config)) e.Use(JWTWithConfig(config))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder() res := httptest.NewRecorder()
e.ServeHTTP(res, req) e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code) assert.Equal(t, http.StatusTeapot, res.Code)
} }
func TestMustJWTWithConfig_SuccessHandler(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
success := c.Get("success").(string)
user := c.Get("user").(string)
return c.String(http.StatusTeapot, fmt.Sprintf("%v:%v", success, user))
})
mw, err := JWTConfig{
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
return auth, nil
},
SuccessHandler: func(c echo.Context) {
c.Set("success", "yes")
},
}.ToMiddleware()
assert.NoError(t, err)
e.Use(mw)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderAuthorization, "Bearer valid_token_base64")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, "yes:valid_token_base64", res.Body.String())
assert.Equal(t, http.StatusTeapot, res.Code)
}
func TestJWTWithConfig_CallNextOnNilErrorHandlerResult(t *testing.T) {
var testCases = []struct {
name string
givenCallNext bool
givenErrorHandler JWTErrorHandlerWithContext
givenTokenLookup string
whenAuthHeaders []string
whenCookies []string
whenParseReturn string
whenParseError error
expectHandlerCalled bool
expect string
expectCode int
}{
{
name: "ok, with valid JWT from auth header",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
return nil
},
whenAuthHeaders: []string{"Bearer valid_token_base64"},
whenParseReturn: "valid_token",
expectCode: http.StatusTeapot,
expect: "valid_token",
},
{
name: "ok, missing header, callNext and set public_token from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
if err != ErrJWTMissing {
panic("must get ErrJWTMissing")
}
c.Set("user", "public_token")
return nil
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusTeapot,
expect: "public_token",
},
{
name: "ok, invalid token, callNext and set public_token from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
// this is probably not realistic usecase. on parse error you probably want to return error
if err.Error() != "parser_error" {
panic("must get parser_error")
}
c.Set("user", "public_token")
return nil
},
whenAuthHeaders: []string{"Bearer invalid_header"},
whenParseError: errors.New("parser_error"),
expectCode: http.StatusTeapot,
expect: "public_token",
},
{
name: "nok, invalid token, return error from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
if err.Error() != "parser_error" {
panic("must get parser_error")
}
return err
},
whenAuthHeaders: []string{"Bearer invalid_header"},
whenParseError: errors.New("parser_error"),
expectCode: http.StatusInternalServerError,
expect: "{\"message\":\"Internal Server Error\"}\n",
},
{
name: "nok, callNext but return error from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
return err
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusUnauthorized,
expect: "{\"message\":\"missing or malformed jwt\"}\n",
},
{
name: "nok, callNext=false",
givenCallNext: false,
givenErrorHandler: func(c echo.Context, err error) error {
return err
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusUnauthorized,
expect: "{\"message\":\"missing or malformed jwt\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
token := c.Get("user").(string)
return c.String(http.StatusTeapot, token)
})
mw, err := JWTConfig{
TokenLookup: tc.givenTokenLookup,
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
return tc.whenParseReturn, tc.whenParseError
},
ErrorHandler: tc.givenErrorHandler,
}.ToMiddleware()
assert.NoError(t, err)
e.Use(mw)
req := httptest.NewRequest(http.MethodGet, "/", nil)
for _, a := range tc.whenAuthHeaders {
req.Header.Add(echo.HeaderAuthorization, a)
}
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expect, res.Body.String())
assert.Equal(t, tc.expectCode, res.Code)
})
}
}

View File

@ -3,58 +3,59 @@ package middleware
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/http"
"strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"net/http"
) )
type ( // KeyAuthConfig defines the config for KeyAuth middleware.
// KeyAuthConfig defines the config for KeyAuth middleware. type KeyAuthConfig struct {
KeyAuthConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// KeyLookup is a string in the form of "<source>:<name>" that is used // KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract key from the request. // to extract key(s) from the request.
// Optional. Default value "header:Authorization". // Optional. Default value "header:Authorization:Bearer ".
// Possible values: // Possible values:
// - "header:<name>" // - "header:<name>:<value prefix>"
// - "query:<name>" // - "query:<name>"
// - "form:<name>" // - "param:<name>"
// - "cookie:<name>" // - "cookie:<name>"
KeyLookup string `yaml:"key_lookup"` // - "form:<name>"
// Multiple sources example:
// - "header:Authorization:Bearer ,cookie:myowncookie"
KeyLookup string
// AuthScheme to be used in the Authorization header. // Validator is a function to validate key.
// Optional. Default value "Bearer". // Required.
AuthScheme string Validator KeyAuthValidator
// Validator is a function to validate key. // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
// Required. // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
Validator KeyAuthValidator // It may be used to define a custom error.
//
// Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain.
ErrorHandler KeyAuthErrorHandler
}
// ErrorHandler defines a function which is executed for an invalid key. // KeyAuthValidator defines a function to validate KeyAuth credentials.
// It may be used to define a custom error. type KeyAuthValidator func(c echo.Context, key string, keyType ExtractorType) (bool, error)
ErrorHandler KeyAuthErrorHandler
}
// KeyAuthValidator defines a function to validate KeyAuth credentials. // KeyAuthErrorHandler defines a function which is executed for an invalid key.
KeyAuthValidator func(string, echo.Context) (bool, error) type KeyAuthErrorHandler func(c echo.Context, err error) error
keyExtractor func(echo.Context) (string, error) // ErrKeyMissing denotes an error raised when key value could not be extracted from request
var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
// KeyAuthErrorHandler defines a function which is executed for an invalid key. // ErrInvalidKey denotes an error raised when key value is invalid by validator
KeyAuthErrorHandler func(error, echo.Context) error var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
)
var ( // DefaultKeyAuthConfig is the default KeyAuth middleware config.
// DefaultKeyAuthConfig is the default KeyAuth middleware config. var DefaultKeyAuthConfig = KeyAuthConfig{
DefaultKeyAuthConfig = KeyAuthConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
KeyLookup: "header:" + echo.HeaderAuthorization, }
AuthScheme: "Bearer",
}
)
// KeyAuth returns an KeyAuth middleware. // KeyAuth returns an KeyAuth middleware.
// //
@ -67,34 +68,32 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
return KeyAuthWithConfig(c) return KeyAuthWithConfig(c)
} }
// KeyAuthWithConfig returns an KeyAuth middleware with config. // KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid.
// See `KeyAuth()`. //
// For first valid key it calls the next handler.
// For invalid key, it sends "401 - Unauthorized" response.
// For missing key, it sends "400 - Bad Request" response.
func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration
func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultKeyAuthConfig.Skipper config.Skipper = DefaultKeyAuthConfig.Skipper
} }
// Defaults
if config.AuthScheme == "" {
config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
}
if config.KeyLookup == "" { if config.KeyLookup == "" {
config.KeyLookup = DefaultKeyAuthConfig.KeyLookup config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
} }
if config.Validator == nil { if config.Validator == nil {
panic("echo: key-auth middleware requires a validator function") return nil, errors.New("echo key-auth middleware requires a validator function")
} }
extractors, err := createExtractors(config.KeyLookup)
// Initialize if err != nil {
parts := strings.Split(config.KeyLookup, ":") return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", err)
extractor := keyFromHeader(parts[1], config.AuthScheme) }
switch parts[0] { if len(extractors) == 0 {
case "query": return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string")
extractor = keyFromQuery(parts[1])
case "form":
extractor = keyFromForm(parts[1])
case "cookie":
extractor = keyFromCookie(parts[1])
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -103,79 +102,50 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
// Extract and verify key var lastExtractorErr error
key, err := extractor(c) var lastValidatorErr error
if err != nil { for _, extractor := range extractors {
if config.ErrorHandler != nil { keys, keyType, extrErr := extractor(c)
return config.ErrorHandler(err, c) if extrErr != nil {
lastExtractorErr = extrErr
continue
}
for _, key := range keys {
valid, err := config.Validator(c, key, keyType)
if err != nil {
lastValidatorErr = err
continue
}
if !valid {
lastValidatorErr = ErrInvalidKey
continue
}
return next(c)
} }
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
} }
valid, err := config.Validator(key, c)
if err != nil { // prioritize validator errors over extracting errors
if config.ErrorHandler != nil { err := lastValidatorErr
return config.ErrorHandler(err, c) if err == nil {
err = lastExtractorErr
}
if config.ErrorHandler != nil {
// Allow error handler to swallow the error and continue handler chain execution
// Useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain
if handledErr := config.ErrorHandler(c, err); handledErr != nil {
return handledErr
} }
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "invalid key",
Internal: err,
}
} else if valid {
return next(c) return next(c)
} }
return echo.ErrUnauthorized if err == ErrExtractionValueMissing {
} return ErrKeyMissing // do not wrap extractor errors
} }
} return &echo.HTTPError{
Code: http.StatusUnauthorized,
// keyFromHeader returns a `keyExtractor` that extracts key from the request header. Message: "Unauthorized",
func keyFromHeader(header string, authScheme string) keyExtractor { Internal: err,
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
if auth == "" {
return "", errors.New("missing key in request header")
}
if header == echo.HeaderAuthorization {
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
} }
return "", errors.New("invalid key in the request header")
} }
return auth, nil }, nil
}
}
// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
func keyFromQuery(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.QueryParam(param)
if key == "" {
return "", errors.New("missing key in the query string")
}
return key, nil
}
}
// keyFromForm returns a `keyExtractor` that extracts key from the form.
func keyFromForm(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.FormValue(param)
if key == "" {
return "", errors.New("missing key in the form")
}
return key, nil
}
}
// keyFromCookie returns a `keyExtractor` that extracts key from the form.
func keyFromCookie(cookieName string) keyExtractor {
return func(c echo.Context) (string, error) {
key, err := c.Cookie(cookieName)
if err != nil {
return "", fmt.Errorf("missing key in cookies: %w", err)
}
return key.Value, nil
}
} }

View File

@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func testKeyValidator(key string, c echo.Context) (bool, error) { func testKeyValidator(c echo.Context, key string, keyType ExtractorType) (bool, error) {
switch key { switch key {
case "valid-key": case "valid-key":
return true, nil return true, nil
@ -28,7 +28,7 @@ func TestKeyAuth(t *testing.T) {
handlerCalled = true handlerCalled = true
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }
middlewareChain := KeyAuth(testKeyValidator)(handler) middlewareChain := KeyAuthWithConfig(KeyAuthConfig{Validator: testKeyValidator})(handler)
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=401, message=Unauthorized", expectError: "code=401, message=Unauthorized, internal=code=401, message=invalid key",
}, },
{ {
name: "nok, defaults, invalid scheme in header", name: "nok, defaults, invalid scheme in header",
@ -84,13 +84,13 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=invalid key in the request header", expectError: "code=401, message=missing key",
}, },
{ {
name: "nok, defaults, missing header", name: "nok, defaults, missing header",
givenRequest: func(req *http.Request) {}, givenRequest: func(req *http.Request) {},
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header", expectError: "code=401, message=missing key",
}, },
{ {
name: "ok, custom key lookup, header", name: "ok, custom key lookup, header",
@ -110,7 +110,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "header:API-Key" conf.KeyLookup = "header:API-Key"
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header", expectError: "code=401, message=missing key",
}, },
{ {
name: "ok, custom key lookup, query", name: "ok, custom key lookup, query",
@ -130,7 +130,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "query:key" conf.KeyLookup = "query:key"
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in the query string", expectError: "code=401, message=missing key",
}, },
{ {
name: "ok, custom key lookup, form", name: "ok, custom key lookup, form",
@ -155,7 +155,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "form:key" conf.KeyLookup = "form:key"
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in the form", expectError: "code=401, message=missing key",
}, },
{ {
name: "ok, custom key lookup, cookie", name: "ok, custom key lookup, cookie",
@ -179,20 +179,20 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "cookie:key" conf.KeyLookup = "cookie:key"
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in cookies: http: named cookie not present", expectError: "code=401, message=missing key",
}, },
{ {
name: "nok, custom errorHandler, error from extractor", name: "nok, custom errorHandler, error from extractor",
whenConfig: func(conf *KeyAuthConfig) { whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:token" conf.KeyLookup = "header:token"
conf.ErrorHandler = func(err error, context echo.Context) error { conf.ErrorHandler = func(c echo.Context, err error) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err httpError.Internal = err
return httpError return httpError
} }
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=418, message=custom, internal=missing key in request header", expectError: "code=418, message=custom, internal=code=400, message=missing or malformed value",
}, },
{ {
name: "nok, custom errorHandler, error from validator", name: "nok, custom errorHandler, error from validator",
@ -200,7 +200,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
}, },
whenConfig: func(conf *KeyAuthConfig) { whenConfig: func(conf *KeyAuthConfig) {
conf.ErrorHandler = func(err error, context echo.Context) error { conf.ErrorHandler = func(c echo.Context, err error) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err httpError.Internal = err
return httpError return httpError
@ -216,7 +216,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
}, },
whenConfig: func(conf *KeyAuthConfig) {}, whenConfig: func(conf *KeyAuthConfig) {},
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=401, message=invalid key, internal=some user defined error", expectError: "code=401, message=Unauthorized, internal=some user defined error",
}, },
} }
@ -257,3 +257,96 @@ func TestKeyAuthWithConfig(t *testing.T) {
}) })
} }
} }
func TestKeyAuthWithConfig_errors(t *testing.T) {
var testCases = []struct {
name string
whenConfig KeyAuthConfig
expectError string
}{
{
name: "ok, no error",
whenConfig: KeyAuthConfig{
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
return false, nil
},
},
},
{
name: "ok, missing validator func",
whenConfig: KeyAuthConfig{
Validator: nil,
},
expectError: "echo key-auth middleware requires a validator function",
},
{
name: "ok, extractor source can not be split",
whenConfig: KeyAuthConfig{
KeyLookup: "nope",
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
return false, nil
},
},
expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope",
},
{
name: "ok, no extractors",
whenConfig: KeyAuthConfig{
KeyLookup: "nope:nope",
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
return false, nil
},
},
expectError: "echo key-auth middleware could not create extractors from KeyLookup string",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mw, err := tc.whenConfig.ToMiddleware()
if tc.expectError != "" {
assert.Nil(t, mw)
assert.EqualError(t, err, tc.expectError)
} else {
assert.NotNil(t, mw)
assert.NoError(t, err)
}
})
}
}
func TestMustKeyAuthWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
KeyAuthWithConfig(KeyAuthConfig{})
})
}
func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) {
handlerCalled := false
var authValue string
handler := func(c echo.Context) error {
handlerCalled = true
authValue = c.Get("auth").(string)
return c.String(http.StatusOK, "test")
}
middlewareChain := KeyAuthWithConfig(KeyAuthConfig{
Validator: testKeyValidator,
ErrorHandler: func(c echo.Context, err error) error {
// could check error to decide if we can swallow the error
c.Set("auth", "public")
return nil
},
})(handler)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
// no auth header this time
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := middlewareChain(c)
assert.NoError(t, err)
assert.True(t, handlerCalled)
assert.Equal(t, "public", authValue)
}

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"strconv" "strconv"
"strings" "strings"
@ -10,81 +11,78 @@ import (
"time" "time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/color"
"github.com/valyala/fasttemplate" "github.com/valyala/fasttemplate"
) )
type ( // LoggerConfig defines the config for Logger middleware.
// LoggerConfig defines the config for Logger middleware. type LoggerConfig struct {
LoggerConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Tags to construct the logger format. // Tags to construct the logger format.
// //
// - time_unix // - time_unix
// - time_unix_nano // - time_unix_nano
// - time_rfc3339 // - time_rfc3339
// - time_rfc3339_nano // - time_rfc3339_nano
// - time_custom // - time_custom
// - id (Request ID) // - id (Request ID)
// - remote_ip // - remote_ip
// - uri // - uri
// - host // - host
// - method // - method
// - path // - path
// - protocol // - protocol
// - referer // - referer
// - user_agent // - user_agent
// - status // - status
// - error // - error
// - latency (In nanoseconds) // - latency (In nanoseconds)
// - latency_human (Human readable) // - latency_human (Human readable)
// - bytes_in (Bytes received) // - bytes_in (Bytes received)
// - bytes_out (Bytes sent) // - bytes_out (Bytes sent)
// - header:<NAME> // - header:<NAME>
// - query:<NAME> // - query:<NAME>
// - form:<NAME> // - form:<NAME>
// //
// Example "${remote_ip} ${status}" // Example "${remote_ip} ${status}"
// //
// Optional. Default value DefaultLoggerConfig.Format. // Optional. Default value DefaultLoggerConfig.Format.
Format string `yaml:"format"` Format string
// Optional. Default value DefaultLoggerConfig.CustomTimeFormat. // Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
CustomTimeFormat string `yaml:"custom_time_format"` CustomTimeFormat string
// Output is a writer where logs in JSON format are written. // Output is a writer where logs in JSON format are written.
// Optional. Default value os.Stdout. // Optional. Default destination `echo.Logger.Infof()`
Output io.Writer Output io.Writer
template *fasttemplate.Template template *fasttemplate.Template
colorer *color.Color pool *sync.Pool
pool *sync.Pool }
}
)
var ( // DefaultLoggerConfig is the default Logger middleware config.
// DefaultLoggerConfig is the default Logger middleware config. var DefaultLoggerConfig = LoggerConfig{
DefaultLoggerConfig = LoggerConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` +
Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` +
`"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` +
`"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n",
`,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", CustomTimeFormat: "2006-01-02 15:04:05.00000",
CustomTimeFormat: "2006-01-02 15:04:05.00000", }
colorer: color.New(),
}
)
// Logger returns a middleware that logs HTTP requests. // Logger returns a middleware that logs HTTP requests.
func Logger() echo.MiddlewareFunc { func Logger() echo.MiddlewareFunc {
return LoggerWithConfig(DefaultLoggerConfig) return LoggerWithConfig(DefaultLoggerConfig)
} }
// LoggerWithConfig returns a Logger middleware with config. // LoggerWithConfig returns a Logger middleware with config or panics on invalid configuration.
// See: `Logger()`.
func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts LoggerConfig to middleware or returns an error for invalid configuration
func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultLoggerConfig.Skipper config.Skipper = DefaultLoggerConfig.Skipper
@ -92,13 +90,8 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
if config.Format == "" { if config.Format == "" {
config.Format = DefaultLoggerConfig.Format config.Format = DefaultLoggerConfig.Format
} }
if config.Output == nil {
config.Output = DefaultLoggerConfig.Output
}
config.template = fasttemplate.New(config.Format, "${", "}") config.template = fasttemplate.New(config.Format, "${", "}")
config.colorer = color.New()
config.colorer.SetOutput(config.Output)
config.pool = &sync.Pool{ config.pool = &sync.Pool{
New: func() interface{} { New: func() interface{} {
return bytes.NewBuffer(make([]byte, 256)) return bytes.NewBuffer(make([]byte, 256))
@ -106,23 +99,23 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) error {
if config.Skipper(c) { if config.Skipper(c) {
return next(c) return next(c)
} }
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
start := time.Now() start := time.Now()
if err = next(c); err != nil { err := next(c)
c.Error(err)
}
stop := time.Now() stop := time.Now()
buf := config.pool.Get().(*bytes.Buffer) buf := config.pool.Get().(*bytes.Buffer)
buf.Reset() buf.Reset()
defer config.pool.Put(buf) defer config.pool.Put(buf)
if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { _, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
switch tag { switch tag {
case "time_unix": case "time_unix":
return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10))
@ -161,17 +154,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
case "user_agent": case "user_agent":
return buf.WriteString(req.UserAgent()) return buf.WriteString(req.UserAgent())
case "status": case "status":
n := res.Status status := res.Status
s := config.colorer.Green(n) if err != nil {
switch { if httpErr, ok := err.(*echo.HTTPError); ok {
case n >= 500: status = httpErr.Code
s = config.colorer.Red(n) }
case n >= 400:
s = config.colorer.Yellow(n)
case n >= 300:
s = config.colorer.Cyan(n)
} }
return buf.WriteString(s) return buf.WriteString(strconv.Itoa(status))
case "error": case "error":
if err != nil { if err != nil {
// Error may contain invalid JSON e.g. `"` // Error may contain invalid JSON e.g. `"`
@ -201,23 +190,31 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
case strings.HasPrefix(tag, "form:"): case strings.HasPrefix(tag, "form:"):
return buf.Write([]byte(c.FormValue(tag[5:]))) return buf.Write([]byte(c.FormValue(tag[5:])))
case strings.HasPrefix(tag, "cookie:"): case strings.HasPrefix(tag, "cookie:"):
cookie, err := c.Cookie(tag[7:]) cookie, cookieErr := c.Cookie(tag[7:])
if err == nil { if cookieErr == nil {
return buf.Write([]byte(cookie.Value)) return buf.Write([]byte(cookie.Value))
} }
} }
} }
return 0, nil return 0, nil
}); err != nil { })
return if tmplErr != nil {
if err != nil {
return fmt.Errorf("error in middleware chain and also failed to create log from template: %v: %w", tmplErr, err)
}
return fmt.Errorf("failed to create log from template: %w", tmplErr)
} }
if config.Output == nil { if config.Output != nil {
_, err = c.Logger().Output().Write(buf.Bytes()) if _, lErr := config.Output.Write(buf.Bytes()); lErr != nil {
return return lErr
}
} else {
if _, lErr := c.Echo().Logger.Write(buf.Bytes()); lErr != nil {
return lErr
}
} }
_, err = config.Output.Write(buf.Bytes()) return err
return
} }
} }, nil
} }

View File

@ -61,7 +61,7 @@ func TestLoggerIPAddress(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
e.Logger.SetOutput(buf) e.Logger = &testLogger{output: buf}
ip := "127.0.0.1" ip := "127.0.0.1"
h := Logger()(func(c echo.Context) error { h := Logger()(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")

View File

@ -6,28 +6,24 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // MethodOverrideConfig defines the config for MethodOverride middleware.
// MethodOverrideConfig defines the config for MethodOverride middleware. type MethodOverrideConfig struct {
MethodOverrideConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Getter is a function that gets overridden method from the request. // Getter is a function that gets overridden method from the request.
// Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
Getter MethodOverrideGetter Getter MethodOverrideGetter
} }
// MethodOverrideGetter is a function that gets overridden method from the request // MethodOverrideGetter is a function that gets overridden method from the request
MethodOverrideGetter func(echo.Context) string type MethodOverrideGetter func(echo.Context) string
)
var ( // DefaultMethodOverrideConfig is the default MethodOverride middleware config.
// DefaultMethodOverrideConfig is the default MethodOverride middleware config. var DefaultMethodOverrideConfig = MethodOverrideConfig{
DefaultMethodOverrideConfig = MethodOverrideConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), }
}
)
// MethodOverride returns a MethodOverride middleware. // MethodOverride returns a MethodOverride middleware.
// MethodOverride middleware checks for the overridden method from the request and // MethodOverride middleware checks for the overridden method from the request and
@ -38,9 +34,13 @@ func MethodOverride() echo.MiddlewareFunc {
return MethodOverrideWithConfig(DefaultMethodOverrideConfig) return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
} }
// MethodOverrideWithConfig returns a MethodOverride middleware with config. // MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration.
// See: `MethodOverride()`.
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration
func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultMethodOverrideConfig.Skipper config.Skipper = DefaultMethodOverrideConfig.Skipper
@ -64,7 +64,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
} }, nil
} }
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from

View File

@ -22,28 +22,70 @@ func TestMethodOverride(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
m(h)(c)
err := m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method) assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_formParam(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Override with form parameter // Override with form parameter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware()
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) assert.NoError(t, err)
rec = httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
c = e.NewContext(req, rec) c := e.NewContext(req, rec)
m(h)(c)
err = m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method) assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_queryParam(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Override with query parameter // Override with query parameter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware()
req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) assert.NoError(t, err)
rec = httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
c = e.NewContext(req, rec) rec := httptest.NewRecorder()
m(h)(c) c := e.NewContext(req, rec)
err = m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method) assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_ignoreGet(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Ignore `GET` // Ignore `GET`
req = httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodGet, req.Method) assert.Equal(t, http.MethodGet, req.Method)
} }

View File

@ -9,14 +9,11 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // Skipper defines a function to skip middleware. Returning true skips processing the middleware.
// Skipper defines a function to skip middleware. Returning true skips processing type Skipper func(c echo.Context) bool
// the middleware.
Skipper func(echo.Context) bool
// BeforeFunc defines a function which is executed just before the middleware. // BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc func(echo.Context) type BeforeFunc func(c echo.Context)
)
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
groups := pattern.FindAllStringSubmatch(input, -1) groups := pattern.FindAllStringSubmatch(input, -1)
@ -87,3 +84,11 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error
func DefaultSkipper(echo.Context) bool { func DefaultSkipper(echo.Context) bool {
return false return false
} }
func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc {
mw, err := config.ToMiddleware()
if err != nil {
panic(err)
}
return mw
}

View File

@ -2,6 +2,7 @@ package middleware
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
@ -20,85 +21,81 @@ import (
// TODO: Handle TLS proxy // TODO: Handle TLS proxy
type ( // ProxyConfig defines the config for Proxy middleware.
// ProxyConfig defines the config for Proxy middleware. type ProxyConfig struct {
ProxyConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Balancer defines a load balancing technique. // Balancer defines a load balancing technique.
// Required. // Required.
Balancer ProxyBalancer Balancer ProxyBalancer
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be // Rewrite defines URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on. // retrieved by index e.g. $1, $2 and so on.
// Examples: // Examples:
// "/old": "/new", // "/old": "/new",
// "/api/*": "/$1", // "/api/*": "/$1",
// "/js/*": "/public/javascripts/$1", // "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2", // "/users/*/orders/*": "/user/$1/order/$2",
Rewrite map[string]string Rewrite map[string]string
// RegexRewrite defines rewrite rules using regexp.Rexexp with captures // RegexRewrite defines rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example: // Example:
// "^/old/[0.9]+/": "/new", // "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1", // "^/api/.+?/(.*)": "/v2/$1",
RegexRewrite map[*regexp.Regexp]string RegexRewrite map[*regexp.Regexp]string
// Context key to store selected ProxyTarget into context. // Context key to store selected ProxyTarget into context.
// Optional. Default value "target". // Optional. Default value "target".
ContextKey string ContextKey string
// To customize the transport to remote. // To customize the transport to remote.
// Examples: If custom TLS certificates are required. // Examples: If custom TLS certificates are required.
Transport http.RoundTripper Transport http.RoundTripper
// ModifyResponse defines function to modify response from ProxyTarget. // ModifyResponse defines function to modify response from ProxyTarget.
ModifyResponse func(*http.Response) error ModifyResponse func(*http.Response) error
} }
// ProxyTarget defines the upstream target. // ProxyTarget defines the upstream target.
ProxyTarget struct { type ProxyTarget struct {
Name string Name string
URL *url.URL URL *url.URL
Meta echo.Map Meta echo.Map
} }
// ProxyBalancer defines an interface to implement a load balancing technique. // ProxyBalancer defines an interface to implement a load balancing technique.
ProxyBalancer interface { type ProxyBalancer interface {
AddTarget(*ProxyTarget) bool AddTarget(*ProxyTarget) bool
RemoveTarget(string) bool RemoveTarget(string) bool
Next(echo.Context) *ProxyTarget Next(echo.Context) *ProxyTarget
} }
commonBalancer struct { type commonBalancer struct {
targets []*ProxyTarget targets []*ProxyTarget
mutex sync.RWMutex mutex sync.RWMutex
} }
// RandomBalancer implements a random load balancing technique. // RandomBalancer implements a random load balancing technique.
randomBalancer struct { type randomBalancer struct {
*commonBalancer *commonBalancer
random *rand.Rand random *rand.Rand
} }
// RoundRobinBalancer implements a round-robin load balancing technique. // RoundRobinBalancer implements a round-robin load balancing technique.
roundRobinBalancer struct { type roundRobinBalancer struct {
*commonBalancer *commonBalancer
i uint32 i uint32
} }
)
var ( // DefaultProxyConfig is the default Proxy middleware config.
// DefaultProxyConfig is the default Proxy middleware config. var DefaultProxyConfig = ProxyConfig{
DefaultProxyConfig = ProxyConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, ContextKey: "target",
ContextKey: "target", }
}
)
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
in, _, err := c.Response().Hijack() in, _, err := c.Response().Hijack()
if err != nil { if err != nil {
@ -203,15 +200,23 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
return ProxyWithConfig(c) return ProxyWithConfig(c)
} }
// ProxyWithConfig returns a Proxy middleware with config. // ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid.
// See: `Proxy()` //
// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration
func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultProxyConfig.Skipper config.Skipper = DefaultProxyConfig.Skipper
} }
if config.ContextKey == "" {
config.ContextKey = DefaultProxyConfig.ContextKey
}
if config.Balancer == nil { if config.Balancer == nil {
panic("echo: proxy middleware requires balancer") return nil, errors.New("echo proxy middleware requires balancer")
} }
if config.Rewrite != nil { if config.Rewrite != nil {
@ -254,10 +259,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Proxy // Proxy
switch { switch {
case c.IsWebSocket(): case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req) proxyRaw(c, tgt).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream": case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
default: default:
proxyHTTP(tgt, c, config).ServeHTTP(res, req) proxyHTTP(c, tgt, config).ServeHTTP(res, req)
} }
if e, ok := c.Get("_error").(error); ok { if e, ok := c.Get("_error").(error); ok {
err = e err = e
@ -265,7 +270,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
return return
} }
} }, nil
} }
// StatusCodeContextCanceled is a custom HTTP status code for situations // StatusCodeContextCanceled is a custom HTTP status code for situations
@ -275,7 +280,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// 499 too instead of the more problematic 5xx, which does not allow to detect this situation // 499 too instead of the more problematic 5xx, which does not allow to detect this situation
const StatusCodeContextCanceled = 499 const StatusCodeContextCanceled = 499
func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { func proxyHTTP(c echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler {
proxy := httputil.NewSingleHostReverseProxy(tgt.URL) proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
desc := tgt.URL.String() desc := tgt.URL.String()

View File

@ -55,7 +55,7 @@ func TestProxy(t *testing.T) {
// Random // Random
e := echo.New() e := echo.New()
e.Use(Proxy(rb)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
@ -77,7 +77,7 @@ func TestProxy(t *testing.T) {
// Round-robin // Round-robin
rrb := NewRoundRobinBalancer(targets) rrb := NewRoundRobinBalancer(targets)
e = echo.New() e = echo.New()
e.Use(Proxy(rrb)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
@ -113,15 +113,20 @@ func TestProxy(t *testing.T) {
return nil return nil
} }
} }
rrb1 := NewRoundRobinBalancer(targets)
e = echo.New() e = echo.New()
e.Use(contextObserver) e.Use(contextObserver)
e.Use(Proxy(rrb1)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)}))
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
} }
func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) {
assert.Panics(t, func() {
ProxyWithConfig(ProxyConfig{Balancer: nil})
})
}
func TestProxyRealIPHeader(t *testing.T) { func TestProxyRealIPHeader(t *testing.T) {
// Setup // Setup
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
@ -129,7 +134,7 @@ func TestProxyRealIPHeader(t *testing.T) {
url, _ := url.Parse(upstream.URL) url, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
e := echo.New() e := echo.New()
e.Use(Proxy(rrb)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@ -334,7 +339,7 @@ func TestProxyError(t *testing.T) {
// Random // Random
e := echo.New() e := echo.New()
e.Use(Proxy(rb)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@ -362,7 +367,7 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
rb := NewRandomBalancer(nil) rb := NewRandomBalancer(nil)
assert.True(t, rb.AddTarget(target)) assert.True(t, rb.AddTarget(target))
e := echo.New() e := echo.New()
e.Use(Proxy(rb)) e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(req.Context()) ctx, cancel := context.WithCancel(req.Context())

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@ -9,39 +10,33 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
type ( // RateLimiterStore is the interface to be implemented by custom stores.
// RateLimiterStore is the interface to be implemented by custom stores. type RateLimiterStore interface {
RateLimiterStore interface { Allow(identifier string) (bool, error)
// Stores for the rate limiter have to implement the Allow method }
Allow(identifier string) (bool, error)
}
)
type ( // RateLimiterConfig defines the configuration for the rate limiter
// RateLimiterConfig defines the configuration for the rate limiter type RateLimiterConfig struct {
RateLimiterConfig struct { Skipper Skipper
Skipper Skipper BeforeFunc BeforeFunc
BeforeFunc BeforeFunc // IdentifierExtractor uses echo.Context to extract the identifier for a visitor
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor IdentifierExtractor Extractor
IdentifierExtractor Extractor // Store defines a store for the rate limiter
// Store defines a store for the rate limiter Store RateLimiterStore
Store RateLimiterStore // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error ErrorHandler func(context echo.Context, err error) error
ErrorHandler func(context echo.Context, err error) error // DenyHandler provides a handler to be called when RateLimiter denies access
// DenyHandler provides a handler to be called when RateLimiter denies access DenyHandler func(context echo.Context, identifier string, err error) error
DenyHandler func(context echo.Context, identifier string, err error) error }
}
// Extractor is used to extract data from echo.Context
Extractor func(context echo.Context) (string, error)
)
// errors // Extractor is used to extract data from echo.Context
var ( type Extractor func(context echo.Context) (string, error)
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
// ErrExtractorError denotes an error raised when extractor function is unsuccessful var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
) // ErrExtractorError denotes an error raised when extractor function is unsuccessful
var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
// DefaultRateLimiterConfig defines default values for RateLimiterConfig // DefaultRateLimiterConfig defines default values for RateLimiterConfig
var DefaultRateLimiterConfig = RateLimiterConfig{ var DefaultRateLimiterConfig = RateLimiterConfig{
@ -111,6 +106,11 @@ RateLimiterWithConfig returns a rate limiting middleware
}, middleware.RateLimiterWithConfig(config)) }, middleware.RateLimiterWithConfig(config))
*/ */
func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration
func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultRateLimiterConfig.Skipper config.Skipper = DefaultRateLimiterConfig.Skipper
} }
@ -124,7 +124,7 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
} }
if config.Store == nil { if config.Store == nil {
panic("Store configuration must be provided") return nil, errors.New("echo rate limiter store configuration must be provided")
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
@ -137,35 +137,32 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
identifier, err := config.IdentifierExtractor(c) identifier, err := config.IdentifierExtractor(c)
if err != nil { if err != nil {
c.Error(config.ErrorHandler(c, err)) return config.ErrorHandler(c, err)
return nil
} }
if allow, err := config.Store.Allow(identifier); !allow { if allow, allowErr := config.Store.Allow(identifier); !allow {
c.Error(config.DenyHandler(c, identifier, err)) return config.DenyHandler(c, identifier, allowErr)
return nil
} }
return next(c) return next(c)
} }
} }, nil
} }
type ( // RateLimiterMemoryStore is the built-in store implementation for RateLimiter
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter type RateLimiterMemoryStore struct {
RateLimiterMemoryStore struct { visitors map[string]*Visitor
visitors map[string]*Visitor mutex sync.Mutex
mutex sync.Mutex rate rate.Limit
rate rate.Limit burst int
burst int expiresIn time.Duration
expiresIn time.Duration lastCleanup time.Time
lastCleanup time.Time }
}
// Visitor signifies a unique user's limiter details // Visitor signifies a unique user's limiter details
Visitor struct { type Visitor struct {
*rate.Limiter *rate.Limiter
lastSeen time.Time lastSeen time.Time
} }
)
/* /*
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with

View File

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
@ -25,19 +24,19 @@ func TestRateLimiter(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
mw := RateLimiter(inMemoryStore) mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
testCases := []struct { testCases := []struct {
id string id string
code int expectErr string
}{ }{
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -47,20 +46,25 @@ func TestRateLimiter(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
_ = mw(handler)(c) err := mw(handler)(c)
assert.Equal(t, tc.code, rec.Code) if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
} }
} }
func TestRateLimiter_panicBehaviour(t *testing.T) { func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
assert.Panics(t, func() { assert.Panics(t, func() {
RateLimiter(nil) RateLimiterWithConfig(RateLimiterConfig{})
}) })
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
RateLimiter(inMemoryStore) RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
}) })
} }
@ -73,7 +77,7 @@ func TestRateLimiterWithConfig(t *testing.T) {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
IdentifierExtractor: func(c echo.Context) (string, error) { IdentifierExtractor: func(c echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP) id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" { if id == "" {
@ -88,7 +92,8 @@ func TestRateLimiterWithConfig(t *testing.T) {
return ctx.JSON(http.StatusBadRequest, nil) return ctx.JSON(http.StatusBadRequest, nil)
}, },
Store: inMemoryStore, Store: inMemoryStore,
}) }.ToMiddleware()
assert.NoError(t, err)
testCases := []struct { testCases := []struct {
id string id string
@ -111,8 +116,9 @@ func TestRateLimiterWithConfig(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
_ = mw(handler)(c) err := mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, tc.code, rec.Code) assert.Equal(t, tc.code, rec.Code)
} }
} }
@ -126,7 +132,7 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
IdentifierExtractor: func(c echo.Context) (string, error) { IdentifierExtractor: func(c echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP) id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" { if id == "" {
@ -135,19 +141,20 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
return id, nil return id, nil
}, },
Store: inMemoryStore, Store: inMemoryStore,
}) }.ToMiddleware()
assert.NoError(t, err)
testCases := []struct { testCases := []struct {
id string id string
code int expectErr string
}{ }{
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"", http.StatusForbidden}, {expectErr: "code=403, message=error while extracting identifier, internal=invalid identifier"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -158,9 +165,13 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
_ = mw(handler)(c) err := mw(handler)(c)
if tc.expectErr != "" {
assert.Equal(t, tc.code, rec.Code) assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
} }
} }
@ -174,21 +185,22 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
Store: inMemoryStore, Store: inMemoryStore,
}) }.ToMiddleware()
assert.NoError(t, err)
testCases := []struct { testCases := []struct {
id string id string
code int expectErr string
}{ }{
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusOK}, {id: "127.0.0.1"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{"127.0.0.1", http.StatusTooManyRequests}, {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
} }
for _, tc := range testCases { for _, tc := range testCases {
@ -199,9 +211,13 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
_ = mw(handler)(c) err := mw(handler)(c)
if tc.expectErr != "" {
assert.Equal(t, tc.code, rec.Code) assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
} }
} }
} }
@ -222,7 +238,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
Skipper: func(c echo.Context) bool { Skipper: func(c echo.Context) bool {
return true return true
}, },
@ -233,10 +249,12 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) { IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil return "127.0.0.1", nil
}, },
}) }.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c) err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, false, beforeFuncRan) assert.Equal(t, false, beforeFuncRan)
} }
@ -256,7 +274,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
Skipper: func(c echo.Context) bool { Skipper: func(c echo.Context) bool {
return false return false
}, },
@ -267,7 +285,8 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) { IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil return "127.0.0.1", nil
}, },
}) }.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c) _ = mw(handler)(c)
@ -291,7 +310,7 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{ mw, err := RateLimiterConfig{
BeforeFunc: func(c echo.Context) { BeforeFunc: func(c echo.Context) {
beforeRan = true beforeRan = true
}, },
@ -299,10 +318,12 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) { IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil return "127.0.0.1", nil
}, },
}) }.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c) err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, true, beforeRan) assert.Equal(t, true, beforeRan)
} }
@ -413,7 +434,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) {
func generateAddressList(count int) []string { func generateAddressList(count int) []string {
addrs := make([]string, count) addrs := make([]string, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
addrs[i] = random.String(15) addrs[i] = randomString(15)
} }
return addrs return addrs
} }

View File

@ -5,44 +5,34 @@ import (
"runtime" "runtime"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
) )
type ( // RecoverConfig defines the config for Recover middleware.
// RecoverConfig defines the config for Recover middleware. type RecoverConfig struct {
RecoverConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Size of the stack to be printed. // Size of the stack to be printed.
// Optional. Default value 4KB. // Optional. Default value 4KB.
StackSize int `yaml:"stack_size"` StackSize int
// DisableStackAll disables formatting stack traces of all other goroutines // DisableStackAll disables formatting stack traces of all other goroutines
// into buffer after the trace for the current goroutine. // into buffer after the trace for the current goroutine.
// Optional. Default value false. // Optional. Default value false.
DisableStackAll bool `yaml:"disable_stack_all"` DisableStackAll bool
// DisablePrintStack disables printing stack trace. // DisablePrintStack disables printing stack trace.
// Optional. Default value as false. // Optional. Default value as false.
DisablePrintStack bool `yaml:"disable_print_stack"` DisablePrintStack bool
}
// LogLevel is log level to printing stack trace. // DefaultRecoverConfig is the default Recover middleware config.
// Optional. Default value 0 (Print). var DefaultRecoverConfig = RecoverConfig{
LogLevel log.Lvl Skipper: DefaultSkipper,
} StackSize: 4 << 10, // 4 KB
) DisableStackAll: false,
DisablePrintStack: false,
var ( }
// DefaultRecoverConfig is the default Recover middleware config.
DefaultRecoverConfig = RecoverConfig{
Skipper: DefaultSkipper,
StackSize: 4 << 10, // 4 KB
DisableStackAll: false,
DisablePrintStack: false,
LogLevel: 0,
}
)
// Recover returns a middleware which recovers from panics anywhere in the chain // Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler. // and handles the control to the centralized HTTPErrorHandler.
@ -50,9 +40,13 @@ func Recover() echo.MiddlewareFunc {
return RecoverWithConfig(DefaultRecoverConfig) return RecoverWithConfig(DefaultRecoverConfig)
} }
// RecoverWithConfig returns a Recover middleware with config. // RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration.
// See: `Recover()`.
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration
func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultRecoverConfig.Skipper config.Skipper = DefaultRecoverConfig.Skipper
@ -62,40 +56,26 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) (err error) {
if config.Skipper(c) { if config.Skipper(c) {
return next(c) return next(c)
} }
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
err, ok := r.(error) tmpErr, ok := r.(error)
if !ok { if !ok {
err = fmt.Errorf("%v", r) tmpErr = fmt.Errorf("%v", r)
} }
stack := make([]byte, config.StackSize)
length := runtime.Stack(stack, !config.DisableStackAll)
if !config.DisablePrintStack { if !config.DisablePrintStack {
msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) stack := make([]byte, config.StackSize)
switch config.LogLevel { length := runtime.Stack(stack, !config.DisableStackAll)
case log.DEBUG: tmpErr = fmt.Errorf("[PANIC RECOVER] %w %s", tmpErr, stack[:length])
c.Logger().Debug(msg)
case log.INFO:
c.Logger().Info(msg)
case log.WARN:
c.Logger().Warn(msg)
case log.ERROR:
c.Logger().Error(msg)
case log.OFF:
// None.
default:
c.Logger().Print(msg)
}
} }
c.Error(err) err = tmpErr
} }
}() }()
return next(c) return next(c)
} }
} }, nil
} }

View File

@ -2,82 +2,109 @@ package middleware
import ( import (
"bytes" "bytes"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestRecover(t *testing.T) { func TestRecover(t *testing.T) {
e := echo.New() e := echo.New()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
e.Logger.SetOutput(buf) e.Logger = &testLogger{output: buf}
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := Recover()(echo.HandlerFunc(func(c echo.Context) error { h := Recover()(func(c echo.Context) error {
panic("test") panic("test")
})) })
h(c) err := h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code) assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine")
assert.Contains(t, buf.String(), "PANIC RECOVER") assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
assert.Contains(t, buf.String(), "") // nothing is logged
} }
func TestRecoverWithConfig_LogLevel(t *testing.T) { func TestRecover_skipper(t *testing.T) {
tests := []struct { e := echo.New()
logLevel log.Lvl
levelName string
}{{
logLevel: log.DEBUG,
levelName: "DEBUG",
}, {
logLevel: log.INFO,
levelName: "INFO",
}, {
logLevel: log.WARN,
levelName: "WARN",
}, {
logLevel: log.ERROR,
levelName: "ERROR",
}, {
logLevel: log.OFF,
levelName: "OFF",
}}
for _, tt := range tests { req := httptest.NewRequest(http.MethodGet, "/", nil)
tt := tt rec := httptest.NewRecorder()
t.Run(tt.levelName, func(t *testing.T) { c := e.NewContext(req, rec)
config := RecoverConfig{
Skipper: func(c echo.Context) bool {
return true
},
}
h := RecoverWithConfig(config)(func(c echo.Context) error {
panic("testPANIC")
})
var err error
assert.Panics(t, func() {
err = h(c)
})
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
}
func TestRecoverWithConfig(t *testing.T) {
var testCases = []struct {
name string
givenNoPanic bool
whenConfig RecoverConfig
expectErrContain string
expectErr string
}{
{
name: "ok, default config",
whenConfig: DefaultRecoverConfig,
expectErrContain: "[PANIC RECOVER] testPANIC goroutine",
},
{
name: "ok, no panic",
givenNoPanic: true,
whenConfig: DefaultRecoverConfig,
expectErrContain: "",
},
{
name: "ok, DisablePrintStack",
whenConfig: RecoverConfig{
DisablePrintStack: true,
},
expectErr: "testPANIC",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New() e := echo.New()
e.Logger.SetLevel(log.DEBUG)
buf := new(bytes.Buffer)
e.Logger.SetOutput(buf)
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
config := DefaultRecoverConfig config := tc.whenConfig
config.LogLevel = tt.logLevel h := RecoverWithConfig(config)(func(c echo.Context) error {
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { if tc.givenNoPanic {
panic("test") return nil
})) }
panic("testPANIC")
})
h(c) err := h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code) if tc.expectErrContain != "" {
assert.Contains(t, err.Error(), tc.expectErrContain)
output := buf.String() } else if tc.expectErr != "" {
if tt.logLevel == log.OFF { assert.Contains(t, err.Error(), tc.expectErr)
assert.Empty(t, output)
} else { } else {
assert.Contains(t, output, "PANIC RECOVER") assert.NoError(t, err)
assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName))
} }
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
}) })
} }
} }

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"net/http" "net/http"
"strings" "strings"
@ -14,7 +15,9 @@ type RedirectConfig struct {
// Status code to be used when redirecting the request. // Status code to be used when redirecting the request.
// Optional. Default value http.StatusMovedPermanently. // Optional. Default value http.StatusMovedPermanently.
Code int `yaml:"code"` Code int
redirect redirectLogic
} }
// redirectLogic represents a function that given a scheme, host and uri // redirectLogic represents a function that given a scheme, host and uri
@ -24,29 +27,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string)
const www = "www." const www = "www."
// DefaultRedirectConfig is the default Redirect middleware config. // RedirectHTTPSConfig is the HTTPS Redirect middleware config.
var DefaultRedirectConfig = RedirectConfig{ var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS}
Skipper: DefaultSkipper,
Code: http.StatusMovedPermanently, // RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config.
} var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW}
// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config.
var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW}
// RedirectWWWConfig is the WWW Redirect middleware config.
var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW}
// RedirectNonWWWConfig is the non WWW Redirect middleware config.
var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW}
// HTTPSRedirect redirects http requests to https. // HTTPSRedirect redirects http requests to https.
// For example, http://labstack.com will be redirect to https://labstack.com. // For example, http://labstack.com will be redirect to https://labstack.com.
// //
// Usage `Echo#Pre(HTTPSRedirect())` // Usage `Echo#Pre(HTTPSRedirect())`
func HTTPSRedirect() echo.MiddlewareFunc { func HTTPSRedirect() echo.MiddlewareFunc {
return HTTPSRedirectWithConfig(DefaultRedirectConfig) return HTTPSRedirectWithConfig(RedirectHTTPSConfig)
} }
// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. // HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration.
// See `HTTPSRedirect()`.
func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) { config.redirect = redirectHTTPS
if scheme != "https" { return toMiddlewareOrPanic(config)
return true, "https://" + host + uri
}
return false, ""
})
} }
// HTTPSWWWRedirect redirects http requests to https www. // HTTPSWWWRedirect redirects http requests to https www.
@ -54,18 +61,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
// //
// Usage `Echo#Pre(HTTPSWWWRedirect())` // Usage `Echo#Pre(HTTPSWWWRedirect())`
func HTTPSWWWRedirect() echo.MiddlewareFunc { func HTTPSWWWRedirect() echo.MiddlewareFunc {
return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig) return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig)
} }
// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration.
// See `HTTPSWWWRedirect()`.
func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) { config.redirect = redirectHTTPSWWW
if scheme != "https" && !strings.HasPrefix(host, www) { return toMiddlewareOrPanic(config)
return true, "https://www." + host + uri
}
return false, ""
})
} }
// HTTPSNonWWWRedirect redirects http requests to https non www. // HTTPSNonWWWRedirect redirects http requests to https non www.
@ -73,19 +75,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
// //
// Usage `Echo#Pre(HTTPSNonWWWRedirect())` // Usage `Echo#Pre(HTTPSNonWWWRedirect())`
func HTTPSNonWWWRedirect() echo.MiddlewareFunc { func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig) return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig)
} }
// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration.
// See `HTTPSNonWWWRedirect()`.
func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (ok bool, url string) { config.redirect = redirectNonHTTPSWWW
if scheme != "https" { return toMiddlewareOrPanic(config)
host = strings.TrimPrefix(host, www)
return true, "https://" + host + uri
}
return false, ""
})
} }
// WWWRedirect redirects non www requests to www. // WWWRedirect redirects non www requests to www.
@ -93,18 +89,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
// //
// Usage `Echo#Pre(WWWRedirect())` // Usage `Echo#Pre(WWWRedirect())`
func WWWRedirect() echo.MiddlewareFunc { func WWWRedirect() echo.MiddlewareFunc {
return WWWRedirectWithConfig(DefaultRedirectConfig) return WWWRedirectWithConfig(RedirectWWWConfig)
} }
// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration.
// See `WWWRedirect()`.
func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) { config.redirect = redirectWWW
if !strings.HasPrefix(host, www) { return toMiddlewareOrPanic(config)
return true, scheme + "://www." + host + uri
}
return false, ""
})
} }
// NonWWWRedirect redirects www requests to non www. // NonWWWRedirect redirects www requests to non www.
@ -112,26 +103,25 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
// //
// Usage `Echo#Pre(NonWWWRedirect())` // Usage `Echo#Pre(NonWWWRedirect())`
func NonWWWRedirect() echo.MiddlewareFunc { func NonWWWRedirect() echo.MiddlewareFunc {
return NonWWWRedirectWithConfig(DefaultRedirectConfig) return NonWWWRedirectWithConfig(RedirectNonWWWConfig)
} }
// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration.
// See `NonWWWRedirect()`.
func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) { config.redirect = redirectNonWWW
if strings.HasPrefix(host, www) { return toMiddlewareOrPanic(config)
return true, scheme + "://" + host[4:] + uri
}
return false, ""
})
} }
func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { // ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration
func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultRedirectConfig.Skipper config.Skipper = DefaultSkipper
} }
if config.Code == 0 { if config.Code == 0 {
config.Code = DefaultRedirectConfig.Code config.Code = http.StatusMovedPermanently
}
if config.redirect == nil {
return nil, errors.New("redirectConfig is missing redirect function")
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -142,11 +132,47 @@ func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
req, scheme := c.Request(), c.Scheme() req, scheme := c.Request(), c.Scheme()
host := req.Host host := req.Host
if ok, url := cb(scheme, host, req.RequestURI); ok { if ok, url := config.redirect(scheme, host, req.RequestURI); ok {
return c.Redirect(config.Code, url) return c.Redirect(config.Code, url)
} }
return next(c) return next(c)
} }
} }, nil
}
var redirectHTTPS = func(scheme, host, uri string) (bool, string) {
if scheme != "https" {
return true, "https://" + host + uri
}
return false, ""
}
var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) {
if scheme != "https" && !strings.HasPrefix(host, www) {
return true, "https://www." + host + uri
}
return false, ""
}
var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) {
if scheme != "https" {
host = strings.TrimPrefix(host, www)
return true, "https://" + host + uri
}
return false, ""
}
var redirectWWW = func(scheme, host, uri string) (bool, string) {
if !strings.HasPrefix(host, www) {
return true, scheme + "://www." + host + uri
}
return false, ""
}
var redirectNonWWW = func(scheme, host, uri string) (bool, string) {
if strings.HasPrefix(host, www) {
return true, scheme + "://" + host[4:] + uri
}
return false, ""
} }

View File

@ -2,45 +2,38 @@ package middleware
import ( import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
) )
type ( // RequestIDConfig defines the config for RequestID middleware.
// RequestIDConfig defines the config for RequestID middleware. type RequestIDConfig struct {
RequestIDConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Generator defines a function to generate an ID. // Generator defines a function to generate an ID.
// Optional. Default value random.String(32). // Optional. Default value random.String(32).
Generator func() string Generator func() string
// RequestIDHandler defines a function which is executed for a request id. // RequestIDHandler defines a function which is executed for a request id.
RequestIDHandler func(echo.Context, string) RequestIDHandler func(c echo.Context, requestID string)
} }
)
var (
// DefaultRequestIDConfig is the default RequestID middleware config.
DefaultRequestIDConfig = RequestIDConfig{
Skipper: DefaultSkipper,
Generator: generator,
}
)
// RequestID returns a X-Request-ID middleware. // RequestID returns a X-Request-ID middleware.
func RequestID() echo.MiddlewareFunc { func RequestID() echo.MiddlewareFunc {
return RequestIDWithConfig(DefaultRequestIDConfig) return RequestIDWithConfig(RequestIDConfig{})
} }
// RequestIDWithConfig returns a X-Request-ID middleware with config. // RequestIDWithConfig returns a X-Request-ID middleware with config or panics on invalid configuration.
func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration
func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultRequestIDConfig.Skipper config.Skipper = DefaultSkipper
} }
if config.Generator == nil { if config.Generator == nil {
config.Generator = generator config.Generator = createRandomStringGenerator(32)
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -62,9 +55,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
} }, nil
}
func generator() string {
return random.String(32)
} }

View File

@ -18,25 +18,104 @@ func TestRequestID(t *testing.T) {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }
rid := RequestIDWithConfig(RequestIDConfig{}) rid := RequestID()
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
}
func TestMustRequestIDWithConfig_skipper(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
generatorCalled := false
e.Use(RequestIDWithConfig(RequestIDConfig{
Skipper: func(c echo.Context) bool {
return true
},
Generator: func() string {
generatorCalled = true
return "customGenerator"
},
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
assert.Equal(t, "test", res.Body.String())
assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "")
assert.False(t, generatorCalled)
}
func TestMustRequestIDWithConfig_customGenerator(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid := RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return "customGenerator" },
})
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
}
func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
called := false
rid := RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return "customGenerator" },
RequestIDHandler: func(c echo.Context, s string) {
called = true
},
})
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
assert.True(t, called)
}
func TestRequestIDWithConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid, err := RequestIDConfig{}.ToMiddleware()
assert.NoError(t, err)
h := rid(handler) h := rid(handler)
h(c) h(c)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
// Custom generator and handler // Custom generator
customID := "customGenerator"
calledHandler := false
rid = RequestIDWithConfig(RequestIDConfig{ rid = RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return customID }, Generator: func() string { return "customGenerator" },
RequestIDHandler: func(_ echo.Context, id string) {
calledHandler = true
assert.Equal(t, customID, id)
},
}) })
h = rid(handler) h = rid(handler)
h(c) h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
assert.True(t, calledHandler)
} }
func TestRequestID_IDNotAltered(t *testing.T) { func TestRequestID_IDNotAltered(t *testing.T) {

View File

@ -24,6 +24,7 @@ import (
// LogStatus: true, // LogStatus: true,
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// logger.Info(). // logger.Info().
// Date("request_start", v.StartTime).
// Str("URI", v.URI). // Str("URI", v.URI).
// Int("status", v.Status). // Int("status", v.Status).
// Msg("request") // Msg("request")
@ -39,6 +40,7 @@ import (
// LogStatus: true, // LogStatus: true,
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// logger.Info("request", // logger.Info("request",
// zap.Time("request_start", v.StartTime),
// zap.String("URI", v.URI), // zap.String("URI", v.URI),
// zap.Int("status", v.Status), // zap.Int("status", v.Status),
// ) // )
@ -54,8 +56,9 @@ import (
// LogStatus: true, // LogStatus: true,
// LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { // LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error {
// log.WithFields(logrus.Fields{ // log.WithFields(logrus.Fields{
// "URI": values.URI, // "request_start": values.StartTime,
// "status": values.Status, // "URI": values.URI,
// "status": values.Status,
// }).Info("request") // }).Info("request")
// //
// return nil // return nil
@ -158,15 +161,15 @@ type RequestLoggerValues struct {
// ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct. // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct.
ResponseSize int64 ResponseSize int64
// Headers are list of headers from request. Note: request can contain more than one header with same value so slice // Headers are list of headers from request. Note: request can contain more than one header with same value so slice
// of values is been logger for each given header. // of values is what will be returned/logged for each given header.
// Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
// names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
Headers map[string][]string Headers map[string][]string
// QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter
// with same name so slice of values is been logger for each given query param name. // with same name so slice of values is what will be returned/logged for each given query param name.
QueryParams map[string][]string QueryParams map[string][]string
// FormValues are list of form values from request body+URI. Note: request can contain more than one form value with // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with
// same name so slice of values is been logger for each given form value name. // same name so slice of values is what will be returned/logged for each given form value name.
FormValues map[string][]string FormValues map[string][]string
} }

View File

@ -289,7 +289,7 @@ func TestRequestLogger_allFields(t *testing.T) {
req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec).(echo.EditableContext)
c.SetPath("/test*") c.SetPath("/test*")

View File

@ -1,62 +1,58 @@
package middleware package middleware
import ( import (
"errors"
"regexp" "regexp"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // RewriteConfig defines the config for Rewrite middleware.
// RewriteConfig defines the config for Rewrite middleware. type RewriteConfig struct {
RewriteConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Rules defines the URL path rewrite rules. The values captured in asterisk can be // Rules defines the URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on. // retrieved by index e.g. $1, $2 and so on.
// Example: // Example:
// "/old": "/new", // "/old": "/new",
// "/api/*": "/$1", // "/api/*": "/$1",
// "/js/*": "/public/javascripts/$1", // "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2", // "/users/*/orders/*": "/user/$1/order/$2",
// Required. // Required.
Rules map[string]string `yaml:"rules"` Rules map[string]string
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example: // Example:
// "^/old/[0.9]+/": "/new", // "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1", // "^/api/.+?/(.*)": "/v2/$1",
RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` RegexRules map[*regexp.Regexp]string
} }
)
var (
// DefaultRewriteConfig is the default Rewrite middleware config.
DefaultRewriteConfig = RewriteConfig{
Skipper: DefaultSkipper,
}
)
// Rewrite returns a Rewrite middleware. // Rewrite returns a Rewrite middleware.
// //
// Rewrite middleware rewrites the URL path based on the provided rules. // Rewrite middleware rewrites the URL path based on the provided rules.
func Rewrite(rules map[string]string) echo.MiddlewareFunc { func Rewrite(rules map[string]string) echo.MiddlewareFunc {
c := DefaultRewriteConfig c := RewriteConfig{}
c.Rules = rules c.Rules = rules
return RewriteWithConfig(c) return RewriteWithConfig(c)
} }
// RewriteWithConfig returns a Rewrite middleware with config. // RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration.
// See: `Rewrite()`. //
// Rewrite middleware rewrites the URL path based on the provided rules.
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
if config.Rules == nil && config.RegexRules == nil { }
panic("echo: rewrite middleware requires url path rewrite rules or regex rules")
}
// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration
func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper config.Skipper = DefaultSkipper
}
if config.Rules == nil && config.RegexRules == nil {
return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules")
} }
if config.RegexRules == nil { if config.RegexRules == nil {
@ -77,5 +73,5 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
} }, nil
} }

View File

@ -24,10 +24,10 @@ func TestRewriteAfterRouting(t *testing.T) {
}, },
})) }))
e.GET("/public/*", func(c echo.Context) error { e.GET("/public/*", func(c echo.Context) error {
return c.String(http.StatusOK, c.Param("*")) return c.String(http.StatusOK, c.PathParam("*"))
}) })
e.GET("/*", func(c echo.Context) error { e.GET("/*", func(c echo.Context) error {
return c.String(http.StatusOK, c.Param("*")) return c.String(http.StatusOK, c.PathParam("*"))
}) })
var testCases = []struct { var testCases = []struct {
@ -90,20 +90,74 @@ func TestRewriteAfterRouting(t *testing.T) {
} }
} }
func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) {
assert.Panics(t, func() {
RewriteWithConfig(RewriteConfig{})
})
}
func TestMustRewriteWithConfig_skipper(t *testing.T) {
var testCases = []struct {
name string
givenSkipper func(c echo.Context) bool
whenURL string
expectURL string
expectStatus int
}{
{
name: "not skipped",
whenURL: "/old",
expectURL: "/new",
expectStatus: http.StatusOK,
},
{
name: "skipped",
givenSkipper: func(c echo.Context) bool {
return true
},
whenURL: "/old",
expectURL: "/old",
expectStatus: http.StatusNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Pre(RewriteWithConfig(
RewriteConfig{
Skipper: tc.givenSkipper,
Rules: map[string]string{"/old": "/new"}},
))
e.GET("/new", func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectURL, req.URL.EscapedPath())
assert.Equal(t, tc.expectStatus, rec.Code)
})
}
}
// Issue #1086 // Issue #1086
func TestEchoRewritePreMiddleware(t *testing.T) { func TestEchoRewritePreMiddleware(t *testing.T) {
e := echo.New() e := echo.New()
r := e.Router()
// Rewrite old url to new one // Rewrite old url to new one
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(Rewrite(map[string]string{ e.Pre(RewriteWithConfig(RewriteConfig{
"/old": "/new", Rules: map[string]string{"/old": "/new"}}),
}, )
))
// Route // Route
r.Add(http.MethodGet, "/new", func(c echo.Context) error { e.Add(http.MethodGet, "/new", func(c echo.Context) error {
return c.NoContent(http.StatusOK) return c.NoContent(http.StatusOK)
}) })
@ -117,7 +171,6 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
// Issue #1143 // Issue #1143
func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
e := echo.New() e := echo.New()
r := e.Router()
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(RewriteWithConfig(RewriteConfig{ e.Pre(RewriteWithConfig(RewriteConfig{
@ -127,10 +180,10 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
}, },
})) }))
r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
return c.String(http.StatusOK, "hosts") return c.String(http.StatusOK, "hosts")
}) })
r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { e.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
return c.String(http.StatusOK, "eng") return c.String(http.StatusOK, "eng")
}) })

View File

@ -6,84 +6,80 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // SecureConfig defines the config for Secure middleware.
// SecureConfig defines the config for Secure middleware. type SecureConfig struct {
SecureConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// XSSProtection provides protection against cross-site scripting attack (XSS) // XSSProtection provides protection against cross-site scripting attack (XSS)
// by setting the `X-XSS-Protection` header. // by setting the `X-XSS-Protection` header.
// Optional. Default value "1; mode=block". // Optional. Default value "1; mode=block".
XSSProtection string `yaml:"xss_protection"` XSSProtection string
// ContentTypeNosniff provides protection against overriding Content-Type // ContentTypeNosniff provides protection against overriding Content-Type
// header by setting the `X-Content-Type-Options` header. // header by setting the `X-Content-Type-Options` header.
// Optional. Default value "nosniff". // Optional. Default value "nosniff".
ContentTypeNosniff string `yaml:"content_type_nosniff"` ContentTypeNosniff string
// XFrameOptions can be used to indicate whether or not a browser should // XFrameOptions can be used to indicate whether or not a browser should
// be allowed to render a page in a <frame>, <iframe> or <object> . // be allowed to render a page in a <frame>, <iframe> or <object> .
// Sites can use this to avoid clickjacking attacks, by ensuring that their // Sites can use this to avoid clickjacking attacks, by ensuring that their
// content is not embedded into other sites.provides protection against // content is not embedded into other sites.provides protection against
// clickjacking. // clickjacking.
// Optional. Default value "SAMEORIGIN". // Optional. Default value "SAMEORIGIN".
// Possible values: // Possible values:
// - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself. // - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself.
// - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so. // - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so.
// - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin. // - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin.
XFrameOptions string `yaml:"x_frame_options"` XFrameOptions string
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how // HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how
// long (in seconds) browsers should remember that this site is only to // long (in seconds) browsers should remember that this site is only to
// be accessed using HTTPS. This reduces your exposure to some SSL-stripping // be accessed using HTTPS. This reduces your exposure to some SSL-stripping
// man-in-the-middle (MITM) attacks. // man-in-the-middle (MITM) attacks.
// Optional. Default value 0. // Optional. Default value 0.
HSTSMaxAge int `yaml:"hsts_max_age"` HSTSMaxAge int
// HSTSExcludeSubdomains won't include subdomains tag in the `Strict Transport Security` // HSTSExcludeSubdomains won't include subdomains tag in the `Strict Transport Security`
// header, excluding all subdomains from security policy. It has no effect // header, excluding all subdomains from security policy. It has no effect
// unless HSTSMaxAge is set to a non-zero value. // unless HSTSMaxAge is set to a non-zero value.
// Optional. Default value false. // Optional. Default value false.
HSTSExcludeSubdomains bool `yaml:"hsts_exclude_subdomains"` HSTSExcludeSubdomains bool
// ContentSecurityPolicy sets the `Content-Security-Policy` header providing // ContentSecurityPolicy sets the `Content-Security-Policy` header providing
// security against cross-site scripting (XSS), clickjacking and other code // security against cross-site scripting (XSS), clickjacking and other code
// injection attacks resulting from execution of malicious content in the // injection attacks resulting from execution of malicious content in the
// trusted web page context. // trusted web page context.
// Optional. Default value "". // Optional. Default value "".
ContentSecurityPolicy string `yaml:"content_security_policy"` ContentSecurityPolicy string
// CSPReportOnly would use the `Content-Security-Policy-Report-Only` header instead // CSPReportOnly would use the `Content-Security-Policy-Report-Only` header instead
// of the `Content-Security-Policy` header. This allows iterative updates of the // of the `Content-Security-Policy` header. This allows iterative updates of the
// content security policy by only reporting the violations that would // content security policy by only reporting the violations that would
// have occurred instead of blocking the resource. // have occurred instead of blocking the resource.
// Optional. Default value false. // Optional. Default value false.
CSPReportOnly bool `yaml:"csp_report_only"` CSPReportOnly bool
// HSTSPreloadEnabled will add the preload tag in the `Strict Transport Security` // HSTSPreloadEnabled will add the preload tag in the `Strict Transport Security`
// header, which enables the domain to be included in the HSTS preload list // header, which enables the domain to be included in the HSTS preload list
// maintained by Chrome (and used by Firefox and Safari): https://hstspreload.org/ // maintained by Chrome (and used by Firefox and Safari): https://hstspreload.org/
// Optional. Default value false. // Optional. Default value false.
HSTSPreloadEnabled bool `yaml:"hsts_preload_enabled"` HSTSPreloadEnabled bool
// ReferrerPolicy sets the `Referrer-Policy` header providing security against // ReferrerPolicy sets the `Referrer-Policy` header providing security against
// leaking potentially sensitive request paths to third parties. // leaking potentially sensitive request paths to third parties.
// Optional. Default value "". // Optional. Default value "".
ReferrerPolicy string `yaml:"referrer_policy"` ReferrerPolicy string
} }
)
var ( // DefaultSecureConfig is the default Secure middleware config.
// DefaultSecureConfig is the default Secure middleware config. var DefaultSecureConfig = SecureConfig{
DefaultSecureConfig = SecureConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, XSSProtection: "1; mode=block",
XSSProtection: "1; mode=block", ContentTypeNosniff: "nosniff",
ContentTypeNosniff: "nosniff", XFrameOptions: "SAMEORIGIN",
XFrameOptions: "SAMEORIGIN", HSTSPreloadEnabled: false,
HSTSPreloadEnabled: false, }
}
)
// Secure returns a Secure middleware. // Secure returns a Secure middleware.
// Secure middleware provides protection against cross-site scripting (XSS) attack, // Secure middleware provides protection against cross-site scripting (XSS) attack,
@ -93,9 +89,13 @@ func Secure() echo.MiddlewareFunc {
return SecureWithConfig(DefaultSecureConfig) return SecureWithConfig(DefaultSecureConfig)
} }
// SecureWithConfig returns a Secure middleware with config. // SecureWithConfig returns a Secure middleware with config or panics on invalid configuration.
// See: `Secure()`.
func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc { func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts SecureConfig to middleware or returns an error for invalid configuration
func (config SecureConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultSecureConfig.Skipper config.Skipper = DefaultSecureConfig.Skipper
@ -141,5 +141,5 @@ func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
} }, nil
} }

View File

@ -19,26 +19,40 @@ func TestSecure(t *testing.T) {
} }
// Default // Default
Secure()(h)(c) err := Secure()(h)(c)
assert.NoError(t, err)
assert.Equal(t, "1; mode=block", rec.Header().Get(echo.HeaderXXSSProtection)) assert.Equal(t, "1; mode=block", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "nosniff", rec.Header().Get(echo.HeaderXContentTypeOptions)) assert.Equal(t, "nosniff", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "SAMEORIGIN", rec.Header().Get(echo.HeaderXFrameOptions)) assert.Equal(t, "SAMEORIGIN", rec.Header().Get(echo.HeaderXFrameOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderStrictTransportSecurity)) assert.Equal(t, "", rec.Header().Get(echo.HeaderStrictTransportSecurity))
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy)) assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy))
assert.Equal(t, "", rec.Header().Get(echo.HeaderReferrerPolicy)) assert.Equal(t, "", rec.Header().Get(echo.HeaderReferrerPolicy))
}
// Custom func TestSecureWithConfig(t *testing.T) {
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXForwardedProto, "https") req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder() rec := httptest.NewRecorder()
c = e.NewContext(req, rec) c := e.NewContext(req, rec)
SecureWithConfig(SecureConfig{ mw, err := SecureConfig{
XSSProtection: "", XSSProtection: "",
ContentTypeNosniff: "", ContentTypeNosniff: "",
XFrameOptions: "", XFrameOptions: "",
HSTSMaxAge: 3600, HSTSMaxAge: 3600,
ContentSecurityPolicy: "default-src 'self'", ContentSecurityPolicy: "default-src 'self'",
ReferrerPolicy: "origin", ReferrerPolicy: "origin",
})(h)(c) }.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions))
@ -47,11 +61,21 @@ func TestSecure(t *testing.T) {
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly)) assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly))
assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy)) assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy))
}
func TestSecureWithConfig_CSPReportOnly(t *testing.T) {
// Custom with CSPReportOnly flag // Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXForwardedProto, "https") req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder() rec := httptest.NewRecorder()
c = e.NewContext(req, rec) c := e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
err := SecureWithConfig(SecureConfig{
XSSProtection: "", XSSProtection: "",
ContentTypeNosniff: "", ContentTypeNosniff: "",
XFrameOptions: "", XFrameOptions: "",
@ -60,6 +84,8 @@ func TestSecure(t *testing.T) {
CSPReportOnly: true, CSPReportOnly: true,
ReferrerPolicy: "origin", ReferrerPolicy: "origin",
})(h)(c) })(h)(c)
assert.NoError(t, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions)) assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions))
@ -67,25 +93,51 @@ func TestSecure(t *testing.T) {
assert.Equal(t, "default-src 'self'", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly)) assert.Equal(t, "default-src 'self'", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly))
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy)) assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy))
assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy)) assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy))
}
func TestSecureWithConfig_HSTSPreloadEnabled(t *testing.T) {
// Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
// Custom, with preload option enabled // Custom, with preload option enabled
req.Header.Set(echo.HeaderXForwardedProto, "https") req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder() rec := httptest.NewRecorder()
c = e.NewContext(req, rec) c := e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
err := SecureWithConfig(SecureConfig{
HSTSMaxAge: 3600, HSTSMaxAge: 3600,
HSTSPreloadEnabled: true, HSTSPreloadEnabled: true,
})(h)(c) })(h)(c)
assert.NoError(t, err)
assert.Equal(t, "max-age=3600; includeSubdomains; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity)) assert.Equal(t, "max-age=3600; includeSubdomains; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity))
}
func TestSecureWithConfig_HSTSExcludeSubdomains(t *testing.T) {
// Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
// Custom, with preload option enabled and subdomains excluded // Custom, with preload option enabled and subdomains excluded
req.Header.Set(echo.HeaderXForwardedProto, "https") req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder() rec := httptest.NewRecorder()
c = e.NewContext(req, rec) c := e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
err := SecureWithConfig(SecureConfig{
HSTSMaxAge: 3600, HSTSMaxAge: 3600,
HSTSPreloadEnabled: true, HSTSPreloadEnabled: true,
HSTSExcludeSubdomains: true, HSTSExcludeSubdomains: true,
})(h)(c) })(h)(c)
assert.NoError(t, err)
assert.Equal(t, "max-age=3600; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity)) assert.Equal(t, "max-age=3600; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity))
} }

View File

@ -1,44 +1,45 @@
package middleware package middleware
import ( import (
"errors"
"net/http"
"strings" "strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
type ( // AddTrailingSlashConfig is the middleware config for adding trailing slash to the request.
// TrailingSlashConfig defines the config for TrailingSlash middleware. type AddTrailingSlashConfig struct {
TrailingSlashConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Status code to be used when redirecting the request. // Status code to be used when redirecting the request.
// Optional, but when provided the request is redirected using this code. // Optional, but when provided the request is redirected using this code.
RedirectCode int `yaml:"redirect_code"` // Valid status codes: [300...308]
} RedirectCode int
) }
var (
// DefaultTrailingSlashConfig is the default TrailingSlash middleware config.
DefaultTrailingSlashConfig = TrailingSlashConfig{
Skipper: DefaultSkipper,
}
)
// AddTrailingSlash returns a root level (before router) middleware which adds a // AddTrailingSlash returns a root level (before router) middleware which adds a
// trailing slash to the request `URL#Path`. // trailing slash to the request `URL#Path`.
// //
// Usage `Echo#Pre(AddTrailingSlash())` // Usage `Echo#Pre(AddTrailingSlash())`
func AddTrailingSlash() echo.MiddlewareFunc { func AddTrailingSlash() echo.MiddlewareFunc {
return AddTrailingSlashWithConfig(DefaultTrailingSlashConfig) return AddTrailingSlashWithConfig(AddTrailingSlashConfig{})
} }
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config. // AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config or panics on invalid configuration.
// See `AddTrailingSlash()`. func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc {
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config)
// Defaults }
// ToMiddleware converts AddTrailingSlashConfig to middleware or returns an error for invalid configuration
func (config AddTrailingSlashConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper config.Skipper = DefaultSkipper
}
if config.RedirectCode != 0 && (config.RedirectCode < http.StatusMultipleChoices || config.RedirectCode > http.StatusPermanentRedirect) {
// this is same check as `echo.context.Redirect()` does, but we can check this before even serving the request.
return nil, errors.New("invalid redirect code for add trailing slash middleware")
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -69,7 +70,17 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
} }
return next(c) return next(c)
} }
} }, nil
}
// RemoveTrailingSlashConfig is the middleware config for removing trailing slash from the request.
type RemoveTrailingSlashConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Status code to be used when redirecting the request.
// Optional, but when provided the request is redirected using this code.
RedirectCode int
} }
// RemoveTrailingSlash returns a root level (before router) middleware which removes // RemoveTrailingSlash returns a root level (before router) middleware which removes
@ -77,15 +88,22 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
// //
// Usage `Echo#Pre(RemoveTrailingSlash())` // Usage `Echo#Pre(RemoveTrailingSlash())`
func RemoveTrailingSlash() echo.MiddlewareFunc { func RemoveTrailingSlash() echo.MiddlewareFunc {
return RemoveTrailingSlashWithConfig(TrailingSlashConfig{}) return RemoveTrailingSlashWithConfig(RemoveTrailingSlashConfig{})
} }
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config. // RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config or panics on invalid configuration.
// See `RemoveTrailingSlash()`. func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc {
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config)
// Defaults }
// ToMiddleware converts RemoveTrailingSlashConfig to middleware or returns an error for invalid configuration
func (config RemoveTrailingSlashConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper config.Skipper = DefaultSkipper
}
if config.RedirectCode != 0 && (config.RedirectCode < http.StatusMultipleChoices || config.RedirectCode > http.StatusPermanentRedirect) {
// this is same check as `echo.context.Redirect()` does, but we can check this before even serving the request.
return nil, errors.New("invalid redirect code for remove trailing slash middleware")
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -117,7 +135,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
} }
return next(c) return next(c)
} }
} }, nil
} }
func sanitizeURI(uri string) string { func sanitizeURI(uri string) string {

View File

@ -67,7 +67,7 @@ func TestAddTrailingSlashWithConfig(t *testing.T) {
t.Run(tc.whenURL, func(t *testing.T) { t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New() e := echo.New()
mw := AddTrailingSlashWithConfig(TrailingSlashConfig{ mw := AddTrailingSlashWithConfig(AddTrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently, RedirectCode: http.StatusMovedPermanently,
}) })
h := mw(func(c echo.Context) error { h := mw(func(c echo.Context) error {
@ -203,7 +203,7 @@ func TestRemoveTrailingSlashWithConfig(t *testing.T) {
t.Run(tc.whenURL, func(t *testing.T) { t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New() e := echo.New()
mw := RemoveTrailingSlashWithConfig(TrailingSlashConfig{ mw := RemoveTrailingSlashWithConfig(RemoveTrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently, RedirectCode: http.StatusMovedPermanently,
}) })
h := mw(func(c echo.Context) error { h := mw(func(c echo.Context) error {

View File

@ -1,55 +1,65 @@
package middleware package middleware
import ( import (
"errors"
"fmt" "fmt"
"html/template" "html/template"
"io"
"io/fs"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/bytes"
) )
type ( // StaticConfig defines the config for Static middleware.
// StaticConfig defines the config for Static middleware. type StaticConfig struct {
StaticConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// Root directory from where the static content is served. // Root directory from where the static content is served (relative to given Filesystem).
// Required. // `Root: "."` means root folder from Filesystem.
Root string `yaml:"root"` // Required.
Root string
// Index file for serving a directory. // Filesystem provides access to the static content.
// Optional. Default value "index.html". // Optional. Defaults to echo.Filesystem (serves files from `.` folder where executable is started)
Index string `yaml:"index"` Filesystem fs.FS
// Enable HTML5 mode by forwarding all not-found requests to root so that // Index file for serving a directory.
// SPA (single-page application) can handle the routing. // Optional. Default value "index.html".
// Optional. Default value false. Index string
HTML5 bool `yaml:"html5"`
// Enable directory browsing. // Enable HTML5 mode by forwarding all not-found requests to root so that
// Optional. Default value false. // SPA (single-page application) can handle the routing.
Browse bool `yaml:"browse"` // Optional. Default value false.
HTML5 bool
// Enable ignoring of the base of the URL path. // Enable directory browsing.
// Example: when assigning a static middleware to a non root path group, // Optional. Default value false.
// the filesystem path is not doubled Browse bool
// Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"`
// Filesystem provides access to the static content. // Enable ignoring of the base of the URL path.
// Optional. Defaults to http.Dir(config.Root) // Example: when assigning a static middleware to a non root path group,
Filesystem http.FileSystem `yaml:"-"` // the filesystem path is not doubled
} // Optional. Default value false.
) IgnoreBase bool
const html = ` // DisablePathUnescaping disables path parameter (param: *) unescaping. This is useful when router is set to unescape
// all parameter and doing it again in this middleware would corrupt filename that is requested.
DisablePathUnescaping bool
// DirectoryListTemplate is template to list directory contents
// Optional. Default to `directoryListHTMLTemplate` constant below.
DirectoryListTemplate string
}
const directoryListHTMLTemplate = `
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
@ -121,25 +131,26 @@ const html = `
</html> </html>
` `
var ( // DefaultStaticConfig is the default Static middleware config.
// DefaultStaticConfig is the default Static middleware config. var DefaultStaticConfig = StaticConfig{
DefaultStaticConfig = StaticConfig{ Skipper: DefaultSkipper,
Skipper: DefaultSkipper, Index: "index.html",
Index: "index.html", }
}
)
// Static returns a Static middleware to serves static content from the provided // Static returns a Static middleware to serves static content from the provided root directory.
// root directory.
func Static(root string) echo.MiddlewareFunc { func Static(root string) echo.MiddlewareFunc {
c := DefaultStaticConfig c := DefaultStaticConfig
c.Root = root c.Root = root
return StaticWithConfig(c) return StaticWithConfig(c)
} }
// StaticWithConfig returns a Static middleware with config. // StaticWithConfig returns a Static middleware to serves static content or panics on invalid configuration.
// See `Static()`.
func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts StaticConfig to middleware or returns an error for invalid configuration
func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults // Defaults
if config.Root == "" { if config.Root == "" {
config.Root = "." // For security we want to restrict to CWD. config.Root = "." // For security we want to restrict to CWD.
@ -150,30 +161,32 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
if config.Index == "" { if config.Index == "" {
config.Index = DefaultStaticConfig.Index config.Index = DefaultStaticConfig.Index
} }
if config.Filesystem == nil { if config.DirectoryListTemplate == "" {
config.Filesystem = http.Dir(config.Root) config.DirectoryListTemplate = directoryListHTMLTemplate
config.Root = "."
} }
// Index template dirListTemplate, err := template.New("index").Parse(config.DirectoryListTemplate)
t, err := template.New("index").Parse(html)
if err != nil { if err != nil {
panic(fmt.Sprintf("echo: %v", err)) return nil, fmt.Errorf("echo static middleware directory list template parsing error: %w", err)
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) { return func(c echo.Context) error {
if config.Skipper(c) { if config.Skipper(c) {
return next(c) return next(c)
} }
p := c.Request().URL.Path p := c.Request().URL.Path
if strings.HasSuffix(c.Path(), "*") { // When serving from a group, e.g. `/static*`. pathUnescape := true
p = c.Param("*") if c.RouteMatchType() == echo.RouteMatchFound && strings.HasSuffix(c.Path(), "*") { // When serving from a group, e.g. `/static*`.
p = c.PathParam("*")
pathUnescape = !config.DisablePathUnescaping // because router could already do PathUnescape
} }
p, err = url.PathUnescape(p) if pathUnescape {
if err != nil { p, err = url.PathUnescape(p)
return if err != nil {
return err
}
} }
name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security
@ -186,22 +199,29 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
} }
} }
file, err := openFile(config.Filesystem, name) currentFS := config.Filesystem
if currentFS == nil {
currentFS = c.Echo().Filesystem
}
file, err := openFile(currentFS, name)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
return err return err
} }
if err = next(c); err == nil { // when file does not exist let handler to handle that request. if it succeeds then we are done
return err err = next(c)
if err == nil {
return nil
} }
he, ok := err.(*echo.HTTPError) he, ok := err.(*echo.HTTPError)
if !(ok && config.HTML5 && he.Code == http.StatusNotFound) { if !(ok && config.HTML5 && he.Code == http.StatusNotFound) {
return err return err
} }
// is case HTML5 mode is enabled + echo 404 we serve index to the client
file, err = openFile(config.Filesystem, filepath.Join(config.Root, config.Index)) file, err = openFile(currentFS, filepath.Join(config.Root, config.Index))
if err != nil { if err != nil {
return err return err
} }
@ -215,10 +235,10 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
} }
if info.IsDir() { if info.IsDir() {
index, err := openFile(config.Filesystem, filepath.Join(name, config.Index)) index, err := openFile(currentFS, filepath.Join(name, config.Index))
if err != nil { if err != nil {
if config.Browse { if config.Browse {
return listDir(t, name, file, c.Response()) return listDir(dirListTemplate, name, currentFS, file, c.Response())
} }
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -238,25 +258,24 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
return serveFile(c, file, info) return serveFile(c, file, info)
} }
} }, nil
} }
func openFile(fs http.FileSystem, name string) (http.File, error) { func openFile(fs fs.FS, name string) (fs.File, error) {
pathWithSlashes := filepath.ToSlash(name) pathWithSlashes := filepath.ToSlash(name)
return fs.Open(pathWithSlashes) return fs.Open(pathWithSlashes)
} }
func serveFile(c echo.Context, file http.File, info os.FileInfo) error { func serveFile(c echo.Context, file fs.File, info os.FileInfo) error {
http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), file) ff, ok := file.(io.ReadSeeker)
if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), ff)
return nil return nil
} }
func listDir(t *template.Template, name string, dir http.File, res *echo.Response) (err error) { func listDir(t *template.Template, name string, filesystem fs.FS, dir fs.File, res *echo.Response) error {
files, err := dir.Readdir(-1)
if err != nil {
return
}
// Create directory index // Create directory index
res.Header().Set(echo.HeaderContentType, echo.MIMETextHTMLCharsetUTF8) res.Header().Set(echo.HeaderContentType, echo.MIMETextHTMLCharsetUTF8)
data := struct { data := struct {
@ -265,12 +284,60 @@ func listDir(t *template.Template, name string, dir http.File, res *echo.Respons
}{ }{
Name: name, Name: name,
} }
for _, f := range files { err := fs.WalkDir(filesystem, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
info, infoErr := d.Info()
if infoErr != nil {
return fmt.Errorf("static middleware list dir error when getting file info: %w", err)
}
data.Files = append(data.Files, struct { data.Files = append(data.Files, struct {
Name string Name string
Dir bool Dir bool
Size string Size string
}{f.Name(), f.IsDir(), bytes.Format(f.Size())}) }{d.Name(), d.IsDir(), format(info.Size())})
return nil
})
if err != nil {
return err
} }
return t.Execute(res, data) return t.Execute(res, data)
} }
// format formats bytes integer to human readable string.
// For example, 31323 bytes will return 30.59KB.
func format(b int64) string {
multiple := ""
value := float64(b)
switch {
case b >= EB:
value /= float64(EB)
multiple = "EB"
case b >= PB:
value /= float64(PB)
multiple = "PB"
case b >= TB:
value /= float64(TB)
multiple = "TB"
case b >= GB:
value /= float64(GB)
multiple = "GB"
case b >= MB:
value /= float64(MB)
multiple = "MB"
case b >= KB:
value /= float64(KB)
multiple = "KB"
case b == 0:
return "0"
default:
return strconv.FormatInt(b, 10) + "B"
}
return fmt.Sprintf("%.2f%s", value, multiple)
}

View File

@ -81,14 +81,16 @@ func TestStatic_CustomFS(t *testing.T) {
config := StaticConfig{ config := StaticConfig{
Root: ".", Root: ".",
Filesystem: http.FS(tc.filesystem), Filesystem: tc.filesystem,
} }
if tc.root != "" { if tc.root != "" {
config.Root = tc.root config.Root = tc.root
} }
middlewareFunc := StaticWithConfig(config) middlewareFunc, err := config.ToMiddleware()
assert.NoError(t, err)
e.Use(middlewareFunc) e.Use(middlewareFunc)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"strings" "strings"
"testing" "testing"
@ -10,6 +11,37 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestStatic_useCaseForApiAndSPAs(t *testing.T) {
e := echo.New()
// serve single page application (SPA) files from server root
e.Use(StaticWithConfig(StaticConfig{
Root: ".",
// by default Echo filesystem is fixed to `./` but this does not allow `../` (moving up in folder structure past filesystem root)
Filesystem: os.DirFS("../_fixture"),
}))
// all requests to `/api/*` will end up in echo handlers (assuming there is not `api` folder and files)
api := e.Group("/api")
users := api.Group("/users")
users.GET("/info", func(c echo.Context) error {
return c.String(http.StatusOK, "users info")
})
req := httptest.NewRequest(http.MethodGet, "/api/users/info", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "users info", rec.Body.String())
req = httptest.NewRequest(http.MethodGet, "/index.html", nil)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, rec.Body.String(), "<title>Echo</title>")
}
func TestStatic(t *testing.T) { func TestStatic(t *testing.T) {
var testCases = []struct { var testCases = []struct {
name string name string
@ -35,7 +67,7 @@ func TestStatic(t *testing.T) {
{ {
name: "ok, when html5 mode serve index for any static file that does not exist", name: "ok, when html5 mode serve index for any static file that does not exist",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "../_fixture", Root: "_fixture",
HTML5: true, HTML5: true,
}, },
whenURL: "/random", whenURL: "/random",
@ -45,7 +77,7 @@ func TestStatic(t *testing.T) {
{ {
name: "ok, serve index as directory index listing files directory", name: "ok, serve index as directory index listing files directory",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "../_fixture/certs", Root: "_fixture/certs",
Browse: true, Browse: true,
}, },
whenURL: "/", whenURL: "/",
@ -55,7 +87,7 @@ func TestStatic(t *testing.T) {
{ {
name: "ok, serve directory index with IgnoreBase and browse", name: "ok, serve directory index with IgnoreBase and browse",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored Root: "_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
IgnoreBase: true, IgnoreBase: true,
Browse: true, Browse: true,
}, },
@ -67,7 +99,7 @@ func TestStatic(t *testing.T) {
{ {
name: "ok, serve file with IgnoreBase", name: "ok, serve file with IgnoreBase",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored Root: "_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
IgnoreBase: true, IgnoreBase: true,
Browse: true, Browse: true,
}, },
@ -95,15 +127,27 @@ func TestStatic(t *testing.T) {
expectContains: "{\"message\":\"Not Found\"}\n", expectContains: "{\"message\":\"Not Found\"}\n",
}, },
{ {
name: "ok, do not serve file, when a handler took care of the request", name: "ok, when no file then a handler will care of the request",
whenURL: "/regular-handler", whenURL: "/regular-handler",
expectCode: http.StatusOK, expectCode: http.StatusOK,
expectContains: "ok", expectContains: "ok",
}, },
{
name: "ok, skip middleware and serve handler",
givenConfig: &StaticConfig{
Root: "_fixture/images/",
Skipper: func(c echo.Context) bool {
return true
},
},
whenURL: "/walle.png",
expectCode: http.StatusTeapot,
expectContains: "walle",
},
{ {
name: "nok, when html5 fail if the index file does not exist", name: "nok, when html5 fail if the index file does not exist",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "../_fixture", Root: "_fixture",
HTML5: true, HTML5: true,
Index: "missing.html", Index: "missing.html",
}, },
@ -114,7 +158,7 @@ func TestStatic(t *testing.T) {
name: "ok, serve from http.FileSystem", name: "ok, serve from http.FileSystem",
givenConfig: &StaticConfig{ givenConfig: &StaticConfig{
Root: "_fixture", Root: "_fixture",
Filesystem: http.Dir(".."), Filesystem: os.DirFS(".."),
}, },
whenURL: "/", whenURL: "/",
expectCode: http.StatusOK, expectCode: http.StatusOK,
@ -125,8 +169,9 @@ func TestStatic(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
e := echo.New() e := echo.New()
e.Filesystem = os.DirFS("../")
config := StaticConfig{Root: "../_fixture"} config := StaticConfig{Root: "_fixture"}
if tc.givenConfig != nil { if tc.givenConfig != nil {
config = *tc.givenConfig config = *tc.givenConfig
} }
@ -136,14 +181,17 @@ func TestStatic(t *testing.T) {
subGroup := e.Group(tc.givenAttachedToGroup, middlewareFunc) subGroup := e.Group(tc.givenAttachedToGroup, middlewareFunc)
// group without http handlers (routes) does not do anything. // group without http handlers (routes) does not do anything.
// Request is matched against http handlers (routes) that have group middleware attached to them // Request is matched against http handlers (routes) that have group middleware attached to them
subGroup.GET("", echo.NotFoundHandler) subGroup.GET("", func(c echo.Context) error { return echo.ErrNotFound })
subGroup.GET("/*", echo.NotFoundHandler) subGroup.GET("/*", func(c echo.Context) error { return echo.ErrNotFound })
} else { } else {
// middleware is on root level // middleware is on root level
e.Use(middlewareFunc) e.Use(middlewareFunc)
e.GET("/regular-handler", func(c echo.Context) error { e.GET("/regular-handler", func(c echo.Context) error {
return c.String(http.StatusOK, "ok") return c.String(http.StatusOK, "ok")
}) })
e.GET("/walle.png", func(c echo.Context) error {
return c.String(http.StatusTeapot, "walle")
})
} }
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
@ -177,7 +225,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "ok", name: "ok",
givenPrefix: "/images", givenPrefix: "/images",
givenRoot: "../_fixture/images", givenRoot: "_fixture/images",
whenURL: "/group/images/walle.png", whenURL: "/group/images/walle.png",
expectStatus: http.StatusOK, expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
@ -185,7 +233,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "No file", name: "No file",
givenPrefix: "/images", givenPrefix: "/images",
givenRoot: "../_fixture/scripts", givenRoot: "_fixture/scripts",
whenURL: "/group/images/bolt.png", whenURL: "/group/images/bolt.png",
expectStatus: http.StatusNotFound, expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -193,7 +241,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Directory not found (no trailing slash)", name: "Directory not found (no trailing slash)",
givenPrefix: "/images", givenPrefix: "/images",
givenRoot: "../_fixture/images", givenRoot: "_fixture/images",
whenURL: "/group/images/", whenURL: "/group/images/",
expectStatus: http.StatusNotFound, expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -201,7 +249,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Directory redirect", name: "Directory redirect",
givenPrefix: "/", givenPrefix: "/",
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/group/folder", whenURL: "/group/folder",
expectStatus: http.StatusMovedPermanently, expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/group/folder/", expectHeaderLocation: "/group/folder/",
@ -211,7 +259,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
name: "Prefixed directory 404 (request URL without slash)", name: "Prefixed directory 404 (request URL without slash)",
givenGroup: "_fixture", givenGroup: "_fixture",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/_fixture/folder", // no trailing slash whenURL: "/_fixture/folder", // no trailing slash
expectStatus: http.StatusNotFound, expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -220,7 +268,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
name: "Prefixed directory redirect (without slash redirect to slash)", name: "Prefixed directory redirect (without slash redirect to slash)",
givenGroup: "_fixture", givenGroup: "_fixture",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/_fixture/folder", // no trailing slash whenURL: "/_fixture/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently, expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/_fixture/folder/", expectHeaderLocation: "/_fixture/folder/",
@ -229,7 +277,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Directory with index.html", name: "Directory with index.html",
givenPrefix: "/", givenPrefix: "/",
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/group/", whenURL: "/group/",
expectStatus: http.StatusOK, expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>", expectBodyStartsWith: "<!doctype html>",
@ -237,7 +285,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Prefixed directory with index.html (prefix ending with slash)", name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/", givenPrefix: "/assets/",
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/group/assets/", whenURL: "/group/assets/",
expectStatus: http.StatusOK, expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>", expectBodyStartsWith: "<!doctype html>",
@ -245,7 +293,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Prefixed directory with index.html (prefix ending without slash)", name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets", givenPrefix: "/assets",
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/group/assets/", whenURL: "/group/assets/",
expectStatus: http.StatusOK, expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>", expectBodyStartsWith: "<!doctype html>",
@ -253,7 +301,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "Sub-directory with index.html", name: "Sub-directory with index.html",
givenPrefix: "/", givenPrefix: "/",
givenRoot: "../_fixture", givenRoot: "_fixture",
whenURL: "/group/folder/", whenURL: "/group/folder/",
expectStatus: http.StatusOK, expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>", expectBodyStartsWith: "<!doctype html>",
@ -261,7 +309,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "do not allow directory traversal (backslash - windows separator)", name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/", givenPrefix: "/",
givenRoot: "../_fixture/", givenRoot: "_fixture/",
whenURL: `/group/..\\middleware/basic_auth.go`, whenURL: `/group/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound, expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -269,7 +317,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{ {
name: "do not allow directory traversal (slash - unix separator)", name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/", givenPrefix: "/",
givenRoot: "../_fixture/", givenRoot: "_fixture/",
whenURL: `/group/../middleware/basic_auth.go`, whenURL: `/group/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound, expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -279,6 +327,8 @@ func TestStatic_GroupWithStatic(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
e := echo.New() e := echo.New()
e.Filesystem = os.DirFS("../") // so we can access test files
group := "/group" group := "/group"
if tc.givenGroup != "" { if tc.givenGroup != "" {
group = tc.givenGroup group = tc.givenGroup
@ -288,7 +338,9 @@ func TestStatic_GroupWithStatic(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code) assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String() body := rec.Body.String()
if tc.expectBodyStartsWith != "" { if tc.expectBodyStartsWith != "" {
@ -306,3 +358,73 @@ func TestStatic_GroupWithStatic(t *testing.T) {
}) })
} }
} }
func TestMustStaticWithConfig_panicsInvalidDirListTemplate(t *testing.T) {
assert.Panics(t, func() {
StaticWithConfig(StaticConfig{DirectoryListTemplate: `{{}`})
})
}
func TestFormat(t *testing.T) {
var testCases = []struct {
name string
when int64
expect string
}{
{
name: "byte",
when: 0,
expect: "0",
},
{
name: "bytes",
when: 515,
expect: "515B",
},
{
name: "KB",
when: 31323,
expect: "30.59KB",
},
{
name: "MB",
when: 13231323,
expect: "12.62MB",
},
{
name: "GB",
when: 7323232398,
expect: "6.82GB",
},
{
name: "TB",
when: 1_099_511_627_776,
expect: "1.00TB",
},
{
name: "PB",
when: 9923232398434432,
expect: "8.81PB",
},
{
// test with 7EB because of https://github.com/labstack/gommon/pull/38 and https://github.com/labstack/gommon/pull/43
//
// 8 exbi equals 2^64, therefore it cannot be stored in int64. The tests use
// the fact that on x86_64 the following expressions holds true:
// int64(0) - 1 == math.MaxInt64.
//
// However, this is not true for other platforms, specifically aarch64, s390x
// and ppc64le.
name: "EB",
when: 8070450532247929000,
expect: "7.00EB",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := format(tc.when)
assert.Equal(t, tc.expect, result)
})
}
}

View File

@ -2,10 +2,9 @@ package middleware
import ( import (
"context" "context"
"github.com/labstack/echo/v4"
"net/http" "net/http"
"time" "time"
"github.com/labstack/echo/v4"
) )
// --------------------------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------------------------
@ -55,51 +54,43 @@ import (
// }) // })
// //
type ( // TimeoutConfig defines the config for Timeout middleware.
// TimeoutConfig defines the config for Timeout middleware. type TimeoutConfig struct {
TimeoutConfig struct { // Skipper defines a function to skip middleware.
// Skipper defines a function to skip middleware. Skipper Skipper
Skipper Skipper
// ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
// It can be used to define a custom timeout error message // It can be used to define a custom timeout error message
ErrorMessage string ErrorMessage string
// OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
// request timeouted and we already had sent the error code (503) and message response to the client. // request timeouted and we already had sent the error code (503) and message response to the client.
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()` // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
OnTimeoutRouteErrorHandler func(err error, c echo.Context) OnTimeoutRouteErrorHandler func(c echo.Context, err error)
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout // Timeout configures a timeout for the middleware, defaults to 0 for no timeout
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable // difference over 500microseconds (0.5millisecond) response seems to be reliable
Timeout time.Duration Timeout time.Duration
} }
)
var (
// DefaultTimeoutConfig is the default Timeout middleware config.
DefaultTimeoutConfig = TimeoutConfig{
Skipper: DefaultSkipper,
Timeout: 0,
ErrorMessage: "",
}
)
// Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler // Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler
// call runs for longer than its time limit. NB: timeout does not stop handler execution. // call runs for longer than its time limit. NB: timeout does not stop handler execution.
func Timeout() echo.MiddlewareFunc { func Timeout() echo.MiddlewareFunc {
return TimeoutWithConfig(DefaultTimeoutConfig) return TimeoutWithConfig(TimeoutConfig{})
} }
// TimeoutWithConfig returns a Timeout middleware with config. // TimeoutWithConfig returns a Timeout middleware with config or panics on invalid configuration.
// See: `Timeout()`.
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
// Defaults return toMiddlewareOrPanic(config)
}
// ToMiddleware converts TimeoutConfig to middleware or returns an error for invalid configuration
func (config TimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultTimeoutConfig.Skipper config.Skipper = DefaultSkipper
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -108,29 +99,30 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
errChan := make(chan error, 1)
handlerWrapper := echoHandlerFuncWrapper{ handlerWrapper := echoHandlerFuncWrapper{
ctx: c, ctx: c,
handler: next, handler: next,
errChan: make(chan error, 1), errChan: errChan,
errHandler: config.OnTimeoutRouteErrorHandler, errHandler: config.OnTimeoutRouteErrorHandler,
} }
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage) handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
handler.ServeHTTP(c.Response().Writer, c.Request()) handler.ServeHTTP(c.Response().Writer, c.Request())
select { select {
case err := <-handlerWrapper.errChan: case err := <-errChan:
return err return err
default: default:
return nil return nil
} }
} }
} }, nil
} }
type echoHandlerFuncWrapper struct { type echoHandlerFuncWrapper struct {
ctx echo.Context ctx echo.Context
handler echo.HandlerFunc handler echo.HandlerFunc
errHandler func(err error, c echo.Context) errHandler func(c echo.Context, err error)
errChan chan error errChan chan error
} }
@ -156,7 +148,7 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques
err := t.handler(t.ctx) err := t.handler(t.ctx)
if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded { if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded {
if err != nil && t.errHandler != nil { if err != nil && t.errHandler != nil {
t.errHandler(err, t.ctx) t.errHandler(t.ctx, err)
} }
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
} }

View File

@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
@ -14,9 +16,6 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
) )
func TestTimeoutSkipper(t *testing.T) { func TestTimeoutSkipper(t *testing.T) {
@ -111,7 +110,7 @@ func TestTimeoutOnTimeoutRouteErrorHandler(t *testing.T) {
actualErrChan := make(chan error, 1) actualErrChan := make(chan error, 1)
m := TimeoutWithConfig(TimeoutConfig{ m := TimeoutWithConfig(TimeoutConfig{
Timeout: 1 * time.Millisecond, Timeout: 1 * time.Millisecond,
OnTimeoutRouteErrorHandler: func(err error, c echo.Context) { OnTimeoutRouteErrorHandler: func(c echo.Context, err error) {
actualErrChan <- err actualErrChan <- err
}, },
}) })
@ -360,7 +359,7 @@ func TestTimeoutWithFullEchoStack(t *testing.T) {
e := echo.New() e := echo.New()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
e.Logger.SetOutput(buf) e.Logger = &testLogger{output: buf}
// NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first
// FIXME: I have no idea how to fix this without adding mutexes. // FIXME: I have no idea how to fix this without adding mutexes.
@ -424,7 +423,7 @@ func startServer(e *echo.Echo) (*http.Server, string, error) {
s := http.Server{ s := http.Server{
Handler: e, Handler: e,
ErrorLog: log.New(e.Logger.Output(), "echo:", 0), ErrorLog: log.New(e.Logger, "echo:", 0),
} }
errCh := make(chan error) errCh := make(chan error)

View File

@ -1,9 +1,27 @@
package middleware package middleware
import ( import (
"crypto/rand"
"fmt"
"strings" "strings"
) )
const (
_ = int64(1 << (10 * iota)) // ignore first value by assigning to blank identifier
// KB is 1 KiloByte = 1024 bytes
KB
// MB is 1 Megabyte = 1_048_576 bytes
MB
// GB is 1 Gigabyte = 1_073_741_824 bytes
GB
// TB is 1 Terabyte = 1_099_511_627_776 bytes
TB
// PB is 1 Petabyte = 1_125_899_906_842_624 bytes
PB
// EB is 1 Exabyte = 1_152_921_504_606_847_000 bytes
EB
)
func matchScheme(domain, pattern string) bool { func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":") didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":") pidx := strings.Index(pattern, ":")
@ -52,3 +70,24 @@ func matchSubdomain(domain, pattern string) bool {
} }
return false return false
} }
func createRandomStringGenerator(length uint8) func() string {
return func() string {
return randomString(length)
}
}
func randomString(length uint8) string {
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
// we are out of random. let the request fail
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
}
for i, b := range bytes {
bytes[i] = charset[b%byte(len(charset))]
}
return string(bytes)
}

View File

@ -1,11 +1,23 @@
package middleware package middleware
import ( import (
"testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"io"
"testing"
) )
type testLogger struct {
output io.Writer
}
func (l *testLogger) Write(p []byte) (n int, err error) {
return l.output.Write(p)
}
func (l *testLogger) Error(err error) {
_, _ = l.output.Write([]byte(err.Error()))
}
func Test_matchScheme(t *testing.T) { func Test_matchScheme(t *testing.T) {
tests := []struct { tests := []struct {
domain, pattern string domain, pattern string
@ -93,3 +105,27 @@ func Test_matchSubdomain(t *testing.T) {
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern)) assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
} }
} }
func TestRandomString(t *testing.T) {
var testCases = []struct {
name string
whenLength uint8
expect string
}{
{
name: "ok, 16",
whenLength: 16,
},
{
name: "ok, 32",
whenLength: 32,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
uid := randomString(tc.whenLength)
assert.Len(t, uid, int(tc.whenLength))
})
}
}

View File

@ -2,24 +2,23 @@ package echo
import ( import (
"bufio" "bufio"
"errors"
"net" "net"
"net/http" "net/http"
) )
type ( // Response wraps an http.ResponseWriter and implements its interface to be used
// Response wraps an http.ResponseWriter and implements its interface to be used // by an HTTP handler to construct an HTTP response.
// by an HTTP handler to construct an HTTP response. // See: https://golang.org/pkg/net/http/#ResponseWriter
// See: https://golang.org/pkg/net/http/#ResponseWriter type Response struct {
Response struct { echo *Echo
echo *Echo beforeFuncs []func()
beforeFuncs []func() afterFuncs []func()
afterFuncs []func() Writer http.ResponseWriter
Writer http.ResponseWriter Status int
Status int Size int64
Size int64 Committed bool
Committed bool }
}
)
// NewResponse creates a new instance of Response. // NewResponse creates a new instance of Response.
func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) { func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) {
@ -47,13 +46,15 @@ func (r *Response) After(fn func()) {
r.afterFuncs = append(r.afterFuncs, fn) r.afterFuncs = append(r.afterFuncs, fn)
} }
var errHeaderAlreadyCommitted = errors.New("response already committed")
// WriteHeader sends an HTTP response header with status code. If WriteHeader is // WriteHeader sends an HTTP response header with status code. If WriteHeader is
// not called explicitly, the first call to Write will trigger an implicit // not called explicitly, the first call to Write will trigger an implicit
// WriteHeader(http.StatusOK). Thus explicit calls to WriteHeader are mainly // WriteHeader(http.StatusOK). Thus explicit calls to WriteHeader are mainly
// used to send error codes. // used to send error codes.
func (r *Response) WriteHeader(code int) { func (r *Response) WriteHeader(code int) {
if r.Committed { if r.Committed {
r.echo.Logger.Warn("response already committed") r.echo.Logger.Error(errHeaderAlreadyCommitted)
return return
} }
r.Status = code r.Status = code

182
route.go Normal file
View File

@ -0,0 +1,182 @@
package echo
import (
"bytes"
"errors"
"fmt"
"reflect"
"runtime"
)
// Route contains information to adding/registering new route with the router.
// Method+Path pair uniquely identifies the Route. It is mandatory to provide Method+Path+Handler fields.
type Route struct {
Method string
Path string
Handler HandlerFunc
Middlewares []MiddlewareFunc
Name string
}
// ToRouteInfo converts Route to RouteInfo
func (r Route) ToRouteInfo(params []string) RouteInfo {
name := r.Name
if name == "" {
name = r.Method + ":" + r.Path
}
return routeInfo{
method: r.Method,
path: r.Path,
params: append([]string(nil), params...),
name: name,
}
}
// ToRoute returns Route which Router uses to register the method handler for path.
func (r Route) ToRoute() Route {
return r
}
// ForGroup recreates Route with added group prefix and group middlewares it is grouped to.
func (r Route) ForGroup(pathPrefix string, middlewares []MiddlewareFunc) Routable {
r.Path = pathPrefix + r.Path
if len(middlewares) > 0 {
m := make([]MiddlewareFunc, 0, len(middlewares)+len(r.Middlewares))
m = append(m, middlewares...)
m = append(m, r.Middlewares...)
r.Middlewares = m
}
return r
}
type routeInfo struct {
method string
path string
params []string
name string
}
func (r routeInfo) Method() string {
return r.method
}
func (r routeInfo) Path() string {
return r.path
}
func (r routeInfo) Params() []string {
return append([]string(nil), r.params...)
}
func (r routeInfo) Name() string {
return r.name
}
// Reverse reverses route to URL string by replacing path parameters with given params values.
func (r routeInfo) Reverse(params ...interface{}) string {
uri := new(bytes.Buffer)
ln := len(params)
n := 0
for i, l := 0, len(r.path); i < l; i++ {
if (r.path[i] == paramLabel || r.path[i] == anyLabel) && n < ln {
for ; i < l && r.path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))
n++
}
if i < l {
uri.WriteByte(r.path[i])
}
}
return uri.String()
}
// HandlerName returns string name for given function.
func HandlerName(h HandlerFunc) string {
t := reflect.ValueOf(h).Type()
if t.Kind() == reflect.Func {
return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
}
return t.String()
}
// Reverse reverses route to URL string by replacing path parameters with given params values.
func (r Routes) Reverse(name string, params ...interface{}) (string, error) {
for _, rr := range r {
if rr.Name() == name {
return rr.Reverse(params...), nil
}
}
return "", errors.New("route not found")
}
// FindByMethodPath searched for matching route info by method and path
func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error) {
if r == nil {
return nil, errors.New("route not found by method and path")
}
for _, rr := range r {
if rr.Method() == method && rr.Path() == path {
return rr, nil
}
}
return nil, errors.New("route not found by method and path")
}
// FilterByMethod searched for matching route info by method
func (r Routes) FilterByMethod(method string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by method")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Method() == method {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by method")
}
return result, nil
}
// FilterByPath searched for matching route info by path
func (r Routes) FilterByPath(path string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by path")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Path() == path {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by path")
}
return result, nil
}
// FilterByName searched for matching route info by name
func (r Routes) FilterByName(name string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by name")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Name() == name {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by name")
}
return result, nil
}

423
route_test.go Normal file
View File

@ -0,0 +1,423 @@
package echo
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)
var myNamedHandler = func(c Context) error {
return nil
}
type NameStruct struct {
}
func (n *NameStruct) getUsers(c Context) error {
return nil
}
func TestHandlerName(t *testing.T) {
myNameFuncVar := func(c Context) error {
return nil
}
tmp := NameStruct{}
var testCases = []struct {
name string
whenHandlerFunc HandlerFunc
expect string
}{
{
name: "ok, func as anonymous func",
whenHandlerFunc: func(c Context) error {
return nil
},
expect: "github.com/labstack/echo/v4.TestHandlerName.func2",
},
{
name: "ok, func as named package variable",
whenHandlerFunc: myNamedHandler,
expect: "github.com/labstack/echo/v4.glob..func3",
},
{
name: "ok, func as named function variable",
whenHandlerFunc: myNameFuncVar,
expect: "github.com/labstack/echo/v4.TestHandlerName.func1",
},
{
name: "ok, func as struct method",
whenHandlerFunc: tmp.getUsers,
expect: "github.com/labstack/echo/v4.(*NameStruct).getUsers-fm",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
name := HandlerName(tc.whenHandlerFunc)
assert.Equal(t, tc.expect, name)
})
}
}
func TestHandlerName_differentFuncSameName(t *testing.T) {
handlerCreator := func(name string) HandlerFunc {
return func(c Context) error {
return c.String(http.StatusTeapot, name)
}
}
h1 := handlerCreator("name1")
assert.Equal(t, "github.com/labstack/echo/v4.TestHandlerName_differentFuncSameName.func2", HandlerName(h1))
h2 := handlerCreator("name2")
assert.Equal(t, "github.com/labstack/echo/v4.TestHandlerName_differentFuncSameName.func3", HandlerName(h2))
}
func TestRoute_ToRouteInfo(t *testing.T) {
var testCases = []struct {
name string
given Route
whenParams []string
expect RouteInfo
}{
{
name: "ok, no params, with name",
given: Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
},
expect: routeInfo{
method: http.MethodGet,
path: "/test",
params: nil,
name: "test route",
},
},
{
name: "ok, params",
given: Route{
Method: http.MethodGet,
Path: "users/:id/:file", // no slash prefix
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "",
},
whenParams: []string{"id", "file"},
expect: routeInfo{
method: http.MethodGet,
path: "users/:id/:file",
params: []string{"id", "file"},
name: "GET:users/:id/:file",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ri := tc.given.ToRouteInfo(tc.whenParams)
assert.Equal(t, tc.expect, ri)
})
}
}
func TestRoute_ToRoute(t *testing.T) {
route := Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
}
r := route.ToRoute()
assert.Equal(t, r.Method, http.MethodGet)
assert.Equal(t, r.Path, "/test")
assert.NotNil(t, r.Handler)
assert.Nil(t, r.Middlewares)
assert.Equal(t, r.Name, "test route")
}
func TestRoute_ForGroup(t *testing.T) {
route := Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
}
mw := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
return next(c)
}
}
gr := route.ForGroup("/users", []MiddlewareFunc{mw})
r := gr.ToRoute()
assert.Equal(t, r.Method, http.MethodGet)
assert.Equal(t, r.Path, "/users/test")
assert.NotNil(t, r.Handler)
assert.Len(t, r.Middlewares, 1)
assert.Equal(t, r.Name, "test route")
}
func exampleRoutes() Routes {
return Routes{
routeInfo{
method: http.MethodGet,
path: "/users",
params: nil,
name: "GET:/users",
},
routeInfo{
method: http.MethodGet,
path: "/users/:id",
params: []string{"id"},
name: "GET:/users/:id",
},
routeInfo{
method: http.MethodPost,
path: "/users/:id",
params: []string{"id"},
name: "POST:/users/:id",
},
routeInfo{
method: http.MethodDelete,
path: "/groups",
params: nil,
name: "non_unique_name",
},
routeInfo{
method: http.MethodPost,
path: "/groups",
params: nil,
name: "non_unique_name",
},
}
}
func TestRoutes_FindByMethodPath(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenMethod string
whenPath string
expectName string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenMethod: http.MethodGet,
whenPath: "/users/:id",
expectName: "GET:/users/:id",
},
{
name: "nok, not found",
given: exampleRoutes(),
whenMethod: http.MethodPut,
whenPath: "/users/:id",
expectName: "",
expectError: "route not found by method and path",
},
{
name: "nok, not found from nil",
given: nil,
whenMethod: http.MethodGet,
whenPath: "/users/:id",
expectName: "",
expectError: "route not found by method and path",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ri, err := tc.given.FindByMethodPath(tc.whenMethod, tc.whenPath)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
assert.Nil(t, ri)
} else {
assert.NoError(t, err)
}
if tc.expectName != "" {
assert.Equal(t, tc.expectName, ri.Name())
}
})
}
}
func TestRoutes_FilterByMethod(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenMethod string
expectNames []string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenMethod: http.MethodGet,
expectNames: []string{"GET:/users", "GET:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenMethod: http.MethodPut,
expectNames: nil,
expectError: "route not found by method",
},
{
name: "nok, not found from nil",
given: nil,
whenMethod: http.MethodGet,
expectNames: nil,
expectError: "route not found by method",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByMethod(tc.whenMethod)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectNames) > 0 {
assert.Len(t, ris, len(tc.expectNames))
for _, ri := range ris {
assert.Contains(t, tc.expectNames, ri.Name())
}
} else {
assert.Nil(t, ris)
}
})
}
}
func TestRoutes_FilterByPath(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenPath string
expectNames []string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenPath: "/users/:id",
expectNames: []string{"GET:/users/:id", "POST:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenPath: "/",
expectNames: nil,
expectError: "route not found by path",
},
{
name: "nok, not found from nil",
given: nil,
whenPath: "/users/:id",
expectNames: nil,
expectError: "route not found by path",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByPath(tc.whenPath)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectNames) > 0 {
assert.Len(t, ris, len(tc.expectNames))
for _, ri := range ris {
assert.Contains(t, tc.expectNames, ri.Name())
}
} else {
assert.Nil(t, ris)
}
})
}
}
func TestRoutes_FilterByName(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenName string
expectMethodPath []string
expectError string
}{
{
name: "ok, found multiple",
given: exampleRoutes(),
whenName: "non_unique_name",
expectMethodPath: []string{"DELETE:/groups", "POST:/groups"},
},
{
name: "ok, found single",
given: exampleRoutes(),
whenName: "GET:/users/:id",
expectMethodPath: []string{"GET:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenName: "/",
expectMethodPath: nil,
expectError: "route not found by name",
},
{
name: "nok, not found from nil",
given: nil,
whenName: "/users/:id",
expectMethodPath: nil,
expectError: "route not found by name",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByName(tc.whenName)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectMethodPath) > 0 {
assert.Len(t, ris, len(tc.expectMethodPath))
for _, ri := range ris {
assert.Contains(t, tc.expectMethodPath, fmt.Sprintf("%v:%v", ri.Method(), ri.Path()))
}
} else {
assert.Nil(t, ris)
}
})
}
}

783
router.go
View File

@ -1,50 +1,134 @@
package echo package echo
import ( import (
"errors"
"net/http" "net/http"
"net/url"
) )
type ( // Router is interface for routing requests to registered routes.
// Router is the registry of all registered routes for an `Echo` instance for type Router interface {
// request matching and URL path parameter parsing. // Add registers Routable with the Router and returns registered RouteInfo
Router struct { Add(routable Routable) (RouteInfo, error)
tree *node // Remove removes route from the Router
routes map[string]*Route Remove(method string, path string) error
echo *Echo // Routes returns information about all registered routes
} Routes() Routes
node struct {
kind kind // Match searches Router for matching route and applies it to result fields.
label byte Match(req *http.Request, params *PathParams) RouteMatch
prefix string }
parent *node
staticChildren children // Routable is interface for registering Route with Router. During route registration process the Router will
ppath string // convert Routable to RouteInfo with ToRouteInfo method. By creating custom implementation of Routable additional
pnames []string // information about registered route can be stored in Routes (i.e. privileges used with route etc.)
methodHandler *methodHandler type Routable interface {
paramChild *node // ToRouteInfo converts Routable to RouteInfo
anyChild *node //
// isLeaf indicates that node does not have child routes // This method is meant to be used by Router after it parses url for path parameters, to store information about
isLeaf bool // route just added.
// isHandler indicates that node has at least one handler registered to it ToRouteInfo(params []string) RouteInfo
isHandler bool // ToRoute converts Routable to Route which Router uses to register the method handler for path.
} //
kind uint8 // This method is meant to be used by Router to get fields (including handler and middleware functions) needed to
children []*node // add Route to Router.
methodHandler struct { ToRoute() Route
connect HandlerFunc // ForGroup recreates routable with added group prefix and group middlewares it is grouped to.
delete HandlerFunc //
get HandlerFunc // Is necessary for Echo.Group to be able to add/register Routable with Router and having group prefix and group
head HandlerFunc // middlewares included in actually registered Route.
options HandlerFunc ForGroup(pathPrefix string, middlewares []MiddlewareFunc) Routable
patch HandlerFunc }
post HandlerFunc
propfind HandlerFunc // Routes is collection of RouteInfo instances with various helper methods.
put HandlerFunc type Routes []RouteInfo
trace HandlerFunc
report HandlerFunc // RouteInfo describes registered route base fields.
} // Method+Path pair uniquely identifies the Route. Name can have duplicates.
type RouteInfo interface {
Method() string
Path() string
Name() string
Params() []string
Reverse(params ...interface{}) string
// NOTE: handler and middlewares are not exposed because handler could be already wrapping middlewares and therefore
// it is not always 100% known if handler function already wraps middlewares or not. In Echo handler could be one
// function or several functions wrapping each other.
}
// RouteMatchType describes possible states that request could be in perspective of routing
type RouteMatchType uint8
const (
// RouteMatchUnknown is state before routing is done. Default state for fresh context.
RouteMatchUnknown RouteMatchType = iota
// RouteMatchNotFound is state when router did not find matching route for current request
RouteMatchNotFound
// RouteMatchMethodNotAllowed is state when router did not find route with matching path + method for current request.
// Although router had matching route with that path but different method.
RouteMatchMethodNotAllowed
// RouteMatchFound is state when router found exact match for path + method combination
RouteMatchFound
) )
// RouteMatch is result object for Router.Match. Its main purpose is to avoid allocating memory for PathParams inside router.
type RouteMatch struct {
// Type contains result as enumeration of Router.Match and helps to understand did Router actually matched Route or
// what kind of error case (404/405) we have at the end of the handler chain.
Type RouteMatchType
// RoutePath contains original path with what matched route was registered with (including placeholders etc).
RoutePath string
// Handler is function(chain) that was matched by router. In case of no match could result to ErrNotFound or ErrMethodNotAllowed.
Handler HandlerFunc
// RouteInfo is information about route we just matched
RouteInfo RouteInfo
}
// PathParams is collections of PathParam instances with various helper methods
type PathParams []PathParam
// PathParam is tuple pf path parameter name and its value in request path
type PathParam struct {
Name string
Value string
}
// DefaultRouter is the registry of all registered routes for an `Echo` instance for
// request matching and URL path parameter parsing.
// Note: DefaultRouter is not coroutine-safe. Do not Add/Remove routes after HTTP server has been started with Echo.
type DefaultRouter struct {
tree *node
routes Routes
echo *Echo
allowOverwritingRoute bool
unescapePathParamValues bool
useEscapedPathForRouting bool
}
type children []*node
type node struct {
kind kind
label byte
prefix string
parent *node
staticChildren children
originalPath string
methods *routeMethods
paramChild *node
anyChild *node
paramsCount int
// isLeaf indicates that node does not have child routes
isLeaf bool
// isHandler indicates that node has at least one handler registered to it
isHandler bool
}
type kind uint8
const ( const (
staticKind kind = iota staticKind kind = iota
paramKind paramKind
@ -54,90 +138,362 @@ const (
anyLabel = byte('*') anyLabel = byte('*')
) )
func (m *methodHandler) isHandler() bool { type routeMethod struct {
return m.connect != nil || *routeInfo
m.delete != nil || handler HandlerFunc
m.get != nil || orgRouteInfo RouteInfo
m.head != nil ||
m.options != nil ||
m.patch != nil ||
m.post != nil ||
m.propfind != nil ||
m.put != nil ||
m.trace != nil ||
m.report != nil
} }
// NewRouter returns a new Router instance. type routeMethods struct {
func NewRouter(e *Echo) *Router { connect *routeMethod
return &Router{ delete *routeMethod
tree: &node{ get *routeMethod
methodHandler: new(methodHandler), head *routeMethod
}, options *routeMethod
routes: map[string]*Route{}, patch *routeMethod
echo: e, post *routeMethod
propfind *routeMethod
put *routeMethod
trace *routeMethod
report *routeMethod
anyOther map[string]*routeMethod
}
func (m *routeMethods) set(method string, r *routeMethod) {
switch method {
case http.MethodConnect:
m.connect = r
case http.MethodDelete:
m.delete = r
case http.MethodGet:
m.get = r
case http.MethodHead:
m.head = r
case http.MethodOptions:
m.options = r
case http.MethodPatch:
m.patch = r
case http.MethodPost:
m.post = r
case PROPFIND:
m.propfind = r
case http.MethodPut:
m.put = r
case http.MethodTrace:
m.trace = r
case REPORT:
m.report = r
default:
if m.anyOther == nil {
m.anyOther = make(map[string]*routeMethod)
}
if r.handler == nil {
delete(m.anyOther, method)
} else {
m.anyOther[method] = r
}
} }
} }
// Add registers a new route for method and path with matching handler. func (m *routeMethods) find(method string) *routeMethod {
func (r *Router) Add(method, path string, h HandlerFunc) { switch method {
// Validate path case http.MethodConnect:
return m.connect
case http.MethodDelete:
return m.delete
case http.MethodGet:
return m.get
case http.MethodHead:
return m.head
case http.MethodOptions:
return m.options
case http.MethodPatch:
return m.patch
case http.MethodPost:
return m.post
case PROPFIND:
return m.propfind
case http.MethodPut:
return m.put
case http.MethodTrace:
return m.trace
case REPORT:
return m.report
default:
return m.anyOther[method]
}
}
func (m *routeMethods) isHandler() bool {
return m.get != nil ||
m.post != nil ||
m.options != nil ||
m.put != nil ||
m.delete != nil ||
m.connect != nil ||
m.head != nil ||
m.patch != nil ||
m.propfind != nil ||
m.trace != nil ||
m.report != nil ||
len(m.anyOther) != 0
}
// RouterConfig is configuration options for (default) router
type RouterConfig struct {
// AllowOverwritingRoute instructs Router NOT to return error when new route is registered with the same method+path
// and replaces matching route with the new one.
AllowOverwritingRoute bool
// UnescapePathParamValues instructs Router to unescape path parameter value when request if matched to the routes
UnescapePathParamValues bool
// UseEscapedPathForMatching instructs Router to use escaped request URL path (req.URL.Path) for matching the request.
UseEscapedPathForMatching bool
}
// NewRouter returns a new Router instance.
func NewRouter(e *Echo, config RouterConfig) *DefaultRouter {
r := &DefaultRouter{
tree: &node{
methods: new(routeMethods),
isLeaf: true,
isHandler: false,
},
routes: make(Routes, 0),
echo: e,
allowOverwritingRoute: config.AllowOverwritingRoute,
unescapePathParamValues: config.UnescapePathParamValues,
useEscapedPathForRouting: config.UseEscapedPathForMatching,
}
return r
}
// Routes returns all registered routes
func (r *DefaultRouter) Routes() Routes {
return r.routes
}
// Remove unregisters registered route
func (r *DefaultRouter) Remove(method string, path string) error {
currentNode := r.tree
if currentNode == nil || (currentNode.isLeaf && !currentNode.isHandler) {
return errors.New("router has no routes to remove")
}
if path == "" { if path == "" {
path = "/" path = "/"
} }
if path[0] != '/' { if path[0] != '/' {
path = "/" + path path = "/" + path
} }
pnames := []string{} // Param names
ppath := path // Pristine path
if h == nil && r.echo.Logger != nil { var nodeToRemove *node
// FIXME: in future we should return error prefixLen := 0
r.echo.Logger.Errorf("Adding route without handler function: %v:%v", method, path) for {
if currentNode.originalPath == path && currentNode.isHandler {
nodeToRemove = currentNode
break
}
if currentNode.kind == staticKind {
prefixLen = prefixLen + len(currentNode.prefix)
} else {
prefixLen = len(currentNode.originalPath)
}
if prefixLen >= len(path) {
break
}
next := path[prefixLen]
switch next {
case paramLabel:
currentNode = currentNode.paramChild
case anyLabel:
currentNode = currentNode.anyChild
default:
currentNode = currentNode.findStaticChild(next)
}
if currentNode == nil {
break
}
} }
if nodeToRemove == nil {
return errors.New("could not find route to remove by given path")
}
if !nodeToRemove.isHandler {
return errors.New("could not find route to remove by given path")
}
if mh := nodeToRemove.methods.find(method); mh == nil {
return errors.New("could not find route to remove by given path and method")
}
nodeToRemove.setHandler(method, nil)
var rIndex int
for i, rr := range r.routes {
if rr.Method() == method && rr.Path() == path {
rIndex = i
break
}
}
r.routes = append(r.routes[:rIndex], r.routes[rIndex+1:]...)
if !nodeToRemove.isHandler && nodeToRemove.isLeaf {
// TODO: if !nodeToRemove.isLeaf and has at least 2 children merge paths for remaining nodes?
current := nodeToRemove
for {
parent := current.parent
if parent == nil {
break
}
switch current.kind {
case staticKind:
var index int
for i, c := range parent.staticChildren {
if c == current {
index = i
break
}
}
parent.staticChildren = append(parent.staticChildren[:index], parent.staticChildren[index+1:]...)
case paramKind:
parent.paramChild = nil
case anyKind:
parent.anyChild = nil
}
parent.isLeaf = parent.anyChild == nil && parent.paramChild == nil && len(parent.staticChildren) == 0
if !parent.isLeaf || parent.isHandler {
break
}
current = parent
}
}
return nil
}
// AddRouteError is error returned by Router.Add containing information what actual route adding failed. Useful for
// mass adding (i.e. Any() routes)
type AddRouteError struct {
Method string
Path string
Err error
}
func (e *AddRouteError) Error() string { return e.Method + " " + e.Path + ": " + e.Err.Error() }
func (e *AddRouteError) Unwrap() error { return e.Err }
func newAddRouteError(route Route, err error) *AddRouteError {
return &AddRouteError{
Method: route.Method,
Path: route.Path,
Err: err,
}
}
// Add registers a new route for method and path with matching handler.
func (r *DefaultRouter) Add(routable Routable) (RouteInfo, error) {
route := routable.ToRoute()
if route.Handler == nil {
return nil, newAddRouteError(route, errors.New("adding route without handler function"))
}
method := route.Method
path := route.Path
h := applyMiddleware(route.Handler, route.Middlewares...)
if !r.allowOverwritingRoute {
for _, rr := range r.routes {
if route.Method == rr.Method() && route.Path == rr.Path() {
return nil, newAddRouteError(route, errors.New("adding duplicate route (same method+path) is not allowed"))
}
}
}
if path == "" {
path = "/"
}
if path[0] != '/' {
path = "/" + path
}
paramNames := make([]string, 0)
originalPath := path
wasAdded := false
var ri RouteInfo
for i, lcpIndex := 0, len(path); i < lcpIndex; i++ { for i, lcpIndex := 0, len(path); i < lcpIndex; i++ {
if path[i] == ':' { if path[i] == paramLabel {
if i > 0 && path[i-1] == '\\' { if i > 0 && path[i-1] == '\\' {
continue continue
} }
j := i + 1 j := i + 1
r.insert(method, path[:i], nil, staticKind, "", nil) r.insert(staticKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
for ; i < lcpIndex && path[i] != '/'; i++ { for ; i < lcpIndex && path[i] != '/'; i++ {
} }
pnames = append(pnames, path[j:i]) paramNames = append(paramNames, path[j:i])
path = path[:j] + path[i:] path = path[:j] + path[i:]
i, lcpIndex = j, len(path) i, lcpIndex = j, len(path)
if i == lcpIndex { if i == lcpIndex {
// path node is last fragment of route path. ie. `/users/:id` // path node is last fragment of route path. ie. `/users/:id`
r.insert(method, path[:i], h, paramKind, ppath, pnames) ri = routable.ToRouteInfo(paramNames)
rm := routeMethod{
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
handler: h,
orgRouteInfo: ri,
}
r.insert(paramKind, path[:i], method, rm)
wasAdded = true
break
} else { } else {
r.insert(method, path[:i], nil, paramKind, "", nil) r.insert(paramKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
} }
} else if path[i] == '*' { } else if path[i] == anyLabel {
r.insert(method, path[:i], nil, staticKind, "", nil) r.insert(staticKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
pnames = append(pnames, "*") paramNames = append(paramNames, "*")
r.insert(method, path[:i+1], h, anyKind, ppath, pnames) ri = routable.ToRouteInfo(paramNames)
rm := routeMethod{
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
handler: h,
orgRouteInfo: ri,
}
r.insert(anyKind, path[:i+1], method, rm)
wasAdded = true
break
} }
} }
r.insert(method, path, h, staticKind, ppath, pnames) if !wasAdded {
ri = routable.ToRouteInfo(paramNames)
rm := routeMethod{
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
handler: h,
orgRouteInfo: ri,
}
r.insert(staticKind, path, method, rm)
}
r.storeRouteInfo(ri)
return ri, nil
} }
func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string) { func (r *DefaultRouter) storeRouteInfo(ri RouteInfo) {
// Adjust max param for i, rr := range r.routes {
paramLen := len(pnames) if ri.Method() == rr.Method() && ri.Path() == rr.Path() {
if *r.echo.maxParam < paramLen { r.routes[i] = ri
*r.echo.maxParam = paramLen return
}
} }
r.routes = append(r.routes, ri)
}
func (r *DefaultRouter) insert(t kind, path string, method string, ri routeMethod) {
currentNode := r.tree // Current node as root currentNode := r.tree // Current node as root
if currentNode == nil {
panic("echo: invalid method")
}
search := path search := path
for { for {
@ -157,11 +513,11 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
// At root node // At root node
currentNode.label = search[0] currentNode.label = search[0]
currentNode.prefix = search currentNode.prefix = search
if h != nil { if ri.handler != nil {
currentNode.kind = t currentNode.kind = t
currentNode.addHandler(method, h) currentNode.setHandler(method, &ri)
currentNode.ppath = ppath currentNode.paramsCount = len(ri.params)
currentNode.pnames = pnames currentNode.originalPath = ri.path
} }
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
} else if lcpLen < prefixLen { } else if lcpLen < prefixLen {
@ -171,9 +527,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.prefix[lcpLen:], currentNode.prefix[lcpLen:],
currentNode, currentNode,
currentNode.staticChildren, currentNode.staticChildren,
currentNode.methodHandler, currentNode.methods,
currentNode.ppath, currentNode.paramsCount,
currentNode.pnames, currentNode.originalPath,
currentNode.paramChild, currentNode.paramChild,
currentNode.anyChild, currentNode.anyChild,
) )
@ -193,9 +549,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.label = currentNode.prefix[0] currentNode.label = currentNode.prefix[0]
currentNode.prefix = currentNode.prefix[:lcpLen] currentNode.prefix = currentNode.prefix[:lcpLen]
currentNode.staticChildren = nil currentNode.staticChildren = nil
currentNode.methodHandler = new(methodHandler) currentNode.methods = new(routeMethods)
currentNode.ppath = "" currentNode.originalPath = ""
currentNode.pnames = nil currentNode.paramsCount = 0
currentNode.paramChild = nil currentNode.paramChild = nil
currentNode.anyChild = nil currentNode.anyChild = nil
currentNode.isLeaf = false currentNode.isLeaf = false
@ -207,13 +563,18 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
if lcpLen == searchLen { if lcpLen == searchLen {
// At parent node // At parent node
currentNode.kind = t currentNode.kind = t
currentNode.addHandler(method, h) if ri.handler != nil {
currentNode.ppath = ppath currentNode.setHandler(method, &ri)
currentNode.pnames = pnames currentNode.paramsCount = len(ri.params)
currentNode.originalPath = ri.path
}
} else { } else {
// Create child node // Create child node
n = newNode(t, search[lcpLen:], currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) n = newNode(t, search[lcpLen:], currentNode, nil, new(routeMethods), 0, ri.path, nil, nil)
n.addHandler(method, h) if ri.handler != nil {
n.setHandler(method, &ri)
n.paramsCount = len(ri.params)
}
// Only Static children could reach here // Only Static children could reach here
currentNode.addStaticChild(n) currentNode.addStaticChild(n)
} }
@ -227,8 +588,11 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
continue continue
} }
// Create child node // Create child node
n := newNode(t, search, currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) n := newNode(t, search, currentNode, nil, new(routeMethods), 0, ri.path, nil, nil)
n.addHandler(method, h) if ri.handler != nil {
n.setHandler(method, &ri)
n.paramsCount = len(ri.params)
}
switch t { switch t {
case staticKind: case staticKind:
currentNode.addStaticChild(n) currentNode.addStaticChild(n)
@ -240,28 +604,26 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
} else { } else {
// Node already exists // Node already exists
if h != nil { if ri.handler != nil {
currentNode.addHandler(method, h) currentNode.setHandler(method, &ri)
currentNode.ppath = ppath currentNode.paramsCount = len(ri.params)
if len(currentNode.pnames) == 0 { // Issue #729 currentNode.originalPath = ri.path
currentNode.pnames = pnames
}
} }
} }
return return
} }
} }
func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath string, pnames []string, paramChildren, anyChildren *node) *node { func newNode(t kind, pre string, p *node, sc children, mh *routeMethods, paramsCount int, ppath string, paramChildren, anyChildren *node) *node {
return &node{ return &node{
kind: t, kind: t,
label: pre[0], label: pre[0],
prefix: pre, prefix: pre,
parent: p, parent: p,
staticChildren: sc, staticChildren: sc,
ppath: ppath, originalPath: ppath,
pnames: pnames, paramsCount: paramsCount,
methodHandler: mh, methods: mh,
paramChild: paramChildren, paramChild: paramChildren,
anyChild: anyChildren, anyChild: anyChildren,
isLeaf: sc == nil && paramChildren == nil && anyChildren == nil, isLeaf: sc == nil && paramChildren == nil && anyChildren == nil,
@ -297,99 +659,77 @@ func (n *node) findChildWithLabel(l byte) *node {
return nil return nil
} }
func (n *node) addHandler(method string, h HandlerFunc) { func (n *node) setHandler(method string, r *routeMethod) {
switch method { n.methods.set(method, r)
case http.MethodConnect: if r != nil && r.handler != nil {
n.methodHandler.connect = h
case http.MethodDelete:
n.methodHandler.delete = h
case http.MethodGet:
n.methodHandler.get = h
case http.MethodHead:
n.methodHandler.head = h
case http.MethodOptions:
n.methodHandler.options = h
case http.MethodPatch:
n.methodHandler.patch = h
case http.MethodPost:
n.methodHandler.post = h
case PROPFIND:
n.methodHandler.propfind = h
case http.MethodPut:
n.methodHandler.put = h
case http.MethodTrace:
n.methodHandler.trace = h
case REPORT:
n.methodHandler.report = h
}
if h != nil {
n.isHandler = true n.isHandler = true
} else { } else {
n.isHandler = n.methodHandler.isHandler() n.isHandler = n.methods.isHandler()
} }
} }
func (n *node) findHandler(method string) HandlerFunc { const (
switch method { // NotFoundRouteName is name of RouteInfo returned when router did not find matching route (404: not found).
case http.MethodConnect: NotFoundRouteName = "EchoRouteNotFound"
return n.methodHandler.connect // MethodNotAllowedRouteName is name of RouteInfo returned when router did not find matching method for route (404: method not allowed).
case http.MethodDelete: MethodNotAllowedRouteName = "EchoRouteMethodNotAllowed"
return n.methodHandler.delete )
case http.MethodGet:
return n.methodHandler.get // Note: notFoundRouteInfo exists to avoid allocations when setting 404 RouteInfo to RouteMatch
case http.MethodHead: var notFoundRouteInfo = &routeInfo{
return n.methodHandler.head method: "",
case http.MethodOptions: path: "",
return n.methodHandler.options params: nil,
case http.MethodPatch: name: NotFoundRouteName,
return n.methodHandler.patch
case http.MethodPost:
return n.methodHandler.post
case PROPFIND:
return n.methodHandler.propfind
case http.MethodPut:
return n.methodHandler.put
case http.MethodTrace:
return n.methodHandler.trace
case REPORT:
return n.methodHandler.report
default:
return nil
}
} }
func (n *node) checkMethodNotAllowed() HandlerFunc { // Note: methodNotAllowedRouteInfo exists to avoid allocations when setting 405 RouteInfo to RouteMatch
for _, m := range methods { var methodNotAllowedRouteInfo = &routeInfo{
if h := n.findHandler(m); h != nil { method: "",
return MethodNotAllowedHandler path: "",
} params: nil,
} name: MethodNotAllowedRouteName,
return NotFoundHandler
} }
// Find lookup a handler registered for method and path. It also parses URL for path // notFoundHandler is handler for 404 cases
// parameters and load them into context. // Handle returned ErrNotFound errors in Echo.HTTPErrorHandler
var notFoundHandler = func(c Context) error {
return ErrNotFound
}
// methodNotAllowedHandler is handler for case when route for path+method match was not found (http code 405)
// Handle returned ErrMethodNotAllowed errors in Echo.HTTPErrorHandler
var methodNotAllowedHandler = func(c Context) error {
return ErrMethodNotAllowed
}
// Match looks up a handler registered for method and path. It also parses URL for path parameters and loads them
// into context.
// //
// For performance: // For performance:
// //
// - Get context from `Echo#AcquireContext()` // - Get context from `Echo#AcquireContext()`
// - Reset it `Context#Reset()` // - Reset it `Context#Reset()`
// - Return it `Echo#ReleaseContext()`. // - Return it `Echo#ReleaseContext()`.
func (r *Router) Find(method, path string, c Context) { func (r *DefaultRouter) Match(req *http.Request, pathParams *PathParams) RouteMatch {
ctx := c.(*context) *pathParams = (*pathParams)[0:cap(*pathParams)]
ctx.path = path
currentNode := r.tree // Current node as root
path := req.URL.Path
if !r.useEscapedPathForRouting && req.URL.RawPath != "" {
// Difference between URL.RawPath and URL.Path is:
// * URL.Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/.
// * URL.RawPath is an optional field which only gets set if the default encoding is different from Path.
path = req.URL.RawPath
}
var ( var (
currentNode = r.tree // root as current node
previousBestMatchNode *node previousBestMatchNode *node
matchedHandler HandlerFunc matchedRouteMethod *routeMethod
// search stores the remaining path to check for match. By each iteration we move from start of path to end of the path // search stores the remaining path to check for match. By each iteration we move from start of path to end of the path
// and search value gets shorter and shorter. // and search value gets shorter and shorter.
search = path search = path
searchIndex = 0 searchIndex = 0
paramIndex int // Param counter paramIndex int // Param counter
paramValues = ctx.pvalues // Use the internal slice so the interface can keep the illusion of a dynamic slice
) )
// Backtracking is needed when a dead end (leaf node) is reached in the router tree. // Backtracking is needed when a dead end (leaf node) is reached in the router tree.
@ -421,8 +761,8 @@ func (r *Router) Find(method, path string, c Context) {
paramIndex-- paramIndex--
// for param/any node.prefix value is always `:` so we can not deduce searchIndex from that and must use pValue // for param/any node.prefix value is always `:` so we can not deduce searchIndex from that and must use pValue
// for that index as it would also contain part of path we cut off before moving into node we are backtracking from // for that index as it would also contain part of path we cut off before moving into node we are backtracking from
searchIndex -= len(paramValues[paramIndex]) searchIndex -= len((*pathParams)[paramIndex].Value)
paramValues[paramIndex] = "" (*pathParams)[paramIndex].Value = ""
} }
search = path[searchIndex:] search = path[searchIndex:]
return return
@ -456,7 +796,7 @@ func (r *Router) Find(method, path string, c Context) {
// No matching prefix, let's backtrack to the first possible alternative node of the decision path // No matching prefix, let's backtrack to the first possible alternative node of the decision path
nk, ok := backtrackToNextNodeKind(staticKind) nk, ok := backtrackToNextNodeKind(staticKind)
if !ok { if !ok {
return // No other possibilities on the decision path break // No other possibilities on the decision path
} else if nk == paramKind { } else if nk == paramKind {
goto Param goto Param
// NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently // NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently
@ -479,8 +819,8 @@ func (r *Router) Find(method, path string, c Context) {
if previousBestMatchNode == nil { if previousBestMatchNode == nil {
previousBestMatchNode = currentNode previousBestMatchNode = currentNode
} }
if h := currentNode.findHandler(method); h != nil { if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
matchedHandler = h matchedRouteMethod = rMethod
break break
} }
} }
@ -507,7 +847,7 @@ func (r *Router) Find(method, path string, c Context) {
} }
} }
paramValues[paramIndex] = search[:i] (*pathParams)[paramIndex].Value = search[:i]
paramIndex++ paramIndex++
search = search[i:] search = search[i:]
searchIndex = searchIndex + i searchIndex = searchIndex + i
@ -519,7 +859,7 @@ func (r *Router) Find(method, path string, c Context) {
if child := currentNode.anyChild; child != nil { if child := currentNode.anyChild; child != nil {
// If any node is found, use remaining path for paramValues // If any node is found, use remaining path for paramValues
currentNode = child currentNode = child
paramValues[len(currentNode.pnames)-1] = search (*pathParams)[currentNode.paramsCount-1].Value = search
// update indexes/search in case we need to backtrack when no handler match is found // update indexes/search in case we need to backtrack when no handler match is found
paramIndex++ paramIndex++
searchIndex += +len(search) searchIndex += +len(search)
@ -530,8 +870,8 @@ func (r *Router) Find(method, path string, c Context) {
if previousBestMatchNode == nil { if previousBestMatchNode == nil {
previousBestMatchNode = currentNode previousBestMatchNode = currentNode
} }
if h := currentNode.findHandler(method); h != nil { if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
matchedHandler = h matchedRouteMethod = rMethod
break break
} }
} }
@ -550,20 +890,63 @@ func (r *Router) Find(method, path string, c Context) {
} }
} }
result := RouteMatch{
Type: RouteMatchNotFound,
Handler: notFoundHandler,
RoutePath: "",
RouteInfo: notFoundRouteInfo,
}
if currentNode == nil && previousBestMatchNode == nil { if currentNode == nil && previousBestMatchNode == nil {
return // nothing matched at all *pathParams = (*pathParams)[0:0]
return result // nothing matched at all with given path
} }
if matchedHandler != nil { if matchedRouteMethod != nil {
ctx.handler = matchedHandler result.Type = RouteMatchFound
result.Handler = matchedRouteMethod.handler
result.RoutePath = matchedRouteMethod.routeInfo.path
result.RouteInfo = matchedRouteMethod.routeInfo
} else { } else {
// use previous match as basis. although we have no matching handler we have path match. // use previous match as basis. although we have no matching handler we have path match.
// so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404) // so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404)
currentNode = previousBestMatchNode currentNode = previousBestMatchNode
ctx.handler = currentNode.checkMethodNotAllowed()
}
ctx.path = currentNode.ppath
ctx.pnames = currentNode.pnames
return // this here is only reason why `RouteMatch.RoutePath` exists. We do not want to create new RouteInfo just for path.
result.RoutePath = currentNode.originalPath
if currentNode.isHandler {
// TODO: in case of OPTIONS method we could respond with list of methods that node has. See https://httpwg.org/specs/rfc7231.html#OPTIONS
result.Type = RouteMatchMethodNotAllowed
result.Handler = methodNotAllowedHandler
result.RouteInfo = methodNotAllowedRouteInfo
}
}
*pathParams = (*pathParams)[0:currentNode.paramsCount]
if matchedRouteMethod != nil {
for i, name := range matchedRouteMethod.params {
(*pathParams)[i].Name = name
}
}
if r.unescapePathParamValues && currentNode.kind != staticKind {
// See issue #1531, #1258 - there are cases when path parameter need to be unescaped
for i, p := range *pathParams {
tmpVal, err := url.PathUnescape(p.Value)
if err == nil { // handle problems by ignoring them.
(*pathParams)[i].Value = tmpVal
}
}
}
return result
}
// Get returns path parameter value for given name or default value.
func (p PathParams) Get(name string, defaultValue string) string {
for _, param := range p {
if param.Name == name {
return param.Value
}
}
return defaultValue
} }

File diff suppressed because it is too large Load Diff

220
server.go Normal file
View File

@ -0,0 +1,220 @@
package echo
import (
stdContext "context"
"crypto/tls"
"fmt"
"io/fs"
"log"
"net"
"net/http"
"os"
"time"
)
const (
website = "https://echo.labstack.com"
// http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo
banner = `
____ __
/ __/___/ / ___
/ _// __/ _ \/ _ \
/___/\__/_//_/\___/ %s
High performance, minimalist Go web framework
%s
____________________________________O/_______
O\
`
)
// StartConfig is for creating configured http.Server instance to start serve http(s) requests with given Echo instance
type StartConfig struct {
// Address for the server to listen on (if not using custom listener)
Address string
// ListenerNetwork allows setting listener network (see net.Listen for allowed values)
// Optional: defaults to "tcp"
ListenerNetwork string
// CertFilesystem is file system used to load certificates and keys (if certs/keys are given as paths)
CertFilesystem fs.FS
// DisableHTTP2 disables supports for HTTP2 in TLS server
DisableHTTP2 bool
// HideBanner does not log Echo banner on server startup
HideBanner bool
// HidePort does not log port on server startup
HidePort bool
// GracefulContext is context that completion signals graceful shutdown start
GracefulContext stdContext.Context
// GracefulTimeout is period which server allows listeners to finish serving ongoing requests. If this time is exceeded process is exited
// Defaults to 10 seconds
GracefulTimeout time.Duration
// OnShutdownError allows customization of what happens when (graceful) server Shutdown method returns an error.
// Defaults to calling e.logger.Error(err)
OnShutdownError func(err error)
// TLSConfigFunc allows modifying TLS configuration before listener is created with it.
TLSConfigFunc func(tlsConfig *tls.Config)
// ListenerAddrFunc allows getting listener address before server starts serving requests on listener. Useful when
// address is set as random (`:0`) port.
ListenerAddrFunc func(addr net.Addr)
// BeforeServeFunc allows customizing/accessing server before server starts serving requests on listener.
BeforeServeFunc func(s *http.Server) error
}
// Start starts a HTTP server.
func (sc StartConfig) Start(e *Echo) error {
logger := e.Logger
server := http.Server{
Handler: e,
ErrorLog: log.New(logger, "", 0),
}
var tlsConfig *tls.Config = nil
if sc.TLSConfigFunc != nil {
tlsConfig = &tls.Config{}
configureTLS(&sc, tlsConfig)
sc.TLSConfigFunc(tlsConfig)
}
listener, err := createListener(&sc, tlsConfig)
if err != nil {
return err
}
return serve(&sc, &server, listener, logger)
}
// StartTLS starts a HTTPS server.
// If `certFile` or `keyFile` is `string` the values are treated as file paths.
// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is.
func (sc StartConfig) StartTLS(e *Echo, certFile, keyFile interface{}) error {
logger := e.Logger
s := http.Server{
Handler: e,
ErrorLog: log.New(logger, "", 0),
}
certFs := sc.CertFilesystem
if certFs == nil {
certFs = os.DirFS(".")
}
cert, err := filepathOrContent(certFile, certFs)
if err != nil {
return err
}
key, err := filepathOrContent(keyFile, certFs)
if err != nil {
return err
}
cer, err := tls.X509KeyPair(cert, key)
if err != nil {
return err
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cer}}
configureTLS(&sc, tlsConfig)
if sc.TLSConfigFunc != nil {
sc.TLSConfigFunc(tlsConfig)
}
listener, err := createListener(&sc, tlsConfig)
if err != nil {
return err
}
return serve(&sc, &s, listener, logger)
}
func serve(sc *StartConfig, server *http.Server, listener net.Listener, logger Logger) error {
if sc.BeforeServeFunc != nil {
if err := sc.BeforeServeFunc(server); err != nil {
return err
}
}
startupGreetings(sc, logger, listener)
if sc.GracefulContext != nil {
ctx, cancel := stdContext.WithCancel(sc.GracefulContext)
defer cancel() // make sure this graceful coroutine will end when serve returns by some other means
go gracefulShutdown(ctx, sc, server, logger)
}
return server.Serve(listener)
}
func configureTLS(sc *StartConfig, tlsConfig *tls.Config) {
if !sc.DisableHTTP2 {
tlsConfig.NextProtos = append(tlsConfig.NextProtos, "h2")
}
}
func createListener(sc *StartConfig, tlsConfig *tls.Config) (net.Listener, error) {
listenerNetwork := sc.ListenerNetwork
if listenerNetwork == "" {
listenerNetwork = "tcp"
}
var listener net.Listener
var err error
if tlsConfig != nil {
listener, err = tls.Listen(listenerNetwork, sc.Address, tlsConfig)
} else {
listener, err = net.Listen(listenerNetwork, sc.Address)
}
if err != nil {
return nil, err
}
if sc.ListenerAddrFunc != nil {
sc.ListenerAddrFunc(listener.Addr())
}
return listener, nil
}
func startupGreetings(sc *StartConfig, logger Logger, listener net.Listener) {
if !sc.HideBanner {
bannerText := fmt.Sprintf(banner, "v"+Version, website)
logger.Write([]byte(bannerText))
}
if !sc.HidePort {
logger.Write([]byte(fmt.Sprintf("⇨ http(s) server started on %s\n", listener.Addr())))
}
}
func filepathOrContent(fileOrContent interface{}, certFilesystem fs.FS) (content []byte, err error) {
switch v := fileOrContent.(type) {
case string:
return fs.ReadFile(certFilesystem, v)
case []byte:
return v, nil
default:
return nil, ErrInvalidCertOrKeyType
}
}
func gracefulShutdown(gracefulCtx stdContext.Context, sc *StartConfig, server *http.Server, logger Logger) {
<-gracefulCtx.Done() // wait until shutdown context is closed.
// note: is server if closed by other means this method is still run but is good as no-op
timeout := sc.GracefulTimeout
if timeout == 0 {
timeout = 10 * time.Second
}
shutdownCtx, cancel := stdContext.WithTimeout(stdContext.Background(), timeout)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
// we end up here when listeners are not shut down within given timeout
if sc.OnShutdownError != nil {
sc.OnShutdownError(err)
return
}
logger.Error(fmt.Errorf("failed to shut down server within given timeout: %w", err))
}
}

815
server_test.go Normal file
View File

@ -0,0 +1,815 @@
package echo
import (
"bytes"
stdContext "context"
"crypto/tls"
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"strings"
"testing"
"time"
)
func startOnRandomPort(ctx stdContext.Context, e *Echo) (string, error) {
addrChan := make(chan string)
errCh := make(chan error)
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
return waitForServerStart(addrChan, errCh)
}
func waitForServerStart(addrChan <-chan string, errCh <-chan error) (string, error) {
waitCtx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer cancel()
// wait for addr to arrive
for {
select {
case <-waitCtx.Done():
return "", waitCtx.Err()
case addr := <-addrChan:
return addr, nil
case err := <-errCh:
if err == http.ErrServerClosed { // was closed normally before listener callback was called. should not be possible
return "", nil
}
// failed to start and we did not manage to get even listener part.
return "", err
}
}
}
func doGet(url string) (int, string, error) {
resp, err := http.Get(url)
if err != nil {
return 0, "", err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return resp.StatusCode, "", err
}
return resp.StatusCode, string(body), nil
}
func TestStartConfig_Start(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
// check if server is actually up
code, body, err := doGet(fmt.Sprintf("http://%v/ok", addr))
if err != nil {
assert.NoError(t, err)
return
}
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, "OK", body)
shutdown()
<-errCh // we will be blocking here until server returns from http.Serve
// check if server was stopped
code, body, err = doGet(fmt.Sprintf("http://%v/ok", addr))
assert.Equal(t, 0, code)
assert.Equal(t, "", body)
if err == nil {
t.Errorf("missing error")
return
}
assert.True(t, strings.Contains(err.Error(), "connect: connection refused"))
}
func TestStartConfig_GracefulShutdown(t *testing.T) {
var testCases = []struct {
name string
whenHandlerTakesLonger bool
expectBody string
expectGracefulError string
}{
{
name: "ok, all handlers returns before graceful shutdown deadline",
whenHandlerTakesLonger: false,
expectBody: "OK",
expectGracefulError: "",
},
{
name: "nok, handlers do not returns before graceful shutdown deadline",
whenHandlerTakesLonger: true,
expectBody: "timeout",
expectGracefulError: stdContext.DeadlineExceeded.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
msg := "OK"
if tc.whenHandlerTakesLonger {
time.Sleep(150 * time.Millisecond)
msg = "timeout"
}
return c.String(http.StatusOK, msg)
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 50*time.Millisecond)
defer shutdown()
shutdownErrChan := make(chan error, 1)
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 50 * time.Millisecond,
OnShutdownError: func(err error) {
shutdownErrChan <- err
},
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
code, body, err := doGet(fmt.Sprintf("http://%v/ok", addr))
if err != nil {
assert.NoError(t, err)
return
}
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, tc.expectBody, body)
var shutdownErr error
select {
case shutdownErr = <-shutdownErrChan:
default:
}
if tc.expectGracefulError != "" {
assert.EqualError(t, shutdownErr, tc.expectGracefulError)
} else {
assert.NoError(t, shutdownErr)
}
shutdown()
<-errCh // we will be blocking here until server returns from http.Serve
// check if server was stopped
code, body, err = doGet(fmt.Sprintf("http://%v/ok", addr))
assert.Error(t, err)
if err != nil {
assert.True(t, strings.Contains(err.Error(), "connect: connection refused"))
}
assert.Equal(t, 0, code)
assert.Equal(t, "", body)
})
}
}
func TestStartConfig_Start_withTLSConfigFunc(t *testing.T) {
e := New()
tlsConfigCalled := false
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, errors.New("not_implemented")
}
tlsConfigCalled = true
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.Start(e)
assert.EqualError(t, err, "stop_now")
assert.True(t, tlsConfigCalled)
}
func TestStartConfig_Start_createListenerError(t *testing.T) {
e := New()
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.Start(e)
assert.EqualError(t, err, "tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
}
func TestStartConfig_StartTLS(t *testing.T) {
var testCases = []struct {
name string
addr string
certFile string
keyFile string
expectError string
}{
{
name: "ok",
addr: ":0",
},
{
name: "nok, invalid certFile",
addr: ":0",
certFile: "not existing",
expectError: "open not existing: no such file or directory",
},
{
name: "nok, invalid keyFile",
addr: ":0",
keyFile: "not existing",
expectError: "open not existing: no such file or directory",
},
{
name: "nok, failed to create cert out of certFile and keyFile",
addr: ":0",
keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key
expectError: "tls: found a certificate rather than a key in the PEM for the private key",
},
{
name: "nok, invalid tls address",
addr: "nope",
expectError: "listen tcp: address nope: missing port in address",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
certFile := "_fixture/certs/cert.pem"
if tc.certFile != "" {
certFile = tc.certFile
}
keyFile := "_fixture/certs/key.pem"
if tc.keyFile != "" {
keyFile = tc.keyFile
}
s := &StartConfig{
Address: tc.addr,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, certFile, keyFile)
}()
_, err := waitForServerStart(addrChan, errCh)
if tc.expectError != "" {
if _, ok := err.(*os.PathError); ok {
assert.Error(t, err) // error messages for unix and windows are different. so name only error type here
} else {
assert.EqualError(t, err, tc.expectError)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestStartConfig_StartTLS_withTLSConfigFunc(t *testing.T) {
e := New()
tlsConfigCalled := false
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
assert.Len(t, tlsConfig.Certificates, 1)
tlsConfigCalled = true
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.StartTLS(e, "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
assert.EqualError(t, err, "stop_now")
assert.True(t, tlsConfigCalled)
}
func TestStartConfig_StartTLSAndStart(t *testing.T) {
// We name if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server
e := New()
e.GET("/", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
tlsCtx, tlsShutdown := stdContext.WithTimeout(stdContext.Background(), 100*time.Millisecond)
defer tlsShutdown()
addrTLSChan := make(chan string)
errTLSChan := make(chan error)
go func() {
s := &StartConfig{
Address: ":0",
GracefulContext: tlsCtx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrTLSChan <- addr.String()
},
}
errTLSChan <- s.StartTLS(e, "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
}()
tlsAddr, err := waitForServerStart(addrTLSChan, errTLSChan)
assert.NoError(t, err)
// check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true)
client := &http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}}
res, err := client.Get(fmt.Sprintf("https://%v", tlsAddr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 100*time.Millisecond)
defer shutdown()
addrChan := make(chan string)
errChan := make(chan error)
go func() {
s := &StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errChan <- s.Start(e)
}()
addr, err := waitForServerStart(addrChan, errChan)
assert.NoError(t, err)
// now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS
res, err = client.Get(fmt.Sprintf("http://%v", addr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
// see if HTTPS works after HTTP listener is also added
res, err = client.Get(fmt.Sprintf("https://%v", tlsAddr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestFilepathOrContent(t *testing.T) {
cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
require.NoError(t, err)
key, err := ioutil.ReadFile("_fixture/certs/key.pem")
require.NoError(t, err)
testCases := []struct {
name string
cert interface{}
key interface{}
expectedErr error
}{
{
name: `ValidCertAndKeyFilePath`,
cert: "_fixture/certs/cert.pem",
key: "_fixture/certs/key.pem",
expectedErr: nil,
},
{
name: `ValidCertAndKeyByteString`,
cert: cert,
key: key,
expectedErr: nil,
},
{
name: `InvalidKeyType`,
cert: cert,
key: 1,
expectedErr: ErrInvalidCertOrKeyType,
},
{
name: `InvalidCertType`,
cert: 0,
key: key,
expectedErr: ErrInvalidCertOrKeyType,
},
{
name: `InvalidCertAndKeyTypes`,
cert: 0,
key: 1,
expectedErr: ErrInvalidCertOrKeyType,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
s := &StartConfig{
Address: ":0",
CertFilesystem: os.DirFS("."),
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, tc.cert, tc.key)
}()
_, err := waitForServerStart(addrChan, errCh)
if tc.expectedErr != nil {
assert.EqualError(t, err, tc.expectedErr.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func supportsIPv6() bool {
addrs, _ := net.InterfaceAddrs()
for _, addr := range addrs {
// Check if any interface has local IPv6 assigned
if strings.Contains(addr.String(), "::1") {
return true
}
}
return false
}
func TestStartConfig_WithListenerNetwork(t *testing.T) {
testCases := []struct {
name string
network string
address string
}{
{
name: "tcp ipv4 address",
network: "tcp",
address: "127.0.0.1:1323",
},
{
name: "tcp ipv6 address",
network: "tcp",
address: "[::1]:1323",
},
{
name: "tcp4 ipv4 address",
network: "tcp4",
address: "127.0.0.1:1323",
},
{
name: "tcp6 ipv6 address",
network: "tcp6",
address: "[::1]:1323",
},
}
hasIPv6 := supportsIPv6()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if !hasIPv6 && strings.Contains(tc.address, "::") {
t.Skip("Skipping testing IPv6 for " + tc.address + ", not available")
}
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
s := &StartConfig{
Address: tc.address,
ListenerNetwork: tc.network,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.Start(e)
}()
_, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
code, body, err := doGet(fmt.Sprintf("http://%s/ok", tc.address))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, "OK", body)
})
}
}
func TestStartConfig_WithHideBanner(t *testing.T) {
var testCases = []struct {
name string
hideBanner bool
}{
{
name: "hide banner on startup",
hideBanner: true,
},
{
name: "show banner on startup",
hideBanner: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Logger = &testLogger{output: buf}
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
_, err := waitForServerStart(addrChan, errCh)
errCh <- err
shutdown()
}()
s := &StartConfig{
Address: ":0",
HideBanner: tc.hideBanner,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
if err := s.Start(e); err != http.ErrServerClosed {
assert.NoError(t, err)
}
assert.NoError(t, <-errCh)
contains := strings.Contains(buf.String(), "High performance, minimalist Go web framework")
if tc.hideBanner {
assert.False(t, contains)
} else {
assert.True(t, contains)
}
})
}
}
func TestStartConfig_WithHidePort(t *testing.T) {
var testCases = []struct {
name string
hidePort bool
}{
{
name: "hide port on startup",
hidePort: true,
},
{
name: "show port on startup",
hidePort: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Logger = &testLogger{output: buf}
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error, 1)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
go func() {
_, err := waitForServerStart(addrChan, errCh)
errCh <- err
shutdown()
}()
s := &StartConfig{
Address: ":0",
HidePort: tc.hidePort,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
if err := s.Start(e); err != http.ErrServerClosed {
assert.NoError(t, err)
}
assert.NoError(t, <-errCh)
portMsg := fmt.Sprintf("http(s) server started on")
contains := strings.Contains(buf.String(), portMsg)
if tc.hidePort {
assert.False(t, contains)
} else {
assert.True(t, contains)
}
})
}
}
func TestStartConfig_WithBeforeServeFunc(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
s := &StartConfig{
Address: ":0",
BeforeServeFunc: func(s *http.Server) error {
return errors.New("is called before serve")
},
}
err := s.Start(e)
assert.EqualError(t, err, "is called before serve")
}
func TestWithDisableHTTP2(t *testing.T) {
var testCases = []struct {
name string
disableHTTP2 bool
}{
{
name: "HTTP2 enabled",
disableHTTP2: false,
},
{
name: "HTTP2 disabled",
disableHTTP2: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error, 1)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
certFile := "_fixture/certs/cert.pem"
keyFile := "_fixture/certs/key.pem"
s := &StartConfig{
Address: ":0",
DisableHTTP2: tc.disableHTTP2,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, certFile, keyFile)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
url := fmt.Sprintf("https://%v/ok", addr)
// do ordinary http(s) request
client := &http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}}
res, err := client.Get(url)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
// do HTTP2 request
client.Transport = &http2.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
resp, err := client.Get(url)
if err != nil {
if tc.disableHTTP2 {
assert.True(t, strings.Contains(err.Error(), `http2: unexpected ALPN protocol ""; want "h2"`))
return
}
log.Fatalf("Failed get: %s", err)
}
assert.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatalf("Failed reading response body: %s", err)
}
assert.Equal(t, "OK", string(body))
})
}
}
type testLogger struct {
output io.Writer
}
func (l *testLogger) Write(p []byte) (n int, err error) {
return l.output.Write(p)
}
func (l *testLogger) Printf(format string, args ...interface{}) {
_, _ = l.output.Write([]byte(fmt.Sprintf(format, args...)))
}
func (l *testLogger) Error(err error) {
_, _ = l.output.Write([]byte(err.Error()))
}