1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-22 20:06:21 +02:00

WIP: logger examples

WIP: make default logger implemented custom writer for jsonlike logs
WIP: improve examples
WIP: defaultErrorHandler use errors.As to unwrap errors. Update readme
WIP: default logger logs json, restore e.Start method
WIP: clean router.Match a bit
WIP: func types/fields have echo.Context has first element
WIP: remove yaml tags as functions etc can not be serialized anyway
WIP: change BindPathParams,BindQueryParams,BindHeaders from methods to functions and reverse arguments to be like DefaultBinder.Bind is
WIP: improved comments, logger now extracts status from error
WIP: go mod tidy
WIP: rebase with 4.5.0
WIP:
* removed todos.
* removed StartAutoTLS and StartH2CServer methods from `StartConfig`
* KeyAuth middleware errorhandler can swallow the error and resume next middleware
WIP: add RouterConfig.UseEscapedPathForMatching to use escaped path for matching request against routes
WIP: FIXMEs
WIP: upgrade golang-jwt/jwt to `v4`
WIP: refactor http methods to return RouteInfo
WIP: refactor static not creating multiple routes
WIP: refactor route and middleware adding functions not to return error directly
WIP: Use 401 for problematic/missing headers for key auth and JWT middleware (#1552, #1402).
> In summary, a 401 Unauthorized response should be used for missing or bad authentication
WIP: replace `HTTPError.SetInternal` with `HTTPError.WithInternal` so we could not mutate global error variables
WIP: add RouteInfo and RouteMatchType into Context what we could know from in middleware what route was matched and/or type of that match (200/404/405)
WIP: make notFoundHandler and methodNotAllowedHandler private. encourage that all errors be handled in Echo.HTTPErrorHandler
WIP: server cleanup ideas
WIP: routable.ForGroup
WIP: note about logger middleware
WIP: bind should not default values on second try. use crypto rand for better randomness
WIP: router add route as interface and returns info as interface
WIP: improve flaky test (remains still flaky)
WIP: add notes about bind default values
WIP: every route can have their own path params names
WIP: routerCreator and different tests
WIP: different things
WIP: remove route implementation
WIP: support custom method types
WIP: extractor tests
WIP: v5.0.x proposal
over v4.4.0
This commit is contained in:
toimtoimtoim 2021-07-15 23:34:01 +03:00
parent c6f0c667f1
commit 6ef5f77bf2
80 changed files with 9216 additions and 4922 deletions

View File

@ -27,7 +27,8 @@ jobs:
os: [ubuntu-latest, macos-latest, windows-latest]
# Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
# Echo tests with last four major releases
go: [1.14, 1.15, 1.16, 1.17]
# except v5 starts from 1.16 until there is last four major releases after that
go: [1.16, 1.17]
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:

View File

@ -1,21 +0,0 @@
arch:
- amd64
- ppc64le
language: go
go:
- 1.14.x
- 1.15.x
- tip
env:
- GO111MODULE=on
install:
- go get -v golang.org/x/lint/golint
script:
- golint -set_exit_status ./...
- go test -race -coverprofile=coverage.txt -covermode=atomic ./...
after_success:
- bash <(curl -s https://codecov.io/bash)
matrix:
allow_failures:
- go: tip

View File

@ -24,11 +24,11 @@ race: ## Run tests with data race detector
@go test -race ${PKG_LIST}
benchmark: ## Run benchmarks
@go test -run="-" -bench=".*" ${PKG_LIST}
@go test -run="-" -benchmem -bench=".*" ${PKG_LIST}
help: ## Display this help screen
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
goversion ?= "1.15"
test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.15
goversion ?= "1.16"
test_version: ## Run tests inside Docker with given version (defaults to 1.16 oldest supported). Example: make test_version goversion=1.16
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"

View File

@ -12,6 +12,8 @@
## Supported Go versions
Echo supports last four major releases. `v5` starts from 1.16 until there is last four major releases after that.
As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
Therefore a Go version capable of understanding /vN suffixed imports is required:
@ -67,8 +69,8 @@ package main
import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
)
func main() {
@ -83,7 +85,9 @@ func main() {
e.GET("/", hello)
// Start server
e.Logger.Fatal(e.Start(":1323"))
if err := e.Start(":1323"); err != http.ErrServerClosed {
log.Fatal(err)
}
}
// Handler

70
bind.go
View File

@ -11,42 +11,38 @@ import (
"strings"
)
type (
// Binder is the interface that wraps the Bind method.
Binder interface {
Bind(i interface{}, c Context) error
type Binder interface {
Bind(c Context, i interface{}) error
}
// DefaultBinder is the default implementation of the Binder interface.
DefaultBinder struct{}
type DefaultBinder struct{}
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
// Types that don't implement this, but do implement encoding.TextUnmarshaler
// will use that interface instead.
BindUnmarshaler interface {
type BindUnmarshaler interface {
// UnmarshalParam decodes and assigns a value from an form or query param.
UnmarshalParam(param string) error
}
)
// BindPathParams binds path params to bindable object
func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error {
names := c.ParamNames()
values := c.ParamValues()
func BindPathParams(c Context, i interface{}) error {
params := map[string][]string{}
for i, name := range names {
params[name] = []string{values[i]}
for _, param := range c.PathParams() {
params[param.Name] = []string{param.Value}
}
if err := b.bindData(i, params, "param"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
if err := bindData(i, params, "param"); err != nil {
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
return nil
}
// BindQueryParams binds query params to bindable object
func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error {
if err := b.bindData(i, c.QueryParams(), "query"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
func BindQueryParams(c Context, i interface{}) error {
if err := bindData(i, c.QueryParams(), "query"); err != nil {
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
return nil
}
@ -56,7 +52,7 @@ func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error {
// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm
// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm
func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
func BindBody(c Context, i interface{}) (err error) {
req := c.Request()
if req.ContentLength == 0 {
return
@ -70,25 +66,25 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
case *HTTPError:
return err
default:
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
}
case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML):
if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
if ute, ok := err.(*xml.UnsupportedTypeError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error()))
} else if se, ok := err.(*xml.SyntaxError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error()))
}
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
params, err := c.FormParams()
if err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
if err = b.bindData(i, params, "form"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
if err = bindData(i, params, "form"); err != nil {
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
default:
return ErrUnsupportedMediaType
@ -98,17 +94,17 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) {
// BindHeaders binds HTTP headers to a bindable object
func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error {
if err := b.bindData(i, c.Request().Header, "header"); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
if err := bindData(i, c.Request().Header, "header"); err != nil {
return NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
}
return nil
}
// Bind implements the `Binder#Bind` function.
// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous
// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
if err := b.BindPathParams(c, i); err != nil {
// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams.
func (b *DefaultBinder) Bind(c Context, i interface{}) (err error) {
if err := BindPathParams(c, i); err != nil {
return err
}
// Issue #1670 - Query params are binded only for GET/DELETE and NOT for usual request with body (POST/PUT/PATCH)
@ -116,15 +112,15 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
// i.e. is `&id=1&lang=en` from URL same as `{"id":100,"lang":"de"}` request body and which one should have priority when binding.
// This HTTP method check restores pre v4.1.11 behavior and avoids different problems when query is mixed with body
if c.Request().Method == http.MethodGet || c.Request().Method == http.MethodDelete {
if err = b.BindQueryParams(c, i); err != nil {
if err = BindQueryParams(c, i); err != nil {
return err
}
}
return b.BindBody(c, i)
return BindBody(c, i)
}
// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag
func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error {
func bindData(destination interface{}, data map[string][]string, tag string) error {
if destination == nil || len(data) == 0 {
return nil
}
@ -170,7 +166,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
// If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags).
// structs that implement BindUnmarshaler are binded only when they have explicit tag
if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil {
if err := bindData(structField.Addr().Interface(), data, tag); err != nil {
return err
}
}
@ -297,7 +293,7 @@ func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
func setIntField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0"
return nil
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err == nil {
@ -308,7 +304,7 @@ func setIntField(value string, bitSize int, field reflect.Value) error {
func setUintField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0"
return nil
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err == nil {
@ -319,7 +315,7 @@ func setUintField(value string, bitSize int, field reflect.Value) error {
func setBoolField(value string, field reflect.Value) error {
if value == "" {
value = "false"
return nil
}
boolVal, err := strconv.ParseBool(value)
if err == nil {
@ -330,7 +326,7 @@ func setBoolField(value string, field reflect.Value) error {
func setFloatField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0.0"
return nil
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err == nil {

View File

@ -300,6 +300,52 @@ func TestBindHeaderParamBadType(t *testing.T) {
}
}
func TestBind_CombineQueryWithHeaderParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/products/999?length=50&page=10&language=et", nil)
req.Header.Set("language", "de")
req.Header.Set("length", "99")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.SetPathParams(PathParams{{
Name: "id",
Value: "999",
}})
type SearchOpts struct {
ID int `param:"id"`
Length int `query:"length"`
Page int `query:"page"`
Search string `query:"search"`
Language string `query:"language" header:"language"`
}
opts := SearchOpts{
Length: 100,
Page: 0,
Search: "default value",
Language: "en",
}
err := c.Bind(&opts)
assert.NoError(t, err)
assert.Equal(t, 50, opts.Length) // bind from query
assert.Equal(t, 10, opts.Page) // bind from query
assert.Equal(t, 999, opts.ID) // bind from path param
assert.Equal(t, "et", opts.Language) // bind from query
assert.Equal(t, "default value", opts.Search) // default value stays
// make sure another bind will not mess already set values unless there are new values
err = (&DefaultBinder{}).BindHeaders(c, &opts)
assert.NoError(t, err)
assert.Equal(t, 50, opts.Length) // does not have tag in struct although header exists
assert.Equal(t, 10, opts.Page)
assert.Equal(t, 999, opts.ID)
assert.Equal(t, "de", opts.Language) // header overwrites now this value
assert.Equal(t, "default value", opts.Search)
}
func TestBindUnmarshalParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
@ -330,7 +376,7 @@ func TestBindUnmarshalParam(t *testing.T) {
func TestBindUnmarshalText(t *testing.T) {
e := New()
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
@ -406,7 +452,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) {
func TestBindUnmarshalTextPtr(t *testing.T) {
e := New()
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil)
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
@ -439,8 +485,7 @@ func TestBindUnsupportedMediaType(t *testing.T) {
func TestBindbindData(t *testing.T) {
a := assert.New(t)
ts := new(bindTestStruct)
b := new(DefaultBinder)
err := b.bindData(ts, values, "form")
err := bindData(ts, values, "form")
a.NoError(err)
a.Equal(0, ts.I)
@ -462,12 +507,15 @@ func TestBindbindData(t *testing.T) {
func TestBindParam(t *testing.T) {
e := New()
req := httptest.NewRequest(GET, "/", nil)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.SetPath("/users/:id/:name")
c.SetParamNames("id", "name")
c.SetParamValues("1", "Jon Snow")
cc := c.(EditableContext)
cc.SetRouteInfo(routeInfo{path: "/users/:id/:name"})
cc.SetPathParams(PathParams{
{Name: "id", Value: "1"},
{Name: "name", Value: "Jon Snow"},
})
u := new(user)
err := c.Bind(u)
@ -478,9 +526,11 @@ func TestBindParam(t *testing.T) {
// Second test for the absence of a param
c2 := e.NewContext(req, rec)
c2.SetPath("/users/:id")
c2.SetParamNames("id")
c2.SetParamValues("1")
cc2 := c2.(EditableContext)
cc2.SetRouteInfo(routeInfo{path: "/users/:id"})
cc2.SetPathParams(PathParams{
{Name: "id", Value: "1"},
})
u = new(user)
err = c2.Bind(u)
@ -492,15 +542,17 @@ func TestBindParam(t *testing.T) {
// Bind something with param and post data payload
body := bytes.NewBufferString(`{ "name": "Jon Snow" }`)
e2 := New()
req2 := httptest.NewRequest(POST, "/", body)
req2 := httptest.NewRequest(http.MethodPost, "/", body)
req2.Header.Set(HeaderContentType, MIMEApplicationJSON)
rec2 := httptest.NewRecorder()
c3 := e2.NewContext(req2, rec2)
c3.SetPath("/users/:id")
c3.SetParamNames("id")
c3.SetParamValues("1")
cc3 := c3.(EditableContext)
cc3.SetRouteInfo(routeInfo{path: "/users/:id"})
cc3.SetPathParams(PathParams{
{Name: "id", Value: "1"},
})
u = new(user)
err = c3.Bind(u)
@ -556,47 +608,115 @@ func TestBindSetWithProperType(t *testing.T) {
assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
}
func TestBindSetFields(t *testing.T) {
assert := assert.New(t)
func TestSetIntField(t *testing.T) {
ts := new(bindTestStruct)
ts.I = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setIntField("", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 100, ts.I)
// second set with value sets the value
err = setIntField("5", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 5, ts.I)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setIntField("", 0, val.FieldByName("I"))
assert.NoError(t, err)
assert.Equal(t, 5, ts.I)
}
func TestSetUintField(t *testing.T) {
ts := new(bindTestStruct)
ts.UI = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setUintField("", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(100), ts.UI)
// second set with value sets the value
err = setUintField("5", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(5), ts.UI)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setUintField("", 0, val.FieldByName("UI"))
assert.NoError(t, err)
assert.Equal(t, uint(5), ts.UI)
}
func TestSetFloatField(t *testing.T) {
ts := new(bindTestStruct)
ts.F32 = 100
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setFloatField("", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(100), ts.F32)
// second set with value sets the value
err = setFloatField("15.5", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(15.5), ts.F32)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setFloatField("", 0, val.FieldByName("F32"))
assert.NoError(t, err)
assert.Equal(t, float32(15.5), ts.F32)
}
func TestSetBoolField(t *testing.T) {
ts := new(bindTestStruct)
ts.B = true
val := reflect.ValueOf(ts).Elem()
// empty value does nothing to field
// in that way we can have default values by setting field value before binding
err := setBoolField("", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// second set with value sets the value
err = setBoolField("true", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// third set without value does nothing to the value
// in that way multiple binds (ala query + header) do not reset fields to 0s
err = setBoolField("", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, true, ts.B)
// fourth set to false
err = setBoolField("false", val.FieldByName("B"))
assert.NoError(t, err)
assert.Equal(t, false, ts.B)
}
func TestUnmarshalFieldNonPtr(t *testing.T) {
ts := new(bindTestStruct)
val := reflect.ValueOf(ts).Elem()
// Int
if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) {
assert.Equal(5, ts.I)
}
if assert.NoError(setIntField("", 0, val.FieldByName("I"))) {
assert.Equal(0, ts.I)
}
// Uint
if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) {
assert.Equal(uint(10), ts.UI)
}
if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) {
assert.Equal(uint(0), ts.UI)
}
// Float
if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) {
assert.Equal(float32(15.5), ts.F32)
}
if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) {
assert.Equal(float32(0.0), ts.F32)
}
// Bool
if assert.NoError(setBoolField("true", val.FieldByName("B"))) {
assert.Equal(true, ts.B)
}
if assert.NoError(setBoolField("", val.FieldByName("B"))) {
assert.Equal(false, ts.B)
}
ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T"))
if assert.NoError(err) {
assert.Equal(ok, true)
assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
if assert.NoError(t, err) {
assert.True(t, ok)
assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
}
}
@ -604,11 +724,10 @@ func BenchmarkBindbindDataWithTags(b *testing.B) {
b.ReportAllocs()
assert := assert.New(b)
ts := new(bindTestStructWithTags)
binder := new(DefaultBinder)
var err error
b.ResetTimer()
for i := 0; i < b.N; i++ {
err = binder.bindData(ts, values, "form")
err = bindData(ts, values, "form")
}
assert.NoError(err)
assertBindTestStruct(assert, (*bindTestStruct)(ts))
@ -840,8 +959,10 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
c := e.NewContext(req, rec)
if !tc.whenNoPathParams {
c.SetParamNames("node")
c.SetParamValues("node_from_path")
cc := c.(EditableContext)
cc.SetPathParams(PathParams{
{Name: "node", Value: "node_from_path"},
})
}
var bindTarget interface{}
@ -852,7 +973,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
}
b := new(DefaultBinder)
err := b.Bind(bindTarget, c)
err := b.Bind(c, bindTarget)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
@ -1021,8 +1142,10 @@ func TestDefaultBinder_BindBody(t *testing.T) {
c := e.NewContext(req, rec)
if !tc.whenNoPathParams {
c.SetParamNames("node")
c.SetParamValues("real_node")
cc := c.(EditableContext)
cc.SetPathParams(PathParams{
{Name: "node", Value: "real_node"},
})
}
var bindTarget interface{}
@ -1031,9 +1154,8 @@ func TestDefaultBinder_BindBody(t *testing.T) {
} else {
bindTarget = &Node{}
}
b := new(DefaultBinder)
err := b.BindBody(c, bindTarget)
err := BindBody(c, bindTarget)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {

View File

@ -118,10 +118,10 @@ func QueryParamsBinder(c Context) *ValueBinder {
func PathParamsBinder(c Context) *ValueBinder {
return &ValueBinder{
failFast: true,
ValueFunc: c.Param,
ValueFunc: c.PathParam,
ValuesFunc: func(sourceParam string) []string {
// path parameter should not have multiple values so getting values does not make sense but lets not error out here
value := c.Param(sourceParam)
value := c.PathParam(sourceParam)
if value == "" {
return nil
}

View File

@ -30,14 +30,15 @@ func createTestContext15(URL string, body io.Reader, pathParams map[string]strin
c := e.NewContext(req, rec)
if len(pathParams) > 0 {
names := make([]string, 0)
values := make([]string, 0)
params := make(PathParams, 0)
for name, value := range pathParams {
names = append(names, name)
values = append(values, value)
params = append(params, PathParam{
Name: name,
Value: value,
})
}
c.SetParamNames(names...)
c.SetParamValues(values...)
cc := c.(EditableContext)
cc.SetPathParams(params)
}
return c

View File

@ -25,14 +25,15 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string)
c := e.NewContext(req, rec)
if len(pathParams) > 0 {
names := make([]string, 0)
values := make([]string, 0)
params := make(PathParams, 0)
for name, value := range pathParams {
names = append(names, name)
values = append(values, value)
params = append(params, PathParam{
Name: name,
Value: value,
})
}
c.SetParamNames(names...)
c.SetParamValues(values...)
cc := c.(EditableContext)
cc.SetPathParams(params)
}
return c
@ -2643,7 +2644,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) {
binder := new(DefaultBinder)
for i := 0; i < b.N; i++ {
var dest Opts
_ = binder.Bind(&dest, c)
_ = binder.Bind(c, &dest)
}
}
@ -2710,7 +2711,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
binder := new(DefaultBinder)
for i := 0; i < b.N; i++ {
var dest Opts
_ = binder.Bind(&dest, c)
_ = binder.Bind(c, &dest)
if dest.Int64 != 1 {
b.Fatalf("int64!=1")
}

View File

@ -3,22 +3,22 @@ package echo
import (
"bytes"
"encoding/xml"
"errors"
"fmt"
"io"
"io/fs"
"mime/multipart"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
)
type (
// Context represents the context of the current HTTP request. It holds request and
// response objects, path, path parameters, data and registered handler.
Context interface {
type Context interface {
// Request returns `*http.Request`.
Request() *http.Request
@ -45,26 +45,32 @@ type (
// The behavior can be configured using `Echo#IPExtractor`.
RealIP() string
// RouteMatchType returns router match type for current context. This helps middlewares to distinguish which type
// of match router found and how this request context handler chain could end:
// * route match - this path + method had matching route.
// * not found - this path did not match any routes enough to be considered match
// * method not allowed - path had routes registered but for other method types then current request is
// * unknown - initial state for fresh context before router tries to do routing
//
// Note: for pre-middleware (Echo.Pre) this method result is always RouteMatchUnknown as at point router has not tried
// to match request to route.
RouteMatchType() RouteMatchType
// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route.
// In case of 404 (route not found) and 405 (method not allowed) RouteInfo returns generic struct for these cases.
RouteInfo() RouteInfo
// Path returns the registered path for the handler.
Path() string
// SetPath sets the registered path for the handler.
SetPath(p string)
// PathParam returns path parameter by name.
PathParam(name string) string
// Param returns path parameter by name.
Param(name string) string
// PathParams returns path parameter values.
PathParams() PathParams
// ParamNames returns path parameter names.
ParamNames() []string
// SetParamNames sets path parameter names.
SetParamNames(names ...string)
// ParamValues returns path parameter values.
ParamValues() []string
// SetParamValues sets path parameter values.
SetParamValues(values ...string)
// SetPathParams set path parameter for during current request lifecycle.
SetPathParams(params PathParams)
// QueryParam returns the query param for the provided name.
QueryParam(name string) string
@ -171,23 +177,33 @@ type (
// Redirect redirects the request to a provided URL with status code.
Redirect(code int, url string) error
// Error invokes the registered HTTP error handler. Generally used by middleware.
// Error invokes the registered HTTP error handler.
// NB: Avoid using this method. It is better to return errors so middlewares up in chain could act on returned error.
Error(err error)
// Handler returns the matched handler by router.
Handler() HandlerFunc
// SetHandler sets the matched handler by router.
SetHandler(h HandlerFunc)
// Logger returns the `Logger` instance.
Logger() Logger
// Set the logger
SetLogger(l Logger)
// Echo returns the `Echo` instance.
Echo() *Echo
}
// EditableContext is additional interface that structure implementing Context must implement. Methods inside this
// interface are meant for Echo internal usage (for mainly routing) and should not be used in middlewares.
type EditableContext interface {
Context
// RawPathParams returns raw path pathParams value.
RawPathParams() *PathParams
// SetRawPathParams replaces any existing param values with new values for this context lifetime (request).
SetRawPathParams(params *PathParams)
// SetPath sets the registered path for the handler.
SetPath(p string)
// SetRouteMatchType sets the RouteMatchType of router match for this request.
SetRouteMatchType(t RouteMatchType)
// SetRouteInfo sets the route info of this request to the context.
SetRouteInfo(ri RouteInfo)
// Reset resets the context after request completes. It must be called along
// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
@ -195,20 +211,25 @@ type (
Reset(r *http.Request, w http.ResponseWriter)
}
context struct {
type context struct {
request *http.Request
response *Response
matchType RouteMatchType
route RouteInfo
path string
pnames []string
pvalues []string
// pathParams holds path/uri parameters determined by Router. Lifecycle is handled by Echo to reduce allocations.
pathParams *PathParams
// currentParams hold path parameters set by non-Echo implementation (custom middlewares, handlers) during the lifetime of Request.
// Lifecycle is not handle by Echo and could have excess allocations per served Request
currentParams PathParams
query url.Values
handler HandlerFunc
store Map
echo *Echo
logger Logger
lock sync.RWMutex
}
)
const (
defaultMemory = 32 << 20 // 32 MB
@ -296,52 +317,50 @@ func (c *context) SetPath(p string) {
c.path = p
}
func (c *context) Param(name string) string {
for i, n := range c.pnames {
if i < len(c.pvalues) {
if n == name {
return c.pvalues[i]
}
}
}
return ""
func (c *context) RouteMatchType() RouteMatchType {
return c.matchType
}
func (c *context) ParamNames() []string {
return c.pnames
func (c *context) SetRouteMatchType(t RouteMatchType) {
c.matchType = t
}
func (c *context) SetParamNames(names ...string) {
c.pnames = names
l := len(names)
if *c.echo.maxParam < l {
*c.echo.maxParam = l
func (c *context) RouteInfo() RouteInfo {
return c.route
}
if len(c.pvalues) < l {
// Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them,
// probably those values will be overriden in a Context#SetParamValues
newPvalues := make([]string, l)
copy(newPvalues, c.pvalues)
c.pvalues = newPvalues
}
func (c *context) SetRouteInfo(ri RouteInfo) {
c.route = ri
}
func (c *context) ParamValues() []string {
return c.pvalues[:len(c.pnames)]
func (c *context) RawPathParams() *PathParams {
return c.pathParams
}
func (c *context) SetParamValues(values ...string) {
// NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times
// It will brake the Router#Find code
limit := len(values)
if limit > *c.echo.maxParam {
limit = *c.echo.maxParam
func (c *context) SetRawPathParams(params *PathParams) {
c.pathParams = params
}
for i := 0; i < limit; i++ {
c.pvalues[i] = values[i]
func (c *context) PathParam(name string) string {
if c.currentParams != nil {
return c.currentParams.Get(name, "")
}
return c.pathParams.Get(name, "")
}
func (c *context) PathParams() PathParams {
if c.currentParams != nil {
return c.currentParams
}
result := make(PathParams, len(*c.pathParams))
copy(result, *c.pathParams)
return result
}
func (c *context) SetPathParams(params PathParams) {
c.currentParams = params
}
func (c *context) QueryParam(name string) string {
@ -422,7 +441,7 @@ func (c *context) Set(key string, val interface{}) {
}
func (c *context) Bind(i interface{}) error {
return c.echo.Binder.Bind(i, c)
return c.echo.Binder.Bind(c, i)
}
func (c *context) Validate(i interface{}) error {
@ -562,27 +581,36 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error)
return
}
func (c *context) File(file string) (err error) {
f, err := os.Open(file)
func (c *context) File(file string) error {
return c.FsFile(file, c.echo.Filesystem)
}
func (c *context) FsFile(file string, filesystem fs.FS) error {
// FIXME: should we add this method into echo.Context interface?
f, err := filesystem.Open(file)
if err != nil {
return NotFoundHandler(c)
return ErrNotFound
}
defer f.Close()
fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.Join(file, indexPage)
f, err = os.Open(file)
f, err = filesystem.Open(file)
if err != nil {
return NotFoundHandler(c)
return ErrNotFound
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
return
return err
}
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f)
return
ff, ok := f.(io.ReadSeeker)
if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
return nil
}
func (c *context) Attachment(file, name string) error {
@ -613,44 +641,23 @@ func (c *context) Redirect(code int, url string) error {
}
func (c *context) Error(err error) {
c.echo.HTTPErrorHandler(err, c)
c.echo.HTTPErrorHandler(c, err)
}
func (c *context) Echo() *Echo {
return c.echo
}
func (c *context) Handler() HandlerFunc {
return c.handler
}
func (c *context) SetHandler(h HandlerFunc) {
c.handler = h
}
func (c *context) Logger() Logger {
res := c.logger
if res != nil {
return res
}
return c.echo.Logger
}
func (c *context) SetLogger(l Logger) {
c.logger = l
}
func (c *context) Reset(r *http.Request, w http.ResponseWriter) {
c.request = r
c.response.reset(w)
c.query = nil
c.handler = NotFoundHandler
c.store = nil
c.matchType = RouteMatchUnknown
c.route = nil
c.path = ""
c.pnames = nil
c.logger = nil
// NOTE: Don't reset because it has to have length c.echo.maxParam at all times
for i := 0; i < *c.echo.maxParam; i++ {
c.pvalues[i] = ""
}
// NOTE: Don't reset because it has to have length c.echo.contextPathParamAllocSize at all times
*c.pathParams = (*c.pathParams)[:0]
c.currentParams = nil
}

View File

@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"mime/multipart"
"net/http"
@ -18,21 +19,19 @@ import (
"text/template"
"time"
"github.com/labstack/gommon/log"
testify "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/assert"
)
type (
Template struct {
type Template struct {
templates *template.Template
}
)
var testUser = user{1, "Jon Snow"}
func BenchmarkAllocJSONP(b *testing.B) {
e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
e.Logger = &jsonLogger{writer: ioutil.Discard}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
@ -46,7 +45,8 @@ func BenchmarkAllocJSONP(b *testing.B) {
func BenchmarkAllocJSON(b *testing.B) {
e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
e.Logger = &jsonLogger{writer: ioutil.Discard}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
@ -60,7 +60,8 @@ func BenchmarkAllocJSON(b *testing.B) {
func BenchmarkAllocXML(b *testing.B) {
e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
e.Logger = &jsonLogger{writer: ioutil.Discard}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
@ -106,16 +107,14 @@ func TestContext(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
assert := testify.New(t)
// Echo
assert.Equal(e, c.Echo())
assert.Equal(t, e, c.Echo())
// Request
assert.NotNil(c.Request())
assert.NotNil(t, c.Request())
// Response
assert.NotNil(c.Response())
assert.NotNil(t, c.Response())
//--------
// Render
@ -126,23 +125,23 @@ func TestContext(t *testing.T) {
}
c.echo.Renderer = tmpl
err := c.Render(http.StatusOK, "hello", "Jon Snow")
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal("Hello, Jon Snow!", rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Hello, Jon Snow!", rec.Body.String())
}
c.echo.Renderer = nil
err = c.Render(http.StatusOK, "hello", "Jon Snow")
assert.Error(err)
assert.Error(t, err)
// JSON
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSON+"\n", rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, userJSON+"\n", rec.Body.String())
}
// JSON with "?pretty"
@ -150,10 +149,10 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSONPretty+"\n", rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
}
req = httptest.NewRequest(http.MethodGet, "/", nil) // reset
@ -161,37 +160,37 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSONPretty+"\n", rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
}
// JSON (error)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.JSON(http.StatusOK, make(chan bool))
assert.Error(err)
assert.Error(t, err)
// JSONP
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
callback := "callback"
err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"})
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String())
}
// XML
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXML, rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, xml.Header+userXML, rec.Body.String())
}
// XML with "?pretty"
@ -199,10 +198,10 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXMLPretty, rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
}
req = httptest.NewRequest(http.MethodGet, "/", nil)
@ -210,22 +209,22 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.XML(http.StatusOK, make(chan bool))
assert.Error(err)
assert.Error(t, err)
// XML response write error
c = e.NewContext(req, rec).(*context)
c.response.Writer = responseWriterErr{}
err = c.XML(0, 0)
testify.Error(t, err)
assert.Error(t, err)
// XMLPretty
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXMLPretty, rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
}
t.Run("empty indent", func(t *testing.T) {
@ -237,7 +236,6 @@ func TestContext(t *testing.T) {
t.Run("json", func(t *testing.T) {
buf.Reset()
assert := testify.New(t)
// New JSONBlob with empty indent
rec = httptest.NewRecorder()
@ -246,16 +244,15 @@ func TestContext(t *testing.T) {
enc.SetIndent(emptyIndent, emptyIndent)
err = enc.Encode(u)
err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(buf.String(), rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, buf.String(), rec.Body.String())
}
})
t.Run("xml", func(t *testing.T) {
buf.Reset()
assert := testify.New(t)
// New XMLBlob with empty indent
rec = httptest.NewRecorder()
@ -264,10 +261,10 @@ func TestContext(t *testing.T) {
enc.Indent(emptyIndent, emptyIndent)
err = enc.Encode(u)
err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+buf.String(), rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, xml.Header+buf.String(), rec.Body.String())
}
})
})
@ -276,12 +273,12 @@ func TestContext(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
data, err := json.Marshal(user{1, "Jon Snow"})
assert.NoError(err)
assert.NoError(t, err)
err = c.JSONBlob(http.StatusOK, data)
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSON, rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, userJSON, rec.Body.String())
}
// Legacy JSONPBlob
@ -289,44 +286,44 @@ func TestContext(t *testing.T) {
c = e.NewContext(req, rec).(*context)
callback = "callback"
data, err = json.Marshal(user{1, "Jon Snow"})
assert.NoError(err)
assert.NoError(t, err)
err = c.JSONPBlob(http.StatusOK, callback, data)
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(callback+"("+userJSON+");", rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, callback+"("+userJSON+");", rec.Body.String())
}
// Legacy XMLBlob
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
data, err = xml.Marshal(user{1, "Jon Snow"})
assert.NoError(err)
assert.NoError(t, err)
err = c.XMLBlob(http.StatusOK, data)
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(xml.Header+userXML, rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, xml.Header+userXML, rec.Body.String())
}
// String
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.String(http.StatusOK, "Hello, World!")
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal("Hello, World!", rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, "Hello, World!", rec.Body.String())
}
// HTML
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.HTML(http.StatusOK, "Hello, <strong>World!</strong>")
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal("Hello, <strong>World!</strong>", rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, "Hello, <strong>World!</strong>", rec.Body.String())
}
// Stream
@ -334,55 +331,55 @@ func TestContext(t *testing.T) {
c = e.NewContext(req, rec).(*context)
r := strings.NewReader("response from a stream")
err = c.Stream(http.StatusOK, "application/octet-stream", r)
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType))
assert.Equal("response from a stream", rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType))
assert.Equal(t, "response from a stream", rec.Body.String())
}
// Attachment
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.Attachment("_fixture/images/walle.png", "walle.png")
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(219885, rec.Body.Len())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, 219885, rec.Body.Len())
}
// Inline
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
err = c.Inline("_fixture/images/walle.png", "walle.png")
if assert.NoError(err) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(219885, rec.Body.Len())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, 219885, rec.Body.Len())
}
// NoContent
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
c.NoContent(http.StatusOK)
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(t, http.StatusOK, rec.Code)
// Error
rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*context)
c.Error(errors.New("error"))
assert.Equal(http.StatusInternalServerError, rec.Code)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
// Reset
c.SetParamNames("foo")
c.SetParamValues("bar")
c.pathParams = &PathParams{
{Name: "foo", Value: "bar"},
}
c.Set("foe", "ban")
c.query = url.Values(map[string][]string{"fon": {"baz"}})
c.Reset(req, httptest.NewRecorder())
assert.Equal(0, len(c.ParamValues()))
assert.Equal(0, len(c.ParamNames()))
assert.Equal(0, len(c.store))
assert.Equal("", c.Path())
assert.Equal(0, len(c.QueryParams()))
assert.Equal(t, 0, len(c.PathParams()))
assert.Equal(t, 0, len(c.store))
assert.Equal(t, nil, c.RouteInfo())
assert.Equal(t, 0, len(c.QueryParams()))
}
func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
@ -392,11 +389,10 @@ func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
c := e.NewContext(req, rec).(*context)
err := c.JSON(http.StatusCreated, user{1, "Jon Snow"})
assert := testify.New(t)
if assert.NoError(err) {
assert.Equal(http.StatusCreated, rec.Code)
assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(userJSON+"\n", rec.Body.String())
if assert.NoError(t, err) {
assert.Equal(t, http.StatusCreated, rec.Code)
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, userJSON+"\n", rec.Body.String())
}
}
@ -407,9 +403,8 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) {
c := e.NewContext(req, rec).(*context)
err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()})
assert := testify.New(t)
if assert.Error(err) {
assert.False(c.response.Committed)
if assert.Error(t, err) {
assert.False(t, c.response.Committed)
}
}
@ -423,22 +418,20 @@ func TestContextCookie(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
assert := testify.New(t)
// Read single
cookie, err := c.Cookie("theme")
if assert.NoError(err) {
assert.Equal("theme", cookie.Name)
assert.Equal("light", cookie.Value)
if assert.NoError(t, err) {
assert.Equal(t, "theme", cookie.Name)
assert.Equal(t, "light", cookie.Value)
}
// Read multiple
for _, cookie := range c.Cookies() {
switch cookie.Name {
case "theme":
assert.Equal("light", cookie.Value)
assert.Equal(t, "light", cookie.Value)
case "user":
assert.Equal("Jon Snow", cookie.Value)
assert.Equal(t, "Jon Snow", cookie.Value)
}
}
@ -453,104 +446,95 @@ func TestContextCookie(t *testing.T) {
HttpOnly: true,
}
c.SetCookie(cookie)
assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID")
assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com")
assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure")
assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly")
}
func TestContextPath(t *testing.T) {
e := New()
r := e.Router()
handler := func(c Context) error { return c.String(http.StatusOK, "OK") }
r.Add(http.MethodGet, "/users/:id", handler)
c := e.NewContext(nil, nil)
r.Find(http.MethodGet, "/users/1", c)
assert := testify.New(t)
assert.Equal("/users/:id", c.Path())
r.Add(http.MethodGet, "/users/:uid/files/:fid", handler)
c = e.NewContext(nil, nil)
r.Find(http.MethodGet, "/users/1/files/1", c)
assert.Equal("/users/:uid/files/:fid", c.Path())
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID")
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com")
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure")
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
}
func TestContextPathParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, nil)
c := e.NewContext(req, nil).(*context)
params := &PathParams{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
}
// ParamNames
c.SetParamNames("uid", "fid")
testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames())
// ParamValues
c.SetParamValues("101", "501")
testify.EqualValues(t, []string{"101", "501"}, c.ParamValues())
c.pathParams = params
assert.EqualValues(t, *params, c.PathParams())
// Param
testify.Equal(t, "501", c.Param("fid"))
testify.Equal(t, "", c.Param("undefined"))
assert.Equal(t, "501", c.PathParam("fid"))
assert.Equal(t, "", c.PathParam("undefined"))
}
func TestContextGetAndSetParam(t *testing.T) {
e := New()
r := e.Router()
r.Add(http.MethodGet, "/:foo", func(Context) error { return nil })
_, err := r.Add(Route{
Method: http.MethodGet,
Path: "/:foo",
Name: "",
Handler: func(Context) error { return nil },
Middlewares: nil,
})
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
c := e.NewContext(req, nil)
c.SetParamNames("foo")
params := &PathParams{{Name: "foo", Value: "101"}}
// ParamNames
c.(*context).pathParams = params
// round-trip param values with modification
paramVals := c.ParamValues()
testify.EqualValues(t, []string{""}, c.ParamValues())
paramVals[0] = "bar"
c.SetParamValues(paramVals...)
testify.EqualValues(t, []string{"bar"}, c.ParamValues())
paramVals := c.PathParams()
assert.Equal(t, *params, c.PathParams())
paramVals[0] = PathParam{Name: "xxx", Value: "yyy"} // PathParams() returns copy and modifying it does nothing to context
assert.Equal(t, PathParams{{Name: "foo", Value: "101"}}, c.PathParams())
pathParams := PathParams{
{Name: "aaa", Value: "bbb"},
{Name: "ccc", Value: "ddd"},
}
c.SetPathParams(pathParams)
assert.Equal(t, pathParams, c.PathParams())
// shouldn't explode during Reset() afterwards!
testify.NotPanics(t, func() {
c.Reset(nil, nil)
assert.NotPanics(t, func() {
c.(EditableContext).Reset(nil, nil)
})
assert.Equal(t, PathParams{}, c.PathParams())
assert.Len(t, *c.(*context).pathParams, 0)
assert.Equal(t, cap(*c.(*context).pathParams), 1)
}
// Issue #1655
func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) {
assert := testify.New(t)
func TestContext_SetParamNamesShouldNotModifyPathParams(t *testing.T) {
e := New()
assert.Equal(0, *e.maxParam)
c := e.NewContext(nil, nil).(*context)
expectedOneParam := []string{"one"}
expectedTwoParams := []string{"one", "two"}
expectedThreeParams := []string{"one", "two", ""}
expectedABCParams := []string{"A", "B", "C"}
assert.Equal(t, 0, e.contextPathParamAllocSize)
expectedTwoParams := &PathParams{
{Name: "1", Value: "one"},
{Name: "2", Value: "two"},
}
c.SetRawPathParams(expectedTwoParams)
assert.Equal(t, 0, e.contextPathParamAllocSize)
assert.Equal(t, *expectedTwoParams, c.PathParams())
c := e.NewContext(nil, nil)
c.SetParamNames("1", "2")
c.SetParamValues(expectedTwoParams...)
assert.Equal(2, *e.maxParam)
assert.EqualValues(expectedTwoParams, c.ParamValues())
c.SetParamNames("1")
assert.Equal(2, *e.maxParam)
// Here for backward compatibility the ParamValues remains as they are
assert.EqualValues(expectedOneParam, c.ParamValues())
c.SetParamNames("1", "2", "3")
assert.Equal(3, *e.maxParam)
// Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam
assert.EqualValues(expectedThreeParams, c.ParamValues())
c.SetParamValues("A", "B", "C", "D")
assert.Equal(3, *e.maxParam)
// Here D shouldn't be returned
assert.EqualValues(expectedABCParams, c.ParamValues())
expectedThreeParams := PathParams{
{Name: "1", Value: "one"},
{Name: "2", Value: "two"},
{Name: "3", Value: "three"},
}
c.SetPathParams(expectedThreeParams)
assert.Equal(t, 0, e.contextPathParamAllocSize)
assert.Equal(t, expectedThreeParams, c.PathParams())
}
func TestContextFormValue(t *testing.T) {
@ -564,13 +548,13 @@ func TestContextFormValue(t *testing.T) {
c := e.NewContext(req, nil)
// FormValue
testify.Equal(t, "Jon Snow", c.FormValue("name"))
testify.Equal(t, "jon@labstack.com", c.FormValue("email"))
assert.Equal(t, "Jon Snow", c.FormValue("name"))
assert.Equal(t, "jon@labstack.com", c.FormValue("email"))
// FormParams
params, err := c.FormParams()
if testify.NoError(t, err) {
testify.Equal(t, url.Values{
if assert.NoError(t, err) {
assert.Equal(t, url.Values{
"name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"},
}, params)
@ -581,8 +565,8 @@ func TestContextFormValue(t *testing.T) {
req.Header.Add(HeaderContentType, MIMEMultipartForm)
c = e.NewContext(req, nil)
params, err = c.FormParams()
testify.Nil(t, params)
testify.Error(t, err)
assert.Nil(t, params)
assert.Error(t, err)
}
func TestContextQueryParam(t *testing.T) {
@ -594,11 +578,11 @@ func TestContextQueryParam(t *testing.T) {
c := e.NewContext(req, nil)
// QueryParam
testify.Equal(t, "Jon Snow", c.QueryParam("name"))
testify.Equal(t, "jon@labstack.com", c.QueryParam("email"))
assert.Equal(t, "Jon Snow", c.QueryParam("name"))
assert.Equal(t, "jon@labstack.com", c.QueryParam("email"))
// QueryParams
testify.Equal(t, url.Values{
assert.Equal(t, url.Values{
"name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"},
}, c.QueryParams())
@ -609,7 +593,7 @@ func TestContextFormFile(t *testing.T) {
buf := new(bytes.Buffer)
mr := multipart.NewWriter(buf)
w, err := mr.CreateFormFile("file", "test")
if testify.NoError(t, err) {
if assert.NoError(t, err) {
w.Write([]byte("test"))
}
mr.Close()
@ -618,8 +602,8 @@ func TestContextFormFile(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
f, err := c.FormFile("file")
if testify.NoError(t, err) {
testify.Equal(t, "test", f.Filename)
if assert.NoError(t, err) {
assert.Equal(t, "test", f.Filename)
}
}
@ -634,8 +618,8 @@ func TestContextMultipartForm(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
f, err := c.MultipartForm()
if testify.NoError(t, err) {
testify.NotNil(t, f)
if assert.NoError(t, err) {
assert.NotNil(t, f)
}
}
@ -644,16 +628,16 @@ func TestContextRedirect(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
testify.Equal(t, http.StatusMovedPermanently, rec.Code)
testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
assert.Equal(t, http.StatusMovedPermanently, rec.Code)
assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
}
func TestContextStore(t *testing.T) {
var c Context = new(context)
c.Set("name", "Jon Snow")
testify.Equal(t, "Jon Snow", c.Get("name"))
assert.Equal(t, "Jon Snow", c.Get("name"))
}
func BenchmarkContext_Store(b *testing.B) {
@ -671,42 +655,6 @@ func BenchmarkContext_Store(b *testing.B) {
}
}
func TestContextHandler(t *testing.T) {
e := New()
r := e.Router()
b := new(bytes.Buffer)
r.Add(http.MethodGet, "/handler", func(Context) error {
_, err := b.Write([]byte("handler"))
return err
})
c := e.NewContext(nil, nil)
r.Find(http.MethodGet, "/handler", c)
err := c.Handler()(c)
testify.Equal(t, "handler", b.String())
testify.NoError(t, err)
}
func TestContext_SetHandler(t *testing.T) {
var c Context = new(context)
testify.Nil(t, c.Handler())
c.SetHandler(func(c Context) error {
return nil
})
testify.NotNil(t, c.Handler())
}
func TestContext_Path(t *testing.T) {
path := "/pa/th"
var c Context = new(context)
c.SetPath(path)
testify.Equal(t, path, c.Path())
}
type validator struct{}
func (*validator) Validate(i interface{}) error {
@ -717,10 +665,10 @@ func TestContext_Validate(t *testing.T) {
e := New()
c := e.NewContext(nil, nil)
testify.Error(t, c.Validate(struct{}{}))
assert.Error(t, c.Validate(struct{}{}))
e.Validator = &validator{}
testify.NoError(t, c.Validate(struct{}{}))
assert.NoError(t, c.Validate(struct{}{}))
}
func TestContext_QueryString(t *testing.T) {
@ -728,21 +676,21 @@ func TestContext_QueryString(t *testing.T) {
queryString := "query=string&var=val"
req := httptest.NewRequest(GET, "/?"+queryString, nil)
req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil)
c := e.NewContext(req, nil)
testify.Equal(t, queryString, c.QueryString())
assert.Equal(t, queryString, c.QueryString())
}
func TestContext_Request(t *testing.T) {
var c Context = new(context)
testify.Nil(t, c.Request())
assert.Nil(t, c.Request())
req := httptest.NewRequest(GET, "/path", nil)
req := httptest.NewRequest(http.MethodGet, "/path", nil)
c.SetRequest(req)
testify.Equal(t, req, c.Request())
assert.Equal(t, req, c.Request())
}
func TestContext_Scheme(t *testing.T) {
@ -799,14 +747,14 @@ func TestContext_Scheme(t *testing.T) {
}
for _, tt := range tests {
testify.Equal(t, tt.s, tt.c.Scheme())
assert.Equal(t, tt.s, tt.c.Scheme())
}
}
func TestContext_IsWebSocket(t *testing.T) {
tests := []struct {
c Context
ws testify.BoolAssertionFunc
ws assert.BoolAssertionFunc
}{
{
&context{
@ -814,7 +762,7 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"websocket"}},
},
},
testify.True,
assert.True,
},
{
&context{
@ -822,13 +770,13 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"Websocket"}},
},
},
testify.True,
assert.True,
},
{
&context{
request: &http.Request{},
},
testify.False,
assert.False,
},
{
&context{
@ -836,7 +784,7 @@ func TestContext_IsWebSocket(t *testing.T) {
Header: http.Header{HeaderUpgrade: []string{"other"}},
},
},
testify.False,
assert.False,
},
}
@ -849,30 +797,14 @@ func TestContext_IsWebSocket(t *testing.T) {
func TestContext_Bind(t *testing.T) {
e := New()
req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
c := e.NewContext(req, nil)
u := new(user)
req.Header.Add(HeaderContentType, MIMEApplicationJSON)
err := c.Bind(u)
testify.NoError(t, err)
testify.Equal(t, &user{1, "Jon Snow"}, u)
}
func TestContext_Logger(t *testing.T) {
e := New()
c := e.NewContext(nil, nil)
log1 := c.Logger()
testify.NotNil(t, log1)
log2 := log.New("echo2")
c.SetLogger(log2)
testify.Equal(t, log2, c.Logger())
// Resetting the context returns the initial logger
c.Reset(nil, nil)
testify.Equal(t, log1, c.Logger())
assert.NoError(t, err)
assert.Equal(t, &user{1, "Jon Snow"}, u)
}
func TestContext_RealIP(t *testing.T) {
@ -925,6 +857,6 @@ func TestContext_RealIP(t *testing.T) {
}
for _, tt := range tests {
testify.Equal(t, tt.s, tt.c.RealIP())
assert.Equal(t, tt.s, tt.c.RealIP())
}
}

890
echo.go

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

16
go.mod
View File

@ -1,17 +1,13 @@
module github.com/labstack/echo/v4
go 1.15
go 1.16
require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/labstack/gommon v0.3.0
github.com/mattn/go-colorable v0.1.8 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/stretchr/testify v1.4.0
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang-jwt/jwt/v4 v4.0.0
github.com/stretchr/testify v1.7.0
github.com/valyala/fasttemplate v1.2.1
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5
golang.org/x/net v0.0.0-20210913180222-943fd674d43e
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)

48
go.sum
View File

@ -1,51 +1,29 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0=
github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o=
github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4=
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210913180222-943fd674d43e h1:+b/22bPvDYt4NPDcy4xAGCmON713ONAWFeY3Z7I3tR8=
golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 h1:xrCZDmdtoloIiooiA9q0OQb9r8HejIHYoHGhGCe1pGg=
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE=
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

150
group.go
View File

@ -4,95 +4,117 @@ import (
"net/http"
)
type (
// Group is a set of sub-routes for a specified route. It can be used for inner
// routes that share a common middleware or functionality that should be separate
// from the parent echo instance while still inheriting from it.
Group struct {
common
type Group struct {
host string
prefix string
middleware []MiddlewareFunc
echo *Echo
}
)
// Use implements `Echo#Use()` for sub-routes within the Group.
// Group middlewares are not executed on request when there is no matching route found.
func (g *Group) Use(middleware ...MiddlewareFunc) {
g.middleware = append(g.middleware, middleware...)
if len(g.middleware) == 0 {
return
}
// Allow all requests to reach the group as they might get dropped if router
// doesn't find a match, making none of the group middleware process.
g.Any("", NotFoundHandler)
g.Any("/*", NotFoundHandler)
}
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group.
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error.
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodConnect, path, h, m...)
}
// DELETE implements `Echo#DELETE()` for sub-routes within the Group.
func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error.
func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodDelete, path, h, m...)
}
// GET implements `Echo#GET()` for sub-routes within the Group.
func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error.
func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodGet, path, h, m...)
}
// HEAD implements `Echo#HEAD()` for sub-routes within the Group.
func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error.
func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodHead, path, h, m...)
}
// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group.
func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error.
func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodOptions, path, h, m...)
}
// PATCH implements `Echo#PATCH()` for sub-routes within the Group.
func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error.
func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPatch, path, h, m...)
}
// POST implements `Echo#POST()` for sub-routes within the Group.
func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error.
func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPost, path, h, m...)
}
// PUT implements `Echo#PUT()` for sub-routes within the Group.
func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error.
func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPut, path, h, m...)
}
// TRACE implements `Echo#TRACE()` for sub-routes within the Group.
func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error.
func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodTrace, path, h, m...)
}
// Any implements `Echo#Any()` for sub-routes within the Group.
func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
routes := make([]*Route, len(methods))
for i, m := range methods {
routes[i] = g.Add(m, path, handler, middleware...)
// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error.
func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
errs := make([]error, 0)
ris := make(Routes, 0)
for _, m := range methods {
ri, err := g.AddRoute(Route{
Method: m,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
errs = append(errs, err)
continue
}
return routes
ris = append(ris, ri)
}
if len(errs) > 0 {
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ris
}
// Match implements `Echo#Match()` for sub-routes within the Group.
func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
routes := make([]*Route, len(methods))
for i, m := range methods {
routes[i] = g.Add(m, path, handler, middleware...)
// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error.
func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
errs := make([]error, 0)
ris := make(Routes, 0)
for _, m := range methods {
ri, err := g.AddRoute(Route{
Method: m,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
errs = append(errs, err)
continue
}
return routes
ris = append(ris, ri)
}
if len(errs) > 0 {
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ris
}
// Group creates a new sub-group with prefix and optional sub-group-level middleware.
// Important! Group middlewares are only executed in case there was exact route match and not
// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add
// a catch-all route `/*` for the group which handler returns always 404
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
@ -102,23 +124,43 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
return
}
// Static implements `Echo#Static()` for sub-routes within the Group.
func (g *Group) Static(prefix, root string) {
g.static(prefix, root, g.GET)
// Static implements `Echo#Static()` for sub-routes within the Group. Panics on error.
func (g *Group) Static(prefix, root string, middleware ...MiddlewareFunc) RouteInfo {
return g.Add(
http.MethodGet,
prefix+"*",
StaticDirectoryHandler(root, false),
middleware...,
)
}
// File implements `Echo#File()` for sub-routes within the Group.
func (g *Group) File(path, file string) {
g.file(path, file, g.GET)
// File implements `Echo#File()` for sub-routes within the Group. Panics on error.
func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
handler := func(c Context) error {
return c.File(file)
}
return g.Add(http.MethodGet, path, handler, middleware...)
}
// Add implements `Echo#Add()` for sub-routes within the Group.
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
// Combine into a new slice to avoid accidentally passing the same slice for
// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
ri, err := g.AddRoute(Route{
Method: method,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ri
}
// AddRoute registers a new Routable with Router
func (g *Group) AddRoute(route Routable) (RouteInfo, error) {
// Combine middleware into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls.
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
m = append(m, middleware...)
return g.echo.add(g.host, method, g.prefix+path, handler, m...)
groupRoute := route.ForGroup(g.prefix, append([]MiddlewareFunc{}, g.middleware...))
return g.echo.add(g.host, groupRoute)
}

View File

@ -1,31 +1,68 @@
package echo
import (
"github.com/stretchr/testify/assert"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
// TODO: Fix me
func TestGroup(t *testing.T) {
g := New().Group("/group")
h := func(Context) error { return nil }
g.CONNECT("/", h)
g.DELETE("/", h)
g.GET("/", h)
g.HEAD("/", h)
g.OPTIONS("/", h)
g.PATCH("/", h)
g.POST("/", h)
g.PUT("/", h)
g.TRACE("/", h)
g.Any("/", h)
g.Match([]string{http.MethodGet, http.MethodPost}, "/", h)
g.Static("/static", "/tmp")
g.File("/walle", "_fixture/images//walle.png")
func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) {
e := New()
called := false
mw := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
called = true
return c.NoContent(http.StatusTeapot)
}
}
// even though group has middleware it will not be executed when there are no routes under that group
_ = e.Group("/group", mw)
status, body := request(http.MethodGet, "/group/nope", e)
assert.Equal(t, http.StatusNotFound, status)
assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
assert.False(t, called)
}
func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) {
e := New()
called := false
mw := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
called = true
return c.NoContent(http.StatusTeapot)
}
}
// even though group has middleware and routes when we have no match on some route the middlewares for that
// group will not be executed
g := e.Group("/group", mw)
g.GET("/yes", handlerFunc)
status, body := request(http.MethodGet, "/group/nope", e)
assert.Equal(t, http.StatusNotFound, status)
assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
assert.False(t, called)
}
func TestGroup_multiLevelGroup(t *testing.T) {
e := New()
api := e.Group("/api")
users := api.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
status, body := request(http.MethodGet, "/api/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroupFile(t *testing.T) {
@ -92,11 +129,11 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
}
m2 := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
return c.String(http.StatusOK, c.Path())
return c.String(http.StatusOK, c.RouteInfo().Path())
}
}
h := func(c Context) error {
return c.String(http.StatusOK, c.Path())
return c.String(http.StatusOK, c.RouteInfo().Path())
}
g.Use(m1)
g.GET("/help", h, m2)
@ -119,3 +156,442 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
assert.Equal(t, "/*", m)
}
func TestGroup_CONNECT(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.CONNECT("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodConnect, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodConnect, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_DELETE(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.DELETE("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodDelete, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodDelete, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_HEAD(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.HEAD("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodHead, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodHead+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodHead, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_OPTIONS(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.OPTIONS("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodOptions, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodOptions, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_PATCH(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.PATCH("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPatch, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPatch, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_POST(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.POST("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPost, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPost+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPost, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_PUT(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.PUT("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPut, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodPut+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodPut, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_TRACE(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.TRACE("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodTrace, ri.Method())
assert.Equal(t, "/users/activate", ri.Path())
assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name())
assert.Nil(t, ri.Params())
status, body := request(http.MethodTrace, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_Any(t *testing.T) {
e := New()
users := e.Group("/users")
ris := users.Any("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Len(t, ris, 11)
for _, m := range methods {
status, body := request(m, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_AnyWithErrors(t *testing.T) {
e := New()
users := e.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
errs := func() (errs []error) {
defer func() {
if r := recover(); r != nil {
if tmpErr, ok := r.([]error); ok {
errs = tmpErr
return
}
panic(r)
}
}()
users.Any("/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
return nil
}()
assert.Len(t, errs, 1)
assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
for _, m := range methods {
status, body := request(m, "/users/activate", e)
expect := http.StatusTeapot
if m == http.MethodGet {
expect = http.StatusOK
}
assert.Equal(t, expect, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_Match(t *testing.T) {
e := New()
myMethods := []string{http.MethodGet, http.MethodPost}
users := e.Group("/users")
ris := users.Match(myMethods, "/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Len(t, ris, 2)
for _, m := range myMethods {
status, body := request(m, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_MatchWithErrors(t *testing.T) {
e := New()
users := e.Group("/users")
users.GET("/activate", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
myMethods := []string{http.MethodGet, http.MethodPost}
errs := func() (errs []error) {
defer func() {
if r := recover(); r != nil {
if tmpErr, ok := r.([]error); ok {
errs = tmpErr
return
}
panic(r)
}
}()
users.Match(myMethods, "/activate", func(c Context) error {
return c.String(http.StatusTeapot, "OK")
})
return nil
}()
assert.Len(t, errs, 1)
assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
for _, m := range myMethods {
status, body := request(m, "/users/activate", e)
expect := http.StatusTeapot
if m == http.MethodGet {
expect = http.StatusOK
}
assert.Equal(t, expect, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_Static(t *testing.T) {
e := New()
g := e.Group("/books")
ri := g.Static("/download", "_fixture")
assert.Equal(t, http.MethodGet, ri.Method())
assert.Equal(t, "/books/download*", ri.Path())
assert.Equal(t, "GET:/books/download*", ri.Name())
assert.Equal(t, []string{"*"}, ri.Params())
req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
body := rec.Body.String()
assert.True(t, strings.HasPrefix(body, "<!doctype html>"))
}
func TestGroup_StaticMultiTest(t *testing.T) {
var testCases = []struct {
name string
givenPrefix string
givenRoot string
whenURL string
expectStatus int
expectHeaderLocation string
expectBodyStartsWith string
}{
{
name: "ok",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/test/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "ok, without prefix",
givenPrefix: "",
givenRoot: "_fixture/images",
whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png`
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "nok, without prefix does not serve dir index",
givenPrefix: "",
givenRoot: "_fixture/images",
whenURL: "/test/", // `/test` + `*` creates route `/test*`
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "No file",
givenPrefix: "/images",
givenRoot: "_fixture/scripts",
whenURL: "/test/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/test/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory Redirect",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory Redirect with non-root path",
givenPrefix: "/static",
givenRoot: "_fixture",
whenURL: "/test/static",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/static/",
expectBodyStartsWith: "",
},
{
name: "Prefixed directory 404 (request URL without slash)",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
givenRoot: "_fixture",
whenURL: "/test/folder", // no trailing slash
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Prefixed directory redirect (without slash redirect to slash)",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
givenRoot: "_fixture",
whenURL: "/test/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/",
givenRoot: "_fixture",
whenURL: "/test/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets",
givenRoot: "_fixture",
whenURL: "/test/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "Sub-directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
},
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/test/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/test/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
g := e.Group("/test")
g.Static(tc.givenPrefix, tc.givenRoot)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String()
if tc.expectBodyStartsWith != "" {
assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
} else {
assert.Equal(t, "", body)
}
if tc.expectHeaderLocation != "" {
assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
} else {
_, ok := rec.Result().Header["Location"]
assert.False(t, ok)
}
})
}
}

74
httperror.go Normal file
View File

@ -0,0 +1,74 @@
package echo
import (
"errors"
"fmt"
"net/http"
)
// Errors
var (
ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType)
ErrNotFound = NewHTTPError(http.StatusNotFound)
ErrUnauthorized = NewHTTPError(http.StatusUnauthorized)
ErrForbidden = NewHTTPError(http.StatusForbidden)
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests)
ErrBadRequest = NewHTTPError(http.StatusBadRequest)
ErrBadGateway = NewHTTPError(http.StatusBadGateway)
ErrInternalServerError = NewHTTPError(http.StatusInternalServerError)
ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout)
ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable)
ErrValidatorNotRegistered = errors.New("validator not registered")
ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
ErrInvalidListenerNetwork = errors.New("invalid listener network")
)
// HTTPError represents an error that occurred while handling a request.
type HTTPError struct {
Code int `json:"-"`
Message interface{} `json:"message"`
Internal error `json:"-"` // Stores the error returned by an external dependency
}
// NewHTTPError creates a new HTTPError instance.
func NewHTTPError(code int, message ...interface{}) *HTTPError { // FIXME: this need cleanup - why vararg if [0] is only used?
he := &HTTPError{Code: code, Message: http.StatusText(code)}
if len(message) > 0 {
he.Message = message[0]
}
return he
}
// NewHTTPErrorWithInternal creates a new HTTPError instance with internal error set.
func NewHTTPErrorWithInternal(code int, internalError error, message ...interface{}) *HTTPError {
he := NewHTTPError(code, message...)
he.Internal = internalError
return he
}
// Error makes it compatible with `error` interface.
func (he *HTTPError) Error() string {
if he.Internal == nil {
return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message)
}
return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal)
}
// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field
func (he *HTTPError) WithInternal(err error) *HTTPError {
return &HTTPError{
Code: he.Code,
Message: he.Message,
Internal: err,
}
}
// Unwrap satisfies the Go 1.13 error wrapper interface.
func (he *HTTPError) Unwrap() error {
return he.Internal
}

52
httperror_test.go Normal file
View File

@ -0,0 +1,52 @@
package echo
import (
"errors"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)
func TestHTTPError(t *testing.T) {
t.Run("non-internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
assert.Equal(t, "code=400, message=map[code:12]", err.Error())
})
t.Run("internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
err = err.WithInternal(errors.New("internal error"))
assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error())
})
}
func TestNewHTTPErrorWithInternal(t *testing.T) {
he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"), "test message")
assert.Equal(t, "code=400, message=test message, internal=test", he.Error())
}
func TestNewHTTPErrorWithInternal_noCustomMessage(t *testing.T) {
he := NewHTTPErrorWithInternal(http.StatusBadRequest, errors.New("test"))
assert.Equal(t, "code=400, message=Bad Request, internal=test", he.Error())
}
func TestHTTPError_Unwrap(t *testing.T) {
t.Run("non-internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
assert.Nil(t, errors.Unwrap(err))
})
t.Run("internal", func(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
"code": 12,
})
err = err.WithInternal(errors.New("internal error"))
assert.Equal(t, "internal error", errors.Unwrap(err).Error())
})
}

11
json.go
View File

@ -23,9 +23,16 @@ func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string
func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error {
err := json.NewDecoder(c.Request().Body).Decode(i)
if ute, ok := err.(*json.UnmarshalTypeError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err)
return NewHTTPErrorWithInternal(
http.StatusBadRequest,
err,
fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset),
)
} else if se, ok := err.(*json.SyntaxError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err)
return NewHTTPErrorWithInternal(http.StatusBadRequest,
err,
fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error()),
)
}
return err
}

168
log.go
View File

@ -1,41 +1,141 @@
package echo
import (
"bytes"
"io"
"github.com/labstack/gommon/log"
"strconv"
"sync"
"time"
)
type (
// Logger defines the logging interface.
Logger interface {
Output() io.Writer
SetOutput(w io.Writer)
Prefix() string
SetPrefix(p string)
Level() log.Lvl
SetLevel(v log.Lvl)
SetHeader(h string)
Print(i ...interface{})
Printf(format string, args ...interface{})
Printj(j log.JSON)
Debug(i ...interface{})
Debugf(format string, args ...interface{})
Debugj(j log.JSON)
Info(i ...interface{})
Infof(format string, args ...interface{})
Infoj(j log.JSON)
Warn(i ...interface{})
Warnf(format string, args ...interface{})
Warnj(j log.JSON)
Error(i ...interface{})
Errorf(format string, args ...interface{})
Errorj(j log.JSON)
Fatal(i ...interface{})
Fatalj(j log.JSON)
Fatalf(format string, args ...interface{})
Panic(i ...interface{})
Panicj(j log.JSON)
Panicf(format string, args ...interface{})
//-----------------------------------------------------------------------------
// Example for Zap (https://github.com/uber-go/zap)
//func main() {
// e := echo.New()
// logger, _ := zap.NewProduction()
// e.Logger = &ZapLogger{logger: logger}
//}
//type ZapLogger struct {
// logger *zap.Logger
//}
//
//func (l *ZapLogger) Write(p []byte) (n int, err error) {
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
// l.logger.Info(string(p), zap.String("subsystem", "echo")) // naively log everything as string message.
// return len(p), nil
//}
//
//func (l *ZapLogger) Error(err error) {
// l.logger.Error(err.Error(), zap.Error(err), zap.String("subsystem", "echo"))
//}
//-----------------------------------------------------------------------------
// Example for Zerolog (https://github.com/rs/zerolog)
//func main() {
// e := echo.New()
// logger := zerolog.New(os.Stdout)
// e.Logger = &ZeroLogger{logger: &logger}
//}
//
//type ZeroLogger struct {
// logger *zerolog.Logger
//}
//
//func (l *ZeroLogger) Write(p []byte) (n int, err error) {
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
// l.logger.Info().Str("subsystem", "echo").Msg(string(p)) // naively log everything as string message.
// return len(p), nil
//}
//
//func (l *ZeroLogger) Error(err error) {
// l.logger.Error().Str("subsystem", "echo").Err(err).Msg(err.Error())
//}
//-----------------------------------------------------------------------------
// Example for Logrus (https://github.com/sirupsen/logrus)
//func main() {
// e := echo.New()
// e.Logger = &LogrusLogger{logger: logrus.New()}
//}
//
//type LogrusLogger struct {
// logger *logrus.Logger
//}
//
//func (l *LogrusLogger) Write(p []byte) (n int, err error) {
// // Note: if `logger` middleware is used it will send json bytes here, and it will not look beautiful at all.
// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Info(string(p)) // naively log everything as string message.
// return len(p), nil
//}
//
//func (l *LogrusLogger) Error(err error) {
// l.logger.WithFields(logrus.Fields{"subsystem": "echo"}).Error(err)
//}
// Logger defines the logging interface that Echo uses internally in few places.
// For logging in handlers use your own logger instance (dependency injected or package/public variable) from logging framework of your choice.
type Logger interface {
// Write provides writer interface for http.Server `ErrorLog` and for logging startup messages.
// `http.Server.ErrorLog` logs errors from accepting connections, unexpected behavior from handlers,
// and underlying FileSystem errors.
// `logger` middleware will use this method to write its JSON payload.
Write(p []byte) (n int, err error)
// Error logs the error
Error(err error)
}
// jsonLogger is similar logger formatting implementation as `v4` had. It is not particularly fast or efficient. Only
// goal it to exist is to have somewhat backwards compatibility with `v4` for Echo internals logging formatting.
// It is not meant for logging in handlers/middlewares. Use some real logging library for those cases.
type jsonLogger struct {
writer io.Writer
bufferPool sync.Pool
timeNow func() time.Time
}
func newJSONLogger(writer io.Writer) *jsonLogger {
return &jsonLogger{
writer: writer,
bufferPool: sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 256))
},
},
timeNow: time.Now,
}
}
func (l *jsonLogger) Write(p []byte) (n int, err error) {
pLen := len(p)
if pLen >= 2 && // naively try to avoid JSON values to be wrapped into message
(p[0] == '{' && p[pLen-2] == '}' && p[pLen-1] == '\n') ||
(p[0] == '{' && p[pLen-1] == '}') {
return l.writer.Write(p)
}
// we log with WARN level as we have no idea what that message level should be. From Echo perspective this method is
// called when we pass Echo logger to http.Server.ErrorLog and there are problems inside http.Server - which probably
// deserves at least WARN level.
return l.printf("WARN", string(p))
}
func (l *jsonLogger) Error(err error) {
_, _ = l.printf("ERROR", err.Error())
}
func (l *jsonLogger) printf(level string, message string) (n int, err error) {
buf := l.bufferPool.Get().(*bytes.Buffer)
buf.Reset()
defer l.bufferPool.Put(buf)
buf.WriteString(`{"time":"`)
buf.WriteString(l.timeNow().Format(time.RFC3339Nano))
buf.WriteString(`","level":"`)
buf.WriteString(level)
buf.WriteString(`","prefix":"echo","message":`)
buf.WriteString(strconv.Quote(message))
buf.WriteString("}\n")
return l.writer.Write(buf.Bytes())
}
)

77
log_test.go Normal file
View File

@ -0,0 +1,77 @@
package echo
import (
"bytes"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestJsonLogger_Write(t *testing.T) {
var testCases = []struct {
name string
when []byte
expect string
}{
{
name: "ok, write non JSONlike message",
when: []byte("version: %v, build: %v"),
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: %v, build: %v"}` + "\n",
},
{
name: "ok, write quoted message",
when: []byte(`version: "%v"`),
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"WARN","prefix":"echo","message":"version: \"%v\""}` + "\n",
},
{
name: "ok, write JSON",
when: []byte(`{"version": 123}` + "\n"),
expect: `{"version": 123}` + "\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
logger := newJSONLogger(buf)
logger.timeNow = func() time.Time {
return time.Unix(1631045377, 0)
}
_, err := logger.Write(tc.when)
result := buf.String()
assert.Equal(t, tc.expect, result)
assert.NoError(t, err)
})
}
}
func TestJsonLogger_Error(t *testing.T) {
var testCases = []struct {
name string
whenError error
expect string
}{
{
name: "ok",
whenError: ErrForbidden,
expect: `{"time":"2021-09-07T23:09:37+03:00","level":"ERROR","prefix":"echo","message":"code=403, message=Forbidden"}` + "\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
logger := newJSONLogger(buf)
logger.timeNow = func() time.Time {
return time.Unix(1631045377, 0)
}
logger.Error(tc.whenError)
result := buf.String()
assert.Equal(t, tc.expect, result)
})
}
}

13
middleware/DEVELOPMENT.md Normal file
View File

@ -0,0 +1,13 @@
# Development Guidelines for middlewares
// FIXME: add info about `MiddlewareConfigurator` interface
## Best practices:
* Do not use `panic` in middleware creator functions in case of invalid configuration.
* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead
because previous middlewares up in call chain could have logic for dealing with returned errors.
* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they
want to create middleware with panics or with returning errors on configuration errors.
* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same.

View File

@ -1,64 +1,59 @@
package middleware
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"strconv"
"strings"
"github.com/labstack/echo/v4"
)
type (
// BasicAuthConfig defines the config for BasicAuth middleware.
BasicAuthConfig struct {
// BasicAuthConfig defines the config for BasicAuthWithConfig middleware.
type BasicAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Validator is a function to validate BasicAuth credentials.
// Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers
// this function would be called once for each header until first valid result is returned
// Required.
Validator BasicAuthValidator
// Realm is a string to define realm attribute of BasicAuth.
// Realm is a string to define realm attribute of BasicAuthWithConfig.
// Default value "Restricted".
Realm string
}
// BasicAuthValidator defines a function to validate BasicAuth credentials.
BasicAuthValidator func(string, string, echo.Context) (bool, error)
)
// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials.
type BasicAuthValidator func(c echo.Context, user string, password string) (bool, error)
const (
basic = "basic"
defaultRealm = "Restricted"
)
var (
// DefaultBasicAuthConfig is the default BasicAuth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{
Skipper: DefaultSkipper,
Realm: defaultRealm,
}
)
// BasicAuth returns an BasicAuth middleware.
//
// For valid credentials it calls the next handler.
// For missing or invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig
c.Validator = fn
return BasicAuthWithConfig(c)
return BasicAuthWithConfig(BasicAuthConfig{Validator: fn})
}
// BasicAuthWithConfig returns an BasicAuth middleware with config.
// See `BasicAuth()`.
// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration
func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Validator == nil {
panic("echo: basic-auth middleware requires a validator function")
return nil, errors.New("echo basic-auth middleware requires a validator function")
}
if config.Skipper == nil {
config.Skipper = DefaultBasicAuthConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.Realm == "" {
config.Realm = defaultRealm
@ -70,27 +65,31 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return next(c)
}
auth := c.Request().Header.Get(echo.HeaderAuthorization)
var lastError error
l := len(basic)
if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err != nil {
return err
for _, auth := range c.Request().Header[echo.HeaderAuthorization] {
if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
continue
}
cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
valid, err := config.Validator(cred[:i], cred[i+1:], c)
if err != nil {
return err
b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
if errDecode != nil {
lastError = fmt.Errorf("invalid basic auth value: %w", errDecode)
continue
}
idx := bytes.IndexByte(b, ':')
if idx >= 0 {
valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:]))
if errValidate != nil {
lastError = errValidate
} else if valid {
return next(c)
}
break
}
}
if lastError != nil {
return lastError
}
realm := defaultRealm
@ -102,5 +101,5 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
return echo.ErrUnauthorized
}
}
}, nil
}

View File

@ -2,6 +2,7 @@ package middleware
import (
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"strings"
@ -12,60 +13,146 @@ import (
)
func TestBasicAuth(t *testing.T) {
validatorFunc := func(c echo.Context, u, p string) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
}
if u == "error" {
return false, errors.New(p)
}
return false, nil
}
defaultConfig := BasicAuthConfig{Validator: validatorFunc}
var testCases = []struct {
name string
givenConfig BasicAuthConfig
whenAuth []string
expectHeader string
expectErr string
}{
{
name: "ok",
givenConfig: defaultConfig,
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, multiple",
givenConfig: defaultConfig,
whenAuth: []string{
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
basic + " NOT_BASE64",
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
},
},
{
name: "nok, invalid Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, not base64 Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
expectErr: "invalid basic auth value: illegal base64 data at input byte 3",
},
{
name: "nok, missing Authorization header",
givenConfig: defaultConfig,
expectHeader: basic + ` realm=Restricted`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "ok, realm",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, realm, case-insensitive header scheme",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "nok, realm, invalid Authorization header",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm="someRealm"`,
expectErr: "code=401, message=Unauthorized",
},
{
name: "nok, validator func returns an error",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
expectErr: "my_error",
},
{
name: "ok, skipped",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool {
return true
}},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
f := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" {
config := tc.givenConfig
mw, err := config.ToMiddleware()
assert.NoError(t, err)
h := mw(func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
if len(tc.whenAuth) != 0 {
for _, a := range tc.whenAuth {
req.Header.Add(echo.HeaderAuthorization, a)
}
}
err = h(c)
if tc.expectErr != "" {
assert.Equal(t, http.StatusOK, res.Code)
assert.EqualError(t, err, tc.expectErr)
} else {
assert.Equal(t, http.StatusTeapot, res.Code)
assert.NoError(t, err)
}
if tc.expectHeader != "" {
assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
}
})
}
}
func TestBasicAuth_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BasicAuth(nil)
assert.NotNil(t, mw)
})
mw := BasicAuth(func(c echo.Context, user string, password string) (bool, error) {
return true, nil
})
assert.NotNil(t, mw)
}
return false, nil
}
h := BasicAuth(f)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
func TestBasicAuthWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
assert.NotNil(t, mw)
})
assert := assert.New(t)
// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
h = BasicAuthWithConfig(BasicAuthConfig{
Skipper: nil,
Validator: f,
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
// Invalid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c echo.Context, user string, password string) (bool, error) {
return true, nil
}})
assert.NotNil(t, mw)
}

View File

@ -3,6 +3,7 @@ package middleware
import (
"bufio"
"bytes"
"errors"
"io"
"io/ioutil"
"net"
@ -11,9 +12,8 @@ import (
"github.com/labstack/echo/v4"
)
type (
// BodyDumpConfig defines the config for BodyDump middleware.
BodyDumpConfig struct {
type BodyDumpConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -23,51 +23,45 @@ type (
}
// BodyDumpHandler receives the request and response payload.
BodyDumpHandler func(echo.Context, []byte, []byte)
type BodyDumpHandler func(c echo.Context, reqBody []byte, resBody []byte)
bodyDumpResponseWriter struct {
type bodyDumpResponseWriter struct {
io.Writer
http.ResponseWriter
}
)
var (
// DefaultBodyDumpConfig is the default BodyDump middleware config.
DefaultBodyDumpConfig = BodyDumpConfig{
Skipper: DefaultSkipper,
}
)
// BodyDump returns a BodyDump middleware.
//
// BodyDump middleware captures the request and response payload and calls the
// registered handler.
func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
c := DefaultBodyDumpConfig
c.Handler = handler
return BodyDumpWithConfig(c)
return BodyDumpWithConfig(BodyDumpConfig{Handler: handler})
}
// BodyDumpWithConfig returns a BodyDump middleware with config.
// See: `BodyDump()`.
func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration
func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Handler == nil {
panic("echo: body-dump middleware requires a handler function")
return nil, errors.New("echo body-dump middleware requires a handler function")
}
if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper
config.Skipper = DefaultSkipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
// Request
reqBody := []byte{}
if c.Request().Body != nil { // Read
if c.Request().Body != nil {
reqBody, _ = ioutil.ReadAll(c.Request().Body)
}
c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset
@ -78,16 +72,14 @@ func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer}
c.Response().Writer = writer
if err = next(c); err != nil {
c.Error(err)
}
err := next(c)
// Callback
config.Handler(c, reqBody, resBody.Bytes())
return
}
return err
}
}, nil
}
func (w *bodyDumpResponseWriter) WriteHeader(code int) {

View File

@ -28,31 +28,48 @@ func TestBodyDump(t *testing.T) {
requestBody := ""
responseBody := ""
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
})
}}.ToMiddleware()
assert.NoError(t, err)
assert := assert.New(t)
if assert.NoError(mw(h)(c)) {
assert.Equal(requestBody, hw)
assert.Equal(responseBody, hw)
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.String())
if assert.NoError(t, mw(h)(c)) {
assert.Equal(t, requestBody, hw)
assert.Equal(t, responseBody, hw)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.String())
}
// Must set default skipper
BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: func(c echo.Context, reqBody, resBody []byte) {
requestBody = string(reqBody)
responseBody = string(resBody)
}
func TestBodyDump_skipper(t *testing.T) {
e := echo.New()
isCalled := false
mw, err := BodyDumpConfig{
Skipper: func(c echo.Context) bool {
return true
},
})
Handler: func(c echo.Context, reqBody, resBody []byte) {
isCalled = true
},
}.ToMiddleware()
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
return errors.New("some error")
}
func TestBodyDumpFails(t *testing.T) {
err = mw(h)(c)
assert.EqualError(t, err, "some error")
assert.False(t, isCalled)
}
func TestBodyDump_fails(t *testing.T) {
e := echo.New()
hw := "Hello, World!"
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
@ -62,30 +79,37 @@ func TestBodyDumpFails(t *testing.T) {
return errors.New("some error")
}
mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
mw, err := BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.EqualError(t, err, "some error")
assert.Equal(t, http.StatusOK, rec.Code)
if !assert.Error(t, mw(h)(c)) {
t.FailNow()
}
func TestBodyDumpWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{
mw := BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: nil,
})
assert.NotNil(t, mw)
})
assert.NotPanics(t, func() {
mw = BodyDumpWithConfig(BodyDumpConfig{
Skipper: func(c echo.Context) bool {
return true
},
Handler: func(c echo.Context, reqBody, resBody []byte) {
},
mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c echo.Context, reqBody, resBody []byte) {}})
assert.NotNil(t, mw)
})
}
func TestBodyDump_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BodyDump(nil)
assert.NotNil(t, mw)
})
if !assert.Error(t, mw(h)(c)) {
t.FailNow()
}
assert.NotPanics(t, func() {
BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
})
}

View File

@ -1,98 +1,83 @@
package middleware
import (
"fmt"
"io"
"sync"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/bytes"
)
type (
// BodyLimitConfig defines the config for BodyLimit middleware.
BodyLimitConfig struct {
// BodyLimitConfig defines the config for BodyLimitWithConfig middleware.
type BodyLimitConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Maximum allowed size for a request body, it can be specified
// as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P.
Limit string `yaml:"limit"`
limit int64
// LimitBytes is maximum allowed size in bytes for a request body
LimitBytes int64
}
limitedReader struct {
type limitedReader struct {
BodyLimitConfig
reader io.ReadCloser
read int64
context echo.Context
}
)
var (
// DefaultBodyLimitConfig is the default BodyLimit middleware config.
DefaultBodyLimitConfig = BodyLimitConfig{
Skipper: DefaultSkipper,
}
)
// BodyLimit returns a BodyLimit middleware.
//
// BodyLimit middleware sets the maximum allowed size for a request body, if the
// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
// response. The BodyLimit is determined based on both `Content-Length` request
// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it
// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request
// header and actual content read, which makes it super secure.
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
// G, T or P.
func BodyLimit(limit string) echo.MiddlewareFunc {
c := DefaultBodyLimitConfig
c.Limit = limit
return BodyLimitWithConfig(c)
func BodyLimit(limitBytes int64) echo.MiddlewareFunc {
return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes})
}
// BodyLimitWithConfig returns a BodyLimit middleware with config.
// See: `BodyLimit()`.
// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for
// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response.
// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which
// makes it super secure.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultBodyLimitConfig.Skipper
return toMiddlewareOrPanic(config)
}
limit, err := bytes.Parse(config.Limit)
if err != nil {
panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit))
// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration
func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
pool := sync.Pool{
New: func() interface{} {
return &limitedReader{BodyLimitConfig: config}
},
}
config.limit = limit
pool := limitedReaderPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
// Based on content length
if req.ContentLength > config.limit {
if req.ContentLength > config.LimitBytes {
return echo.ErrStatusRequestEntityTooLarge
}
// Based on content read
r := pool.Get().(*limitedReader)
r.Reset(req.Body, c)
r.Reset(c, req.Body)
defer pool.Put(r)
req.Body = r
return next(c)
}
}
}, nil
}
func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b)
r.read += int64(n)
if r.read > r.limit {
if r.read > r.LimitBytes {
return n, echo.ErrStatusRequestEntityTooLarge
}
return
@ -102,16 +87,8 @@ func (r *limitedReader) Close() error {
return r.reader.Close()
}
func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) {
func (r *limitedReader) Reset(context echo.Context, reader io.ReadCloser) {
r.reader = reader
r.context = context
r.read = 0
}
func limitedReaderPool(c BodyLimitConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
return &limitedReader{BodyLimitConfig: c}
},
}
}

View File

@ -11,6 +11,137 @@ import (
"github.com/stretchr/testify/assert"
)
func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
// Based on content length (within limit)
mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
// Based on content read (overlimit)
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
assert.NoError(t, err)
he := mw(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Hello, World!", rec.Body.String())
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
assert.NoError(t, err)
he = mw(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
}
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
config := BodyLimitConfig{
Skipper: DefaultSkipper,
LimitBytes: 2,
}
reader := &limitedReader{
BodyLimitConfig: config,
reader: ioutil.NopCloser(bytes.NewReader(hw)),
context: e.NewContext(req, rec),
}
// read all should return ErrStatusRequestEntityTooLarge
_, err := ioutil.ReadAll(reader)
he := err.(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
reader.Reset(e.NewContext(req, rec), ioutil.NopCloser(bytes.NewReader(hw)))
n, err := reader.Read(bt)
assert.Equal(t, 2, n)
assert.Equal(t, nil, err)
}
func TestBodyLimit_skipper(t *testing.T) {
e := echo.New()
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
mw, err := BodyLimitConfig{
Skipper: func(c echo.Context) bool {
return true
},
LimitBytes: 2,
}.ToMiddleware()
assert.NoError(t, err)
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
func TestBodyLimitWithConfig(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB})
err := mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
func TestBodyLimit(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
@ -25,61 +156,10 @@ func TestBodyLimit(t *testing.T) {
return c.String(http.StatusOK, string(body))
}
assert := assert.New(t)
mw := BodyLimit(2 * MB)
// Based on content length (within limit)
if assert.NoError(BodyLimit("2M")(h)(c)) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.Bytes())
}
// Based on content read (overlimit)
he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
if assert.NoError(BodyLimit("2M")(h)(c)) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal("Hello, World!", rec.Body.String())
}
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
}
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
config := BodyLimitConfig{
Skipper: DefaultSkipper,
Limit: "2B",
limit: 2,
}
reader := &limitedReader{
BodyLimitConfig: config,
reader: ioutil.NopCloser(bytes.NewReader(hw)),
context: e.NewContext(req, rec),
}
// read all should return ErrStatusRequestEntityTooLarge
_, err := ioutil.ReadAll(reader)
he := err.(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec))
n, err := reader.Read(bt)
assert.Equal(t, 2, n)
assert.Equal(t, nil, err)
err := mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}

View File

@ -3,6 +3,7 @@ package middleware
import (
"bufio"
"compress/gzip"
"errors"
"io"
"io/ioutil"
"net"
@ -13,50 +14,45 @@ import (
"github.com/labstack/echo/v4"
)
type (
const (
gzipScheme = "gzip"
)
// GzipConfig defines the config for Gzip middleware.
GzipConfig struct {
type GzipConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Gzip compression level.
// Optional. Default value -1.
Level int `yaml:"level"`
Level int
}
gzipResponseWriter struct {
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
)
const (
gzipScheme = "gzip"
)
var (
// DefaultGzipConfig is the default Gzip middleware config.
DefaultGzipConfig = GzipConfig{
Skipper: DefaultSkipper,
Level: -1,
}
)
// Gzip returns a middleware which compresses HTTP response using gzip compression
// scheme.
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
func Gzip() echo.MiddlewareFunc {
return GzipWithConfig(DefaultGzipConfig)
return GzipWithConfig(GzipConfig{})
}
// GzipWithConfig return Gzip middleware with config.
// See: `Gzip()`.
// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration
func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
return nil, errors.New("invalid gzip level")
}
if config.Level == 0 {
config.Level = DefaultGzipConfig.Level
config.Level = -1
}
pool := gzipCompressPool(config)
@ -97,7 +93,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
}
return next(c)
}
}
}, nil
}
func (w *gzipResponseWriter) WriteHeader(code int) {

View File

@ -3,94 +3,128 @@ package middleware
import (
"bytes"
"compress/gzip"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestGzip(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
func TestGzip_NoAcceptEncodingHeader(t *testing.T) {
// Skip if no Accept-Encoding header
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
assert.Equal("test", rec.Body.String())
err := h(c)
assert.NoError(t, err)
// Gzip
req = httptest.NewRequest(http.MethodGet, "/", nil)
assert.Equal(t, "test", rec.Body.String())
}
func TestMustGzipWithConfig_panics(t *testing.T) {
assert.Panics(t, func() {
GzipWithConfig(GzipConfig{Level: 999})
})
}
func TestGzip_AcceptEncodingHeader(t *testing.T) {
h := Gzip()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
r, err := gzip.NewReader(rec.Body)
if assert.NoError(err) {
assert.NoError(t, err)
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal("test", buf.String())
assert.Equal(t, "test", buf.String())
}
chunkBuf := make([]byte, 5)
// Gzip chunked
req = httptest.NewRequest(http.MethodGet, "/", nil)
func TestGzip_chunked(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec = httptest.NewRecorder()
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c = e.NewContext(req, rec)
Gzip()(func(c echo.Context) error {
chunkChan := make(chan struct{})
waitChan := make(chan struct{})
h := Gzip()(func(c echo.Context) error {
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Transfer-Encoding", "chunked")
// Write and flush the first part of the data
c.Response().Write([]byte("test\n"))
c.Response().Write([]byte("first\n"))
c.Response().Flush()
// Read the first part of the data
assert.True(rec.Flushed)
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r.Reset(rec.Body)
_, err = io.ReadFull(r, chunkBuf)
assert.NoError(err)
assert.Equal("test\n", string(chunkBuf))
chunkChan <- struct{}{}
<-waitChan
// Write and flush the second part of the data
c.Response().Write([]byte("test\n"))
c.Response().Write([]byte("second\n"))
c.Response().Flush()
_, err = io.ReadFull(r, chunkBuf)
assert.NoError(err)
assert.Equal("test\n", string(chunkBuf))
chunkChan <- struct{}{}
<-waitChan
// Write the final part of the data and return
c.Response().Write([]byte("test"))
return nil
})(c)
c.Response().Write([]byte("third"))
chunkChan <- struct{}{}
return nil
})
go func() {
err := h(c)
chunkChan <- struct{}{}
assert.NoError(t, err)
}()
<-chunkChan // wait for first write
waitChan <- struct{}{}
<-chunkChan // wait for second write
waitChan <- struct{}{}
<-chunkChan // wait for final write in handler
<-chunkChan // wait for return from handler
time.Sleep(5 * time.Millisecond) // to have time for flushing
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r, err := gzip.NewReader(rec.Body)
assert.NoError(t, err)
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal("test", buf.String())
assert.Equal(t, "first\nsecond\nthird", buf.String())
}
func TestGzipNoContent(t *testing.T) {
func TestGzip_NoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
@ -106,7 +140,7 @@ func TestGzipNoContent(t *testing.T) {
}
}
func TestGzipErrorReturned(t *testing.T) {
func TestGzip_ErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Gzip())
e.GET("/", func(c echo.Context) error {
@ -120,31 +154,25 @@ func TestGzipErrorReturned(t *testing.T) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
func TestGzipErrorReturnedInvalidConfig(t *testing.T) {
e := echo.New()
// Invalid level
e.Use(GzipWithConfig(GzipConfig{Level: 12}))
e.GET("/", func(c echo.Context) error {
c.Response().Write([]byte("test"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "gzip")
func TestGzipWithConfig_invalidLevel(t *testing.T) {
mw, err := GzipConfig{Level: 12}.ToMiddleware()
assert.EqualError(t, err, "invalid gzip level")
assert.Nil(t, mw)
}
// Issue #806
func TestGzipWithStatic(t *testing.T) {
e := echo.New()
e.Filesystem = os.DirFS("../")
e.Use(Gzip())
e.Static("/test", "../_fixture/images")
e.Static("/test", "_fixture/images")
req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Data is written out in chunks when Content-Length == "", so only
// validate the content length if it's not set.

View File

@ -9,60 +9,56 @@ import (
"github.com/labstack/echo/v4"
)
type (
// CORSConfig defines the config for CORS middleware.
CORSConfig struct {
type CORSConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// AllowOrigin defines a list of origins that may access the resource.
// Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"`
AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
AllowOriginFunc func(origin string) (bool, error)
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods.
AllowMethods []string `yaml:"allow_methods"`
AllowMethods []string
// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
// Optional. Default value []string{}.
AllowHeaders []string `yaml:"allow_headers"`
AllowHeaders []string
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials.
// Optional. Default value false.
AllowCredentials bool `yaml:"allow_credentials"`
AllowCredentials bool
// ExposeHeaders defines a whitelist headers that clients are allowed to
// access.
// Optional. Default value []string{}.
ExposeHeaders []string `yaml:"expose_headers"`
ExposeHeaders []string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
// Optional. Default value 0.
MaxAge int `yaml:"max_age"`
MaxAge int
}
)
var (
// DefaultCORSConfig is the default CORS middleware config.
DefaultCORSConfig = CORSConfig{
var DefaultCORSConfig = CORSConfig{
Skipper: DefaultSkipper,
AllowOrigins: []string{"*"},
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}
)
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
@ -70,9 +66,14 @@ func CORS() echo.MiddlewareFunc {
return CORSWithConfig(DefaultCORSConfig)
}
// CORSWithConfig returns a CORS middleware with config.
// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
// See: `CORS()`.
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration
func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCORSConfig.Skipper
@ -207,5 +208,5 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
}
return c.NoContent(http.StatusNoContent)
}
}
}, nil
}

View File

@ -17,7 +17,7 @@ func TestCORS(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := CORS()(echo.NotFoundHandler)
h := CORS()(func(c echo.Context) error { return echo.ErrNotFound })
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -26,7 +26,7 @@ func TestCORS(t *testing.T) {
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORS()(echo.NotFoundHandler)
h = CORS()(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
@ -38,7 +38,7 @@ func TestCORS(t *testing.T) {
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})(echo.NotFoundHandler)
})(func(c echo.Context) error { return echo.ErrNotFound })
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -55,7 +55,7 @@ func TestCORS(t *testing.T) {
AllowCredentials: true,
MaxAge: 3600,
})
h = cors(echo.NotFoundHandler)
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
@ -73,7 +73,7 @@ func TestCORS(t *testing.T) {
AllowCredentials: true,
MaxAge: 3600,
})
h = cors(echo.NotFoundHandler)
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
@ -90,7 +90,7 @@ func TestCORS(t *testing.T) {
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
})
h = cors(echo.NotFoundHandler)
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
@ -104,7 +104,7 @@ func TestCORS(t *testing.T) {
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"http://*.example.com"},
})
h = cors(echo.NotFoundHandler)
h = cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -149,7 +149,7 @@ func Test_allowOriginScheme(t *testing.T) {
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern},
})
h := cors(echo.NotFoundHandler)
h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
if tt.expected {
@ -240,7 +240,7 @@ func Test_allowOriginSubdomain(t *testing.T) {
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern},
})
h := cors(echo.NotFoundHandler)
h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
if tt.expected {
@ -331,7 +331,7 @@ func TestCorsHeaders(t *testing.T) {
//AllowCredentials: true,
//MaxAge: 3600,
})
h := cors(echo.NotFoundHandler)
h := cors(func(c echo.Context) error { return echo.ErrNotFound })
h(c)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
@ -387,11 +387,11 @@ func Test_allowOriginFunc(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
cors := CORSWithConfig(CORSConfig{
AllowOriginFunc: allowOriginFunc,
})
h := cors(echo.NotFoundHandler)
err := h(c)
cors, err := CORSConfig{AllowOriginFunc: allowOriginFunc}.ToMiddleware()
assert.NoError(t, err)
h := cors(func(c echo.Context) error { return echo.ErrNotFound })
err = h(c)
expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil {

View File

@ -8,19 +8,21 @@ import (
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
)
type (
// CSRFConfig defines the config for CSRF middleware.
CSRFConfig struct {
type CSRFConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// TokenLength is the length of the generated token.
TokenLength uint8 `yaml:"token_length"`
TokenLength uint8
// Optional. Default value 32.
// Generator defines a function to generate token.
// Optional. Defaults tp randomString(TokenLength).
Generator func() string
// TokenLookup is a string in the form of "<source>:<key>" that is used
// to extract token from the request.
// Optional. Default value "header:X-CSRF-Token".
@ -28,49 +30,46 @@ type (
// - "header:<name>"
// - "form:<name>"
// - "query:<name>"
TokenLookup string `yaml:"token_lookup"`
TokenLookup string
// Context key to store generated CSRF token into context.
// Optional. Default value "csrf".
ContextKey string `yaml:"context_key"`
ContextKey string
// Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf".
CookieName string `yaml:"cookie_name"`
CookieName string
// Domain of the CSRF cookie.
// Optional. Default value none.
CookieDomain string `yaml:"cookie_domain"`
CookieDomain string
// Path of the CSRF cookie.
// Optional. Default value none.
CookiePath string `yaml:"cookie_path"`
CookiePath string
// Max age (in seconds) of the CSRF cookie.
// Optional. Default value 86400 (24hr).
CookieMaxAge int `yaml:"cookie_max_age"`
CookieMaxAge int
// Indicates if CSRF cookie is secure.
// Optional. Default value false.
CookieSecure bool `yaml:"cookie_secure"`
CookieSecure bool
// Indicates if CSRF cookie is HTTP only.
// Optional. Default value false.
CookieHTTPOnly bool `yaml:"cookie_http_only"`
CookieHTTPOnly bool
// Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
CookieSameSite http.SameSite
}
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
// either a token or an error.
csrfTokenExtractor func(echo.Context) (string, error)
)
// csrfTokenExtractor defines a function that takes `echo.Context` and returns either a token or an error.
type csrfTokenExtractor func(echo.Context) (string, error)
var (
// DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{
var DefaultCSRFConfig = CSRFConfig{
Skipper: DefaultSkipper,
TokenLength: 32,
TokenLookup: "header:" + echo.HeaderXCSRFToken,
@ -79,18 +78,20 @@ var (
CookieMaxAge: 86400,
CookieSameSite: http.SameSiteDefaultMode,
}
)
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
func CSRF() echo.MiddlewareFunc {
c := DefaultCSRFConfig
return CSRFWithConfig(c)
return CSRFWithConfig(DefaultCSRFConfig)
}
// CSRFWithConfig returns a CSRF middleware with config.
// See `CSRF()`.
// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration
func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCSRFConfig.Skipper
@ -98,6 +99,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength
}
if config.Generator == nil {
config.Generator = createRandomStringGenerator(config.TokenLength)
}
if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup
}
@ -136,7 +140,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
// Generate token
if err != nil {
token = random.String(config.TokenLength)
token = config.Generator()
} else {
// Reuse token
token = k.Value
@ -181,7 +185,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return next(c)
}
}
}, nil
}
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the

View File

@ -9,11 +9,26 @@ import (
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/stretchr/testify/assert"
)
func TestCSRF(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRF()
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Generate CSRF token
h(c)
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
}
func TestMustCSRFWithConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
@ -43,7 +58,7 @@ func TestCSRF(t *testing.T) {
assert.Error(t, h(c))
// Valid CSRF token
token := random.String(16)
token := randomString(16)
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
req.Header.Set(echo.HeaderXCSRFToken, token)
if assert.NoError(t, h(c)) {
@ -145,9 +160,10 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
csrf, err := CSRFConfig{
CookieSameSite: http.SameSiteNoneMode,
})
}.ToMiddleware()
assert.NoError(t, err)
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")

View File

@ -11,16 +11,14 @@ import (
"github.com/labstack/echo/v4"
)
type (
// DecompressConfig defines the config for Decompress middleware.
DecompressConfig struct {
type DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
}
)
// GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"
@ -30,14 +28,6 @@ type Decompressor interface {
gzipDecompressPool() sync.Pool
}
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
)
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}
@ -67,17 +57,21 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
// Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig)
return DecompressWithConfig(DecompressConfig{})
}
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration.
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration
func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
config.GzipDecompressPool = &DefaultGzipDecompressPool{}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -116,5 +110,5 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
}
return next(c)
}
}
}, nil
}

View File

@ -17,6 +17,31 @@ import (
func TestDecompress(t *testing.T) {
e := echo.New()
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
// Decompress request body
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}
func TestDecompress_skippedIfNoHeader(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@ -26,39 +51,42 @@ func TestDecompress(t *testing.T) {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
err := h(c)
assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}
func TestDecompressDefaultConfig(t *testing.T) {
func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
h, err := DecompressConfig{}.ToMiddleware()
assert.NoError(t, err)
err = h(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})(c)
assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
}
func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Decompress()(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)
assert := assert.New(t)
assert.Equal("test", rec.Body.String())
// Decompress
body := `{"name": "echo"}`
@ -67,11 +95,14 @@ func TestDecompressDefaultConfig(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
err := h(c)
assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
@ -82,7 +113,9 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(t, err)
@ -99,7 +132,10 @@ func TestDecompressNoContent(t *testing.T) {
h := Decompress()(func(c echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
if assert.NoError(t, h(c)) {
err := h(c)
if assert.NoError(t, err) {
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes()))
@ -115,7 +151,9 @@ func TestDecompressErrorReturned(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
@ -132,7 +170,9 @@ func TestDecompressSkipper(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8)
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
@ -161,7 +201,9 @@ func TestDecompressPoolError(t *testing.T) {
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)

148
middleware/extractor.go Normal file
View File

@ -0,0 +1,148 @@
package middleware
import (
"fmt"
"github.com/labstack/echo/v4"
"net/http"
"net/textproto"
"strings"
)
// ErrExtractionValueMissing denotes an error raised when value could not be extracted from request
var ErrExtractionValueMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed value")
// ExtractorType is enum type for where extractor will take its data
type ExtractorType string
const (
// HeaderExtractor tells extractor to take values from request header
HeaderExtractor ExtractorType = "header"
// QueryExtractor tells extractor to take values from request query parameters
QueryExtractor ExtractorType = "query"
// ParamExtractor tells extractor to take values from request route parameters
ParamExtractor ExtractorType = "param"
// CookieExtractor tells extractor to take values from request cookie
CookieExtractor ExtractorType = "cookie"
// FormExtractor tells extractor to take values from request form fields
FormExtractor ExtractorType = "form"
)
func createExtractors(lookups string) ([]valuesExtractor, error) {
sources := strings.Split(lookups, ",")
var extractors []valuesExtractor
for _, source := range sources {
parts := strings.Split(source, ":")
if len(parts) < 2 {
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
}
switch ExtractorType(parts[0]) {
case QueryExtractor:
extractors = append(extractors, valuesFromQuery(parts[1]))
case ParamExtractor:
extractors = append(extractors, valuesFromParam(parts[1]))
case CookieExtractor:
extractors = append(extractors, valuesFromCookie(parts[1]))
case FormExtractor:
extractors = append(extractors, valuesFromForm(parts[1]))
case HeaderExtractor:
prefix := ""
if len(parts) > 2 {
prefix = parts[2]
}
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
}
}
return extractors, nil
}
// valuesFromHeader returns a functions that extracts values from the request header.
func valuesFromHeader(header string, valuePrefix string) valuesExtractor {
prefixLen := len(valuePrefix)
return func(c echo.Context) ([]string, ExtractorType, error) {
values := textproto.MIMEHeader(c.Request().Header).Values(header)
if len(values) == 0 {
return nil, HeaderExtractor, ErrExtractionValueMissing
}
result := make([]string, 0)
for _, value := range values {
if prefixLen == 0 {
result = append(result, value)
continue
}
if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
result = append(result, value[prefixLen:])
}
}
if len(result) == 0 {
return nil, HeaderExtractor, ErrExtractionValueMissing
}
return result, HeaderExtractor, nil
}
}
// valuesFromQuery returns a function that extracts values from the query string.
func valuesFromQuery(param string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
result := c.QueryParams()[param]
if len(result) == 0 {
return nil, QueryExtractor, ErrExtractionValueMissing
}
return result, QueryExtractor, nil
}
}
// valuesFromParam returns a function that extracts values from the url param string.
func valuesFromParam(param string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
result := make([]string, 0)
for _, p := range c.PathParams() {
if param == p.Name {
result = append(result, p.Value)
}
}
if len(result) == 0 {
return nil, ParamExtractor, ErrExtractionValueMissing
}
return result, ParamExtractor, nil
}
}
// valuesFromCookie returns a function that extracts values from the named cookie.
func valuesFromCookie(name string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
cookies := c.Cookies()
if len(cookies) == 0 {
return nil, CookieExtractor, ErrExtractionValueMissing
}
result := make([]string, 0)
for _, cookie := range cookies {
if name == cookie.Name {
result = append(result, cookie.Value)
}
}
if len(result) == 0 {
return nil, CookieExtractor, ErrExtractionValueMissing
}
return result, CookieExtractor, nil
}
}
// valuesFromForm returns a function that extracts values from the form field.
func valuesFromForm(name string) valuesExtractor {
return func(c echo.Context) ([]string, ExtractorType, error) {
if err := c.Request().ParseForm(); err != nil {
return nil, FormExtractor, fmt.Errorf("valuesFromForm parse form failed: %w", err)
}
values := c.Request().Form[name]
if len(values) == 0 {
return nil, FormExtractor, ErrExtractionValueMissing
}
result := append([]string{}, values...)
return result, FormExtractor, nil
}
}

View File

@ -0,0 +1,498 @@
package middleware
import (
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestCreateExtractors(t *testing.T) {
var testCases = []struct {
name string
givenRequest func() *http.Request
givenPathParams echo.PathParams
whenLoopups string
expectValues []string
expectExtractorType ExtractorType
expectCreateError string
expectError string
}{
{
name: "ok, header",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer token")
return req
},
whenLoopups: "header:Authorization:Bearer ",
expectValues: []string{"token"},
expectExtractorType: HeaderExtractor,
},
{
name: "ok, form",
givenRequest: func() *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenLoopups: "form:name",
expectValues: []string{"Jon Snow"},
expectExtractorType: FormExtractor,
},
{
name: "ok, cookie",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderCookie, "_csrf=token")
return req
},
whenLoopups: "cookie:_csrf",
expectValues: []string{"token"},
expectExtractorType: CookieExtractor,
},
{
name: "ok, param",
givenPathParams: echo.PathParams{
{Name: "id", Value: "123"},
},
whenLoopups: "param:id",
expectValues: []string{"123"},
expectExtractorType: ParamExtractor,
},
{
name: "ok, query",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
return req
},
whenLoopups: "query:id",
expectValues: []string{"999"},
expectExtractorType: QueryExtractor,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
req = tc.givenRequest()
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(echo.EditableContext)
if tc.givenPathParams != nil {
c.SetRawPathParams(&tc.givenPathParams)
}
extractors, err := createExtractors(tc.whenLoopups)
if tc.expectCreateError != "" {
assert.EqualError(t, err, tc.expectCreateError)
return
}
assert.NoError(t, err)
for _, e := range extractors {
values, eType, eErr := e(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, tc.expectExtractorType, eType)
if tc.expectError != "" {
assert.EqualError(t, eErr, tc.expectError)
return
}
assert.NoError(t, eErr)
}
})
}
}
func TestValuesFromHeader(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
whenValuePrefix string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, single value, case insensitive",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
},
{
name: "ok, empty prefix",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "",
expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="},
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Bearer ",
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderWWWAuthenticate,
whenValuePrefix: "",
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no headers",
givenRequest: nil,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, HeaderExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromQuery(t *testing.T) {
var testCases = []struct {
name string
givenQueryPart string
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenQueryPart: "?id=123&name=test",
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenQueryPart: "?id=123&id=456&name=test",
whenName: "id",
expectValues: []string{"123", "456"},
},
{
name: "nok, missing value",
givenQueryPart: "?id=123&name=test",
whenName: "nope",
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromQuery(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, QueryExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromParam(t *testing.T) {
examplePathParams := echo.PathParams{
{Name: "id", Value: "123"},
{Name: "gid", Value: "456"},
{Name: "gid", Value: "789"},
}
var testCases = []struct {
name string
givenPathParams echo.PathParams
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenPathParams: examplePathParams,
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenPathParams: examplePathParams,
whenName: "gid",
expectValues: []string{"456", "789"},
},
{
name: "nok, no values",
givenPathParams: nil,
whenName: "nope",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no matching value",
givenPathParams: examplePathParams,
whenName: "nope",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(echo.EditableContext)
if tc.givenPathParams != nil {
c.SetRawPathParams(&tc.givenPathParams)
}
extractor := valuesFromParam(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ParamExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromCookie(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderCookie, "_csrf=token")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: "_csrf",
expectValues: []string{"token"},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Add(echo.HeaderCookie, "_csrf=token")
req.Header.Add(echo.HeaderCookie, "_csrf=token2")
},
whenName: "_csrf",
expectValues: []string{"token", "token2"},
},
{
name: "nok, no matching cookie",
givenRequest: exampleRequest,
whenName: "xxx",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, no cookies at all",
givenRequest: nil,
whenName: "xxx",
expectValues: nil,
expectError: ErrExtractionValueMissing.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromCookie(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, CookieExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromForm(t *testing.T) {
examplePostFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
}
exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil)
return req
}
var testCases = []struct {
name string
givenRequest *http.Request
whenName string
expectValues []string
expectError string
}{
{
name: "ok, POST form, single value",
givenRequest: examplePostFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, POST form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "ok, GET form, single value",
givenRequest: exampleGetFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, GET form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "nok, POST form, value missing",
givenRequest: examplePostFormRequest(nil),
whenName: "nope",
expectError: ErrExtractionValueMissing.Error(),
},
{
name: "nok, POST form, form parsing error",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Body = nil
return req
}(),
whenName: "name",
expectError: "valuesFromForm parse form failed: missing form body",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := tc.givenRequest
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromForm(tc.whenName)
values, eType, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, FormExtractor, eType)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@ -1,21 +1,14 @@
// +build go1.15
package middleware
import (
"errors"
"fmt"
"net/http"
"reflect"
"strings"
"github.com/golang-jwt/jwt"
"github.com/labstack/echo/v4"
"net/http"
)
type (
// JWTConfig defines the config for JWT middleware.
JWTConfig struct {
type JWTConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -25,110 +18,67 @@ type (
// SuccessHandler defines a function which is executed for a valid token.
SuccessHandler JWTSuccessHandler
// ErrorHandler defines a function which is executed for an invalid token.
// ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
// function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
// It may be used to define a custom JWT error.
ErrorHandler JWTErrorHandler
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
ErrorHandlerWithContext JWTErrorHandlerWithContext
// Signing key to validate token.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKeys is provided.
SigningKey interface{}
// Map of signing keys to validate token with kid field usage.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither user-defined KeyFunc nor SigningKey is provided.
SigningKeys map[string]interface{}
// Signing method used to check the token's signing algorithm.
// Optional. Default value HS256.
SigningMethod string
//
// Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public JWT token value to request and continue with handler chain.
ErrorHandler JWTErrorHandlerWithContext
// Context key to store user information from the token into context.
// Optional. Default value "user".
ContextKey string
// Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation.
// Not used if custom ParseTokenFunc is set.
// Optional. Default value jwt.MapClaims
Claims jwt.Claims
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract token from the request.
// Optional. Default value "header:Authorization".
// to extract token(s) from the request.
// Optional. Default value "header:Authorization:Bearer ".
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "param:<name>"
// - "cookie:<name>"
// - "form:<name>"
// Multiply sources example:
// Multiple sources example:
// - "header:Authorization,cookie:myowncookie"
TokenLookup string
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// KeyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
// Used by default ParseTokenFunc implementation.
//
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither SigningKeys nor SigningKey is provided.
// Not used if custom ParseTokenFunc is set.
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
KeyFunc jwt.Keyfunc
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
// parsing fails or parsed token is invalid.
// Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library
ParseTokenFunc func(auth string, c echo.Context) (interface{}, error)
// NB: could be called multiple times per request when token lookup is able to extract multiple token values (i.e. multiple Authorization headers)
// See `jwt_external_test.go` for example implementation using `github.com/golang-jwt/jwt` as JWT implementation library
ParseTokenFunc func(c echo.Context, auth string) (interface{}, error)
}
// JWTSuccessHandler defines a function which is executed for a valid token.
JWTSuccessHandler func(echo.Context)
type JWTSuccessHandler func(c echo.Context)
// JWTErrorHandler defines a function which is executed for an invalid token.
JWTErrorHandler func(error) error
type JWTErrorHandler func(err error) error
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
JWTErrorHandlerWithContext func(error, echo.Context) error
type JWTErrorHandlerWithContext func(c echo.Context, err error) error
jwtExtractor func(echo.Context) (string, error)
)
type valuesExtractor func(c echo.Context) ([]string, ExtractorType, error)
// Algorithms
const (
// AlgorithmHS256 is token signing algorithm
AlgorithmHS256 = "HS256"
)
// Errors
var (
ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
)
// ErrJWTMissing denotes an error raised when JWT token value could not be extracted from request
var ErrJWTMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing or malformed jwt")
// ErrJWTInvalid denotes an error raised when JWT token value is invalid or expired
var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
var (
// DefaultJWTConfig is the default JWT auth middleware config.
DefaultJWTConfig = JWTConfig{
var DefaultJWTConfig = JWTConfig{
Skipper: DefaultSkipper,
SigningMethod: AlgorithmHS256,
ContextKey: "user",
TokenLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
Claims: jwt.MapClaims{},
KeyFunc: nil,
TokenLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
}
)
// JWT returns a JSON Web Token (JWT) auth middleware.
//
@ -137,64 +87,43 @@ var (
// For missing token, it returns "400 - Bad Request" error.
//
// See: https://jwt.io/introduction
// See `JWTConfig.TokenLookup`
func JWT(key interface{}) echo.MiddlewareFunc {
func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc {
c := DefaultJWTConfig
c.SigningKey = key
c.ParseTokenFunc = parseTokenFunc
return JWTWithConfig(c)
}
// JWTWithConfig returns a JWT auth middleware with config.
// See: `JWT()`.
// JWTWithConfig returns a JSON Web Token (JWT) auth middleware or panics if configuration is invalid.
//
// For valid token, it sets the user in context and calls next handler.
// For invalid token, it returns "401 - Unauthorized" error.
// For missing token, it returns "400 - Bad Request" error.
//
// See: https://jwt.io/introduction
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts JWTConfig to middleware or returns an error for invalid configuration
func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultJWTConfig.Skipper
}
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil {
panic("echo: jwt middleware requires signing key")
}
if config.SigningMethod == "" {
config.SigningMethod = DefaultJWTConfig.SigningMethod
if config.ParseTokenFunc == nil {
return nil, errors.New("echo jwt middleware requires parse token function")
}
if config.ContextKey == "" {
config.ContextKey = DefaultJWTConfig.ContextKey
}
if config.Claims == nil {
config.Claims = DefaultJWTConfig.Claims
}
if config.TokenLookup == "" {
config.TokenLookup = DefaultJWTConfig.TokenLookup
}
if config.AuthScheme == "" {
config.AuthScheme = DefaultJWTConfig.AuthScheme
}
if config.KeyFunc == nil {
config.KeyFunc = config.defaultKeyFunc
}
if config.ParseTokenFunc == nil {
config.ParseTokenFunc = config.defaultParseToken
}
// Initialize
// Split sources
sources := strings.Split(config.TokenLookup, ",")
var extractors []jwtExtractor
for _, source := range sources {
parts := strings.Split(source, ":")
switch parts[0] {
case "query":
extractors = append(extractors, jwtFromQuery(parts[1]))
case "param":
extractors = append(extractors, jwtFromParam(parts[1]))
case "cookie":
extractors = append(extractors, jwtFromCookie(parts[1]))
case "form":
extractors = append(extractors, jwtFromForm(parts[1]))
case "header":
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
extractors, err := createExtractors(config.TokenLookup)
if err != nil {
return nil, fmt.Errorf("echo jwt middleware could not create token extractor: %w", err)
}
if len(extractors) == 0 {
return nil, errors.New("echo jwt middleware could not create extractors from TokenLookup string")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -206,30 +135,20 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.BeforeFunc != nil {
config.BeforeFunc(c)
}
var auth string
var err error
var lastExtractorErr error
var lastTokenErr error
for _, extractor := range extractors {
// Extract token from extractor, if it's not fail break the loop and
// set auth
auth, err = extractor(c)
if err == nil {
break
auths, _, extrErr := extractor(c)
if extrErr != nil {
lastExtractorErr = extrErr
continue
}
}
// If none of extractor has a token, handle error
for _, auth := range auths {
token, err := config.ParseTokenFunc(c, auth)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err)
lastTokenErr = err
continue
}
if config.ErrorHandlerWithContext != nil {
return config.ErrorHandlerWithContext(err, c)
}
return err
}
token, err := config.ParseTokenFunc(auth, c)
if err == nil {
// Store user information from token into context.
c.Set(config.ContextKey, token)
if config.SuccessHandler != nil {
@ -237,111 +156,34 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
}
return next(c)
}
}
// prioritize token errors over extracting errors
err := lastTokenErr
if err == nil {
err = lastExtractorErr
}
if config.ErrorHandler != nil {
return config.ErrorHandler(err)
if err == ErrExtractionValueMissing {
err = ErrJWTMissing
}
if config.ErrorHandlerWithContext != nil {
return config.ErrorHandlerWithContext(err, c)
// Allow error handler to swallow the error and continue handler chain execution
// Useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public token to request and continue with handler chain
if handledErr := config.ErrorHandler(c, err); handledErr != nil {
return handledErr
}
return next(c)
}
if err == ErrExtractionValueMissing {
return ErrJWTMissing
}
// everything else goes under http.StatusUnauthorized to avoid exposing JWT internals with generic error
return &echo.HTTPError{
Code: ErrJWTInvalid.Code,
Message: ErrJWTInvalid.Message,
Internal: err,
}
}
}
}
func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) {
token := new(jwt.Token)
var err error
// Issue #647, #656
if _, ok := config.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, config.KeyFunc)
} else {
t := reflect.ValueOf(config.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
}
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
// defaultKeyFunc returns a signing key of the given token.
func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(config.SigningKeys) > 0 {
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := config.SigningKeys[kid]; ok {
return key, nil
}
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}
return config.SigningKey, nil
}
// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) jwtExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
l := len(authScheme)
if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) {
return auth[l+1:], nil
}
return "", ErrJWTMissing
}
}
// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string.
func jwtFromQuery(param string) jwtExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string.
func jwtFromParam(param string) jwtExtractor {
return func(c echo.Context) (string, error) {
token := c.Param(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie.
func jwtFromCookie(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
cookie, err := c.Cookie(name)
if err != nil {
return "", ErrJWTMissing
}
return cookie.Value, nil
}
}
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}, nil
}

View File

@ -0,0 +1,76 @@
package middleware_test
import (
"errors"
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"net/http"
"net/http/httptest"
)
// CreateJWTGoParseTokenFunc creates JWTGo implementation for ParseTokenFunc
//
// signingKey is signing key to validate token.
// This is one of the options to provide a token validation key.
// The order of precedence is a user-defined SigningKeys and SigningKey.
// Required if signingKeys is not provided.
//
// signingKeys is Map of signing keys to validate token with kid field usage.
// This is one of the options to provide a token validation key.
// The order of precedence is a user-defined SigningKeys and SigningKey.
// Required if signingKey is not provided
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) {
// keyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != middleware.AlgorithmHS256 {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
if len(signingKeys) == 0 {
return signingKey, nil
}
if kid, ok := t.Header["kid"].(string); ok {
if key, ok := signingKeys[kid]; ok {
return key, nil
}
}
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
}
return func(c echo.Context, auth string) (interface{}, error) {
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
}
func ExampleJWTConfig_withJWTGoAsTokenParser() {
mw := middleware.JWTWithConfig(middleware.JWTConfig{
ParseTokenFunc: CreateJWTGoParseTokenFunc([]byte("secret"), nil),
})
e := echo.New()
e.Use(mw)
e.GET("/", func(c echo.Context) error {
user := c.Get("user").(*jwt.Token)
return c.JSON(http.StatusTeapot, user.Claims)
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
fmt.Printf("status: %v, body: %v", res.Code, res.Body.String())
// Output: status: 418, body: {"admin":true,"name":"John Doe","sub":"1234567890"}
}

View File

@ -1,5 +1,3 @@
// +build go1.15
package middleware
import (
@ -11,11 +9,32 @@ import (
"strings"
"testing"
"github.com/golang-jwt/jwt"
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func createTestParseTokenFuncForJWTGo(signingMethod string, signingKey interface{}) func(c echo.Context, auth string) (interface{}, error) {
// This is minimal implementation for github.com/golang-jwt/jwt as JWT parser library. good enough to get old tests running
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != signingMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
return signingKey, nil
}
return func(c echo.Context, auth string) (interface{}, error) {
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc)
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
}
// jwtCustomInfo defines some custom types we're going to use within our tokens.
type jwtCustomInfo struct {
Name string `json:"name"`
@ -28,43 +47,7 @@ type jwtCustomClaims struct {
jwtCustomInfo
}
func TestJWTRace(t *testing.T) {
e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss"
validKey := []byte("secret")
h := JWTWithConfig(JWTConfig{
Claims: &jwtCustomClaims{},
SigningKey: validKey,
})(handler)
makeReq := func(token string) echo.Context {
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token)
c := e.NewContext(req, res)
assert.NoError(t, h(c))
return c
}
c := makeReq(initialToken)
user := c.Get("user").(*jwt.Token)
claims := user.Claims.(*jwtCustomClaims)
assert.Equal(t, claims.Name, "John Doe")
makeReq(raceToken)
user = c.Get("user").(*jwt.Token)
claims = user.Claims.(*jwtCustomClaims)
// Initial context should still be "John Doe", not "Race Condition"
assert.Equal(t, claims.Name, "John Doe")
assert.Equal(t, claims.Admin, true)
}
func TestJWT(t *testing.T) {
func TestJWT_combinations(t *testing.T) {
e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
@ -72,201 +55,180 @@ func TestJWT(t *testing.T) {
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
validKey := []byte("secret")
invalidKey := []byte("invalid-key")
validAuth := DefaultJWTConfig.AuthScheme + " " + token
validAuth := "Bearer " + token
for _, tc := range []struct {
expPanic bool
expErrCode int // 0 for Success
var testCases = []struct {
name string
config JWTConfig
reqURL string // "/" if empty
hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
info string
expPanic bool
expErrCode int // 0 for Success
}{
{
expPanic: true,
info: "No signing key provided",
},
{
expErrCode: http.StatusBadRequest,
config: JWTConfig{
SigningKey: validKey,
SigningMethod: "RS256",
},
info: "Unexpected signing method",
name: "No signing key provided",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: validAuth,
config: JWTConfig{SigningKey: invalidKey},
info: "Invalid key",
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo("RS256", validKey),
},
name: "Unexpected signing method",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: validAuth,
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, invalidKey),
},
name: "Invalid key",
},
{
hdrAuth: validAuth,
config: JWTConfig{SigningKey: validKey},
info: "Valid JWT",
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
name: "Valid JWT",
},
{
hdrAuth: "Token" + " " + token,
config: JWTConfig{AuthScheme: "Token", SigningKey: validKey},
info: "Valid JWT with custom AuthScheme",
config: JWTConfig{
TokenLookup: "header:" + echo.HeaderAuthorization + ":Token ",
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
name: "Valid JWT with custom AuthScheme",
},
{
hdrAuth: validAuth,
config: JWTConfig{
Claims: &jwtCustomClaims{},
SigningKey: []byte("secret"),
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
},
info: "Valid JWT with custom claims",
name: "Valid JWT with custom claims",
},
{
hdrAuth: "invalid-auth",
expErrCode: http.StatusBadRequest,
config: JWTConfig{SigningKey: validKey},
info: "Invalid Authorization header",
expErrCode: http.StatusUnauthorized,
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
{
config: JWTConfig{SigningKey: validKey},
expErrCode: http.StatusBadRequest,
info: "Empty header auth field",
name: "Invalid Authorization header",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
},
expErrCode: http.StatusUnauthorized,
name: "Empty header auth field",
},
{
config: JWTConfig{
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt",
},
reqURL: "/?a=b&jwt=" + token,
info: "Valid query method",
name: "Valid query method",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt",
},
reqURL: "/?a=b&jwtxyz=" + token,
expErrCode: http.StatusBadRequest,
info: "Invalid query param name",
expErrCode: http.StatusUnauthorized,
name: "Invalid query param name",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt",
},
reqURL: "/?a=b&jwt=invalid-token",
expErrCode: http.StatusUnauthorized,
info: "Invalid query param value",
name: "Invalid query param value",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt",
},
reqURL: "/?a=b",
expErrCode: http.StatusBadRequest,
info: "Empty query",
expErrCode: http.StatusUnauthorized,
name: "Empty query",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "param:jwt",
},
reqURL: "/" + token,
info: "Valid param method",
name: "Valid param method",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt",
},
hdrCookie: "jwt=" + token,
info: "Valid cookie method",
name: "Valid cookie method",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "query:jwt,cookie:jwt",
},
hdrCookie: "jwt=" + token,
info: "Multiple jwt lookuop",
name: "Multiple jwt lookuop",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt",
},
expErrCode: http.StatusUnauthorized,
hdrCookie: "jwt=invalid",
info: "Invalid token with cookie method",
name: "Invalid token with cookie method",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "cookie:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty cookie",
expErrCode: http.StatusUnauthorized,
name: "Empty cookie",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt",
},
formValues: map[string]string{"jwt": token},
info: "Valid form method",
name: "Valid form method",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt",
},
expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method",
name: "Invalid token with form method",
},
{
config: JWTConfig{
SigningKey: validKey,
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, validKey),
TokenLookup: "form:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
{
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return validKey, nil
},
},
info: "Valid JWT with a valid key using a user-defined KeyFunc",
},
{
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return invalidKey, nil
},
},
expErrCode: http.StatusUnauthorized,
info: "Valid JWT with an invalid key using a user-defined KeyFunc",
name: "Empty form field",
},
{
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return nil, errors.New("faulty KeyFunc")
},
},
expErrCode: http.StatusUnauthorized,
info: "Token verification does not pass using a user-defined KeyFunc",
},
{
hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token,
config: JWTConfig{SigningKey: validKey},
info: "Valid JWT with lower case AuthScheme",
},
} {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.reqURL == "" {
tc.reqURL = "/"
}
@ -289,127 +251,40 @@ func TestJWT(t *testing.T) {
c := e.NewContext(req, res)
if tc.reqURL == "/"+token {
c.SetParamNames("jwt")
c.SetParamValues(token)
cc := c.(echo.EditableContext)
cc.SetPathParams(echo.PathParams{
{Name: "jwt", Value: token},
})
}
if tc.expPanic {
assert.Panics(t, func() {
JWTWithConfig(tc.config)
}, tc.info)
continue
}, tc.name)
return
}
if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
assert.Equal(t, tc.expErrCode, he.Code, tc.info)
continue
assert.Equal(t, tc.expErrCode, he.Code)
return
}
h := JWTWithConfig(tc.config)(handler)
if assert.NoError(t, h(c), tc.info) {
if assert.NoError(t, h(c), tc.name) {
user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) {
case jwt.MapClaims:
assert.Equal(t, claims["name"], "John Doe", tc.info)
assert.Equal(t, claims["name"], "John Doe")
case *jwtCustomClaims:
assert.Equal(t, claims.Name, "John Doe", tc.info)
assert.Equal(t, claims.Admin, true, tc.info)
default:
panic("unexpected type of claims")
}
}
}
}
func TestJWTwithKID(t *testing.T) {
test := assert.New(t)
e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk"
secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM"
wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90"
staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o"
validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")}
invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")}
staticSecret := []byte("static_secret")
invalidStaticSecret := []byte("invalid_secret")
for _, tc := range []struct {
expErrCode int // 0 for Success
config JWTConfig
hdrAuth string
info string
}{
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
config: JWTConfig{SigningKeys: validKeys},
info: "First token valid",
},
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
config: JWTConfig{SigningKeys: validKeys},
info: "Second token valid",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken,
config: JWTConfig{SigningKeys: validKeys},
info: "Wrong key id token",
},
{
hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
config: JWTConfig{SigningKey: staticSecret},
info: "Valid static secret token",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
config: JWTConfig{SigningKey: invalidStaticSecret},
info: "Invalid static secret",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
config: JWTConfig{SigningKeys: invalidKeys},
info: "Invalid keys first token",
},
{
expErrCode: http.StatusUnauthorized,
hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
config: JWTConfig{SigningKeys: invalidKeys},
info: "Invalid keys second token",
},
} {
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
c := e.NewContext(req, res)
if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
test.Equal(tc.expErrCode, he.Code, tc.info)
continue
}
h := JWTWithConfig(tc.config)(handler)
if test.NoError(h(c), tc.info) {
user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) {
case jwt.MapClaims:
test.Equal(claims["name"], "John Doe", tc.info)
case *jwtCustomClaims:
test.Equal(claims.Name, "John Doe", tc.info)
test.Equal(claims.Admin, true, tc.info)
assert.Equal(t, claims.Name, "John Doe")
assert.Equal(t, claims.Admin, true)
default:
panic("unexpected type of claims")
}
}
})
}
}
@ -420,7 +295,7 @@ func TestJWTConfig_skipper(t *testing.T) {
Skipper: func(context echo.Context) bool {
return true // skip everything
},
SigningKey: []byte("secret"),
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
}))
isCalled := false
@ -448,11 +323,11 @@ func TestJWTConfig_BeforeFunc(t *testing.T) {
BeforeFunc: func(context echo.Context) {
isCalled = true
},
SigningKey: []byte("secret"),
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
@ -469,18 +344,8 @@ func TestJWTConfig_extractorErrorHandling(t *testing.T) {
{
name: "ok, ErrorHandler is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandler: func(err error) error {
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
},
},
expectStatusCode: http.StatusTeapot,
},
{
name: "ok, ErrorHandlerWithContext is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, context echo.Context) error {
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
ErrorHandler: func(c echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
},
},
@ -515,23 +380,13 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
{
name: "ok, ErrorHandler is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandler: func(err error) error {
ParseTokenFunc: createTestParseTokenFuncForJWTGo(AlgorithmHS256, []byte("secret")),
ErrorHandler: func(c echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error())
},
},
expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n",
},
{
name: "ok, ErrorHandlerWithContext is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, context echo.Context) error {
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error())
},
},
expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n",
},
}
for _, tc := range testCases {
@ -544,14 +399,14 @@ func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
config := tc.given
parseTokenCalled := false
config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) {
config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) {
parseTokenCalled = true
return nil, errors.New("parsing failed")
}
e.Use(JWTWithConfig(config))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
@ -574,7 +429,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
signingKey := []byte("secret")
config := JWTConfig{
ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) {
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != "HS256" {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
@ -597,9 +452,161 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
e.Use(JWTWithConfig(config))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
}
func TestMustJWTWithConfig_SuccessHandler(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
success := c.Get("success").(string)
user := c.Get("user").(string)
return c.String(http.StatusTeapot, fmt.Sprintf("%v:%v", success, user))
})
mw, err := JWTConfig{
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
return auth, nil
},
SuccessHandler: func(c echo.Context) {
c.Set("success", "yes")
},
}.ToMiddleware()
assert.NoError(t, err)
e.Use(mw)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderAuthorization, "Bearer valid_token_base64")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, "yes:valid_token_base64", res.Body.String())
assert.Equal(t, http.StatusTeapot, res.Code)
}
func TestJWTWithConfig_CallNextOnNilErrorHandlerResult(t *testing.T) {
var testCases = []struct {
name string
givenCallNext bool
givenErrorHandler JWTErrorHandlerWithContext
givenTokenLookup string
whenAuthHeaders []string
whenCookies []string
whenParseReturn string
whenParseError error
expectHandlerCalled bool
expect string
expectCode int
}{
{
name: "ok, with valid JWT from auth header",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
return nil
},
whenAuthHeaders: []string{"Bearer valid_token_base64"},
whenParseReturn: "valid_token",
expectCode: http.StatusTeapot,
expect: "valid_token",
},
{
name: "ok, missing header, callNext and set public_token from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
if err != ErrJWTMissing {
panic("must get ErrJWTMissing")
}
c.Set("user", "public_token")
return nil
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusTeapot,
expect: "public_token",
},
{
name: "ok, invalid token, callNext and set public_token from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
// this is probably not realistic usecase. on parse error you probably want to return error
if err.Error() != "parser_error" {
panic("must get parser_error")
}
c.Set("user", "public_token")
return nil
},
whenAuthHeaders: []string{"Bearer invalid_header"},
whenParseError: errors.New("parser_error"),
expectCode: http.StatusTeapot,
expect: "public_token",
},
{
name: "nok, invalid token, return error from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
if err.Error() != "parser_error" {
panic("must get parser_error")
}
return err
},
whenAuthHeaders: []string{"Bearer invalid_header"},
whenParseError: errors.New("parser_error"),
expectCode: http.StatusInternalServerError,
expect: "{\"message\":\"Internal Server Error\"}\n",
},
{
name: "nok, callNext but return error from error handler",
givenCallNext: true,
givenErrorHandler: func(c echo.Context, err error) error {
return err
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusUnauthorized,
expect: "{\"message\":\"missing or malformed jwt\"}\n",
},
{
name: "nok, callNext=false",
givenCallNext: false,
givenErrorHandler: func(c echo.Context, err error) error {
return err
},
whenAuthHeaders: []string{}, // no JWT header
expectCode: http.StatusUnauthorized,
expect: "{\"message\":\"missing or malformed jwt\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
token := c.Get("user").(string)
return c.String(http.StatusTeapot, token)
})
mw, err := JWTConfig{
TokenLookup: tc.givenTokenLookup,
ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) {
return tc.whenParseReturn, tc.whenParseError
},
ErrorHandler: tc.givenErrorHandler,
}.ToMiddleware()
assert.NoError(t, err)
e.Use(mw)
req := httptest.NewRequest(http.MethodGet, "/", nil)
for _, a := range tc.whenAuthHeaders {
req.Header.Add(echo.HeaderAuthorization, a)
}
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expect, res.Body.String())
assert.Equal(t, tc.expectCode, res.Code)
})
}
}

View File

@ -3,58 +3,59 @@ package middleware
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"net/http"
)
type (
// KeyAuthConfig defines the config for KeyAuth middleware.
KeyAuthConfig struct {
type KeyAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// KeyLookup is a string in the form of "<source>:<name>" that is used
// to extract key from the request.
// Optional. Default value "header:Authorization".
// KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract key(s) from the request.
// Optional. Default value "header:Authorization:Bearer ".
// Possible values:
// - "header:<name>"
// - "header:<name>:<value prefix>"
// - "query:<name>"
// - "form:<name>"
// - "param:<name>"
// - "cookie:<name>"
KeyLookup string `yaml:"key_lookup"`
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// - "form:<name>"
// Multiple sources example:
// - "header:Authorization:Bearer ,cookie:myowncookie"
KeyLookup string
// Validator is a function to validate key.
// Required.
Validator KeyAuthValidator
// ErrorHandler defines a function which is executed for an invalid key.
// ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
// function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
// It may be used to define a custom error.
//
// Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain.
ErrorHandler KeyAuthErrorHandler
}
// KeyAuthValidator defines a function to validate KeyAuth credentials.
KeyAuthValidator func(string, echo.Context) (bool, error)
keyExtractor func(echo.Context) (string, error)
type KeyAuthValidator func(c echo.Context, key string, keyType ExtractorType) (bool, error)
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
KeyAuthErrorHandler func(error, echo.Context) error
)
type KeyAuthErrorHandler func(c echo.Context, err error) error
// ErrKeyMissing denotes an error raised when key value could not be extracted from request
var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
// ErrInvalidKey denotes an error raised when key value is invalid by validator
var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
var (
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
DefaultKeyAuthConfig = KeyAuthConfig{
var DefaultKeyAuthConfig = KeyAuthConfig{
Skipper: DefaultSkipper,
KeyLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
}
)
// KeyAuth returns an KeyAuth middleware.
//
@ -67,34 +68,32 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
return KeyAuthWithConfig(c)
}
// KeyAuthWithConfig returns an KeyAuth middleware with config.
// See `KeyAuth()`.
// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid.
//
// For first valid key it calls the next handler.
// For invalid key, it sends "401 - Unauthorized" response.
// For missing key, it sends "400 - Bad Request" response.
func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration
func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultKeyAuthConfig.Skipper
}
// Defaults
if config.AuthScheme == "" {
config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
}
if config.KeyLookup == "" {
config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
}
if config.Validator == nil {
panic("echo: key-auth middleware requires a validator function")
return nil, errors.New("echo key-auth middleware requires a validator function")
}
// Initialize
parts := strings.Split(config.KeyLookup, ":")
extractor := keyFromHeader(parts[1], config.AuthScheme)
switch parts[0] {
case "query":
extractor = keyFromQuery(parts[1])
case "form":
extractor = keyFromForm(parts[1])
case "cookie":
extractor = keyFromCookie(parts[1])
extractors, err := createExtractors(config.KeyLookup)
if err != nil {
return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", err)
}
if len(extractors) == 0 {
return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -103,79 +102,50 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
return next(c)
}
// Extract and verify key
key, err := extractor(c)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
var lastExtractorErr error
var lastValidatorErr error
for _, extractor := range extractors {
keys, keyType, extrErr := extractor(c)
if extrErr != nil {
lastExtractorErr = extrErr
continue
}
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
valid, err := config.Validator(key, c)
for _, key := range keys {
valid, err := config.Validator(c, key, keyType)
if err != nil {
lastValidatorErr = err
continue
}
if !valid {
lastValidatorErr = ErrInvalidKey
continue
}
return next(c)
}
}
// prioritize validator errors over extracting errors
err := lastValidatorErr
if err == nil {
err = lastExtractorErr
}
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
// Allow error handler to swallow the error and continue handler chain execution
// Useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain
if handledErr := config.ErrorHandler(c, err); handledErr != nil {
return handledErr
}
return next(c)
}
if err == ErrExtractionValueMissing {
return ErrKeyMissing // do not wrap extractor errors
}
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "invalid key",
Message: "Unauthorized",
Internal: err,
}
} else if valid {
return next(c)
}
return echo.ErrUnauthorized
}
}
}
// keyFromHeader returns a `keyExtractor` that extracts key from the request header.
func keyFromHeader(header string, authScheme string) keyExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
if auth == "" {
return "", errors.New("missing key in request header")
}
if header == echo.HeaderAuthorization {
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", errors.New("invalid key in the request header")
}
return auth, nil
}
}
// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
func keyFromQuery(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.QueryParam(param)
if key == "" {
return "", errors.New("missing key in the query string")
}
return key, nil
}
}
// keyFromForm returns a `keyExtractor` that extracts key from the form.
func keyFromForm(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.FormValue(param)
if key == "" {
return "", errors.New("missing key in the form")
}
return key, nil
}
}
// keyFromCookie returns a `keyExtractor` that extracts key from the form.
func keyFromCookie(cookieName string) keyExtractor {
return func(c echo.Context) (string, error) {
key, err := c.Cookie(cookieName)
if err != nil {
return "", fmt.Errorf("missing key in cookies: %w", err)
}
return key.Value, nil
}
}, nil
}

View File

@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/assert"
)
func testKeyValidator(key string, c echo.Context) (bool, error) {
func testKeyValidator(c echo.Context, key string, keyType ExtractorType) (bool, error) {
switch key {
case "valid-key":
return true, nil
@ -28,7 +28,7 @@ func TestKeyAuth(t *testing.T) {
handlerCalled = true
return c.String(http.StatusOK, "test")
}
middlewareChain := KeyAuth(testKeyValidator)(handler)
middlewareChain := KeyAuthWithConfig(KeyAuthConfig{Validator: testKeyValidator})(handler)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
},
expectHandlerCalled: false,
expectError: "code=401, message=Unauthorized",
expectError: "code=401, message=Unauthorized, internal=code=401, message=invalid key",
},
{
name: "nok, defaults, invalid scheme in header",
@ -84,13 +84,13 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
},
expectHandlerCalled: false,
expectError: "code=400, message=invalid key in the request header",
expectError: "code=401, message=missing key",
},
{
name: "nok, defaults, missing header",
givenRequest: func(req *http.Request) {},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header",
expectError: "code=401, message=missing key",
},
{
name: "ok, custom key lookup, header",
@ -110,7 +110,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "header:API-Key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header",
expectError: "code=401, message=missing key",
},
{
name: "ok, custom key lookup, query",
@ -130,7 +130,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "query:key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in the query string",
expectError: "code=401, message=missing key",
},
{
name: "ok, custom key lookup, form",
@ -155,7 +155,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "form:key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in the form",
expectError: "code=401, message=missing key",
},
{
name: "ok, custom key lookup, cookie",
@ -179,20 +179,20 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "cookie:key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in cookies: http: named cookie not present",
expectError: "code=401, message=missing key",
},
{
name: "nok, custom errorHandler, error from extractor",
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:token"
conf.ErrorHandler = func(err error, context echo.Context) error {
conf.ErrorHandler = func(c echo.Context, err error) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err
return httpError
}
},
expectHandlerCalled: false,
expectError: "code=418, message=custom, internal=missing key in request header",
expectError: "code=418, message=custom, internal=code=400, message=missing or malformed value",
},
{
name: "nok, custom errorHandler, error from validator",
@ -200,7 +200,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.ErrorHandler = func(err error, context echo.Context) error {
conf.ErrorHandler = func(c echo.Context, err error) error {
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
httpError.Internal = err
return httpError
@ -216,7 +216,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
},
whenConfig: func(conf *KeyAuthConfig) {},
expectHandlerCalled: false,
expectError: "code=401, message=invalid key, internal=some user defined error",
expectError: "code=401, message=Unauthorized, internal=some user defined error",
},
}
@ -257,3 +257,96 @@ func TestKeyAuthWithConfig(t *testing.T) {
})
}
}
func TestKeyAuthWithConfig_errors(t *testing.T) {
var testCases = []struct {
name string
whenConfig KeyAuthConfig
expectError string
}{
{
name: "ok, no error",
whenConfig: KeyAuthConfig{
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
return false, nil
},
},
},
{
name: "ok, missing validator func",
whenConfig: KeyAuthConfig{
Validator: nil,
},
expectError: "echo key-auth middleware requires a validator function",
},
{
name: "ok, extractor source can not be split",
whenConfig: KeyAuthConfig{
KeyLookup: "nope",
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
return false, nil
},
},
expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope",
},
{
name: "ok, no extractors",
whenConfig: KeyAuthConfig{
KeyLookup: "nope:nope",
Validator: func(c echo.Context, key string, keyType ExtractorType) (bool, error) {
return false, nil
},
},
expectError: "echo key-auth middleware could not create extractors from KeyLookup string",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mw, err := tc.whenConfig.ToMiddleware()
if tc.expectError != "" {
assert.Nil(t, mw)
assert.EqualError(t, err, tc.expectError)
} else {
assert.NotNil(t, mw)
assert.NoError(t, err)
}
})
}
}
func TestMustKeyAuthWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
KeyAuthWithConfig(KeyAuthConfig{})
})
}
func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) {
handlerCalled := false
var authValue string
handler := func(c echo.Context) error {
handlerCalled = true
authValue = c.Get("auth").(string)
return c.String(http.StatusOK, "test")
}
middlewareChain := KeyAuthWithConfig(KeyAuthConfig{
Validator: testKeyValidator,
ErrorHandler: func(c echo.Context, err error) error {
// could check error to decide if we can swallow the error
c.Set("auth", "public")
return nil
},
})(handler)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
// no auth header this time
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := middlewareChain(c)
assert.NoError(t, err)
assert.True(t, handlerCalled)
assert.Equal(t, "public", authValue)
}

View File

@ -3,6 +3,7 @@ package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"strconv"
"strings"
@ -10,13 +11,11 @@ import (
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/color"
"github.com/valyala/fasttemplate"
)
type (
// LoggerConfig defines the config for Logger middleware.
LoggerConfig struct {
type LoggerConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -49,42 +48,41 @@ type (
// Example "${remote_ip} ${status}"
//
// Optional. Default value DefaultLoggerConfig.Format.
Format string `yaml:"format"`
Format string
// Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
CustomTimeFormat string `yaml:"custom_time_format"`
CustomTimeFormat string
// Output is a writer where logs in JSON format are written.
// Optional. Default value os.Stdout.
// Optional. Default destination `echo.Logger.Infof()`
Output io.Writer
template *fasttemplate.Template
colorer *color.Color
pool *sync.Pool
}
)
var (
// DefaultLoggerConfig is the default Logger middleware config.
DefaultLoggerConfig = LoggerConfig{
var DefaultLoggerConfig = LoggerConfig{
Skipper: DefaultSkipper,
Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` +
`"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` +
`"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` +
`,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n",
CustomTimeFormat: "2006-01-02 15:04:05.00000",
colorer: color.New(),
}
)
// Logger returns a middleware that logs HTTP requests.
func Logger() echo.MiddlewareFunc {
return LoggerWithConfig(DefaultLoggerConfig)
}
// LoggerWithConfig returns a Logger middleware with config.
// See: `Logger()`.
// LoggerWithConfig returns a Logger middleware with config or panics on invalid configuration.
func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts LoggerConfig to middleware or returns an error for invalid configuration
func (config LoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultLoggerConfig.Skipper
@ -92,13 +90,8 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
if config.Format == "" {
config.Format = DefaultLoggerConfig.Format
}
if config.Output == nil {
config.Output = DefaultLoggerConfig.Output
}
config.template = fasttemplate.New(config.Format, "${", "}")
config.colorer = color.New()
config.colorer.SetOutput(config.Output)
config.pool = &sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 256))
@ -106,23 +99,23 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
start := time.Now()
if err = next(c); err != nil {
c.Error(err)
}
err := next(c)
stop := time.Now()
buf := config.pool.Get().(*bytes.Buffer)
buf.Reset()
defer config.pool.Put(buf)
if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
_, tmplErr := config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
switch tag {
case "time_unix":
return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10))
@ -161,17 +154,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
case "user_agent":
return buf.WriteString(req.UserAgent())
case "status":
n := res.Status
s := config.colorer.Green(n)
switch {
case n >= 500:
s = config.colorer.Red(n)
case n >= 400:
s = config.colorer.Yellow(n)
case n >= 300:
s = config.colorer.Cyan(n)
status := res.Status
if err != nil {
if httpErr, ok := err.(*echo.HTTPError); ok {
status = httpErr.Code
}
return buf.WriteString(s)
}
return buf.WriteString(strconv.Itoa(status))
case "error":
if err != nil {
// Error may contain invalid JSON e.g. `"`
@ -201,23 +190,31 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
case strings.HasPrefix(tag, "form:"):
return buf.Write([]byte(c.FormValue(tag[5:])))
case strings.HasPrefix(tag, "cookie:"):
cookie, err := c.Cookie(tag[7:])
if err == nil {
cookie, cookieErr := c.Cookie(tag[7:])
if cookieErr == nil {
return buf.Write([]byte(cookie.Value))
}
}
}
return 0, nil
}); err != nil {
return
})
if tmplErr != nil {
if err != nil {
return fmt.Errorf("error in middleware chain and also failed to create log from template: %v: %w", tmplErr, err)
}
return fmt.Errorf("failed to create log from template: %w", tmplErr)
}
if config.Output == nil {
_, err = c.Logger().Output().Write(buf.Bytes())
return
if config.Output != nil {
if _, lErr := config.Output.Write(buf.Bytes()); lErr != nil {
return lErr
}
_, err = config.Output.Write(buf.Bytes())
return
} else {
if _, lErr := c.Echo().Logger.Write(buf.Bytes()); lErr != nil {
return lErr
}
}
return err
}
}, nil
}

View File

@ -61,7 +61,7 @@ func TestLoggerIPAddress(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
buf := new(bytes.Buffer)
e.Logger.SetOutput(buf)
e.Logger = &testLogger{output: buf}
ip := "127.0.0.1"
h := Logger()(func(c echo.Context) error {
return c.String(http.StatusOK, "test")

View File

@ -6,9 +6,8 @@ import (
"github.com/labstack/echo/v4"
)
type (
// MethodOverrideConfig defines the config for MethodOverride middleware.
MethodOverrideConfig struct {
type MethodOverrideConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -18,16 +17,13 @@ type (
}
// MethodOverrideGetter is a function that gets overridden method from the request
MethodOverrideGetter func(echo.Context) string
)
type MethodOverrideGetter func(echo.Context) string
var (
// DefaultMethodOverrideConfig is the default MethodOverride middleware config.
DefaultMethodOverrideConfig = MethodOverrideConfig{
var DefaultMethodOverrideConfig = MethodOverrideConfig{
Skipper: DefaultSkipper,
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
}
)
// MethodOverride returns a MethodOverride middleware.
// MethodOverride middleware checks for the overridden method from the request and
@ -38,9 +34,13 @@ func MethodOverride() echo.MiddlewareFunc {
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
}
// MethodOverrideWithConfig returns a MethodOverride middleware with config.
// See: `MethodOverride()`.
// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration.
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration
func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultMethodOverrideConfig.Skipper
@ -64,7 +64,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
}
return next(c)
}
}
}, nil
}
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from

View File

@ -22,28 +22,70 @@ func TestMethodOverride(t *testing.T) {
rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
c := e.NewContext(req, rec)
m(h)(c)
err := m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_formParam(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Override with form parameter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")})
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
rec = httptest.NewRecorder()
m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware()
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
c = e.NewContext(req, rec)
m(h)(c)
c := e.NewContext(req, rec)
err = m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_queryParam(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Override with query parameter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")})
req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
m(h)(c)
m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware()
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err = m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_ignoreGet(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Ignore `GET`
req = httptest.NewRequest(http.MethodGet, "/", nil)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodGet, req.Method)
}

View File

@ -9,14 +9,11 @@ import (
"github.com/labstack/echo/v4"
)
type (
// Skipper defines a function to skip middleware. Returning true skips processing
// the middleware.
Skipper func(echo.Context) bool
// Skipper defines a function to skip middleware. Returning true skips processing the middleware.
type Skipper func(c echo.Context) bool
// BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc func(echo.Context)
)
type BeforeFunc func(c echo.Context)
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
groups := pattern.FindAllStringSubmatch(input, -1)
@ -87,3 +84,11 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error
func DefaultSkipper(echo.Context) bool {
return false
}
func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc {
mw, err := config.ToMiddleware()
if err != nil {
panic(err)
}
return mw
}

View File

@ -2,6 +2,7 @@ package middleware
import (
"context"
"errors"
"fmt"
"io"
"math/rand"
@ -20,9 +21,8 @@ import (
// TODO: Handle TLS proxy
type (
// ProxyConfig defines the config for Proxy middleware.
ProxyConfig struct {
type ProxyConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -59,46 +59,43 @@ type (
}
// ProxyTarget defines the upstream target.
ProxyTarget struct {
type ProxyTarget struct {
Name string
URL *url.URL
Meta echo.Map
}
// ProxyBalancer defines an interface to implement a load balancing technique.
ProxyBalancer interface {
type ProxyBalancer interface {
AddTarget(*ProxyTarget) bool
RemoveTarget(string) bool
Next(echo.Context) *ProxyTarget
}
commonBalancer struct {
type commonBalancer struct {
targets []*ProxyTarget
mutex sync.RWMutex
}
// RandomBalancer implements a random load balancing technique.
randomBalancer struct {
type randomBalancer struct {
*commonBalancer
random *rand.Rand
}
// RoundRobinBalancer implements a round-robin load balancing technique.
roundRobinBalancer struct {
type roundRobinBalancer struct {
*commonBalancer
i uint32
}
)
var (
// DefaultProxyConfig is the default Proxy middleware config.
DefaultProxyConfig = ProxyConfig{
var DefaultProxyConfig = ProxyConfig{
Skipper: DefaultSkipper,
ContextKey: "target",
}
)
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
func proxyRaw(c echo.Context, t *ProxyTarget) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
in, _, err := c.Response().Hijack()
if err != nil {
@ -203,15 +200,23 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
return ProxyWithConfig(c)
}
// ProxyWithConfig returns a Proxy middleware with config.
// See: `Proxy()`
// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid.
//
// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration
func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultProxyConfig.Skipper
}
if config.ContextKey == "" {
config.ContextKey = DefaultProxyConfig.ContextKey
}
if config.Balancer == nil {
panic("echo: proxy middleware requires balancer")
return nil, errors.New("echo proxy middleware requires balancer")
}
if config.Rewrite != nil {
@ -254,10 +259,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req)
proxyRaw(c, tgt).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
default:
proxyHTTP(tgt, c, config).ServeHTTP(res, req)
proxyHTTP(c, tgt, config).ServeHTTP(res, req)
}
if e, ok := c.Get("_error").(error); ok {
err = e
@ -265,7 +270,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
return
}
}
}, nil
}
// StatusCodeContextCanceled is a custom HTTP status code for situations
@ -275,7 +280,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
// 499 too instead of the more problematic 5xx, which does not allow to detect this situation
const StatusCodeContextCanceled = 499
func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
func proxyHTTP(c echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler {
proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
desc := tgt.URL.String()

View File

@ -55,7 +55,7 @@ func TestProxy(t *testing.T) {
// Random
e := echo.New()
e.Use(Proxy(rb))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
@ -77,7 +77,7 @@ func TestProxy(t *testing.T) {
// Round-robin
rrb := NewRoundRobinBalancer(targets)
e = echo.New()
e.Use(Proxy(rrb))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
@ -113,15 +113,20 @@ func TestProxy(t *testing.T) {
return nil
}
}
rrb1 := NewRoundRobinBalancer(targets)
e = echo.New()
e.Use(contextObserver)
e.Use(Proxy(rrb1))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)}))
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
}
func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) {
assert.Panics(t, func() {
ProxyWithConfig(ProxyConfig{Balancer: nil})
})
}
func TestProxyRealIPHeader(t *testing.T) {
// Setup
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
@ -129,7 +134,7 @@ func TestProxyRealIPHeader(t *testing.T) {
url, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
e := echo.New()
e.Use(Proxy(rrb))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
@ -334,7 +339,7 @@ func TestProxyError(t *testing.T) {
// Random
e := echo.New()
e.Use(Proxy(rb))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
@ -362,7 +367,7 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
rb := NewRandomBalancer(nil)
assert.True(t, rb.AddTarget(target))
e := echo.New()
e.Use(Proxy(rb))
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(req.Context())

View File

@ -1,6 +1,7 @@
package middleware
import (
"errors"
"net/http"
"sync"
"time"
@ -9,17 +10,13 @@ import (
"golang.org/x/time/rate"
)
type (
// RateLimiterStore is the interface to be implemented by custom stores.
RateLimiterStore interface {
// Stores for the rate limiter have to implement the Allow method
type RateLimiterStore interface {
Allow(identifier string) (bool, error)
}
)
type (
// RateLimiterConfig defines the configuration for the rate limiter
RateLimiterConfig struct {
type RateLimiterConfig struct {
Skipper Skipper
BeforeFunc BeforeFunc
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor
@ -31,17 +28,15 @@ type (
// DenyHandler provides a handler to be called when RateLimiter denies access
DenyHandler func(context echo.Context, identifier string, err error) error
}
// Extractor is used to extract data from echo.Context
Extractor func(context echo.Context) (string, error)
)
// errors
var (
// Extractor is used to extract data from echo.Context
type Extractor func(context echo.Context) (string, error)
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
// ErrExtractorError denotes an error raised when extractor function is unsuccessful
ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
)
var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
// DefaultRateLimiterConfig defines default values for RateLimiterConfig
var DefaultRateLimiterConfig = RateLimiterConfig{
@ -111,6 +106,11 @@ RateLimiterWithConfig returns a rate limiting middleware
}, middleware.RateLimiterWithConfig(config))
*/
func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration
func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultRateLimiterConfig.Skipper
}
@ -124,7 +124,7 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
}
if config.Store == nil {
panic("Store configuration must be provided")
return nil, errors.New("echo rate limiter store configuration must be provided")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
@ -137,22 +137,19 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
identifier, err := config.IdentifierExtractor(c)
if err != nil {
c.Error(config.ErrorHandler(c, err))
return nil
return config.ErrorHandler(c, err)
}
if allow, err := config.Store.Allow(identifier); !allow {
c.Error(config.DenyHandler(c, identifier, err))
return nil
if allow, allowErr := config.Store.Allow(identifier); !allow {
return config.DenyHandler(c, identifier, allowErr)
}
return next(c)
}
}
}, nil
}
type (
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
RateLimiterMemoryStore struct {
type RateLimiterMemoryStore struct {
visitors map[string]*Visitor
mutex sync.Mutex
rate rate.Limit
@ -160,12 +157,12 @@ type (
expiresIn time.Duration
lastCleanup time.Time
}
// Visitor signifies a unique user's limiter details
Visitor struct {
type Visitor struct {
*rate.Limiter
lastSeen time.Time
}
)
/*
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with

View File

@ -11,7 +11,6 @@ import (
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
)
@ -25,19 +24,19 @@ func TestRateLimiter(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
mw := RateLimiter(inMemoryStore)
mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
testCases := []struct {
id string
code int
expectErr string
}{
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
@ -47,20 +46,25 @@ func TestRateLimiter(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
_ = mw(handler)(c)
assert.Equal(t, tc.code, rec.Code)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
func TestRateLimiter_panicBehaviour(t *testing.T) {
func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
assert.Panics(t, func() {
RateLimiter(nil)
RateLimiterWithConfig(RateLimiterConfig{})
})
assert.NotPanics(t, func() {
RateLimiter(inMemoryStore)
RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
})
}
@ -73,7 +77,7 @@ func TestRateLimiterWithConfig(t *testing.T) {
return c.String(http.StatusOK, "test")
}
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
IdentifierExtractor: func(c echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" {
@ -88,7 +92,8 @@ func TestRateLimiterWithConfig(t *testing.T) {
return ctx.JSON(http.StatusBadRequest, nil)
},
Store: inMemoryStore,
})
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
@ -111,8 +116,9 @@ func TestRateLimiterWithConfig(t *testing.T) {
c := e.NewContext(req, rec)
_ = mw(handler)(c)
err := mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, tc.code, rec.Code)
}
}
@ -126,7 +132,7 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
return c.String(http.StatusOK, "test")
}
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
IdentifierExtractor: func(c echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" {
@ -135,19 +141,20 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
return id, nil
},
Store: inMemoryStore,
})
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
code int
expectErr string
}{
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusTooManyRequests},
{"", http.StatusForbidden},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{expectErr: "code=403, message=error while extracting identifier, internal=invalid identifier"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
@ -158,9 +165,13 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
c := e.NewContext(req, rec)
_ = mw(handler)(c)
assert.Equal(t, tc.code, rec.Code)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
@ -174,21 +185,22 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
return c.String(http.StatusOK, "test")
}
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
Store: inMemoryStore,
})
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
code int
expectErr string
}{
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{"127.0.0.1", http.StatusTooManyRequests},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
@ -199,9 +211,13 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
c := e.NewContext(req, rec)
_ = mw(handler)(c)
assert.Equal(t, tc.code, rec.Code)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
}
@ -222,7 +238,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
Skipper: func(c echo.Context) bool {
return true
},
@ -233,10 +249,12 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil
},
})
}.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c)
err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, false, beforeFuncRan)
}
@ -256,7 +274,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
Skipper: func(c echo.Context) bool {
return false
},
@ -267,7 +285,8 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil
},
})
}.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c)
@ -291,7 +310,7 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
c := e.NewContext(req, rec)
mw := RateLimiterWithConfig(RateLimiterConfig{
mw, err := RateLimiterConfig{
BeforeFunc: func(c echo.Context) {
beforeRan = true
},
@ -299,10 +318,12 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
IdentifierExtractor: func(ctx echo.Context) (string, error) {
return "127.0.0.1", nil
},
})
}.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c)
err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, true, beforeRan)
}
@ -413,7 +434,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) {
func generateAddressList(count int) []string {
addrs := make([]string, count)
for i := 0; i < count; i++ {
addrs[i] = random.String(15)
addrs[i] = randomString(15)
}
return addrs
}

View File

@ -5,44 +5,34 @@ import (
"runtime"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
)
type (
// RecoverConfig defines the config for Recover middleware.
RecoverConfig struct {
type RecoverConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Size of the stack to be printed.
// Optional. Default value 4KB.
StackSize int `yaml:"stack_size"`
StackSize int
// DisableStackAll disables formatting stack traces of all other goroutines
// into buffer after the trace for the current goroutine.
// Optional. Default value false.
DisableStackAll bool `yaml:"disable_stack_all"`
DisableStackAll bool
// DisablePrintStack disables printing stack trace.
// Optional. Default value as false.
DisablePrintStack bool `yaml:"disable_print_stack"`
// LogLevel is log level to printing stack trace.
// Optional. Default value 0 (Print).
LogLevel log.Lvl
DisablePrintStack bool
}
)
var (
// DefaultRecoverConfig is the default Recover middleware config.
DefaultRecoverConfig = RecoverConfig{
var DefaultRecoverConfig = RecoverConfig{
Skipper: DefaultSkipper,
StackSize: 4 << 10, // 4 KB
DisableStackAll: false,
DisablePrintStack: false,
LogLevel: 0,
}
)
// Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
@ -50,9 +40,13 @@ func Recover() echo.MiddlewareFunc {
return RecoverWithConfig(DefaultRecoverConfig)
}
// RecoverWithConfig returns a Recover middleware with config.
// See: `Recover()`.
// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration.
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration
func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultRecoverConfig.Skipper
@ -62,40 +56,26 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
defer func() {
if r := recover(); r != nil {
err, ok := r.(error)
tmpErr, ok := r.(error)
if !ok {
err = fmt.Errorf("%v", r)
tmpErr = fmt.Errorf("%v", r)
}
if !config.DisablePrintStack {
stack := make([]byte, config.StackSize)
length := runtime.Stack(stack, !config.DisableStackAll)
if !config.DisablePrintStack {
msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])
switch config.LogLevel {
case log.DEBUG:
c.Logger().Debug(msg)
case log.INFO:
c.Logger().Info(msg)
case log.WARN:
c.Logger().Warn(msg)
case log.ERROR:
c.Logger().Error(msg)
case log.OFF:
// None.
default:
c.Logger().Print(msg)
tmpErr = fmt.Errorf("[PANIC RECOVER] %w %s", tmpErr, stack[:length])
}
}
c.Error(err)
err = tmpErr
}
}()
return next(c)
}
}
}, nil
}

View File

@ -2,82 +2,109 @@ package middleware
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/stretchr/testify/assert"
)
func TestRecover(t *testing.T) {
e := echo.New()
buf := new(bytes.Buffer)
e.Logger.SetOutput(buf)
e.Logger = &testLogger{output: buf}
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Recover()(echo.HandlerFunc(func(c echo.Context) error {
h := Recover()(func(c echo.Context) error {
panic("test")
}))
h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, buf.String(), "PANIC RECOVER")
})
err := h(c)
assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine")
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
assert.Contains(t, buf.String(), "") // nothing is logged
}
func TestRecoverWithConfig_LogLevel(t *testing.T) {
tests := []struct {
logLevel log.Lvl
levelName string
}{{
logLevel: log.DEBUG,
levelName: "DEBUG",
}, {
logLevel: log.INFO,
levelName: "INFO",
}, {
logLevel: log.WARN,
levelName: "WARN",
}, {
logLevel: log.ERROR,
levelName: "ERROR",
}, {
logLevel: log.OFF,
levelName: "OFF",
}}
for _, tt := range tests {
tt := tt
t.Run(tt.levelName, func(t *testing.T) {
func TestRecover_skipper(t *testing.T) {
e := echo.New()
e.Logger.SetLevel(log.DEBUG)
buf := new(bytes.Buffer)
e.Logger.SetOutput(buf)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
config := DefaultRecoverConfig
config.LogLevel = tt.logLevel
h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error {
panic("test")
}))
h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
output := buf.String()
if tt.logLevel == log.OFF {
assert.Empty(t, output)
} else {
assert.Contains(t, output, "PANIC RECOVER")
assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName))
config := RecoverConfig{
Skipper: func(c echo.Context) bool {
return true
},
}
h := RecoverWithConfig(config)(func(c echo.Context) error {
panic("testPANIC")
})
var err error
assert.Panics(t, func() {
err = h(c)
})
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
}
func TestRecoverWithConfig(t *testing.T) {
var testCases = []struct {
name string
givenNoPanic bool
whenConfig RecoverConfig
expectErrContain string
expectErr string
}{
{
name: "ok, default config",
whenConfig: DefaultRecoverConfig,
expectErrContain: "[PANIC RECOVER] testPANIC goroutine",
},
{
name: "ok, no panic",
givenNoPanic: true,
whenConfig: DefaultRecoverConfig,
expectErrContain: "",
},
{
name: "ok, DisablePrintStack",
whenConfig: RecoverConfig{
DisablePrintStack: true,
},
expectErr: "testPANIC",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
config := tc.whenConfig
h := RecoverWithConfig(config)(func(c echo.Context) error {
if tc.givenNoPanic {
return nil
}
panic("testPANIC")
})
err := h(c)
if tc.expectErrContain != "" {
assert.Contains(t, err.Error(), tc.expectErrContain)
} else if tc.expectErr != "" {
assert.Contains(t, err.Error(), tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
})
}
}

View File

@ -1,6 +1,7 @@
package middleware
import (
"errors"
"net/http"
"strings"
@ -14,7 +15,9 @@ type RedirectConfig struct {
// Status code to be used when redirecting the request.
// Optional. Default value http.StatusMovedPermanently.
Code int `yaml:"code"`
Code int
redirect redirectLogic
}
// redirectLogic represents a function that given a scheme, host and uri
@ -24,29 +27,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string)
const www = "www."
// DefaultRedirectConfig is the default Redirect middleware config.
var DefaultRedirectConfig = RedirectConfig{
Skipper: DefaultSkipper,
Code: http.StatusMovedPermanently,
}
// RedirectHTTPSConfig is the HTTPS Redirect middleware config.
var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS}
// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config.
var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW}
// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config.
var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW}
// RedirectWWWConfig is the WWW Redirect middleware config.
var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW}
// RedirectNonWWWConfig is the non WWW Redirect middleware config.
var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW}
// HTTPSRedirect redirects http requests to https.
// For example, http://labstack.com will be redirect to https://labstack.com.
//
// Usage `Echo#Pre(HTTPSRedirect())`
func HTTPSRedirect() echo.MiddlewareFunc {
return HTTPSRedirectWithConfig(DefaultRedirectConfig)
return HTTPSRedirectWithConfig(RedirectHTTPSConfig)
}
// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `HTTPSRedirect()`.
// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration.
func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if scheme != "https" {
return true, "https://" + host + uri
}
return false, ""
})
config.redirect = redirectHTTPS
return toMiddlewareOrPanic(config)
}
// HTTPSWWWRedirect redirects http requests to https www.
@ -54,18 +61,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(HTTPSWWWRedirect())`
func HTTPSWWWRedirect() echo.MiddlewareFunc {
return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig)
return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig)
}
// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `HTTPSWWWRedirect()`.
// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration.
func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if scheme != "https" && !strings.HasPrefix(host, www) {
return true, "https://www." + host + uri
}
return false, ""
})
config.redirect = redirectHTTPSWWW
return toMiddlewareOrPanic(config)
}
// HTTPSNonWWWRedirect redirects http requests to https non www.
@ -73,19 +75,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(HTTPSNonWWWRedirect())`
func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig)
return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig)
}
// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `HTTPSNonWWWRedirect()`.
// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration.
func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
if scheme != "https" {
host = strings.TrimPrefix(host, www)
return true, "https://" + host + uri
}
return false, ""
})
config.redirect = redirectNonHTTPSWWW
return toMiddlewareOrPanic(config)
}
// WWWRedirect redirects non www requests to www.
@ -93,18 +89,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(WWWRedirect())`
func WWWRedirect() echo.MiddlewareFunc {
return WWWRedirectWithConfig(DefaultRedirectConfig)
return WWWRedirectWithConfig(RedirectWWWConfig)
}
// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `WWWRedirect()`.
// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration.
func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if !strings.HasPrefix(host, www) {
return true, scheme + "://www." + host + uri
}
return false, ""
})
config.redirect = redirectWWW
return toMiddlewareOrPanic(config)
}
// NonWWWRedirect redirects www requests to non www.
@ -112,26 +103,25 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(NonWWWRedirect())`
func NonWWWRedirect() echo.MiddlewareFunc {
return NonWWWRedirectWithConfig(DefaultRedirectConfig)
return NonWWWRedirectWithConfig(RedirectNonWWWConfig)
}
// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
// See `NonWWWRedirect()`.
// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration.
func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
return redirect(config, func(scheme, host, uri string) (bool, string) {
if strings.HasPrefix(host, www) {
return true, scheme + "://" + host[4:] + uri
}
return false, ""
})
config.redirect = redirectNonWWW
return toMiddlewareOrPanic(config)
}
func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration
func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultRedirectConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.Code == 0 {
config.Code = DefaultRedirectConfig.Code
config.Code = http.StatusMovedPermanently
}
if config.redirect == nil {
return nil, errors.New("redirectConfig is missing redirect function")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -142,11 +132,47 @@ func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
req, scheme := c.Request(), c.Scheme()
host := req.Host
if ok, url := cb(scheme, host, req.RequestURI); ok {
if ok, url := config.redirect(scheme, host, req.RequestURI); ok {
return c.Redirect(config.Code, url)
}
return next(c)
}
}, nil
}
var redirectHTTPS = func(scheme, host, uri string) (bool, string) {
if scheme != "https" {
return true, "https://" + host + uri
}
return false, ""
}
var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) {
if scheme != "https" && !strings.HasPrefix(host, www) {
return true, "https://www." + host + uri
}
return false, ""
}
var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) {
if scheme != "https" {
host = strings.TrimPrefix(host, www)
return true, "https://" + host + uri
}
return false, ""
}
var redirectWWW = func(scheme, host, uri string) (bool, string) {
if !strings.HasPrefix(host, www) {
return true, scheme + "://www." + host + uri
}
return false, ""
}
var redirectNonWWW = func(scheme, host, uri string) (bool, string) {
if strings.HasPrefix(host, www) {
return true, scheme + "://" + host[4:] + uri
}
return false, ""
}

View File

@ -2,12 +2,10 @@ package middleware
import (
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
)
type (
// RequestIDConfig defines the config for RequestID middleware.
RequestIDConfig struct {
type RequestIDConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -16,31 +14,26 @@ type (
Generator func() string
// RequestIDHandler defines a function which is executed for a request id.
RequestIDHandler func(echo.Context, string)
RequestIDHandler func(c echo.Context, requestID string)
}
)
var (
// DefaultRequestIDConfig is the default RequestID middleware config.
DefaultRequestIDConfig = RequestIDConfig{
Skipper: DefaultSkipper,
Generator: generator,
}
)
// RequestID returns a X-Request-ID middleware.
func RequestID() echo.MiddlewareFunc {
return RequestIDWithConfig(DefaultRequestIDConfig)
return RequestIDWithConfig(RequestIDConfig{})
}
// RequestIDWithConfig returns a X-Request-ID middleware with config.
// RequestIDWithConfig returns a X-Request-ID middleware with config or panics on invalid configuration.
func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration
func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultRequestIDConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.Generator == nil {
config.Generator = generator
config.Generator = createRandomStringGenerator(32)
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -62,9 +55,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
return next(c)
}
}
}
func generator() string {
return random.String(32)
}, nil
}

View File

@ -18,25 +18,104 @@ func TestRequestID(t *testing.T) {
return c.String(http.StatusOK, "test")
}
rid := RequestIDWithConfig(RequestIDConfig{})
rid := RequestID()
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
}
func TestMustRequestIDWithConfig_skipper(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
generatorCalled := false
e.Use(RequestIDWithConfig(RequestIDConfig{
Skipper: func(c echo.Context) bool {
return true
},
Generator: func() string {
generatorCalled = true
return "customGenerator"
},
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
assert.Equal(t, "test", res.Body.String())
assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "")
assert.False(t, generatorCalled)
}
func TestMustRequestIDWithConfig_customGenerator(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid := RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return "customGenerator" },
})
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
}
func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
called := false
rid := RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return "customGenerator" },
RequestIDHandler: func(c echo.Context, s string) {
called = true
},
})
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
assert.True(t, called)
}
func TestRequestIDWithConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid, err := RequestIDConfig{}.ToMiddleware()
assert.NoError(t, err)
h := rid(handler)
h(c)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
// Custom generator and handler
customID := "customGenerator"
calledHandler := false
// Custom generator
rid = RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return customID },
RequestIDHandler: func(_ echo.Context, id string) {
calledHandler = true
assert.Equal(t, customID, id)
},
Generator: func() string { return "customGenerator" },
})
h = rid(handler)
h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
assert.True(t, calledHandler)
}
func TestRequestID_IDNotAltered(t *testing.T) {

View File

@ -24,6 +24,7 @@ import (
// LogStatus: true,
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// logger.Info().
// Date("request_start", v.StartTime).
// Str("URI", v.URI).
// Int("status", v.Status).
// Msg("request")
@ -39,6 +40,7 @@ import (
// LogStatus: true,
// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error {
// logger.Info("request",
// zap.Time("request_start", v.StartTime),
// zap.String("URI", v.URI),
// zap.Int("status", v.Status),
// )
@ -54,6 +56,7 @@ import (
// LogStatus: true,
// LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error {
// log.WithFields(logrus.Fields{
// "request_start": values.StartTime,
// "URI": values.URI,
// "status": values.Status,
// }).Info("request")
@ -158,15 +161,15 @@ type RequestLoggerValues struct {
// ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct.
ResponseSize int64
// Headers are list of headers from request. Note: request can contain more than one header with same value so slice
// of values is been logger for each given header.
// of values is what will be returned/logged for each given header.
// Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
// names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
Headers map[string][]string
// QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter
// with same name so slice of values is been logger for each given query param name.
// with same name so slice of values is what will be returned/logged for each given query param name.
QueryParams map[string][]string
// FormValues are list of form values from request body+URI. Note: request can contain more than one form value with
// same name so slice of values is been logger for each given form value name.
// same name so slice of values is what will be returned/logged for each given form value name.
FormValues map[string][]string
}

View File

@ -289,7 +289,7 @@ func TestRequestLogger_allFields(t *testing.T) {
req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c := e.NewContext(req, rec).(echo.EditableContext)
c.SetPath("/test*")

View File

@ -1,14 +1,14 @@
package middleware
import (
"errors"
"regexp"
"github.com/labstack/echo/v4"
)
type (
// RewriteConfig defines the config for Rewrite middleware.
RewriteConfig struct {
type RewriteConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -20,43 +20,39 @@ type (
// "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
// Required.
Rules map[string]string `yaml:"rules"`
Rules map[string]string
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"`
RegexRules map[*regexp.Regexp]string
}
)
var (
// DefaultRewriteConfig is the default Rewrite middleware config.
DefaultRewriteConfig = RewriteConfig{
Skipper: DefaultSkipper,
}
)
// Rewrite returns a Rewrite middleware.
//
// Rewrite middleware rewrites the URL path based on the provided rules.
func Rewrite(rules map[string]string) echo.MiddlewareFunc {
c := DefaultRewriteConfig
c := RewriteConfig{}
c.Rules = rules
return RewriteWithConfig(c)
}
// RewriteWithConfig returns a Rewrite middleware with config.
// See: `Rewrite()`.
// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration.
//
// Rewrite middleware rewrites the URL path based on the provided rules.
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
// Defaults
if config.Rules == nil && config.RegexRules == nil {
panic("echo: rewrite middleware requires url path rewrite rules or regex rules")
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration
func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.Rules == nil && config.RegexRules == nil {
return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules")
}
if config.RegexRules == nil {
@ -77,5 +73,5 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
}
return next(c)
}
}
}, nil
}

View File

@ -24,10 +24,10 @@ func TestRewriteAfterRouting(t *testing.T) {
},
}))
e.GET("/public/*", func(c echo.Context) error {
return c.String(http.StatusOK, c.Param("*"))
return c.String(http.StatusOK, c.PathParam("*"))
})
e.GET("/*", func(c echo.Context) error {
return c.String(http.StatusOK, c.Param("*"))
return c.String(http.StatusOK, c.PathParam("*"))
})
var testCases = []struct {
@ -90,20 +90,74 @@ func TestRewriteAfterRouting(t *testing.T) {
}
}
func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) {
assert.Panics(t, func() {
RewriteWithConfig(RewriteConfig{})
})
}
func TestMustRewriteWithConfig_skipper(t *testing.T) {
var testCases = []struct {
name string
givenSkipper func(c echo.Context) bool
whenURL string
expectURL string
expectStatus int
}{
{
name: "not skipped",
whenURL: "/old",
expectURL: "/new",
expectStatus: http.StatusOK,
},
{
name: "skipped",
givenSkipper: func(c echo.Context) bool {
return true
},
whenURL: "/old",
expectURL: "/old",
expectStatus: http.StatusNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Pre(RewriteWithConfig(
RewriteConfig{
Skipper: tc.givenSkipper,
Rules: map[string]string{"/old": "/new"}},
))
e.GET("/new", func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectURL, req.URL.EscapedPath())
assert.Equal(t, tc.expectStatus, rec.Code)
})
}
}
// Issue #1086
func TestEchoRewritePreMiddleware(t *testing.T) {
e := echo.New()
r := e.Router()
// Rewrite old url to new one
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(Rewrite(map[string]string{
"/old": "/new",
},
))
e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{"/old": "/new"}}),
)
// Route
r.Add(http.MethodGet, "/new", func(c echo.Context) error {
e.Add(http.MethodGet, "/new", func(c echo.Context) error {
return c.NoContent(http.StatusOK)
})
@ -117,7 +171,6 @@ func TestEchoRewritePreMiddleware(t *testing.T) {
// Issue #1143
func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
e := echo.New()
r := e.Router()
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(RewriteWithConfig(RewriteConfig{
@ -127,10 +180,10 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
},
}))
r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
return c.String(http.StatusOK, "hosts")
})
r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
e.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
return c.String(http.StatusOK, "eng")
})

View File

@ -6,21 +6,20 @@ import (
"github.com/labstack/echo/v4"
)
type (
// SecureConfig defines the config for Secure middleware.
SecureConfig struct {
type SecureConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// XSSProtection provides protection against cross-site scripting attack (XSS)
// by setting the `X-XSS-Protection` header.
// Optional. Default value "1; mode=block".
XSSProtection string `yaml:"xss_protection"`
XSSProtection string
// ContentTypeNosniff provides protection against overriding Content-Type
// header by setting the `X-Content-Type-Options` header.
// Optional. Default value "nosniff".
ContentTypeNosniff string `yaml:"content_type_nosniff"`
ContentTypeNosniff string
// XFrameOptions can be used to indicate whether or not a browser should
// be allowed to render a page in a <frame>, <iframe> or <object> .
@ -32,58 +31,55 @@ type (
// - "SAMEORIGIN" - The page can only be displayed in a frame on the same origin as the page itself.
// - "DENY" - The page cannot be displayed in a frame, regardless of the site attempting to do so.
// - "ALLOW-FROM uri" - The page can only be displayed in a frame on the specified origin.
XFrameOptions string `yaml:"x_frame_options"`
XFrameOptions string
// HSTSMaxAge sets the `Strict-Transport-Security` header to indicate how
// long (in seconds) browsers should remember that this site is only to
// be accessed using HTTPS. This reduces your exposure to some SSL-stripping
// man-in-the-middle (MITM) attacks.
// Optional. Default value 0.
HSTSMaxAge int `yaml:"hsts_max_age"`
HSTSMaxAge int
// HSTSExcludeSubdomains won't include subdomains tag in the `Strict Transport Security`
// header, excluding all subdomains from security policy. It has no effect
// unless HSTSMaxAge is set to a non-zero value.
// Optional. Default value false.
HSTSExcludeSubdomains bool `yaml:"hsts_exclude_subdomains"`
HSTSExcludeSubdomains bool
// ContentSecurityPolicy sets the `Content-Security-Policy` header providing
// security against cross-site scripting (XSS), clickjacking and other code
// injection attacks resulting from execution of malicious content in the
// trusted web page context.
// Optional. Default value "".
ContentSecurityPolicy string `yaml:"content_security_policy"`
ContentSecurityPolicy string
// CSPReportOnly would use the `Content-Security-Policy-Report-Only` header instead
// of the `Content-Security-Policy` header. This allows iterative updates of the
// content security policy by only reporting the violations that would
// have occurred instead of blocking the resource.
// Optional. Default value false.
CSPReportOnly bool `yaml:"csp_report_only"`
CSPReportOnly bool
// HSTSPreloadEnabled will add the preload tag in the `Strict Transport Security`
// header, which enables the domain to be included in the HSTS preload list
// maintained by Chrome (and used by Firefox and Safari): https://hstspreload.org/
// Optional. Default value false.
HSTSPreloadEnabled bool `yaml:"hsts_preload_enabled"`
HSTSPreloadEnabled bool
// ReferrerPolicy sets the `Referrer-Policy` header providing security against
// leaking potentially sensitive request paths to third parties.
// Optional. Default value "".
ReferrerPolicy string `yaml:"referrer_policy"`
ReferrerPolicy string
}
)
var (
// DefaultSecureConfig is the default Secure middleware config.
DefaultSecureConfig = SecureConfig{
var DefaultSecureConfig = SecureConfig{
Skipper: DefaultSkipper,
XSSProtection: "1; mode=block",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
HSTSPreloadEnabled: false,
}
)
// Secure returns a Secure middleware.
// Secure middleware provides protection against cross-site scripting (XSS) attack,
@ -93,9 +89,13 @@ func Secure() echo.MiddlewareFunc {
return SecureWithConfig(DefaultSecureConfig)
}
// SecureWithConfig returns a Secure middleware with config.
// See: `Secure()`.
// SecureWithConfig returns a Secure middleware with config or panics on invalid configuration.
func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts SecureConfig to middleware or returns an error for invalid configuration
func (config SecureConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultSecureConfig.Skipper
@ -141,5 +141,5 @@ func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
}
return next(c)
}
}
}, nil
}

View File

@ -19,26 +19,40 @@ func TestSecure(t *testing.T) {
}
// Default
Secure()(h)(c)
err := Secure()(h)(c)
assert.NoError(t, err)
assert.Equal(t, "1; mode=block", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "nosniff", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "SAMEORIGIN", rec.Header().Get(echo.HeaderXFrameOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderStrictTransportSecurity))
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy))
assert.Equal(t, "", rec.Header().Get(echo.HeaderReferrerPolicy))
}
// Custom
func TestSecureWithConfig(t *testing.T) {
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw, err := SecureConfig{
XSSProtection: "",
ContentTypeNosniff: "",
XFrameOptions: "",
HSTSMaxAge: 3600,
ContentSecurityPolicy: "default-src 'self'",
ReferrerPolicy: "origin",
})(h)(c)
}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions))
@ -47,11 +61,21 @@ func TestSecure(t *testing.T) {
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly))
assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy))
}
func TestSecureWithConfig_CSPReportOnly(t *testing.T) {
// Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := SecureWithConfig(SecureConfig{
XSSProtection: "",
ContentTypeNosniff: "",
XFrameOptions: "",
@ -60,6 +84,8 @@ func TestSecure(t *testing.T) {
CSPReportOnly: true,
ReferrerPolicy: "origin",
})(h)(c)
assert.NoError(t, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderXXSSProtection))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXContentTypeOptions))
assert.Equal(t, "", rec.Header().Get(echo.HeaderXFrameOptions))
@ -67,25 +93,51 @@ func TestSecure(t *testing.T) {
assert.Equal(t, "default-src 'self'", rec.Header().Get(echo.HeaderContentSecurityPolicyReportOnly))
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy))
assert.Equal(t, "origin", rec.Header().Get(echo.HeaderReferrerPolicy))
}
func TestSecureWithConfig_HSTSPreloadEnabled(t *testing.T) {
// Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
// Custom, with preload option enabled
req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := SecureWithConfig(SecureConfig{
HSTSMaxAge: 3600,
HSTSPreloadEnabled: true,
})(h)(c)
assert.NoError(t, err)
assert.Equal(t, "max-age=3600; includeSubdomains; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity))
}
func TestSecureWithConfig_HSTSExcludeSubdomains(t *testing.T) {
// Custom with CSPReportOnly flag
e := echo.New()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
// Custom, with preload option enabled and subdomains excluded
req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
SecureWithConfig(SecureConfig{
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := SecureWithConfig(SecureConfig{
HSTSMaxAge: 3600,
HSTSPreloadEnabled: true,
HSTSExcludeSubdomains: true,
})(h)(c)
assert.NoError(t, err)
assert.Equal(t, "max-age=3600; preload", rec.Header().Get(echo.HeaderStrictTransportSecurity))
}

View File

@ -1,44 +1,45 @@
package middleware
import (
"errors"
"net/http"
"strings"
"github.com/labstack/echo/v4"
)
type (
// TrailingSlashConfig defines the config for TrailingSlash middleware.
TrailingSlashConfig struct {
// AddTrailingSlashConfig is the middleware config for adding trailing slash to the request.
type AddTrailingSlashConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Status code to be used when redirecting the request.
// Optional, but when provided the request is redirected using this code.
RedirectCode int `yaml:"redirect_code"`
// Valid status codes: [300...308]
RedirectCode int
}
)
var (
// DefaultTrailingSlashConfig is the default TrailingSlash middleware config.
DefaultTrailingSlashConfig = TrailingSlashConfig{
Skipper: DefaultSkipper,
}
)
// AddTrailingSlash returns a root level (before router) middleware which adds a
// trailing slash to the request `URL#Path`.
//
// Usage `Echo#Pre(AddTrailingSlash())`
func AddTrailingSlash() echo.MiddlewareFunc {
return AddTrailingSlashWithConfig(DefaultTrailingSlashConfig)
return AddTrailingSlashWithConfig(AddTrailingSlashConfig{})
}
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config.
// See `AddTrailingSlash()`.
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
// Defaults
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config or panics on invalid configuration.
func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts AddTrailingSlashConfig to middleware or returns an error for invalid configuration
func (config AddTrailingSlashConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.RedirectCode != 0 && (config.RedirectCode < http.StatusMultipleChoices || config.RedirectCode > http.StatusPermanentRedirect) {
// this is same check as `echo.context.Redirect()` does, but we can check this before even serving the request.
return nil, errors.New("invalid redirect code for add trailing slash middleware")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -69,7 +70,17 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
}
return next(c)
}
}, nil
}
// RemoveTrailingSlashConfig is the middleware config for removing trailing slash from the request.
type RemoveTrailingSlashConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Status code to be used when redirecting the request.
// Optional, but when provided the request is redirected using this code.
RedirectCode int
}
// RemoveTrailingSlash returns a root level (before router) middleware which removes
@ -77,15 +88,22 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
//
// Usage `Echo#Pre(RemoveTrailingSlash())`
func RemoveTrailingSlash() echo.MiddlewareFunc {
return RemoveTrailingSlashWithConfig(TrailingSlashConfig{})
return RemoveTrailingSlashWithConfig(RemoveTrailingSlashConfig{})
}
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config.
// See `RemoveTrailingSlash()`.
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
// Defaults
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config or panics on invalid configuration.
func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RemoveTrailingSlashConfig to middleware or returns an error for invalid configuration
func (config RemoveTrailingSlashConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper
config.Skipper = DefaultSkipper
}
if config.RedirectCode != 0 && (config.RedirectCode < http.StatusMultipleChoices || config.RedirectCode > http.StatusPermanentRedirect) {
// this is same check as `echo.context.Redirect()` does, but we can check this before even serving the request.
return nil, errors.New("invalid redirect code for remove trailing slash middleware")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -117,7 +135,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
}
return next(c)
}
}
}, nil
}
func sanitizeURI(uri string) string {

View File

@ -67,7 +67,7 @@ func TestAddTrailingSlashWithConfig(t *testing.T) {
t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New()
mw := AddTrailingSlashWithConfig(TrailingSlashConfig{
mw := AddTrailingSlashWithConfig(AddTrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently,
})
h := mw(func(c echo.Context) error {
@ -203,7 +203,7 @@ func TestRemoveTrailingSlashWithConfig(t *testing.T) {
t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New()
mw := RemoveTrailingSlashWithConfig(TrailingSlashConfig{
mw := RemoveTrailingSlashWithConfig(RemoveTrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently,
})
h := mw(func(c echo.Context) error {

View File

@ -1,55 +1,65 @@
package middleware
import (
"errors"
"fmt"
"html/template"
"io"
"io/fs"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/bytes"
)
type (
// StaticConfig defines the config for Static middleware.
StaticConfig struct {
type StaticConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Root directory from where the static content is served.
// Root directory from where the static content is served (relative to given Filesystem).
// `Root: "."` means root folder from Filesystem.
// Required.
Root string `yaml:"root"`
Root string
// Filesystem provides access to the static content.
// Optional. Defaults to echo.Filesystem (serves files from `.` folder where executable is started)
Filesystem fs.FS
// Index file for serving a directory.
// Optional. Default value "index.html".
Index string `yaml:"index"`
Index string
// Enable HTML5 mode by forwarding all not-found requests to root so that
// SPA (single-page application) can handle the routing.
// Optional. Default value false.
HTML5 bool `yaml:"html5"`
HTML5 bool
// Enable directory browsing.
// Optional. Default value false.
Browse bool `yaml:"browse"`
Browse bool
// Enable ignoring of the base of the URL path.
// Example: when assigning a static middleware to a non root path group,
// the filesystem path is not doubled
// Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"`
IgnoreBase bool
// Filesystem provides access to the static content.
// Optional. Defaults to http.Dir(config.Root)
Filesystem http.FileSystem `yaml:"-"`
// DisablePathUnescaping disables path parameter (param: *) unescaping. This is useful when router is set to unescape
// all parameter and doing it again in this middleware would corrupt filename that is requested.
DisablePathUnescaping bool
// DirectoryListTemplate is template to list directory contents
// Optional. Default to `directoryListHTMLTemplate` constant below.
DirectoryListTemplate string
}
)
const html = `
const directoryListHTMLTemplate = `
<!DOCTYPE html>
<html lang="en">
<head>
@ -121,25 +131,26 @@ const html = `
</html>
`
var (
// DefaultStaticConfig is the default Static middleware config.
DefaultStaticConfig = StaticConfig{
var DefaultStaticConfig = StaticConfig{
Skipper: DefaultSkipper,
Index: "index.html",
}
)
// Static returns a Static middleware to serves static content from the provided
// root directory.
// Static returns a Static middleware to serves static content from the provided root directory.
func Static(root string) echo.MiddlewareFunc {
c := DefaultStaticConfig
c.Root = root
return StaticWithConfig(c)
}
// StaticWithConfig returns a Static middleware with config.
// See `Static()`.
// StaticWithConfig returns a Static middleware to serves static content or panics on invalid configuration.
func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts StaticConfig to middleware or returns an error for invalid configuration
func (config StaticConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Root == "" {
config.Root = "." // For security we want to restrict to CWD.
@ -150,30 +161,32 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
if config.Index == "" {
config.Index = DefaultStaticConfig.Index
}
if config.Filesystem == nil {
config.Filesystem = http.Dir(config.Root)
config.Root = "."
if config.DirectoryListTemplate == "" {
config.DirectoryListTemplate = directoryListHTMLTemplate
}
// Index template
t, err := template.New("index").Parse(html)
dirListTemplate, err := template.New("index").Parse(config.DirectoryListTemplate)
if err != nil {
panic(fmt.Sprintf("echo: %v", err))
return nil, fmt.Errorf("echo static middleware directory list template parsing error: %w", err)
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
p := c.Request().URL.Path
if strings.HasSuffix(c.Path(), "*") { // When serving from a group, e.g. `/static*`.
p = c.Param("*")
pathUnescape := true
if c.RouteMatchType() == echo.RouteMatchFound && strings.HasSuffix(c.Path(), "*") { // When serving from a group, e.g. `/static*`.
p = c.PathParam("*")
pathUnescape = !config.DisablePathUnescaping // because router could already do PathUnescape
}
if pathUnescape {
p, err = url.PathUnescape(p)
if err != nil {
return
return err
}
}
name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security
@ -186,22 +199,29 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
}
}
file, err := openFile(config.Filesystem, name)
currentFS := config.Filesystem
if currentFS == nil {
currentFS = c.Echo().Filesystem
}
file, err := openFile(currentFS, name)
if err != nil {
if !os.IsNotExist(err) {
return err
}
if err = next(c); err == nil {
return err
// when file does not exist let handler to handle that request. if it succeeds then we are done
err = next(c)
if err == nil {
return nil
}
he, ok := err.(*echo.HTTPError)
if !(ok && config.HTML5 && he.Code == http.StatusNotFound) {
return err
}
file, err = openFile(config.Filesystem, filepath.Join(config.Root, config.Index))
// is case HTML5 mode is enabled + echo 404 we serve index to the client
file, err = openFile(currentFS, filepath.Join(config.Root, config.Index))
if err != nil {
return err
}
@ -215,10 +235,10 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
}
if info.IsDir() {
index, err := openFile(config.Filesystem, filepath.Join(name, config.Index))
index, err := openFile(currentFS, filepath.Join(name, config.Index))
if err != nil {
if config.Browse {
return listDir(t, name, file, c.Response())
return listDir(dirListTemplate, name, currentFS, file, c.Response())
}
if os.IsNotExist(err) {
@ -238,25 +258,24 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
return serveFile(c, file, info)
}
}
}, nil
}
func openFile(fs http.FileSystem, name string) (http.File, error) {
func openFile(fs fs.FS, name string) (fs.File, error) {
pathWithSlashes := filepath.ToSlash(name)
return fs.Open(pathWithSlashes)
}
func serveFile(c echo.Context, file http.File, info os.FileInfo) error {
http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), file)
func serveFile(c echo.Context, file fs.File, info os.FileInfo) error {
ff, ok := file.(io.ReadSeeker)
if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), ff)
return nil
}
func listDir(t *template.Template, name string, dir http.File, res *echo.Response) (err error) {
files, err := dir.Readdir(-1)
if err != nil {
return
}
func listDir(t *template.Template, name string, filesystem fs.FS, dir fs.File, res *echo.Response) error {
// Create directory index
res.Header().Set(echo.HeaderContentType, echo.MIMETextHTMLCharsetUTF8)
data := struct {
@ -265,12 +284,60 @@ func listDir(t *template.Template, name string, dir http.File, res *echo.Respons
}{
Name: name,
}
for _, f := range files {
err := fs.WalkDir(filesystem, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
info, infoErr := d.Info()
if infoErr != nil {
return fmt.Errorf("static middleware list dir error when getting file info: %w", err)
}
data.Files = append(data.Files, struct {
Name string
Dir bool
Size string
}{f.Name(), f.IsDir(), bytes.Format(f.Size())})
}{d.Name(), d.IsDir(), format(info.Size())})
return nil
})
if err != nil {
return err
}
return t.Execute(res, data)
}
// format formats bytes integer to human readable string.
// For example, 31323 bytes will return 30.59KB.
func format(b int64) string {
multiple := ""
value := float64(b)
switch {
case b >= EB:
value /= float64(EB)
multiple = "EB"
case b >= PB:
value /= float64(PB)
multiple = "PB"
case b >= TB:
value /= float64(TB)
multiple = "TB"
case b >= GB:
value /= float64(GB)
multiple = "GB"
case b >= MB:
value /= float64(MB)
multiple = "MB"
case b >= KB:
value /= float64(KB)
multiple = "KB"
case b == 0:
return "0"
default:
return strconv.FormatInt(b, 10) + "B"
}
return fmt.Sprintf("%.2f%s", value, multiple)
}

View File

@ -81,14 +81,16 @@ func TestStatic_CustomFS(t *testing.T) {
config := StaticConfig{
Root: ".",
Filesystem: http.FS(tc.filesystem),
Filesystem: tc.filesystem,
}
if tc.root != "" {
config.Root = tc.root
}
middlewareFunc := StaticWithConfig(config)
middlewareFunc, err := config.ToMiddleware()
assert.NoError(t, err)
e.Use(middlewareFunc)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)

View File

@ -3,6 +3,7 @@ package middleware
import (
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
@ -10,6 +11,37 @@ import (
"github.com/stretchr/testify/assert"
)
func TestStatic_useCaseForApiAndSPAs(t *testing.T) {
e := echo.New()
// serve single page application (SPA) files from server root
e.Use(StaticWithConfig(StaticConfig{
Root: ".",
// by default Echo filesystem is fixed to `./` but this does not allow `../` (moving up in folder structure past filesystem root)
Filesystem: os.DirFS("../_fixture"),
}))
// all requests to `/api/*` will end up in echo handlers (assuming there is not `api` folder and files)
api := e.Group("/api")
users := api.Group("/users")
users.GET("/info", func(c echo.Context) error {
return c.String(http.StatusOK, "users info")
})
req := httptest.NewRequest(http.MethodGet, "/api/users/info", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "users info", rec.Body.String())
req = httptest.NewRequest(http.MethodGet, "/index.html", nil)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, rec.Body.String(), "<title>Echo</title>")
}
func TestStatic(t *testing.T) {
var testCases = []struct {
name string
@ -35,7 +67,7 @@ func TestStatic(t *testing.T) {
{
name: "ok, when html5 mode serve index for any static file that does not exist",
givenConfig: &StaticConfig{
Root: "../_fixture",
Root: "_fixture",
HTML5: true,
},
whenURL: "/random",
@ -45,7 +77,7 @@ func TestStatic(t *testing.T) {
{
name: "ok, serve index as directory index listing files directory",
givenConfig: &StaticConfig{
Root: "../_fixture/certs",
Root: "_fixture/certs",
Browse: true,
},
whenURL: "/",
@ -55,7 +87,7 @@ func TestStatic(t *testing.T) {
{
name: "ok, serve directory index with IgnoreBase and browse",
givenConfig: &StaticConfig{
Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
Root: "_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
IgnoreBase: true,
Browse: true,
},
@ -67,7 +99,7 @@ func TestStatic(t *testing.T) {
{
name: "ok, serve file with IgnoreBase",
givenConfig: &StaticConfig{
Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
Root: "_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored
IgnoreBase: true,
Browse: true,
},
@ -95,15 +127,27 @@ func TestStatic(t *testing.T) {
expectContains: "{\"message\":\"Not Found\"}\n",
},
{
name: "ok, do not serve file, when a handler took care of the request",
name: "ok, when no file then a handler will care of the request",
whenURL: "/regular-handler",
expectCode: http.StatusOK,
expectContains: "ok",
},
{
name: "ok, skip middleware and serve handler",
givenConfig: &StaticConfig{
Root: "_fixture/images/",
Skipper: func(c echo.Context) bool {
return true
},
},
whenURL: "/walle.png",
expectCode: http.StatusTeapot,
expectContains: "walle",
},
{
name: "nok, when html5 fail if the index file does not exist",
givenConfig: &StaticConfig{
Root: "../_fixture",
Root: "_fixture",
HTML5: true,
Index: "missing.html",
},
@ -114,7 +158,7 @@ func TestStatic(t *testing.T) {
name: "ok, serve from http.FileSystem",
givenConfig: &StaticConfig{
Root: "_fixture",
Filesystem: http.Dir(".."),
Filesystem: os.DirFS(".."),
},
whenURL: "/",
expectCode: http.StatusOK,
@ -125,8 +169,9 @@ func TestStatic(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Filesystem = os.DirFS("../")
config := StaticConfig{Root: "../_fixture"}
config := StaticConfig{Root: "_fixture"}
if tc.givenConfig != nil {
config = *tc.givenConfig
}
@ -136,14 +181,17 @@ func TestStatic(t *testing.T) {
subGroup := e.Group(tc.givenAttachedToGroup, middlewareFunc)
// group without http handlers (routes) does not do anything.
// Request is matched against http handlers (routes) that have group middleware attached to them
subGroup.GET("", echo.NotFoundHandler)
subGroup.GET("/*", echo.NotFoundHandler)
subGroup.GET("", func(c echo.Context) error { return echo.ErrNotFound })
subGroup.GET("/*", func(c echo.Context) error { return echo.ErrNotFound })
} else {
// middleware is on root level
e.Use(middlewareFunc)
e.GET("/regular-handler", func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
e.GET("/walle.png", func(c echo.Context) error {
return c.String(http.StatusTeapot, "walle")
})
}
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
@ -177,7 +225,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "ok",
givenPrefix: "/images",
givenRoot: "../_fixture/images",
givenRoot: "_fixture/images",
whenURL: "/group/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
@ -185,7 +233,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "No file",
givenPrefix: "/images",
givenRoot: "../_fixture/scripts",
givenRoot: "_fixture/scripts",
whenURL: "/group/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -193,7 +241,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Directory not found (no trailing slash)",
givenPrefix: "/images",
givenRoot: "../_fixture/images",
givenRoot: "_fixture/images",
whenURL: "/group/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -201,7 +249,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Directory redirect",
givenPrefix: "/",
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/group/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/group/folder/",
@ -211,7 +259,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
name: "Prefixed directory 404 (request URL without slash)",
givenGroup: "_fixture",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/_fixture/folder", // no trailing slash
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -220,7 +268,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
name: "Prefixed directory redirect (without slash redirect to slash)",
givenGroup: "_fixture",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/_fixture/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/_fixture/folder/",
@ -229,7 +277,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Directory with index.html",
givenPrefix: "/",
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/group/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
@ -237,7 +285,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/",
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/group/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
@ -245,7 +293,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets",
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/group/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
@ -253,7 +301,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "Sub-directory with index.html",
givenPrefix: "/",
givenRoot: "../_fixture",
givenRoot: "_fixture",
whenURL: "/group/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "<!doctype html>",
@ -261,7 +309,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
givenRoot: "../_fixture/",
givenRoot: "_fixture/",
whenURL: `/group/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -269,7 +317,7 @@ func TestStatic_GroupWithStatic(t *testing.T) {
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
givenRoot: "../_fixture/",
givenRoot: "_fixture/",
whenURL: `/group/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
@ -279,6 +327,8 @@ func TestStatic_GroupWithStatic(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Filesystem = os.DirFS("../") // so we can access test files
group := "/group"
if tc.givenGroup != "" {
group = tc.givenGroup
@ -288,7 +338,9 @@ func TestStatic_GroupWithStatic(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String()
if tc.expectBodyStartsWith != "" {
@ -306,3 +358,73 @@ func TestStatic_GroupWithStatic(t *testing.T) {
})
}
}
func TestMustStaticWithConfig_panicsInvalidDirListTemplate(t *testing.T) {
assert.Panics(t, func() {
StaticWithConfig(StaticConfig{DirectoryListTemplate: `{{}`})
})
}
func TestFormat(t *testing.T) {
var testCases = []struct {
name string
when int64
expect string
}{
{
name: "byte",
when: 0,
expect: "0",
},
{
name: "bytes",
when: 515,
expect: "515B",
},
{
name: "KB",
when: 31323,
expect: "30.59KB",
},
{
name: "MB",
when: 13231323,
expect: "12.62MB",
},
{
name: "GB",
when: 7323232398,
expect: "6.82GB",
},
{
name: "TB",
when: 1_099_511_627_776,
expect: "1.00TB",
},
{
name: "PB",
when: 9923232398434432,
expect: "8.81PB",
},
{
// test with 7EB because of https://github.com/labstack/gommon/pull/38 and https://github.com/labstack/gommon/pull/43
//
// 8 exbi equals 2^64, therefore it cannot be stored in int64. The tests use
// the fact that on x86_64 the following expressions holds true:
// int64(0) - 1 == math.MaxInt64.
//
// However, this is not true for other platforms, specifically aarch64, s390x
// and ppc64le.
name: "EB",
when: 8070450532247929000,
expect: "7.00EB",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := format(tc.when)
assert.Equal(t, tc.expect, result)
})
}
}

View File

@ -2,10 +2,9 @@ package middleware
import (
"context"
"github.com/labstack/echo/v4"
"net/http"
"time"
"github.com/labstack/echo/v4"
)
// ---------------------------------------------------------------------------------------------------------------
@ -55,9 +54,8 @@ import (
// })
//
type (
// TimeoutConfig defines the config for Timeout middleware.
TimeoutConfig struct {
type TimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
@ -69,7 +67,7 @@ type (
// request timeouted and we already had sent the error code (503) and message response to the client.
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
OnTimeoutRouteErrorHandler func(err error, c echo.Context)
OnTimeoutRouteErrorHandler func(c echo.Context, err error)
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
@ -77,29 +75,22 @@ type (
// difference over 500microseconds (0.5millisecond) response seems to be reliable
Timeout time.Duration
}
)
var (
// DefaultTimeoutConfig is the default Timeout middleware config.
DefaultTimeoutConfig = TimeoutConfig{
Skipper: DefaultSkipper,
Timeout: 0,
ErrorMessage: "",
}
)
// Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler
// call runs for longer than its time limit. NB: timeout does not stop handler execution.
func Timeout() echo.MiddlewareFunc {
return TimeoutWithConfig(DefaultTimeoutConfig)
return TimeoutWithConfig(TimeoutConfig{})
}
// TimeoutWithConfig returns a Timeout middleware with config.
// See: `Timeout()`.
// TimeoutWithConfig returns a Timeout middleware with config or panics on invalid configuration.
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
// Defaults
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts TimeoutConfig to middleware or returns an error for invalid configuration
func (config TimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultTimeoutConfig.Skipper
config.Skipper = DefaultSkipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -108,29 +99,30 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
return next(c)
}
errChan := make(chan error, 1)
handlerWrapper := echoHandlerFuncWrapper{
ctx: c,
handler: next,
errChan: make(chan error, 1),
errChan: errChan,
errHandler: config.OnTimeoutRouteErrorHandler,
}
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
handler.ServeHTTP(c.Response().Writer, c.Request())
select {
case err := <-handlerWrapper.errChan:
case err := <-errChan:
return err
default:
return nil
}
}
}
}, nil
}
type echoHandlerFuncWrapper struct {
ctx echo.Context
handler echo.HandlerFunc
errHandler func(err error, c echo.Context)
errHandler func(c echo.Context, err error)
errChan chan error
}
@ -156,7 +148,7 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques
err := t.handler(t.ctx)
if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded {
if err != nil && t.errHandler != nil {
t.errHandler(err, t.ctx)
t.errHandler(t.ctx, err)
}
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
}

View File

@ -4,6 +4,8 @@ import (
"bytes"
"errors"
"fmt"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"io/ioutil"
"log"
"net"
@ -14,9 +16,6 @@ import (
"strings"
"testing"
"time"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestTimeoutSkipper(t *testing.T) {
@ -111,7 +110,7 @@ func TestTimeoutOnTimeoutRouteErrorHandler(t *testing.T) {
actualErrChan := make(chan error, 1)
m := TimeoutWithConfig(TimeoutConfig{
Timeout: 1 * time.Millisecond,
OnTimeoutRouteErrorHandler: func(err error, c echo.Context) {
OnTimeoutRouteErrorHandler: func(c echo.Context, err error) {
actualErrChan <- err
},
})
@ -360,7 +359,7 @@ func TestTimeoutWithFullEchoStack(t *testing.T) {
e := echo.New()
buf := new(bytes.Buffer)
e.Logger.SetOutput(buf)
e.Logger = &testLogger{output: buf}
// NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first
// FIXME: I have no idea how to fix this without adding mutexes.
@ -424,7 +423,7 @@ func startServer(e *echo.Echo) (*http.Server, string, error) {
s := http.Server{
Handler: e,
ErrorLog: log.New(e.Logger.Output(), "echo:", 0),
ErrorLog: log.New(e.Logger, "echo:", 0),
}
errCh := make(chan error)

View File

@ -1,9 +1,27 @@
package middleware
import (
"crypto/rand"
"fmt"
"strings"
)
const (
_ = int64(1 << (10 * iota)) // ignore first value by assigning to blank identifier
// KB is 1 KiloByte = 1024 bytes
KB
// MB is 1 Megabyte = 1_048_576 bytes
MB
// GB is 1 Gigabyte = 1_073_741_824 bytes
GB
// TB is 1 Terabyte = 1_099_511_627_776 bytes
TB
// PB is 1 Petabyte = 1_125_899_906_842_624 bytes
PB
// EB is 1 Exabyte = 1_152_921_504_606_847_000 bytes
EB
)
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
@ -52,3 +70,24 @@ func matchSubdomain(domain, pattern string) bool {
}
return false
}
func createRandomStringGenerator(length uint8) func() string {
return func() string {
return randomString(length)
}
}
func randomString(length uint8) string {
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
// we are out of random. let the request fail
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
}
for i, b := range bytes {
bytes[i] = charset[b%byte(len(charset))]
}
return string(bytes)
}

View File

@ -1,11 +1,23 @@
package middleware
import (
"testing"
"github.com/stretchr/testify/assert"
"io"
"testing"
)
type testLogger struct {
output io.Writer
}
func (l *testLogger) Write(p []byte) (n int, err error) {
return l.output.Write(p)
}
func (l *testLogger) Error(err error) {
_, _ = l.output.Write([]byte(err.Error()))
}
func Test_matchScheme(t *testing.T) {
tests := []struct {
domain, pattern string
@ -93,3 +105,27 @@ func Test_matchSubdomain(t *testing.T) {
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
}
}
func TestRandomString(t *testing.T) {
var testCases = []struct {
name string
whenLength uint8
expect string
}{
{
name: "ok, 16",
whenLength: 16,
},
{
name: "ok, 32",
whenLength: 32,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
uid := randomString(tc.whenLength)
assert.Len(t, uid, int(tc.whenLength))
})
}
}

View File

@ -2,15 +2,15 @@ package echo
import (
"bufio"
"errors"
"net"
"net/http"
)
type (
// Response wraps an http.ResponseWriter and implements its interface to be used
// by an HTTP handler to construct an HTTP response.
// See: https://golang.org/pkg/net/http/#ResponseWriter
Response struct {
type Response struct {
echo *Echo
beforeFuncs []func()
afterFuncs []func()
@ -19,7 +19,6 @@ type (
Size int64
Committed bool
}
)
// NewResponse creates a new instance of Response.
func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) {
@ -47,13 +46,15 @@ func (r *Response) After(fn func()) {
r.afterFuncs = append(r.afterFuncs, fn)
}
var errHeaderAlreadyCommitted = errors.New("response already committed")
// WriteHeader sends an HTTP response header with status code. If WriteHeader is
// not called explicitly, the first call to Write will trigger an implicit
// WriteHeader(http.StatusOK). Thus explicit calls to WriteHeader are mainly
// used to send error codes.
func (r *Response) WriteHeader(code int) {
if r.Committed {
r.echo.Logger.Warn("response already committed")
r.echo.Logger.Error(errHeaderAlreadyCommitted)
return
}
r.Status = code

182
route.go Normal file
View File

@ -0,0 +1,182 @@
package echo
import (
"bytes"
"errors"
"fmt"
"reflect"
"runtime"
)
// Route contains information to adding/registering new route with the router.
// Method+Path pair uniquely identifies the Route. It is mandatory to provide Method+Path+Handler fields.
type Route struct {
Method string
Path string
Handler HandlerFunc
Middlewares []MiddlewareFunc
Name string
}
// ToRouteInfo converts Route to RouteInfo
func (r Route) ToRouteInfo(params []string) RouteInfo {
name := r.Name
if name == "" {
name = r.Method + ":" + r.Path
}
return routeInfo{
method: r.Method,
path: r.Path,
params: append([]string(nil), params...),
name: name,
}
}
// ToRoute returns Route which Router uses to register the method handler for path.
func (r Route) ToRoute() Route {
return r
}
// ForGroup recreates Route with added group prefix and group middlewares it is grouped to.
func (r Route) ForGroup(pathPrefix string, middlewares []MiddlewareFunc) Routable {
r.Path = pathPrefix + r.Path
if len(middlewares) > 0 {
m := make([]MiddlewareFunc, 0, len(middlewares)+len(r.Middlewares))
m = append(m, middlewares...)
m = append(m, r.Middlewares...)
r.Middlewares = m
}
return r
}
type routeInfo struct {
method string
path string
params []string
name string
}
func (r routeInfo) Method() string {
return r.method
}
func (r routeInfo) Path() string {
return r.path
}
func (r routeInfo) Params() []string {
return append([]string(nil), r.params...)
}
func (r routeInfo) Name() string {
return r.name
}
// Reverse reverses route to URL string by replacing path parameters with given params values.
func (r routeInfo) Reverse(params ...interface{}) string {
uri := new(bytes.Buffer)
ln := len(params)
n := 0
for i, l := 0, len(r.path); i < l; i++ {
if (r.path[i] == paramLabel || r.path[i] == anyLabel) && n < ln {
for ; i < l && r.path[i] != '/'; i++ {
}
uri.WriteString(fmt.Sprintf("%v", params[n]))
n++
}
if i < l {
uri.WriteByte(r.path[i])
}
}
return uri.String()
}
// HandlerName returns string name for given function.
func HandlerName(h HandlerFunc) string {
t := reflect.ValueOf(h).Type()
if t.Kind() == reflect.Func {
return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
}
return t.String()
}
// Reverse reverses route to URL string by replacing path parameters with given params values.
func (r Routes) Reverse(name string, params ...interface{}) (string, error) {
for _, rr := range r {
if rr.Name() == name {
return rr.Reverse(params...), nil
}
}
return "", errors.New("route not found")
}
// FindByMethodPath searched for matching route info by method and path
func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error) {
if r == nil {
return nil, errors.New("route not found by method and path")
}
for _, rr := range r {
if rr.Method() == method && rr.Path() == path {
return rr, nil
}
}
return nil, errors.New("route not found by method and path")
}
// FilterByMethod searched for matching route info by method
func (r Routes) FilterByMethod(method string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by method")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Method() == method {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by method")
}
return result, nil
}
// FilterByPath searched for matching route info by path
func (r Routes) FilterByPath(path string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by path")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Path() == path {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by path")
}
return result, nil
}
// FilterByName searched for matching route info by name
func (r Routes) FilterByName(name string) (Routes, error) {
if r == nil {
return nil, errors.New("route not found by name")
}
result := make(Routes, 0)
for _, rr := range r {
if rr.Name() == name {
result = append(result, rr)
}
}
if len(result) == 0 {
return nil, errors.New("route not found by name")
}
return result, nil
}

423
route_test.go Normal file
View File

@ -0,0 +1,423 @@
package echo
import (
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)
var myNamedHandler = func(c Context) error {
return nil
}
type NameStruct struct {
}
func (n *NameStruct) getUsers(c Context) error {
return nil
}
func TestHandlerName(t *testing.T) {
myNameFuncVar := func(c Context) error {
return nil
}
tmp := NameStruct{}
var testCases = []struct {
name string
whenHandlerFunc HandlerFunc
expect string
}{
{
name: "ok, func as anonymous func",
whenHandlerFunc: func(c Context) error {
return nil
},
expect: "github.com/labstack/echo/v4.TestHandlerName.func2",
},
{
name: "ok, func as named package variable",
whenHandlerFunc: myNamedHandler,
expect: "github.com/labstack/echo/v4.glob..func3",
},
{
name: "ok, func as named function variable",
whenHandlerFunc: myNameFuncVar,
expect: "github.com/labstack/echo/v4.TestHandlerName.func1",
},
{
name: "ok, func as struct method",
whenHandlerFunc: tmp.getUsers,
expect: "github.com/labstack/echo/v4.(*NameStruct).getUsers-fm",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
name := HandlerName(tc.whenHandlerFunc)
assert.Equal(t, tc.expect, name)
})
}
}
func TestHandlerName_differentFuncSameName(t *testing.T) {
handlerCreator := func(name string) HandlerFunc {
return func(c Context) error {
return c.String(http.StatusTeapot, name)
}
}
h1 := handlerCreator("name1")
assert.Equal(t, "github.com/labstack/echo/v4.TestHandlerName_differentFuncSameName.func2", HandlerName(h1))
h2 := handlerCreator("name2")
assert.Equal(t, "github.com/labstack/echo/v4.TestHandlerName_differentFuncSameName.func3", HandlerName(h2))
}
func TestRoute_ToRouteInfo(t *testing.T) {
var testCases = []struct {
name string
given Route
whenParams []string
expect RouteInfo
}{
{
name: "ok, no params, with name",
given: Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
},
expect: routeInfo{
method: http.MethodGet,
path: "/test",
params: nil,
name: "test route",
},
},
{
name: "ok, params",
given: Route{
Method: http.MethodGet,
Path: "users/:id/:file", // no slash prefix
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "",
},
whenParams: []string{"id", "file"},
expect: routeInfo{
method: http.MethodGet,
path: "users/:id/:file",
params: []string{"id", "file"},
name: "GET:users/:id/:file",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ri := tc.given.ToRouteInfo(tc.whenParams)
assert.Equal(t, tc.expect, ri)
})
}
}
func TestRoute_ToRoute(t *testing.T) {
route := Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
}
r := route.ToRoute()
assert.Equal(t, r.Method, http.MethodGet)
assert.Equal(t, r.Path, "/test")
assert.NotNil(t, r.Handler)
assert.Nil(t, r.Middlewares)
assert.Equal(t, r.Name, "test route")
}
func TestRoute_ForGroup(t *testing.T) {
route := Route{
Method: http.MethodGet,
Path: "/test",
Handler: func(c Context) error {
return c.String(http.StatusTeapot, "OK")
},
Middlewares: nil,
Name: "test route",
}
mw := func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
return next(c)
}
}
gr := route.ForGroup("/users", []MiddlewareFunc{mw})
r := gr.ToRoute()
assert.Equal(t, r.Method, http.MethodGet)
assert.Equal(t, r.Path, "/users/test")
assert.NotNil(t, r.Handler)
assert.Len(t, r.Middlewares, 1)
assert.Equal(t, r.Name, "test route")
}
func exampleRoutes() Routes {
return Routes{
routeInfo{
method: http.MethodGet,
path: "/users",
params: nil,
name: "GET:/users",
},
routeInfo{
method: http.MethodGet,
path: "/users/:id",
params: []string{"id"},
name: "GET:/users/:id",
},
routeInfo{
method: http.MethodPost,
path: "/users/:id",
params: []string{"id"},
name: "POST:/users/:id",
},
routeInfo{
method: http.MethodDelete,
path: "/groups",
params: nil,
name: "non_unique_name",
},
routeInfo{
method: http.MethodPost,
path: "/groups",
params: nil,
name: "non_unique_name",
},
}
}
func TestRoutes_FindByMethodPath(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenMethod string
whenPath string
expectName string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenMethod: http.MethodGet,
whenPath: "/users/:id",
expectName: "GET:/users/:id",
},
{
name: "nok, not found",
given: exampleRoutes(),
whenMethod: http.MethodPut,
whenPath: "/users/:id",
expectName: "",
expectError: "route not found by method and path",
},
{
name: "nok, not found from nil",
given: nil,
whenMethod: http.MethodGet,
whenPath: "/users/:id",
expectName: "",
expectError: "route not found by method and path",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ri, err := tc.given.FindByMethodPath(tc.whenMethod, tc.whenPath)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
assert.Nil(t, ri)
} else {
assert.NoError(t, err)
}
if tc.expectName != "" {
assert.Equal(t, tc.expectName, ri.Name())
}
})
}
}
func TestRoutes_FilterByMethod(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenMethod string
expectNames []string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenMethod: http.MethodGet,
expectNames: []string{"GET:/users", "GET:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenMethod: http.MethodPut,
expectNames: nil,
expectError: "route not found by method",
},
{
name: "nok, not found from nil",
given: nil,
whenMethod: http.MethodGet,
expectNames: nil,
expectError: "route not found by method",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByMethod(tc.whenMethod)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectNames) > 0 {
assert.Len(t, ris, len(tc.expectNames))
for _, ri := range ris {
assert.Contains(t, tc.expectNames, ri.Name())
}
} else {
assert.Nil(t, ris)
}
})
}
}
func TestRoutes_FilterByPath(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenPath string
expectNames []string
expectError string
}{
{
name: "ok, found",
given: exampleRoutes(),
whenPath: "/users/:id",
expectNames: []string{"GET:/users/:id", "POST:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenPath: "/",
expectNames: nil,
expectError: "route not found by path",
},
{
name: "nok, not found from nil",
given: nil,
whenPath: "/users/:id",
expectNames: nil,
expectError: "route not found by path",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByPath(tc.whenPath)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectNames) > 0 {
assert.Len(t, ris, len(tc.expectNames))
for _, ri := range ris {
assert.Contains(t, tc.expectNames, ri.Name())
}
} else {
assert.Nil(t, ris)
}
})
}
}
func TestRoutes_FilterByName(t *testing.T) {
var testCases = []struct {
name string
given Routes
whenName string
expectMethodPath []string
expectError string
}{
{
name: "ok, found multiple",
given: exampleRoutes(),
whenName: "non_unique_name",
expectMethodPath: []string{"DELETE:/groups", "POST:/groups"},
},
{
name: "ok, found single",
given: exampleRoutes(),
whenName: "GET:/users/:id",
expectMethodPath: []string{"GET:/users/:id"},
},
{
name: "nok, not found",
given: exampleRoutes(),
whenName: "/",
expectMethodPath: nil,
expectError: "route not found by name",
},
{
name: "nok, not found from nil",
given: nil,
whenName: "/users/:id",
expectMethodPath: nil,
expectError: "route not found by name",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ris, err := tc.given.FilterByName(tc.whenName)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
if len(tc.expectMethodPath) > 0 {
assert.Len(t, ris, len(tc.expectMethodPath))
for _, ri := range ris {
assert.Contains(t, tc.expectMethodPath, fmt.Sprintf("%v:%v", ri.Method(), ri.Path()))
}
} else {
assert.Nil(t, ris)
}
})
}
}

743
router.go
View File

@ -1,49 +1,133 @@
package echo
import (
"errors"
"net/http"
"net/url"
)
type (
// Router is the registry of all registered routes for an `Echo` instance for
// request matching and URL path parameter parsing.
Router struct {
tree *node
routes map[string]*Route
echo *Echo
// Router is interface for routing requests to registered routes.
type Router interface {
// Add registers Routable with the Router and returns registered RouteInfo
Add(routable Routable) (RouteInfo, error)
// Remove removes route from the Router
Remove(method string, path string) error
// Routes returns information about all registered routes
Routes() Routes
// Match searches Router for matching route and applies it to result fields.
Match(req *http.Request, params *PathParams) RouteMatch
}
node struct {
// Routable is interface for registering Route with Router. During route registration process the Router will
// convert Routable to RouteInfo with ToRouteInfo method. By creating custom implementation of Routable additional
// information about registered route can be stored in Routes (i.e. privileges used with route etc.)
type Routable interface {
// ToRouteInfo converts Routable to RouteInfo
//
// This method is meant to be used by Router after it parses url for path parameters, to store information about
// route just added.
ToRouteInfo(params []string) RouteInfo
// ToRoute converts Routable to Route which Router uses to register the method handler for path.
//
// This method is meant to be used by Router to get fields (including handler and middleware functions) needed to
// add Route to Router.
ToRoute() Route
// ForGroup recreates routable with added group prefix and group middlewares it is grouped to.
//
// Is necessary for Echo.Group to be able to add/register Routable with Router and having group prefix and group
// middlewares included in actually registered Route.
ForGroup(pathPrefix string, middlewares []MiddlewareFunc) Routable
}
// Routes is collection of RouteInfo instances with various helper methods.
type Routes []RouteInfo
// RouteInfo describes registered route base fields.
// Method+Path pair uniquely identifies the Route. Name can have duplicates.
type RouteInfo interface {
Method() string
Path() string
Name() string
Params() []string
Reverse(params ...interface{}) string
// NOTE: handler and middlewares are not exposed because handler could be already wrapping middlewares and therefore
// it is not always 100% known if handler function already wraps middlewares or not. In Echo handler could be one
// function or several functions wrapping each other.
}
// RouteMatchType describes possible states that request could be in perspective of routing
type RouteMatchType uint8
const (
// RouteMatchUnknown is state before routing is done. Default state for fresh context.
RouteMatchUnknown RouteMatchType = iota
// RouteMatchNotFound is state when router did not find matching route for current request
RouteMatchNotFound
// RouteMatchMethodNotAllowed is state when router did not find route with matching path + method for current request.
// Although router had matching route with that path but different method.
RouteMatchMethodNotAllowed
// RouteMatchFound is state when router found exact match for path + method combination
RouteMatchFound
)
// RouteMatch is result object for Router.Match. Its main purpose is to avoid allocating memory for PathParams inside router.
type RouteMatch struct {
// Type contains result as enumeration of Router.Match and helps to understand did Router actually matched Route or
// what kind of error case (404/405) we have at the end of the handler chain.
Type RouteMatchType
// RoutePath contains original path with what matched route was registered with (including placeholders etc).
RoutePath string
// Handler is function(chain) that was matched by router. In case of no match could result to ErrNotFound or ErrMethodNotAllowed.
Handler HandlerFunc
// RouteInfo is information about route we just matched
RouteInfo RouteInfo
}
// PathParams is collections of PathParam instances with various helper methods
type PathParams []PathParam
// PathParam is tuple pf path parameter name and its value in request path
type PathParam struct {
Name string
Value string
}
// DefaultRouter is the registry of all registered routes for an `Echo` instance for
// request matching and URL path parameter parsing.
// Note: DefaultRouter is not coroutine-safe. Do not Add/Remove routes after HTTP server has been started with Echo.
type DefaultRouter struct {
tree *node
routes Routes
echo *Echo
allowOverwritingRoute bool
unescapePathParamValues bool
useEscapedPathForRouting bool
}
type children []*node
type node struct {
kind kind
label byte
prefix string
parent *node
staticChildren children
ppath string
pnames []string
methodHandler *methodHandler
originalPath string
methods *routeMethods
paramChild *node
anyChild *node
paramsCount int
// isLeaf indicates that node does not have child routes
isLeaf bool
// isHandler indicates that node has at least one handler registered to it
isHandler bool
}
kind uint8
children []*node
methodHandler struct {
connect HandlerFunc
delete HandlerFunc
get HandlerFunc
head HandlerFunc
options HandlerFunc
patch HandlerFunc
post HandlerFunc
propfind HandlerFunc
put HandlerFunc
trace HandlerFunc
report HandlerFunc
}
)
type kind uint8
const (
staticKind kind = iota
@ -54,90 +138,362 @@ const (
anyLabel = byte('*')
)
func (m *methodHandler) isHandler() bool {
return m.connect != nil ||
m.delete != nil ||
m.get != nil ||
m.head != nil ||
m.options != nil ||
m.patch != nil ||
type routeMethod struct {
*routeInfo
handler HandlerFunc
orgRouteInfo RouteInfo
}
type routeMethods struct {
connect *routeMethod
delete *routeMethod
get *routeMethod
head *routeMethod
options *routeMethod
patch *routeMethod
post *routeMethod
propfind *routeMethod
put *routeMethod
trace *routeMethod
report *routeMethod
anyOther map[string]*routeMethod
}
func (m *routeMethods) set(method string, r *routeMethod) {
switch method {
case http.MethodConnect:
m.connect = r
case http.MethodDelete:
m.delete = r
case http.MethodGet:
m.get = r
case http.MethodHead:
m.head = r
case http.MethodOptions:
m.options = r
case http.MethodPatch:
m.patch = r
case http.MethodPost:
m.post = r
case PROPFIND:
m.propfind = r
case http.MethodPut:
m.put = r
case http.MethodTrace:
m.trace = r
case REPORT:
m.report = r
default:
if m.anyOther == nil {
m.anyOther = make(map[string]*routeMethod)
}
if r.handler == nil {
delete(m.anyOther, method)
} else {
m.anyOther[method] = r
}
}
}
func (m *routeMethods) find(method string) *routeMethod {
switch method {
case http.MethodConnect:
return m.connect
case http.MethodDelete:
return m.delete
case http.MethodGet:
return m.get
case http.MethodHead:
return m.head
case http.MethodOptions:
return m.options
case http.MethodPatch:
return m.patch
case http.MethodPost:
return m.post
case PROPFIND:
return m.propfind
case http.MethodPut:
return m.put
case http.MethodTrace:
return m.trace
case REPORT:
return m.report
default:
return m.anyOther[method]
}
}
func (m *routeMethods) isHandler() bool {
return m.get != nil ||
m.post != nil ||
m.propfind != nil ||
m.options != nil ||
m.put != nil ||
m.delete != nil ||
m.connect != nil ||
m.head != nil ||
m.patch != nil ||
m.propfind != nil ||
m.trace != nil ||
m.report != nil
m.report != nil ||
len(m.anyOther) != 0
}
// RouterConfig is configuration options for (default) router
type RouterConfig struct {
// AllowOverwritingRoute instructs Router NOT to return error when new route is registered with the same method+path
// and replaces matching route with the new one.
AllowOverwritingRoute bool
// UnescapePathParamValues instructs Router to unescape path parameter value when request if matched to the routes
UnescapePathParamValues bool
// UseEscapedPathForMatching instructs Router to use escaped request URL path (req.URL.Path) for matching the request.
UseEscapedPathForMatching bool
}
// NewRouter returns a new Router instance.
func NewRouter(e *Echo) *Router {
return &Router{
func NewRouter(e *Echo, config RouterConfig) *DefaultRouter {
r := &DefaultRouter{
tree: &node{
methodHandler: new(methodHandler),
methods: new(routeMethods),
isLeaf: true,
isHandler: false,
},
routes: map[string]*Route{},
routes: make(Routes, 0),
echo: e,
allowOverwritingRoute: config.AllowOverwritingRoute,
unescapePathParamValues: config.UnescapePathParamValues,
useEscapedPathForRouting: config.UseEscapedPathForMatching,
}
return r
}
// Routes returns all registered routes
func (r *DefaultRouter) Routes() Routes {
return r.routes
}
// Remove unregisters registered route
func (r *DefaultRouter) Remove(method string, path string) error {
currentNode := r.tree
if currentNode == nil || (currentNode.isLeaf && !currentNode.isHandler) {
return errors.New("router has no routes to remove")
}
// Add registers a new route for method and path with matching handler.
func (r *Router) Add(method, path string, h HandlerFunc) {
// Validate path
if path == "" {
path = "/"
}
if path[0] != '/' {
path = "/" + path
}
pnames := []string{} // Param names
ppath := path // Pristine path
if h == nil && r.echo.Logger != nil {
// FIXME: in future we should return error
r.echo.Logger.Errorf("Adding route without handler function: %v:%v", method, path)
var nodeToRemove *node
prefixLen := 0
for {
if currentNode.originalPath == path && currentNode.isHandler {
nodeToRemove = currentNode
break
}
if currentNode.kind == staticKind {
prefixLen = prefixLen + len(currentNode.prefix)
} else {
prefixLen = len(currentNode.originalPath)
}
if prefixLen >= len(path) {
break
}
next := path[prefixLen]
switch next {
case paramLabel:
currentNode = currentNode.paramChild
case anyLabel:
currentNode = currentNode.anyChild
default:
currentNode = currentNode.findStaticChild(next)
}
if currentNode == nil {
break
}
}
if nodeToRemove == nil {
return errors.New("could not find route to remove by given path")
}
if !nodeToRemove.isHandler {
return errors.New("could not find route to remove by given path")
}
if mh := nodeToRemove.methods.find(method); mh == nil {
return errors.New("could not find route to remove by given path and method")
}
nodeToRemove.setHandler(method, nil)
var rIndex int
for i, rr := range r.routes {
if rr.Method() == method && rr.Path() == path {
rIndex = i
break
}
}
r.routes = append(r.routes[:rIndex], r.routes[rIndex+1:]...)
if !nodeToRemove.isHandler && nodeToRemove.isLeaf {
// TODO: if !nodeToRemove.isLeaf and has at least 2 children merge paths for remaining nodes?
current := nodeToRemove
for {
parent := current.parent
if parent == nil {
break
}
switch current.kind {
case staticKind:
var index int
for i, c := range parent.staticChildren {
if c == current {
index = i
break
}
}
parent.staticChildren = append(parent.staticChildren[:index], parent.staticChildren[index+1:]...)
case paramKind:
parent.paramChild = nil
case anyKind:
parent.anyChild = nil
}
parent.isLeaf = parent.anyChild == nil && parent.paramChild == nil && len(parent.staticChildren) == 0
if !parent.isLeaf || parent.isHandler {
break
}
current = parent
}
}
return nil
}
// AddRouteError is error returned by Router.Add containing information what actual route adding failed. Useful for
// mass adding (i.e. Any() routes)
type AddRouteError struct {
Method string
Path string
Err error
}
func (e *AddRouteError) Error() string { return e.Method + " " + e.Path + ": " + e.Err.Error() }
func (e *AddRouteError) Unwrap() error { return e.Err }
func newAddRouteError(route Route, err error) *AddRouteError {
return &AddRouteError{
Method: route.Method,
Path: route.Path,
Err: err,
}
}
// Add registers a new route for method and path with matching handler.
func (r *DefaultRouter) Add(routable Routable) (RouteInfo, error) {
route := routable.ToRoute()
if route.Handler == nil {
return nil, newAddRouteError(route, errors.New("adding route without handler function"))
}
method := route.Method
path := route.Path
h := applyMiddleware(route.Handler, route.Middlewares...)
if !r.allowOverwritingRoute {
for _, rr := range r.routes {
if route.Method == rr.Method() && route.Path == rr.Path() {
return nil, newAddRouteError(route, errors.New("adding duplicate route (same method+path) is not allowed"))
}
}
}
if path == "" {
path = "/"
}
if path[0] != '/' {
path = "/" + path
}
paramNames := make([]string, 0)
originalPath := path
wasAdded := false
var ri RouteInfo
for i, lcpIndex := 0, len(path); i < lcpIndex; i++ {
if path[i] == ':' {
if path[i] == paramLabel {
if i > 0 && path[i-1] == '\\' {
continue
}
j := i + 1
r.insert(method, path[:i], nil, staticKind, "", nil)
r.insert(staticKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
for ; i < lcpIndex && path[i] != '/'; i++ {
}
pnames = append(pnames, path[j:i])
paramNames = append(paramNames, path[j:i])
path = path[:j] + path[i:]
i, lcpIndex = j, len(path)
if i == lcpIndex {
// path node is last fragment of route path. ie. `/users/:id`
r.insert(method, path[:i], h, paramKind, ppath, pnames)
ri = routable.ToRouteInfo(paramNames)
rm := routeMethod{
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
handler: h,
orgRouteInfo: ri,
}
r.insert(paramKind, path[:i], method, rm)
wasAdded = true
break
} else {
r.insert(method, path[:i], nil, paramKind, "", nil)
r.insert(paramKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
}
} else if path[i] == '*' {
r.insert(method, path[:i], nil, staticKind, "", nil)
pnames = append(pnames, "*")
r.insert(method, path[:i+1], h, anyKind, ppath, pnames)
} else if path[i] == anyLabel {
r.insert(staticKind, path[:i], method, routeMethod{routeInfo: &routeInfo{method: method}})
paramNames = append(paramNames, "*")
ri = routable.ToRouteInfo(paramNames)
rm := routeMethod{
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
handler: h,
orgRouteInfo: ri,
}
r.insert(anyKind, path[:i+1], method, rm)
wasAdded = true
break
}
}
r.insert(method, path, h, staticKind, ppath, pnames)
if !wasAdded {
ri = routable.ToRouteInfo(paramNames)
rm := routeMethod{
routeInfo: &routeInfo{method: method, path: originalPath, params: paramNames, name: route.Name},
handler: h,
orgRouteInfo: ri,
}
r.insert(staticKind, path, method, rm)
}
func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string) {
// Adjust max param
paramLen := len(pnames)
if *r.echo.maxParam < paramLen {
*r.echo.maxParam = paramLen
r.storeRouteInfo(ri)
return ri, nil
}
func (r *DefaultRouter) storeRouteInfo(ri RouteInfo) {
for i, rr := range r.routes {
if ri.Method() == rr.Method() && ri.Path() == rr.Path() {
r.routes[i] = ri
return
}
}
r.routes = append(r.routes, ri)
}
func (r *DefaultRouter) insert(t kind, path string, method string, ri routeMethod) {
currentNode := r.tree // Current node as root
if currentNode == nil {
panic("echo: invalid method")
}
search := path
for {
@ -157,11 +513,11 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
// At root node
currentNode.label = search[0]
currentNode.prefix = search
if h != nil {
if ri.handler != nil {
currentNode.kind = t
currentNode.addHandler(method, h)
currentNode.ppath = ppath
currentNode.pnames = pnames
currentNode.setHandler(method, &ri)
currentNode.paramsCount = len(ri.params)
currentNode.originalPath = ri.path
}
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
} else if lcpLen < prefixLen {
@ -171,9 +527,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.prefix[lcpLen:],
currentNode,
currentNode.staticChildren,
currentNode.methodHandler,
currentNode.ppath,
currentNode.pnames,
currentNode.methods,
currentNode.paramsCount,
currentNode.originalPath,
currentNode.paramChild,
currentNode.anyChild,
)
@ -193,9 +549,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.label = currentNode.prefix[0]
currentNode.prefix = currentNode.prefix[:lcpLen]
currentNode.staticChildren = nil
currentNode.methodHandler = new(methodHandler)
currentNode.ppath = ""
currentNode.pnames = nil
currentNode.methods = new(routeMethods)
currentNode.originalPath = ""
currentNode.paramsCount = 0
currentNode.paramChild = nil
currentNode.anyChild = nil
currentNode.isLeaf = false
@ -207,13 +563,18 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
if lcpLen == searchLen {
// At parent node
currentNode.kind = t
currentNode.addHandler(method, h)
currentNode.ppath = ppath
currentNode.pnames = pnames
if ri.handler != nil {
currentNode.setHandler(method, &ri)
currentNode.paramsCount = len(ri.params)
currentNode.originalPath = ri.path
}
} else {
// Create child node
n = newNode(t, search[lcpLen:], currentNode, nil, new(methodHandler), ppath, pnames, nil, nil)
n.addHandler(method, h)
n = newNode(t, search[lcpLen:], currentNode, nil, new(routeMethods), 0, ri.path, nil, nil)
if ri.handler != nil {
n.setHandler(method, &ri)
n.paramsCount = len(ri.params)
}
// Only Static children could reach here
currentNode.addStaticChild(n)
}
@ -227,8 +588,11 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
continue
}
// Create child node
n := newNode(t, search, currentNode, nil, new(methodHandler), ppath, pnames, nil, nil)
n.addHandler(method, h)
n := newNode(t, search, currentNode, nil, new(routeMethods), 0, ri.path, nil, nil)
if ri.handler != nil {
n.setHandler(method, &ri)
n.paramsCount = len(ri.params)
}
switch t {
case staticKind:
currentNode.addStaticChild(n)
@ -240,28 +604,26 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
} else {
// Node already exists
if h != nil {
currentNode.addHandler(method, h)
currentNode.ppath = ppath
if len(currentNode.pnames) == 0 { // Issue #729
currentNode.pnames = pnames
}
if ri.handler != nil {
currentNode.setHandler(method, &ri)
currentNode.paramsCount = len(ri.params)
currentNode.originalPath = ri.path
}
}
return
}
}
func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath string, pnames []string, paramChildren, anyChildren *node) *node {
func newNode(t kind, pre string, p *node, sc children, mh *routeMethods, paramsCount int, ppath string, paramChildren, anyChildren *node) *node {
return &node{
kind: t,
label: pre[0],
prefix: pre,
parent: p,
staticChildren: sc,
ppath: ppath,
pnames: pnames,
methodHandler: mh,
originalPath: ppath,
paramsCount: paramsCount,
methods: mh,
paramChild: paramChildren,
anyChild: anyChildren,
isLeaf: sc == nil && paramChildren == nil && anyChildren == nil,
@ -297,99 +659,77 @@ func (n *node) findChildWithLabel(l byte) *node {
return nil
}
func (n *node) addHandler(method string, h HandlerFunc) {
switch method {
case http.MethodConnect:
n.methodHandler.connect = h
case http.MethodDelete:
n.methodHandler.delete = h
case http.MethodGet:
n.methodHandler.get = h
case http.MethodHead:
n.methodHandler.head = h
case http.MethodOptions:
n.methodHandler.options = h
case http.MethodPatch:
n.methodHandler.patch = h
case http.MethodPost:
n.methodHandler.post = h
case PROPFIND:
n.methodHandler.propfind = h
case http.MethodPut:
n.methodHandler.put = h
case http.MethodTrace:
n.methodHandler.trace = h
case REPORT:
n.methodHandler.report = h
}
if h != nil {
func (n *node) setHandler(method string, r *routeMethod) {
n.methods.set(method, r)
if r != nil && r.handler != nil {
n.isHandler = true
} else {
n.isHandler = n.methodHandler.isHandler()
n.isHandler = n.methods.isHandler()
}
}
func (n *node) findHandler(method string) HandlerFunc {
switch method {
case http.MethodConnect:
return n.methodHandler.connect
case http.MethodDelete:
return n.methodHandler.delete
case http.MethodGet:
return n.methodHandler.get
case http.MethodHead:
return n.methodHandler.head
case http.MethodOptions:
return n.methodHandler.options
case http.MethodPatch:
return n.methodHandler.patch
case http.MethodPost:
return n.methodHandler.post
case PROPFIND:
return n.methodHandler.propfind
case http.MethodPut:
return n.methodHandler.put
case http.MethodTrace:
return n.methodHandler.trace
case REPORT:
return n.methodHandler.report
default:
return nil
}
const (
// NotFoundRouteName is name of RouteInfo returned when router did not find matching route (404: not found).
NotFoundRouteName = "EchoRouteNotFound"
// MethodNotAllowedRouteName is name of RouteInfo returned when router did not find matching method for route (404: method not allowed).
MethodNotAllowedRouteName = "EchoRouteMethodNotAllowed"
)
// Note: notFoundRouteInfo exists to avoid allocations when setting 404 RouteInfo to RouteMatch
var notFoundRouteInfo = &routeInfo{
method: "",
path: "",
params: nil,
name: NotFoundRouteName,
}
func (n *node) checkMethodNotAllowed() HandlerFunc {
for _, m := range methods {
if h := n.findHandler(m); h != nil {
return MethodNotAllowedHandler
}
}
return NotFoundHandler
// Note: methodNotAllowedRouteInfo exists to avoid allocations when setting 405 RouteInfo to RouteMatch
var methodNotAllowedRouteInfo = &routeInfo{
method: "",
path: "",
params: nil,
name: MethodNotAllowedRouteName,
}
// Find lookup a handler registered for method and path. It also parses URL for path
// parameters and load them into context.
// notFoundHandler is handler for 404 cases
// Handle returned ErrNotFound errors in Echo.HTTPErrorHandler
var notFoundHandler = func(c Context) error {
return ErrNotFound
}
// methodNotAllowedHandler is handler for case when route for path+method match was not found (http code 405)
// Handle returned ErrMethodNotAllowed errors in Echo.HTTPErrorHandler
var methodNotAllowedHandler = func(c Context) error {
return ErrMethodNotAllowed
}
// Match looks up a handler registered for method and path. It also parses URL for path parameters and loads them
// into context.
//
// For performance:
//
// - Get context from `Echo#AcquireContext()`
// - Reset it `Context#Reset()`
// - Return it `Echo#ReleaseContext()`.
func (r *Router) Find(method, path string, c Context) {
ctx := c.(*context)
ctx.path = path
currentNode := r.tree // Current node as root
func (r *DefaultRouter) Match(req *http.Request, pathParams *PathParams) RouteMatch {
*pathParams = (*pathParams)[0:cap(*pathParams)]
path := req.URL.Path
if !r.useEscapedPathForRouting && req.URL.RawPath != "" {
// Difference between URL.RawPath and URL.Path is:
// * URL.Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/.
// * URL.RawPath is an optional field which only gets set if the default encoding is different from Path.
path = req.URL.RawPath
}
var (
currentNode = r.tree // root as current node
previousBestMatchNode *node
matchedHandler HandlerFunc
matchedRouteMethod *routeMethod
// search stores the remaining path to check for match. By each iteration we move from start of path to end of the path
// and search value gets shorter and shorter.
search = path
searchIndex = 0
paramIndex int // Param counter
paramValues = ctx.pvalues // Use the internal slice so the interface can keep the illusion of a dynamic slice
)
// Backtracking is needed when a dead end (leaf node) is reached in the router tree.
@ -421,8 +761,8 @@ func (r *Router) Find(method, path string, c Context) {
paramIndex--
// for param/any node.prefix value is always `:` so we can not deduce searchIndex from that and must use pValue
// for that index as it would also contain part of path we cut off before moving into node we are backtracking from
searchIndex -= len(paramValues[paramIndex])
paramValues[paramIndex] = ""
searchIndex -= len((*pathParams)[paramIndex].Value)
(*pathParams)[paramIndex].Value = ""
}
search = path[searchIndex:]
return
@ -456,7 +796,7 @@ func (r *Router) Find(method, path string, c Context) {
// No matching prefix, let's backtrack to the first possible alternative node of the decision path
nk, ok := backtrackToNextNodeKind(staticKind)
if !ok {
return // No other possibilities on the decision path
break // No other possibilities on the decision path
} else if nk == paramKind {
goto Param
// NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently
@ -479,8 +819,8 @@ func (r *Router) Find(method, path string, c Context) {
if previousBestMatchNode == nil {
previousBestMatchNode = currentNode
}
if h := currentNode.findHandler(method); h != nil {
matchedHandler = h
if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
matchedRouteMethod = rMethod
break
}
}
@ -507,7 +847,7 @@ func (r *Router) Find(method, path string, c Context) {
}
}
paramValues[paramIndex] = search[:i]
(*pathParams)[paramIndex].Value = search[:i]
paramIndex++
search = search[i:]
searchIndex = searchIndex + i
@ -519,7 +859,7 @@ func (r *Router) Find(method, path string, c Context) {
if child := currentNode.anyChild; child != nil {
// If any node is found, use remaining path for paramValues
currentNode = child
paramValues[len(currentNode.pnames)-1] = search
(*pathParams)[currentNode.paramsCount-1].Value = search
// update indexes/search in case we need to backtrack when no handler match is found
paramIndex++
searchIndex += +len(search)
@ -530,8 +870,8 @@ func (r *Router) Find(method, path string, c Context) {
if previousBestMatchNode == nil {
previousBestMatchNode = currentNode
}
if h := currentNode.findHandler(method); h != nil {
matchedHandler = h
if rMethod := currentNode.methods.find(req.Method); rMethod != nil {
matchedRouteMethod = rMethod
break
}
}
@ -550,20 +890,63 @@ func (r *Router) Find(method, path string, c Context) {
}
}
result := RouteMatch{
Type: RouteMatchNotFound,
Handler: notFoundHandler,
RoutePath: "",
RouteInfo: notFoundRouteInfo,
}
if currentNode == nil && previousBestMatchNode == nil {
return // nothing matched at all
*pathParams = (*pathParams)[0:0]
return result // nothing matched at all with given path
}
if matchedHandler != nil {
ctx.handler = matchedHandler
if matchedRouteMethod != nil {
result.Type = RouteMatchFound
result.Handler = matchedRouteMethod.handler
result.RoutePath = matchedRouteMethod.routeInfo.path
result.RouteInfo = matchedRouteMethod.routeInfo
} else {
// use previous match as basis. although we have no matching handler we have path match.
// so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404)
currentNode = previousBestMatchNode
ctx.handler = currentNode.checkMethodNotAllowed()
}
ctx.path = currentNode.ppath
ctx.pnames = currentNode.pnames
return
// this here is only reason why `RouteMatch.RoutePath` exists. We do not want to create new RouteInfo just for path.
result.RoutePath = currentNode.originalPath
if currentNode.isHandler {
// TODO: in case of OPTIONS method we could respond with list of methods that node has. See https://httpwg.org/specs/rfc7231.html#OPTIONS
result.Type = RouteMatchMethodNotAllowed
result.Handler = methodNotAllowedHandler
result.RouteInfo = methodNotAllowedRouteInfo
}
}
*pathParams = (*pathParams)[0:currentNode.paramsCount]
if matchedRouteMethod != nil {
for i, name := range matchedRouteMethod.params {
(*pathParams)[i].Name = name
}
}
if r.unescapePathParamValues && currentNode.kind != staticKind {
// See issue #1531, #1258 - there are cases when path parameter need to be unescaped
for i, p := range *pathParams {
tmpVal, err := url.PathUnescape(p.Value)
if err == nil { // handle problems by ignoring them.
(*pathParams)[i].Value = tmpVal
}
}
}
return result
}
// Get returns path parameter value for given name or default value.
func (p PathParams) Get(name string, defaultValue string) string {
for _, param := range p {
if param.Name == name {
return param.Value
}
}
return defaultValue
}

File diff suppressed because it is too large Load Diff

220
server.go Normal file
View File

@ -0,0 +1,220 @@
package echo
import (
stdContext "context"
"crypto/tls"
"fmt"
"io/fs"
"log"
"net"
"net/http"
"os"
"time"
)
const (
website = "https://echo.labstack.com"
// http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo
banner = `
____ __
/ __/___/ / ___
/ _// __/ _ \/ _ \
/___/\__/_//_/\___/ %s
High performance, minimalist Go web framework
%s
____________________________________O/_______
O\
`
)
// StartConfig is for creating configured http.Server instance to start serve http(s) requests with given Echo instance
type StartConfig struct {
// Address for the server to listen on (if not using custom listener)
Address string
// ListenerNetwork allows setting listener network (see net.Listen for allowed values)
// Optional: defaults to "tcp"
ListenerNetwork string
// CertFilesystem is file system used to load certificates and keys (if certs/keys are given as paths)
CertFilesystem fs.FS
// DisableHTTP2 disables supports for HTTP2 in TLS server
DisableHTTP2 bool
// HideBanner does not log Echo banner on server startup
HideBanner bool
// HidePort does not log port on server startup
HidePort bool
// GracefulContext is context that completion signals graceful shutdown start
GracefulContext stdContext.Context
// GracefulTimeout is period which server allows listeners to finish serving ongoing requests. If this time is exceeded process is exited
// Defaults to 10 seconds
GracefulTimeout time.Duration
// OnShutdownError allows customization of what happens when (graceful) server Shutdown method returns an error.
// Defaults to calling e.logger.Error(err)
OnShutdownError func(err error)
// TLSConfigFunc allows modifying TLS configuration before listener is created with it.
TLSConfigFunc func(tlsConfig *tls.Config)
// ListenerAddrFunc allows getting listener address before server starts serving requests on listener. Useful when
// address is set as random (`:0`) port.
ListenerAddrFunc func(addr net.Addr)
// BeforeServeFunc allows customizing/accessing server before server starts serving requests on listener.
BeforeServeFunc func(s *http.Server) error
}
// Start starts a HTTP server.
func (sc StartConfig) Start(e *Echo) error {
logger := e.Logger
server := http.Server{
Handler: e,
ErrorLog: log.New(logger, "", 0),
}
var tlsConfig *tls.Config = nil
if sc.TLSConfigFunc != nil {
tlsConfig = &tls.Config{}
configureTLS(&sc, tlsConfig)
sc.TLSConfigFunc(tlsConfig)
}
listener, err := createListener(&sc, tlsConfig)
if err != nil {
return err
}
return serve(&sc, &server, listener, logger)
}
// StartTLS starts a HTTPS server.
// If `certFile` or `keyFile` is `string` the values are treated as file paths.
// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is.
func (sc StartConfig) StartTLS(e *Echo, certFile, keyFile interface{}) error {
logger := e.Logger
s := http.Server{
Handler: e,
ErrorLog: log.New(logger, "", 0),
}
certFs := sc.CertFilesystem
if certFs == nil {
certFs = os.DirFS(".")
}
cert, err := filepathOrContent(certFile, certFs)
if err != nil {
return err
}
key, err := filepathOrContent(keyFile, certFs)
if err != nil {
return err
}
cer, err := tls.X509KeyPair(cert, key)
if err != nil {
return err
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cer}}
configureTLS(&sc, tlsConfig)
if sc.TLSConfigFunc != nil {
sc.TLSConfigFunc(tlsConfig)
}
listener, err := createListener(&sc, tlsConfig)
if err != nil {
return err
}
return serve(&sc, &s, listener, logger)
}
func serve(sc *StartConfig, server *http.Server, listener net.Listener, logger Logger) error {
if sc.BeforeServeFunc != nil {
if err := sc.BeforeServeFunc(server); err != nil {
return err
}
}
startupGreetings(sc, logger, listener)
if sc.GracefulContext != nil {
ctx, cancel := stdContext.WithCancel(sc.GracefulContext)
defer cancel() // make sure this graceful coroutine will end when serve returns by some other means
go gracefulShutdown(ctx, sc, server, logger)
}
return server.Serve(listener)
}
func configureTLS(sc *StartConfig, tlsConfig *tls.Config) {
if !sc.DisableHTTP2 {
tlsConfig.NextProtos = append(tlsConfig.NextProtos, "h2")
}
}
func createListener(sc *StartConfig, tlsConfig *tls.Config) (net.Listener, error) {
listenerNetwork := sc.ListenerNetwork
if listenerNetwork == "" {
listenerNetwork = "tcp"
}
var listener net.Listener
var err error
if tlsConfig != nil {
listener, err = tls.Listen(listenerNetwork, sc.Address, tlsConfig)
} else {
listener, err = net.Listen(listenerNetwork, sc.Address)
}
if err != nil {
return nil, err
}
if sc.ListenerAddrFunc != nil {
sc.ListenerAddrFunc(listener.Addr())
}
return listener, nil
}
func startupGreetings(sc *StartConfig, logger Logger, listener net.Listener) {
if !sc.HideBanner {
bannerText := fmt.Sprintf(banner, "v"+Version, website)
logger.Write([]byte(bannerText))
}
if !sc.HidePort {
logger.Write([]byte(fmt.Sprintf("⇨ http(s) server started on %s\n", listener.Addr())))
}
}
func filepathOrContent(fileOrContent interface{}, certFilesystem fs.FS) (content []byte, err error) {
switch v := fileOrContent.(type) {
case string:
return fs.ReadFile(certFilesystem, v)
case []byte:
return v, nil
default:
return nil, ErrInvalidCertOrKeyType
}
}
func gracefulShutdown(gracefulCtx stdContext.Context, sc *StartConfig, server *http.Server, logger Logger) {
<-gracefulCtx.Done() // wait until shutdown context is closed.
// note: is server if closed by other means this method is still run but is good as no-op
timeout := sc.GracefulTimeout
if timeout == 0 {
timeout = 10 * time.Second
}
shutdownCtx, cancel := stdContext.WithTimeout(stdContext.Background(), timeout)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
// we end up here when listeners are not shut down within given timeout
if sc.OnShutdownError != nil {
sc.OnShutdownError(err)
return
}
logger.Error(fmt.Errorf("failed to shut down server within given timeout: %w", err))
}
}

815
server_test.go Normal file
View File

@ -0,0 +1,815 @@
package echo
import (
"bytes"
stdContext "context"
"crypto/tls"
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"strings"
"testing"
"time"
)
func startOnRandomPort(ctx stdContext.Context, e *Echo) (string, error) {
addrChan := make(chan string)
errCh := make(chan error)
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
return waitForServerStart(addrChan, errCh)
}
func waitForServerStart(addrChan <-chan string, errCh <-chan error) (string, error) {
waitCtx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer cancel()
// wait for addr to arrive
for {
select {
case <-waitCtx.Done():
return "", waitCtx.Err()
case addr := <-addrChan:
return addr, nil
case err := <-errCh:
if err == http.ErrServerClosed { // was closed normally before listener callback was called. should not be possible
return "", nil
}
// failed to start and we did not manage to get even listener part.
return "", err
}
}
}
func doGet(url string) (int, string, error) {
resp, err := http.Get(url)
if err != nil {
return 0, "", err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return resp.StatusCode, "", err
}
return resp.StatusCode, string(body), nil
}
func TestStartConfig_Start(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
// check if server is actually up
code, body, err := doGet(fmt.Sprintf("http://%v/ok", addr))
if err != nil {
assert.NoError(t, err)
return
}
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, "OK", body)
shutdown()
<-errCh // we will be blocking here until server returns from http.Serve
// check if server was stopped
code, body, err = doGet(fmt.Sprintf("http://%v/ok", addr))
assert.Equal(t, 0, code)
assert.Equal(t, "", body)
if err == nil {
t.Errorf("missing error")
return
}
assert.True(t, strings.Contains(err.Error(), "connect: connection refused"))
}
func TestStartConfig_GracefulShutdown(t *testing.T) {
var testCases = []struct {
name string
whenHandlerTakesLonger bool
expectBody string
expectGracefulError string
}{
{
name: "ok, all handlers returns before graceful shutdown deadline",
whenHandlerTakesLonger: false,
expectBody: "OK",
expectGracefulError: "",
},
{
name: "nok, handlers do not returns before graceful shutdown deadline",
whenHandlerTakesLonger: true,
expectBody: "timeout",
expectGracefulError: stdContext.DeadlineExceeded.Error(),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
msg := "OK"
if tc.whenHandlerTakesLonger {
time.Sleep(150 * time.Millisecond)
msg = "timeout"
}
return c.String(http.StatusOK, msg)
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 50*time.Millisecond)
defer shutdown()
shutdownErrChan := make(chan error, 1)
go func() {
errCh <- (&StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 50 * time.Millisecond,
OnShutdownError: func(err error) {
shutdownErrChan <- err
},
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}).Start(e)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
code, body, err := doGet(fmt.Sprintf("http://%v/ok", addr))
if err != nil {
assert.NoError(t, err)
return
}
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, tc.expectBody, body)
var shutdownErr error
select {
case shutdownErr = <-shutdownErrChan:
default:
}
if tc.expectGracefulError != "" {
assert.EqualError(t, shutdownErr, tc.expectGracefulError)
} else {
assert.NoError(t, shutdownErr)
}
shutdown()
<-errCh // we will be blocking here until server returns from http.Serve
// check if server was stopped
code, body, err = doGet(fmt.Sprintf("http://%v/ok", addr))
assert.Error(t, err)
if err != nil {
assert.True(t, strings.Contains(err.Error(), "connect: connection refused"))
}
assert.Equal(t, 0, code)
assert.Equal(t, "", body)
})
}
}
func TestStartConfig_Start_withTLSConfigFunc(t *testing.T) {
e := New()
tlsConfigCalled := false
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, errors.New("not_implemented")
}
tlsConfigCalled = true
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.Start(e)
assert.EqualError(t, err, "stop_now")
assert.True(t, tlsConfigCalled)
}
func TestStartConfig_Start_createListenerError(t *testing.T) {
e := New()
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.Start(e)
assert.EqualError(t, err, "tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
}
func TestStartConfig_StartTLS(t *testing.T) {
var testCases = []struct {
name string
addr string
certFile string
keyFile string
expectError string
}{
{
name: "ok",
addr: ":0",
},
{
name: "nok, invalid certFile",
addr: ":0",
certFile: "not existing",
expectError: "open not existing: no such file or directory",
},
{
name: "nok, invalid keyFile",
addr: ":0",
keyFile: "not existing",
expectError: "open not existing: no such file or directory",
},
{
name: "nok, failed to create cert out of certFile and keyFile",
addr: ":0",
keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key
expectError: "tls: found a certificate rather than a key in the PEM for the private key",
},
{
name: "nok, invalid tls address",
addr: "nope",
expectError: "listen tcp: address nope: missing port in address",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
certFile := "_fixture/certs/cert.pem"
if tc.certFile != "" {
certFile = tc.certFile
}
keyFile := "_fixture/certs/key.pem"
if tc.keyFile != "" {
keyFile = tc.keyFile
}
s := &StartConfig{
Address: tc.addr,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, certFile, keyFile)
}()
_, err := waitForServerStart(addrChan, errCh)
if tc.expectError != "" {
if _, ok := err.(*os.PathError); ok {
assert.Error(t, err) // error messages for unix and windows are different. so name only error type here
} else {
assert.EqualError(t, err, tc.expectError)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestStartConfig_StartTLS_withTLSConfigFunc(t *testing.T) {
e := New()
tlsConfigCalled := false
s := &StartConfig{
Address: ":0",
TLSConfigFunc: func(tlsConfig *tls.Config) {
assert.Len(t, tlsConfig.Certificates, 1)
tlsConfigCalled = true
},
BeforeServeFunc: func(s *http.Server) error {
return errors.New("stop_now")
},
}
err := s.StartTLS(e, "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
assert.EqualError(t, err, "stop_now")
assert.True(t, tlsConfigCalled)
}
func TestStartConfig_StartTLSAndStart(t *testing.T) {
// We name if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server
e := New()
e.GET("/", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
tlsCtx, tlsShutdown := stdContext.WithTimeout(stdContext.Background(), 100*time.Millisecond)
defer tlsShutdown()
addrTLSChan := make(chan string)
errTLSChan := make(chan error)
go func() {
s := &StartConfig{
Address: ":0",
GracefulContext: tlsCtx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrTLSChan <- addr.String()
},
}
errTLSChan <- s.StartTLS(e, "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
}()
tlsAddr, err := waitForServerStart(addrTLSChan, errTLSChan)
assert.NoError(t, err)
// check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true)
client := &http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}}
res, err := client.Get(fmt.Sprintf("https://%v", tlsAddr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 100*time.Millisecond)
defer shutdown()
addrChan := make(chan string)
errChan := make(chan error)
go func() {
s := &StartConfig{
Address: ":0",
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errChan <- s.Start(e)
}()
addr, err := waitForServerStart(addrChan, errChan)
assert.NoError(t, err)
// now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS
res, err = client.Get(fmt.Sprintf("http://%v", addr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
// see if HTTPS works after HTTP listener is also added
res, err = client.Get(fmt.Sprintf("https://%v", tlsAddr))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestFilepathOrContent(t *testing.T) {
cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
require.NoError(t, err)
key, err := ioutil.ReadFile("_fixture/certs/key.pem")
require.NoError(t, err)
testCases := []struct {
name string
cert interface{}
key interface{}
expectedErr error
}{
{
name: `ValidCertAndKeyFilePath`,
cert: "_fixture/certs/cert.pem",
key: "_fixture/certs/key.pem",
expectedErr: nil,
},
{
name: `ValidCertAndKeyByteString`,
cert: cert,
key: key,
expectedErr: nil,
},
{
name: `InvalidKeyType`,
cert: cert,
key: 1,
expectedErr: ErrInvalidCertOrKeyType,
},
{
name: `InvalidCertType`,
cert: 0,
key: key,
expectedErr: ErrInvalidCertOrKeyType,
},
{
name: `InvalidCertAndKeyTypes`,
cert: 0,
key: 1,
expectedErr: ErrInvalidCertOrKeyType,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
s := &StartConfig{
Address: ":0",
CertFilesystem: os.DirFS("."),
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, tc.cert, tc.key)
}()
_, err := waitForServerStart(addrChan, errCh)
if tc.expectedErr != nil {
assert.EqualError(t, err, tc.expectedErr.Error())
} else {
assert.NoError(t, err)
}
})
}
}
func supportsIPv6() bool {
addrs, _ := net.InterfaceAddrs()
for _, addr := range addrs {
// Check if any interface has local IPv6 assigned
if strings.Contains(addr.String(), "::1") {
return true
}
}
return false
}
func TestStartConfig_WithListenerNetwork(t *testing.T) {
testCases := []struct {
name string
network string
address string
}{
{
name: "tcp ipv4 address",
network: "tcp",
address: "127.0.0.1:1323",
},
{
name: "tcp ipv6 address",
network: "tcp",
address: "[::1]:1323",
},
{
name: "tcp4 ipv4 address",
network: "tcp4",
address: "127.0.0.1:1323",
},
{
name: "tcp6 ipv6 address",
network: "tcp6",
address: "[::1]:1323",
},
}
hasIPv6 := supportsIPv6()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if !hasIPv6 && strings.Contains(tc.address, "::") {
t.Skip("Skipping testing IPv6 for " + tc.address + ", not available")
}
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
s := &StartConfig{
Address: tc.address,
ListenerNetwork: tc.network,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.Start(e)
}()
_, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
code, body, err := doGet(fmt.Sprintf("http://%s/ok", tc.address))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, code)
assert.Equal(t, "OK", body)
})
}
}
func TestStartConfig_WithHideBanner(t *testing.T) {
var testCases = []struct {
name string
hideBanner bool
}{
{
name: "hide banner on startup",
hideBanner: true,
},
{
name: "show banner on startup",
hideBanner: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Logger = &testLogger{output: buf}
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
_, err := waitForServerStart(addrChan, errCh)
errCh <- err
shutdown()
}()
s := &StartConfig{
Address: ":0",
HideBanner: tc.hideBanner,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
if err := s.Start(e); err != http.ErrServerClosed {
assert.NoError(t, err)
}
assert.NoError(t, <-errCh)
contains := strings.Contains(buf.String(), "High performance, minimalist Go web framework")
if tc.hideBanner {
assert.False(t, contains)
} else {
assert.True(t, contains)
}
})
}
}
func TestStartConfig_WithHidePort(t *testing.T) {
var testCases = []struct {
name string
hidePort bool
}{
{
name: "hide port on startup",
hidePort: true,
},
{
name: "show port on startup",
hidePort: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Logger = &testLogger{output: buf}
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error, 1)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
go func() {
_, err := waitForServerStart(addrChan, errCh)
errCh <- err
shutdown()
}()
s := &StartConfig{
Address: ":0",
HidePort: tc.hidePort,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
if err := s.Start(e); err != http.ErrServerClosed {
assert.NoError(t, err)
}
assert.NoError(t, <-errCh)
portMsg := fmt.Sprintf("http(s) server started on")
contains := strings.Contains(buf.String(), portMsg)
if tc.hidePort {
assert.False(t, contains)
} else {
assert.True(t, contains)
}
})
}
}
func TestStartConfig_WithBeforeServeFunc(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
s := &StartConfig{
Address: ":0",
BeforeServeFunc: func(s *http.Server) error {
return errors.New("is called before serve")
},
}
err := s.Start(e)
assert.EqualError(t, err, "is called before serve")
}
func TestWithDisableHTTP2(t *testing.T) {
var testCases = []struct {
name string
disableHTTP2 bool
}{
{
name: "HTTP2 enabled",
disableHTTP2: false,
},
{
name: "HTTP2 disabled",
disableHTTP2: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})
addrChan := make(chan string)
errCh := make(chan error, 1)
ctx, shutdown := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer shutdown()
go func() {
certFile := "_fixture/certs/cert.pem"
keyFile := "_fixture/certs/key.pem"
s := &StartConfig{
Address: ":0",
DisableHTTP2: tc.disableHTTP2,
GracefulContext: ctx,
GracefulTimeout: 100 * time.Millisecond,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}
errCh <- s.StartTLS(e, certFile, keyFile)
}()
addr, err := waitForServerStart(addrChan, errCh)
assert.NoError(t, err)
url := fmt.Sprintf("https://%v/ok", addr)
// do ordinary http(s) request
client := &http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}}
res, err := client.Get(url)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
// do HTTP2 request
client.Transport = &http2.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
resp, err := client.Get(url)
if err != nil {
if tc.disableHTTP2 {
assert.True(t, strings.Contains(err.Error(), `http2: unexpected ALPN protocol ""; want "h2"`))
return
}
log.Fatalf("Failed get: %s", err)
}
assert.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatalf("Failed reading response body: %s", err)
}
assert.Equal(t, "OK", string(body))
})
}
}
type testLogger struct {
output io.Writer
}
func (l *testLogger) Write(p []byte) (n int, err error) {
return l.output.Write(p)
}
func (l *testLogger) Printf(format string, args ...interface{}) {
_, _ = l.output.Write([]byte(fmt.Sprintf(format, args...)))
}
func (l *testLogger) Error(err error) {
_, _ = l.output.Write([]byte(err.Error()))
}